File size: 1,236 Bytes
fe7ffd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from transformers.cache_utils import DynamicCache
from typing import Optional, List, Tuple, Dict, Any

class DiffusionDynamicCache(DynamicCache):
    def __init__(self, num_hidden_layers: Optional[int] = None):
        super().__init__(num_hidden_layers)

    def full_update(
        self,
        new_kv: Tuple,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ):
        for i, (key, val) in enumerate(new_kv):
            self.key_cache[i] = torch.cat([self.key_cache[i], key], dim=-2)
            self.value_cache[i] = torch.cat([self.value_cache[i], val], dim=-2)
    
    def select_partial(
        self,
        indices: torch.Tensor,
    ):
        for i in range(len(self.key_cache)):
            self.key_cache[i] = self.key_cache[i][:, :, indices, :]
            self.value_cache[i] = self.value_cache[i][:, :, indices, :]
    
    def batch_select_minibatch(self, indices: torch.Tensor):
        """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
        for layer_idx in range(len(self)):
            self.key_cache[layer_idx] = self.key_cache[layer_idx][:indices, ...]
            self.value_cache[layer_idx] = self.value_cache[layer_idx][:indices, ...]