ReFusion / diffusion_cache_utils.py
JinaLeejnl's picture
upload
fe7ffd3 verified
import torch
from transformers.cache_utils import DynamicCache
from typing import Optional, List, Tuple, Dict, Any
class DiffusionDynamicCache(DynamicCache):
def __init__(self, num_hidden_layers: Optional[int] = None):
super().__init__(num_hidden_layers)
def full_update(
self,
new_kv: Tuple,
cache_kwargs: Optional[Dict[str, Any]] = None,
):
for i, (key, val) in enumerate(new_kv):
self.key_cache[i] = torch.cat([self.key_cache[i], key], dim=-2)
self.value_cache[i] = torch.cat([self.value_cache[i], val], dim=-2)
def select_partial(
self,
indices: torch.Tensor,
):
for i in range(len(self.key_cache)):
self.key_cache[i] = self.key_cache[i][:, :, indices, :]
self.value_cache[i] = self.value_cache[i][:, :, indices, :]
def batch_select_minibatch(self, indices: torch.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx][:indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][:indices, ...]