Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,460 Bytes
388d03f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from typing import *
import torch
import torch.nn.functional as F
from torchvision import transforms
from transformers import DINOv3ViTModel
import numpy as np
from PIL import Image
class DinoV2FeatureExtractor:
"""
Feature extractor for DINOv2 models.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True)
self.model.eval()
self.transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def to(self, device):
self.model.to(device)
def cuda(self):
self.model.cuda()
def cpu(self):
self.model.cpu()
@torch.no_grad()
def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
"""
Extract features from the image.
Args:
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
Returns:
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
image = [i.resize((518, 518), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).cuda()
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
image = self.transform(image).cuda()
features = self.model(image, is_training=True)['x_prenorm']
patchtokens = F.layer_norm(features, features.shape[-1:])
return patchtokens
class DinoV3FeatureExtractor:
"""
Feature extractor for DINOv3 models.
"""
def __init__(self, model_name: str, image_size=512):
self.model_name = model_name
self.model = DINOv3ViTModel.from_pretrained(model_name)
self.model.eval()
self.image_size = image_size
self.transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def to(self, device):
self.model.to(device)
def cuda(self):
self.model.cuda()
def cpu(self):
self.model.cpu()
def extract_features(self, image: torch.Tensor) -> torch.Tensor:
image = image.to(self.model.embeddings.patch_embeddings.weight.dtype)
hidden_states = self.model.embeddings(image, bool_masked_pos=None)
position_embeddings = self.model.rope_embeddings(image)
for i, layer_module in enumerate(self.model.layer):
hidden_states = layer_module(
hidden_states,
position_embeddings=position_embeddings,
)
return F.layer_norm(hidden_states, hidden_states.shape[-1:])
@torch.no_grad()
def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
"""
Extract features from the image.
Args:
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
Returns:
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).cuda()
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
image = self.transform(image).cuda()
features = self.extract_features(image)
return features
|