JeffreyXiang's picture
Finalize
a1e3f5f
from typing import *
import torch
import torch.nn as nn
class RotaryPositionEmbedder(nn.Module):
def __init__(
self,
head_dim: int,
dim: int = 3,
rope_freq: Tuple[float, float] = (1.0, 10000.0)
):
super().__init__()
assert head_dim % 2 == 0, "Head dim must be divisible by 2"
self.head_dim = head_dim
self.dim = dim
self.rope_freq = rope_freq
self.freq_dim = head_dim // 2 // dim
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
self.freqs = self.freqs.to(indices.device)
phases = torch.outer(indices, self.freqs)
phases = torch.polar(torch.ones_like(phases), phases)
return phases
@staticmethod
def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_rotated = x_complex * phases.unsqueeze(-2)
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
return x_embed
def forward(self, indices: torch.Tensor) -> torch.Tensor:
"""
Args:
indices (torch.Tensor): [..., N, C] tensor of spatial positions
"""
assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}"
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
if phases.shape[-1] < self.head_dim // 2:
padn = self.head_dim // 2 - phases.shape[-1]
phases = torch.cat([phases, torch.polar(
torch.ones(*phases.shape[:-1], padn, device=phases.device),
torch.zeros(*phases.shape[:-1], padn, device=phases.device)
)], dim=-1)
return phases