|
|
import torch |
|
|
import math |
|
|
|
|
|
class RopeND: |
|
|
def __init__(self, head_dim=64, nd=3, max_lens=[1024, 64, 64], nd_split=[2, 1, 1], bases=[1000, 1000, 1000], |
|
|
auto_base=True, cache_longer=1): |
|
|
self.nd = nd |
|
|
self.head_dim = head_dim |
|
|
self.max_lens = max_lens |
|
|
self.nd_split = nd_split |
|
|
self.split_dims = [2 * i * (head_dim // 2 // sum(nd_split)) for i in nd_split] |
|
|
assert sum(self.split_dims) == head_dim |
|
|
self.auto_base = auto_base |
|
|
if auto_base: |
|
|
|
|
|
|
|
|
|
|
|
self.bases = [(int(8 * l / math.pi) // 100 + 1) * 100 for l in self.max_lens] |
|
|
print(f"Bases for rope: {self.bases}") |
|
|
else: |
|
|
self.bases = bases |
|
|
self.cache_longer = cache_longer |
|
|
|
|
|
def generated_cos_sin_mix2d(self, max_len, dim, device, base=1000): |
|
|
inv_freq = 1.0 / (base ** \ |
|
|
(torch.linspace(start=0, end=self.head_dim, steps=dim // 2, |
|
|
device=device).float() / self.head_dim)) |
|
|
assert inv_freq.size(0) * 2 == dim, f"inv_freq.size(0) = {inv_freq.size(0)}, required dim = {dim}" |
|
|
|
|
|
t = torch.arange(max_len * self.cache_longer, device=device).type_as(inv_freq) |
|
|
freqs = torch.einsum("i,j->ij", t, inv_freq) |
|
|
freqs = torch.cat([freqs, freqs], dim=1) |
|
|
return freqs.cos().to(torch.float), freqs.sin().to(torch.float) |
|
|
|
|
|
def generate_pos_embs_mix2d(self, position_ids, device=None): |
|
|
if device is None: |
|
|
device = position_ids.device |
|
|
|
|
|
if position_ids.dim() == 1: |
|
|
position_ids = position_ids.unsqueeze(0) |
|
|
|
|
|
cos_emb_all, sin_emb_all = [], [] |
|
|
for i in range(self.nd): |
|
|
dim_i = self.split_dims[i] |
|
|
base_i = self.bases[i] |
|
|
max_len_i = self.max_lens[i] |
|
|
if not hasattr(self, f"cos_{i}"): |
|
|
_cos, _sin = self.generated_cos_sin_mix2d(max_len=max_len_i, dim=dim_i, device=device, base=base_i) |
|
|
setattr(self, f"cos_{i}", _cos) |
|
|
setattr(self, f"sin_{i}", _sin) |
|
|
cos_emb_all.append(getattr(self, f'cos_{i}')[position_ids[i, :], :]) |
|
|
sin_emb_all.append(getattr(self, f'sin_{i}')[position_ids[i, :], :]) |
|
|
cos_emb = torch.cat(cos_emb_all, dim=-1) |
|
|
sin_emb = torch.cat(sin_emb_all, dim=-1) |
|
|
return cos_emb, sin_emb |
|
|
|
|
|
def __call__(self, q, k, position_ids): |
|
|
'''q: N N_head L C |
|
|
''' |
|
|
cos_emb, sin_emb = self.generate_pos_embs_mix2d(position_ids, device=q.device) |
|
|
|
|
|
def rotate_half(x): |
|
|
"""Rotates half the hidden dims of the input.""" |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2:] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin): |
|
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
|
|
Args: |
|
|
q (`torch.Tensor`): The query tensor. |
|
|
k (`torch.Tensor`): The key tensor. |
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
|
Returns: |
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
|
""" |
|
|
cos = cos.unsqueeze(0).unsqueeze(0) |
|
|
sin = sin.unsqueeze(0).unsqueeze(0) |
|
|
dtype = q.dtype |
|
|
q = q.to(torch.float) |
|
|
k = k.to(torch.float) |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
q_embed = q_embed.to(dtype) |
|
|
k_embed = k_embed.to(dtype) |
|
|
return q_embed, k_embed |
|
|
|
|
|
q, k = apply_rotary_pos_emb(q, k, cos_emb, sin_emb) |
|
|
return q, k |