otter-bi-mmbert / modeling_biencoder.py
whoisjones's picture
Upload modeling_biencoder.py with huggingface_hub
61e7826 verified
import os
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig, PreTrainedModel, MT5EncoderModel, PretrainedConfig
from transformers.file_utils import ModelOutput
from pathlib import Path
from .configuration_biencoder import SpanModelConfig
from .metrics import compute_span_predictions
from .loss import BCELoss, FocalLoss
@dataclass
class SpanModelOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
start_logits: torch.FloatTensor = None
end_logits: torch.FloatTensor = None
span_logits: torch.FloatTensor = None
def mlp(input_size, output_size, dropout):
return torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.Dropout(dropout),
torch.nn.ReLU(),
torch.nn.Linear(output_size, output_size),
)
class OtterBiEncoderModel(PreTrainedModel):
config_class = SpanModelConfig
def __init__(self, config, token_config=None, type_config=None):
super().__init__(config)
self.config = config
self.token_config = token_config
self.type_config = type_config
if self.token_config is None:
self.token_config = AutoConfig.from_pretrained(config.token_encoder)
if self.type_config is None:
self.type_config = AutoConfig.from_pretrained(config.type_encoder)
self.max_span_length = config.max_span_length
self.dropout = torch.nn.Dropout(config.dropout)
self.linear_hidden_size = config.linear_hidden_size
self.config.pruned_heads = self.token_config.pruned_heads
self.type_linear = mlp(self.type_config.hidden_size, config.linear_hidden_size, config.dropout)
self.token_start_linear = mlp(self.token_config.hidden_size, config.linear_hidden_size, config.dropout)
self.token_end_linear = mlp(self.token_config.hidden_size, config.linear_hidden_size, config.dropout)
self.token_span_linear = mlp(config.linear_hidden_size * 2 + config.span_width_embedding_size, config.linear_hidden_size, config.dropout)
self.fusion_linear = mlp(config.linear_hidden_size * 2, config.linear_hidden_size, config.dropout)
self.width_embedding = torch.nn.Embedding(config.max_span_length + 1, config.span_width_embedding_size, padding_idx=0)
self.start_logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / config.init_temperature))
self.end_logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / config.init_temperature))
self.span_logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / config.init_temperature))
self.post_init()
if "mt5" in config.token_encoder:
self.token_encoder = MT5EncoderModel(config=self.token_config)
else:
self.token_encoder = AutoModel.from_config(self.token_config)
if "mt5" in config.type_encoder:
self.type_encoder = MT5EncoderModel(config=self.type_config)
else:
self.type_encoder = AutoModel.from_config(self.type_config)
if config.loss_fn == "focal":
self.loss_fn = FocalLoss(alpha=config.focal_alpha, gamma=config.focal_gamma)
elif config.loss_fn == "bce":
self.loss_fn = BCELoss()
else:
raise ValueError(f"Invalid loss function: {config.loss_fn}")
def _init_weights(self, module):
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, torch.nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, torch.nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def gather_spans(self, hidden_states, span_indices):
_, _, H = hidden_states.shape
expanded_indices = span_indices.unsqueeze(2).expand(-1, -1, H)
span_representations = torch.gather(hidden_states, 1, expanded_indices)
return span_representations
def forward(
self,
token_encoder_inputs: dict = None,
type_encoder_inputs: dict = None,
labels: dict = None
):
token_embeds = self.token_encoder(**token_encoder_inputs)
type_embeds = self.type_encoder(**type_encoder_inputs)
token_output = token_embeds.last_hidden_state
if self.config.type_encoder_pooling == "mean":
if type_encoder_inputs["attention_mask"] is not None:
attention_mask_expanded = type_encoder_inputs["attention_mask"].unsqueeze(-1).expand(type_embeds.last_hidden_state.size()).float()
sum_embeddings = torch.sum(type_embeds.last_hidden_state * attention_mask_expanded, dim=1)
sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9)
type_output = sum_embeddings / sum_mask
else:
type_output = type_embeds.last_hidden_state.mean(dim=1)
else:
type_output = type_embeds.last_hidden_state[:, 0, :]
token_start_output = F.normalize(self.dropout(self.token_start_linear(token_output)), dim=-1)
token_end_output = F.normalize(self.dropout(self.token_end_linear(token_output)), dim=-1)
type_output = F.normalize(self.dropout(self.type_linear(type_output)), dim=-1)
start_scores = self.start_logit_scale.exp() * torch.einsum("BSH,CH->BCS", token_start_output, type_output)
end_scores = self.end_logit_scale.exp() * torch.einsum("BSH,CH->BCS", token_end_output, type_output)
span_width_embeddings = self.width_embedding(labels["span_lengths"])
span_hidden = torch.cat(
[
self.gather_spans(token_start_output, labels["span_subword_indices"][:, :, 0]),
self.gather_spans(token_end_output, labels["span_subword_indices"][:, :, 1]),
span_width_embeddings,
],
dim=2
)
token_span_output = F.normalize(self.dropout(self.token_span_linear(span_hidden)), dim=-1)
span_scores = self.span_logit_scale.exp() * torch.einsum("BSH,CH->BCS", token_span_output, type_output)
if labels is not None and self.training:
start_pos_weight = None
if self.config.bce_start_pos_weight is not None:
start_pos_weight = torch.tensor(self.config.bce_start_pos_weight, device=start_scores.device, dtype=start_scores.dtype)
end_pos_weight = None
if self.config.bce_end_pos_weight is not None:
end_pos_weight = torch.tensor(self.config.bce_end_pos_weight, device=end_scores.device, dtype=end_scores.dtype)
span_pos_weight = None
if self.config.bce_span_pos_weight is not None:
span_pos_weight = torch.tensor(self.config.bce_span_pos_weight, device=span_scores.device, dtype=span_scores.dtype)
start_loss = self.loss_fn(
start_scores,
labels["start_labels"],
mask=labels["valid_start_mask"],
pos_weight=start_pos_weight
)
end_loss = self.loss_fn(
end_scores,
labels["end_labels"],
mask=labels["valid_end_mask"],
pos_weight=end_pos_weight
)
span_loss = self.loss_fn(
span_scores,
labels["span_labels"],
mask=labels["valid_span_mask"],
pos_weight=span_pos_weight
)
loss = self.config.start_loss_weight * start_loss + self.config.end_loss_weight * end_loss + self.config.span_loss_weight * span_loss
return SpanModelOutput(loss=loss, start_logits=start_scores, end_logits=end_scores, span_logits=span_scores)
else:
return SpanModelOutput(start_logits=start_scores, end_logits=end_scores, span_logits=span_scores)
def predict(self, batch: dict, threshold: float = 0.5):
with torch.no_grad():
output = self.forward(
token_encoder_inputs=batch["token_encoder_inputs"],
type_encoder_inputs=batch["type_encoder_inputs"],
labels=batch["labels"]
)
predictions = compute_span_predictions(
span_logits=output.span_logits.detach().cpu().numpy(),
span_mask=batch["labels"]["valid_span_mask"].cpu().numpy(),
span_mapping=batch["labels"]["span_subword_indices"].cpu().numpy(),
id2label=batch["id2label"],
threshold=threshold
)
return predictions
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
if not isinstance(save_directory, Path):
save_directory = Path(save_directory)
super().save_pretrained(save_directory, **kwargs)
self.token_encoder.config.to_json_file(str(save_directory / "token_encoder_config.json"))
self.type_encoder.config.to_json_file(str(save_directory / "type_encoder_config.json"))
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
**kwargs,
):
base_path = Path(pretrained_model_name_or_path)
token_encoder_config_path = base_path / "token_encoder_config.json"
type_encoder_config_path = base_path / "type_encoder_config.json"
if token_encoder_config_path.exists():
token_config = AutoConfig.from_pretrained(token_encoder_config_path)
kwargs["token_config"] = token_config
if type_encoder_config_path.exists():
type_config = AutoConfig.from_pretrained(type_encoder_config_path)
kwargs["type_config"] = type_config
return super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
**kwargs,
)
def gradient_checkpointing_enable(self):
self.token_encoder.gradient_checkpointing_enable()
self.type_encoder.gradient_checkpointing_enable()