| import torch |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
|
|
| from .configuration_histaug import HistaugConfig |
| from .histaug_model import HistaugModel |
|
|
| class HistaugPretrainedModel(PreTrainedModel): |
| config_class = HistaugConfig |
|
|
| def __init__(self, config: HistaugConfig, *model_args, **model_kwargs): |
| super().__init__(config) |
|
|
| |
| self.histaug = HistaugModel( |
| input_dim=config.input_dim, |
| depth=config.depth, |
| num_heads=config.num_heads, |
| mlp_ratio=config.mlp_ratio, |
| use_transform_pos_embeddings=config.use_transform_pos_embeddings, |
| positional_encoding_type=config.positional_encoding_type, |
| final_activation=config.final_activation, |
| embedding_type=config.embedding_type, |
| chunk_size=config.chunk_size, |
| transforms=config.transforms, |
| **model_kwargs, |
| ) |
|
|
| self.post_init() |
|
|
| self.histaug.eval() |
| for p in self.histaug.parameters(): |
| p.requires_grad = False |
|
|
| def forward(self, x: torch.Tensor, aug_params, **kwargs) -> torch.Tensor: |
| """ |
| Forward pass through the histaug model. |
| Args: |
| x: Input tensor of shape (batch_size, input_dim) |
| aug_params: Augmentation parameters dict as expected by HistaugModel |
| """ |
| return self.histaug(x, aug_params, **kwargs) |
|
|
| def sample_aug_params( |
| self, |
| batch_size: int, |
| device: torch.device = None, |
| mode: str = "wsi_wise", |
| ): |
| """ |
| Proxy to HistaugModel.sample_aug_params |
| """ |
| device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| return self.histaug.sample_aug_params(batch_size=batch_size, device=device, mode=mode) |
|
|
| def save_pretrained(self, save_directory: str, **kwargs): |
| """ |
| Save the model and configuration to the directory. |
| """ |
| super().save_pretrained(save_directory, **kwargs) |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: str, |
| *model_args, |
| **kwargs, |
| ): |
| """ |
| Load a model from a pretrained checkpoint. |
| """ |
| return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |