Spaces:
Running
Running
| """ | |
| Tokenizer facade for LexiMind. | |
| Wraps HuggingFace tokenizers with a simplified interface that handles | |
| special token management, batch encoding, and T5-specific conventions | |
| for decoder input preparation. | |
| Author: Oliver Perrin | |
| Date: December 2025 | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Iterable, List, Sequence, cast | |
| import torch | |
| from transformers import AutoTokenizer, PreTrainedTokenizerBase | |
| class TokenizerConfig: | |
| pretrained_model_name: str = "google/flan-t5-base" | |
| max_length: int = 512 | |
| padding: str = "max_length" | |
| truncation: bool = True | |
| lower: bool = False | |
| class Tokenizer: | |
| """Lightweight façade over a HuggingFace tokenizer.""" | |
| def __init__(self, config: TokenizerConfig | None = None) -> None: | |
| cfg = config or TokenizerConfig() | |
| self.config = cfg | |
| self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained( | |
| cfg.pretrained_model_name | |
| ) | |
| self._pad_token_id = self._resolve_id(self._tokenizer.pad_token_id) | |
| # T5 uses different special tokens than BART: | |
| # T5: pad=0, eos=1, no explicit bos (uses pad or eos as decoder start) | |
| # BART: bos=0, pad=1, eos=2 | |
| # We use eos_token_id as bos for T5 decoder start (common practice) | |
| eos_id = self._tokenizer.eos_token_id | |
| bos_id = self._tokenizer.bos_token_id | |
| # For T5, decoder_start_token_id is typically pad_token_id (0) | |
| # But we'll use a sensible default based on what's available | |
| if bos_id is not None: | |
| self._bos_token_id = self._resolve_id(bos_id) | |
| elif ( | |
| hasattr(self._tokenizer, "decoder_start_token_id") | |
| and self._tokenizer.decoder_start_token_id is not None | |
| ): | |
| self._bos_token_id = self._resolve_id(self._tokenizer.decoder_start_token_id) | |
| else: | |
| # T5 convention: use pad_token_id as decoder start | |
| self._bos_token_id = self._pad_token_id | |
| self._eos_token_id = self._resolve_id( | |
| eos_id if eos_id is not None else self._tokenizer.sep_token_id | |
| ) | |
| def tokenizer(self) -> PreTrainedTokenizerBase: | |
| return self._tokenizer | |
| def pad_token_id(self) -> int: | |
| return self._pad_token_id | |
| def bos_token_id(self) -> int: | |
| return self._bos_token_id | |
| def eos_token_id(self) -> int: | |
| return self._eos_token_id | |
| def vocab_size(self) -> int: | |
| vocab = getattr(self._tokenizer, "vocab_size", None) | |
| if vocab is None: | |
| raise RuntimeError("Tokenizer must expose vocab_size") | |
| return int(vocab) | |
| def _resolve_id(value) -> int: | |
| if value is None: | |
| raise ValueError("Tokenizer is missing required special token ids") | |
| if isinstance(value, (list, tuple)): | |
| value = value[0] | |
| return int(value) | |
| def encode(self, text: str) -> List[int]: | |
| content = text.lower() if self.config.lower else text | |
| return cast( | |
| List[int], | |
| self._tokenizer.encode( | |
| content, | |
| max_length=self.config.max_length, | |
| truncation=self.config.truncation, | |
| padding=self.config.padding, | |
| ), | |
| ) | |
| def encode_batch(self, texts: Sequence[str]) -> List[List[int]]: | |
| normalized = (text.lower() if self.config.lower else text for text in texts) | |
| encoded = self._tokenizer.batch_encode_plus( | |
| list(normalized), | |
| max_length=self.config.max_length, | |
| padding=self.config.padding, | |
| truncation=self.config.truncation, | |
| return_attention_mask=False, | |
| return_tensors=None, | |
| ) | |
| return cast(List[List[int]], encoded["input_ids"]) | |
| def batch_encode( | |
| self, texts: Sequence[str], *, max_length: int | None = None | |
| ) -> dict[str, torch.Tensor]: | |
| normalized = [text.lower() if self.config.lower else text for text in texts] | |
| encoded = self._tokenizer( | |
| normalized, | |
| padding=self.config.padding, | |
| truncation=self.config.truncation, | |
| max_length=max_length or self.config.max_length, | |
| return_tensors="pt", | |
| ) | |
| input_ids = cast(torch.Tensor, encoded["input_ids"]) | |
| attention_mask = cast(torch.Tensor, encoded["attention_mask"]) | |
| if input_ids.dtype != torch.long: | |
| input_ids = input_ids.to(dtype=torch.long) | |
| if attention_mask.dtype != torch.bool: | |
| attention_mask = attention_mask.to(dtype=torch.bool) | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| } | |
| def decode(self, token_ids: Iterable[int]) -> str: | |
| return cast(str, self._tokenizer.decode(list(token_ids), skip_special_tokens=True)) | |
| def decode_batch(self, sequences: Sequence[Sequence[int]]) -> List[str]: | |
| prepared = [list(seq) for seq in sequences] | |
| return cast(List[str], self._tokenizer.batch_decode(prepared, skip_special_tokens=True)) | |
| def prepare_decoder_inputs(self, labels: torch.Tensor) -> torch.Tensor: | |
| """Shift decoder labels to create input ids prefixed by BOS.""" | |
| bos = self.bos_token_id | |
| pad = self.pad_token_id | |
| decoder_inputs = torch.full_like(labels, pad) | |
| decoder_inputs[:, 0] = bos | |
| decoder_inputs[:, 1:] = labels[:, :-1] | |
| return decoder_inputs | |