import os from dataclasses import dataclass from typing import Optional, Union import numpy as np import torch import torch.nn.functional as F from transformers import AutoModel, AutoConfig, PreTrainedModel, MT5EncoderModel, PretrainedConfig from transformers.file_utils import ModelOutput from pathlib import Path from .configuration_biencoder import SpanModelConfig from .metrics import compute_span_predictions from .loss import BCELoss, FocalLoss @dataclass class SpanModelOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None start_logits: torch.FloatTensor = None end_logits: torch.FloatTensor = None span_logits: torch.FloatTensor = None def mlp(input_size, output_size, dropout): return torch.nn.Sequential( torch.nn.Linear(input_size, output_size), torch.nn.Dropout(dropout), torch.nn.ReLU(), torch.nn.Linear(output_size, output_size), ) class OtterBiEncoderModel(PreTrainedModel): config_class = SpanModelConfig def __init__(self, config, token_config=None, type_config=None): super().__init__(config) self.config = config self.token_config = token_config self.type_config = type_config if self.token_config is None: self.token_config = AutoConfig.from_pretrained(config.token_encoder) if self.type_config is None: self.type_config = AutoConfig.from_pretrained(config.type_encoder) self.max_span_length = config.max_span_length self.dropout = torch.nn.Dropout(config.dropout) self.linear_hidden_size = config.linear_hidden_size self.config.pruned_heads = self.token_config.pruned_heads self.type_linear = mlp(self.type_config.hidden_size, config.linear_hidden_size, config.dropout) self.token_start_linear = mlp(self.token_config.hidden_size, config.linear_hidden_size, config.dropout) self.token_end_linear = mlp(self.token_config.hidden_size, config.linear_hidden_size, config.dropout) self.token_span_linear = mlp(config.linear_hidden_size * 2 + config.span_width_embedding_size, config.linear_hidden_size, config.dropout) self.fusion_linear = mlp(config.linear_hidden_size * 2, config.linear_hidden_size, config.dropout) self.width_embedding = torch.nn.Embedding(config.max_span_length + 1, config.span_width_embedding_size, padding_idx=0) self.start_logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / config.init_temperature)) self.end_logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / config.init_temperature)) self.span_logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / config.init_temperature)) self.post_init() if "mt5" in config.token_encoder: self.token_encoder = MT5EncoderModel(config=self.token_config) else: self.token_encoder = AutoModel.from_config(self.token_config) if "mt5" in config.type_encoder: self.type_encoder = MT5EncoderModel(config=self.type_config) else: self.type_encoder = AutoModel.from_config(self.type_config) if config.loss_fn == "focal": self.loss_fn = FocalLoss(alpha=config.focal_alpha, gamma=config.focal_gamma) elif config.loss_fn == "bce": self.loss_fn = BCELoss() else: raise ValueError(f"Invalid loss function: {config.loss_fn}") def _init_weights(self, module): if isinstance(module, torch.nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, torch.nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, torch.nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def gather_spans(self, hidden_states, span_indices): _, _, H = hidden_states.shape expanded_indices = span_indices.unsqueeze(2).expand(-1, -1, H) span_representations = torch.gather(hidden_states, 1, expanded_indices) return span_representations def forward( self, token_encoder_inputs: dict = None, type_encoder_inputs: dict = None, labels: dict = None ): token_embeds = self.token_encoder(**token_encoder_inputs) type_embeds = self.type_encoder(**type_encoder_inputs) token_output = token_embeds.last_hidden_state if self.config.type_encoder_pooling == "mean": if type_encoder_inputs["attention_mask"] is not None: attention_mask_expanded = type_encoder_inputs["attention_mask"].unsqueeze(-1).expand(type_embeds.last_hidden_state.size()).float() sum_embeddings = torch.sum(type_embeds.last_hidden_state * attention_mask_expanded, dim=1) sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9) type_output = sum_embeddings / sum_mask else: type_output = type_embeds.last_hidden_state.mean(dim=1) else: type_output = type_embeds.last_hidden_state[:, 0, :] token_start_output = F.normalize(self.dropout(self.token_start_linear(token_output)), dim=-1) token_end_output = F.normalize(self.dropout(self.token_end_linear(token_output)), dim=-1) type_output = F.normalize(self.dropout(self.type_linear(type_output)), dim=-1) start_scores = self.start_logit_scale.exp() * torch.einsum("BSH,CH->BCS", token_start_output, type_output) end_scores = self.end_logit_scale.exp() * torch.einsum("BSH,CH->BCS", token_end_output, type_output) span_width_embeddings = self.width_embedding(labels["span_lengths"]) span_hidden = torch.cat( [ self.gather_spans(token_start_output, labels["span_subword_indices"][:, :, 0]), self.gather_spans(token_end_output, labels["span_subword_indices"][:, :, 1]), span_width_embeddings, ], dim=2 ) token_span_output = F.normalize(self.dropout(self.token_span_linear(span_hidden)), dim=-1) span_scores = self.span_logit_scale.exp() * torch.einsum("BSH,CH->BCS", token_span_output, type_output) if labels is not None and self.training: start_pos_weight = None if self.config.bce_start_pos_weight is not None: start_pos_weight = torch.tensor(self.config.bce_start_pos_weight, device=start_scores.device, dtype=start_scores.dtype) end_pos_weight = None if self.config.bce_end_pos_weight is not None: end_pos_weight = torch.tensor(self.config.bce_end_pos_weight, device=end_scores.device, dtype=end_scores.dtype) span_pos_weight = None if self.config.bce_span_pos_weight is not None: span_pos_weight = torch.tensor(self.config.bce_span_pos_weight, device=span_scores.device, dtype=span_scores.dtype) start_loss = self.loss_fn( start_scores, labels["start_labels"], mask=labels["valid_start_mask"], pos_weight=start_pos_weight ) end_loss = self.loss_fn( end_scores, labels["end_labels"], mask=labels["valid_end_mask"], pos_weight=end_pos_weight ) span_loss = self.loss_fn( span_scores, labels["span_labels"], mask=labels["valid_span_mask"], pos_weight=span_pos_weight ) loss = self.config.start_loss_weight * start_loss + self.config.end_loss_weight * end_loss + self.config.span_loss_weight * span_loss return SpanModelOutput(loss=loss, start_logits=start_scores, end_logits=end_scores, span_logits=span_scores) else: return SpanModelOutput(start_logits=start_scores, end_logits=end_scores, span_logits=span_scores) def predict(self, batch: dict, threshold: float = 0.5): with torch.no_grad(): output = self.forward( token_encoder_inputs=batch["token_encoder_inputs"], type_encoder_inputs=batch["type_encoder_inputs"], labels=batch["labels"] ) predictions = compute_span_predictions( span_logits=output.span_logits.detach().cpu().numpy(), span_mask=batch["labels"]["valid_span_mask"].cpu().numpy(), span_mapping=batch["labels"]["span_subword_indices"].cpu().numpy(), id2label=batch["id2label"], threshold=threshold ) return predictions def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): if not isinstance(save_directory, Path): save_directory = Path(save_directory) super().save_pretrained(save_directory, **kwargs) self.token_encoder.config.to_json_file(str(save_directory / "token_encoder_config.json")) self.type_encoder.config.to_json_file(str(save_directory / "type_encoder_config.json")) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, **kwargs, ): base_path = Path(pretrained_model_name_or_path) token_encoder_config_path = base_path / "token_encoder_config.json" type_encoder_config_path = base_path / "type_encoder_config.json" if token_encoder_config_path.exists(): token_config = AutoConfig.from_pretrained(token_encoder_config_path) kwargs["token_config"] = token_config if type_encoder_config_path.exists(): type_config = AutoConfig.from_pretrained(type_encoder_config_path) kwargs["type_config"] = type_config return super().from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **kwargs, ) def gradient_checkpointing_enable(self): self.token_encoder.gradient_checkpointing_enable() self.type_encoder.gradient_checkpointing_enable()