Instructions to use whoisjones/otter-bi-mmbert with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use whoisjones/otter-bi-mmbert with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="whoisjones/otter-bi-mmbert", trust_remote_code=True)# Load model directly from transformers import AutoModelForTokenClassification model = AutoModelForTokenClassification.from_pretrained("whoisjones/otter-bi-mmbert", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os | |
| from dataclasses import dataclass | |
| from typing import Optional, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoModel, AutoConfig, PreTrainedModel, MT5EncoderModel, PretrainedConfig | |
| from transformers.file_utils import ModelOutput | |
| from pathlib import Path | |
| from .configuration_biencoder import SpanModelConfig | |
| from .metrics import compute_span_predictions | |
| from .loss import BCELoss, FocalLoss | |
| class SpanModelOutput(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| start_logits: torch.FloatTensor = None | |
| end_logits: torch.FloatTensor = None | |
| span_logits: torch.FloatTensor = None | |
| def mlp(input_size, output_size, dropout): | |
| return torch.nn.Sequential( | |
| torch.nn.Linear(input_size, output_size), | |
| torch.nn.Dropout(dropout), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(output_size, output_size), | |
| ) | |
| class OtterBiEncoderModel(PreTrainedModel): | |
| config_class = SpanModelConfig | |
| def __init__(self, config, token_config=None, type_config=None): | |
| super().__init__(config) | |
| self.config = config | |
| self.token_config = token_config | |
| self.type_config = type_config | |
| if self.token_config is None: | |
| self.token_config = AutoConfig.from_pretrained(config.token_encoder) | |
| if self.type_config is None: | |
| self.type_config = AutoConfig.from_pretrained(config.type_encoder) | |
| self.max_span_length = config.max_span_length | |
| self.dropout = torch.nn.Dropout(config.dropout) | |
| self.linear_hidden_size = config.linear_hidden_size | |
| self.config.pruned_heads = self.token_config.pruned_heads | |
| self.type_linear = mlp(self.type_config.hidden_size, config.linear_hidden_size, config.dropout) | |
| self.token_start_linear = mlp(self.token_config.hidden_size, config.linear_hidden_size, config.dropout) | |
| self.token_end_linear = mlp(self.token_config.hidden_size, config.linear_hidden_size, config.dropout) | |
| self.token_span_linear = mlp(config.linear_hidden_size * 2 + config.span_width_embedding_size, config.linear_hidden_size, config.dropout) | |
| self.fusion_linear = mlp(config.linear_hidden_size * 2, config.linear_hidden_size, config.dropout) | |
| self.width_embedding = torch.nn.Embedding(config.max_span_length + 1, config.span_width_embedding_size, padding_idx=0) | |
| self.start_logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / config.init_temperature)) | |
| self.end_logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / config.init_temperature)) | |
| self.span_logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / config.init_temperature)) | |
| self.post_init() | |
| if "mt5" in config.token_encoder: | |
| self.token_encoder = MT5EncoderModel(config=self.token_config) | |
| else: | |
| self.token_encoder = AutoModel.from_config(self.token_config) | |
| if "mt5" in config.type_encoder: | |
| self.type_encoder = MT5EncoderModel(config=self.type_config) | |
| else: | |
| self.type_encoder = AutoModel.from_config(self.type_config) | |
| if config.loss_fn == "focal": | |
| self.loss_fn = FocalLoss(alpha=config.focal_alpha, gamma=config.focal_gamma) | |
| elif config.loss_fn == "bce": | |
| self.loss_fn = BCELoss() | |
| else: | |
| raise ValueError(f"Invalid loss function: {config.loss_fn}") | |
| def _init_weights(self, module): | |
| if isinstance(module, torch.nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, torch.nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, torch.nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def gather_spans(self, hidden_states, span_indices): | |
| _, _, H = hidden_states.shape | |
| expanded_indices = span_indices.unsqueeze(2).expand(-1, -1, H) | |
| span_representations = torch.gather(hidden_states, 1, expanded_indices) | |
| return span_representations | |
| def forward( | |
| self, | |
| token_encoder_inputs: dict = None, | |
| type_encoder_inputs: dict = None, | |
| labels: dict = None | |
| ): | |
| token_embeds = self.token_encoder(**token_encoder_inputs) | |
| type_embeds = self.type_encoder(**type_encoder_inputs) | |
| token_output = token_embeds.last_hidden_state | |
| if self.config.type_encoder_pooling == "mean": | |
| if type_encoder_inputs["attention_mask"] is not None: | |
| attention_mask_expanded = type_encoder_inputs["attention_mask"].unsqueeze(-1).expand(type_embeds.last_hidden_state.size()).float() | |
| sum_embeddings = torch.sum(type_embeds.last_hidden_state * attention_mask_expanded, dim=1) | |
| sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9) | |
| type_output = sum_embeddings / sum_mask | |
| else: | |
| type_output = type_embeds.last_hidden_state.mean(dim=1) | |
| else: | |
| type_output = type_embeds.last_hidden_state[:, 0, :] | |
| token_start_output = F.normalize(self.dropout(self.token_start_linear(token_output)), dim=-1) | |
| token_end_output = F.normalize(self.dropout(self.token_end_linear(token_output)), dim=-1) | |
| type_output = F.normalize(self.dropout(self.type_linear(type_output)), dim=-1) | |
| start_scores = self.start_logit_scale.exp() * torch.einsum("BSH,CH->BCS", token_start_output, type_output) | |
| end_scores = self.end_logit_scale.exp() * torch.einsum("BSH,CH->BCS", token_end_output, type_output) | |
| span_width_embeddings = self.width_embedding(labels["span_lengths"]) | |
| span_hidden = torch.cat( | |
| [ | |
| self.gather_spans(token_start_output, labels["span_subword_indices"][:, :, 0]), | |
| self.gather_spans(token_end_output, labels["span_subword_indices"][:, :, 1]), | |
| span_width_embeddings, | |
| ], | |
| dim=2 | |
| ) | |
| token_span_output = F.normalize(self.dropout(self.token_span_linear(span_hidden)), dim=-1) | |
| span_scores = self.span_logit_scale.exp() * torch.einsum("BSH,CH->BCS", token_span_output, type_output) | |
| if labels is not None and self.training: | |
| start_pos_weight = None | |
| if self.config.bce_start_pos_weight is not None: | |
| start_pos_weight = torch.tensor(self.config.bce_start_pos_weight, device=start_scores.device, dtype=start_scores.dtype) | |
| end_pos_weight = None | |
| if self.config.bce_end_pos_weight is not None: | |
| end_pos_weight = torch.tensor(self.config.bce_end_pos_weight, device=end_scores.device, dtype=end_scores.dtype) | |
| span_pos_weight = None | |
| if self.config.bce_span_pos_weight is not None: | |
| span_pos_weight = torch.tensor(self.config.bce_span_pos_weight, device=span_scores.device, dtype=span_scores.dtype) | |
| start_loss = self.loss_fn( | |
| start_scores, | |
| labels["start_labels"], | |
| mask=labels["valid_start_mask"], | |
| pos_weight=start_pos_weight | |
| ) | |
| end_loss = self.loss_fn( | |
| end_scores, | |
| labels["end_labels"], | |
| mask=labels["valid_end_mask"], | |
| pos_weight=end_pos_weight | |
| ) | |
| span_loss = self.loss_fn( | |
| span_scores, | |
| labels["span_labels"], | |
| mask=labels["valid_span_mask"], | |
| pos_weight=span_pos_weight | |
| ) | |
| loss = self.config.start_loss_weight * start_loss + self.config.end_loss_weight * end_loss + self.config.span_loss_weight * span_loss | |
| return SpanModelOutput(loss=loss, start_logits=start_scores, end_logits=end_scores, span_logits=span_scores) | |
| else: | |
| return SpanModelOutput(start_logits=start_scores, end_logits=end_scores, span_logits=span_scores) | |
| def predict(self, batch: dict, threshold: float = 0.5): | |
| with torch.no_grad(): | |
| output = self.forward( | |
| token_encoder_inputs=batch["token_encoder_inputs"], | |
| type_encoder_inputs=batch["type_encoder_inputs"], | |
| labels=batch["labels"] | |
| ) | |
| predictions = compute_span_predictions( | |
| span_logits=output.span_logits.detach().cpu().numpy(), | |
| span_mask=batch["labels"]["valid_span_mask"].cpu().numpy(), | |
| span_mapping=batch["labels"]["span_subword_indices"].cpu().numpy(), | |
| id2label=batch["id2label"], | |
| threshold=threshold | |
| ) | |
| return predictions | |
| def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): | |
| if not isinstance(save_directory, Path): | |
| save_directory = Path(save_directory) | |
| super().save_pretrained(save_directory, **kwargs) | |
| self.token_encoder.config.to_json_file(str(save_directory / "token_encoder_config.json")) | |
| self.type_encoder.config.to_json_file(str(save_directory / "type_encoder_config.json")) | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], | |
| *model_args, | |
| config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, | |
| **kwargs, | |
| ): | |
| base_path = Path(pretrained_model_name_or_path) | |
| token_encoder_config_path = base_path / "token_encoder_config.json" | |
| type_encoder_config_path = base_path / "type_encoder_config.json" | |
| if token_encoder_config_path.exists(): | |
| token_config = AutoConfig.from_pretrained(token_encoder_config_path) | |
| kwargs["token_config"] = token_config | |
| if type_encoder_config_path.exists(): | |
| type_config = AutoConfig.from_pretrained(type_encoder_config_path) | |
| kwargs["type_config"] = type_config | |
| return super().from_pretrained( | |
| pretrained_model_name_or_path, | |
| *model_args, | |
| config=config, | |
| **kwargs, | |
| ) | |
| def gradient_checkpointing_enable(self): | |
| self.token_encoder.gradient_checkpointing_enable() | |
| self.type_encoder.gradient_checkpointing_enable() | |