Spaces:
Running
Running
File size: 5,562 Bytes
590a604 ee1a8a3 1fbc47b 2286a5e 1fbc47b b43ba56 1fbc47b b43ba56 1fbc47b ee1a8a3 1fbc47b b43ba56 1fbc47b b43ba56 1fbc47b 2286a5e 1fbc47b ee1a8a3 1fbc47b 2286a5e 1fbc47b 2286a5e 1fbc47b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
"""
Tokenizer facade for LexiMind.
Wraps HuggingFace tokenizers with a simplified interface that handles
special token management, batch encoding, and T5-specific conventions
for decoder input preparation.
Author: Oliver Perrin
Date: December 2025
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, List, Sequence, cast
import torch
from transformers import AutoTokenizer, PreTrainedTokenizerBase
@dataclass
class TokenizerConfig:
pretrained_model_name: str = "google/flan-t5-base"
max_length: int = 512
padding: str = "max_length"
truncation: bool = True
lower: bool = False
class Tokenizer:
"""Lightweight façade over a HuggingFace tokenizer."""
def __init__(self, config: TokenizerConfig | None = None) -> None:
cfg = config or TokenizerConfig()
self.config = cfg
self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
cfg.pretrained_model_name
)
self._pad_token_id = self._resolve_id(self._tokenizer.pad_token_id)
# T5 uses different special tokens than BART:
# T5: pad=0, eos=1, no explicit bos (uses pad or eos as decoder start)
# BART: bos=0, pad=1, eos=2
# We use eos_token_id as bos for T5 decoder start (common practice)
eos_id = self._tokenizer.eos_token_id
bos_id = self._tokenizer.bos_token_id
# For T5, decoder_start_token_id is typically pad_token_id (0)
# But we'll use a sensible default based on what's available
if bos_id is not None:
self._bos_token_id = self._resolve_id(bos_id)
elif (
hasattr(self._tokenizer, "decoder_start_token_id")
and self._tokenizer.decoder_start_token_id is not None
):
self._bos_token_id = self._resolve_id(self._tokenizer.decoder_start_token_id)
else:
# T5 convention: use pad_token_id as decoder start
self._bos_token_id = self._pad_token_id
self._eos_token_id = self._resolve_id(
eos_id if eos_id is not None else self._tokenizer.sep_token_id
)
@property
def tokenizer(self) -> PreTrainedTokenizerBase:
return self._tokenizer
@property
def pad_token_id(self) -> int:
return self._pad_token_id
@property
def bos_token_id(self) -> int:
return self._bos_token_id
@property
def eos_token_id(self) -> int:
return self._eos_token_id
@property
def vocab_size(self) -> int:
vocab = getattr(self._tokenizer, "vocab_size", None)
if vocab is None:
raise RuntimeError("Tokenizer must expose vocab_size")
return int(vocab)
@staticmethod
def _resolve_id(value) -> int:
if value is None:
raise ValueError("Tokenizer is missing required special token ids")
if isinstance(value, (list, tuple)):
value = value[0]
return int(value)
def encode(self, text: str) -> List[int]:
content = text.lower() if self.config.lower else text
return cast(
List[int],
self._tokenizer.encode(
content,
max_length=self.config.max_length,
truncation=self.config.truncation,
padding=self.config.padding,
),
)
def encode_batch(self, texts: Sequence[str]) -> List[List[int]]:
normalized = (text.lower() if self.config.lower else text for text in texts)
encoded = self._tokenizer.batch_encode_plus(
list(normalized),
max_length=self.config.max_length,
padding=self.config.padding,
truncation=self.config.truncation,
return_attention_mask=False,
return_tensors=None,
)
return cast(List[List[int]], encoded["input_ids"])
def batch_encode(
self, texts: Sequence[str], *, max_length: int | None = None
) -> dict[str, torch.Tensor]:
normalized = [text.lower() if self.config.lower else text for text in texts]
encoded = self._tokenizer(
normalized,
padding=self.config.padding,
truncation=self.config.truncation,
max_length=max_length or self.config.max_length,
return_tensors="pt",
)
input_ids = cast(torch.Tensor, encoded["input_ids"])
attention_mask = cast(torch.Tensor, encoded["attention_mask"])
if input_ids.dtype != torch.long:
input_ids = input_ids.to(dtype=torch.long)
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask.to(dtype=torch.bool)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
def decode(self, token_ids: Iterable[int]) -> str:
return cast(str, self._tokenizer.decode(list(token_ids), skip_special_tokens=True))
def decode_batch(self, sequences: Sequence[Sequence[int]]) -> List[str]:
prepared = [list(seq) for seq in sequences]
return cast(List[str], self._tokenizer.batch_decode(prepared, skip_special_tokens=True))
def prepare_decoder_inputs(self, labels: torch.Tensor) -> torch.Tensor:
"""Shift decoder labels to create input ids prefixed by BOS."""
bos = self.bos_token_id
pad = self.pad_token_id
decoder_inputs = torch.full_like(labels, pad)
decoder_inputs[:, 0] = bos
decoder_inputs[:, 1:] = labels[:, :-1]
return decoder_inputs
|