| import torch | |
| from transformers.cache_utils import DynamicCache | |
| from typing import Optional, List, Tuple, Dict, Any | |
| class DiffusionDynamicCache(DynamicCache): | |
| def __init__(self, num_hidden_layers: Optional[int] = None): | |
| super().__init__(num_hidden_layers) | |
| def full_update( | |
| self, | |
| new_kv: Tuple, | |
| cache_kwargs: Optional[Dict[str, Any]] = None, | |
| ): | |
| for i, (key, val) in enumerate(new_kv): | |
| self.key_cache[i] = torch.cat([self.key_cache[i], key], dim=-2) | |
| self.value_cache[i] = torch.cat([self.value_cache[i], val], dim=-2) | |
| def select_partial( | |
| self, | |
| indices: torch.Tensor, | |
| ): | |
| for i in range(len(self.key_cache)): | |
| self.key_cache[i] = self.key_cache[i][:, :, indices, :] | |
| self.value_cache[i] = self.value_cache[i][:, :, indices, :] | |
| def batch_select_minibatch(self, indices: torch.Tensor): | |
| """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" | |
| for layer_idx in range(len(self)): | |
| self.key_cache[layer_idx] = self.key_cache[layer_idx][:indices, ...] | |
| self.value_cache[layer_idx] = self.value_cache[layer_idx][:indices, ...] |