LexiMind / src /data /dataloader.py
OliverPerrin
Update LexiMind: improved training, model architecture, and evaluation
1ec7405
"""
DataLoader builders for LexiMind.
Task-specific collators and factory functions for summarization, emotion, and topic.
Author: Oliver Perrin
Date: December 2025
"""
from __future__ import annotations
from typing import Dict, List
import torch
from torch.utils.data import DataLoader
from .dataset import (
EmotionDataset,
EmotionExample,
SummarizationDataset,
SummarizationExample,
TopicDataset,
TopicExample,
)
from .tokenization import Tokenizer
# --------------- Collators ---------------
class SummarizationCollator:
"""Prepare encoder-decoder batches for seq2seq summarization."""
def __init__(
self,
tokenizer: Tokenizer,
*,
max_source_length: int | None = None,
max_target_length: int | None = None,
) -> None:
self.tokenizer = tokenizer
self.max_source_length = max_source_length
self.max_target_length = max_target_length
def __call__(self, batch: List[SummarizationExample]) -> Dict[str, torch.Tensor]:
sources = [ex.source for ex in batch]
targets = [ex.summary for ex in batch]
src_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
tgt_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)
ids = tgt_enc["input_ids"]
mask = tgt_enc["attention_mask"]
# Create labels for loss: mask padding with -100
labels = ids.clone()
labels[mask == 0] = -100
# Create decoder inputs from original ids (no -100)
# prepare_decoder_inputs shifts right and adds BOS
tgt_ids = self.tokenizer.prepare_decoder_inputs(ids)
return {
"src_ids": src_enc["input_ids"],
"src_mask": src_enc["attention_mask"],
"tgt_ids": tgt_ids,
"labels": labels,
}
class EmotionCollator:
"""Prepare batches for multi-label emotion classification."""
def __init__(
self, tokenizer: Tokenizer, dataset: EmotionDataset, *, max_length: int | None = None
) -> None:
self.tokenizer = tokenizer
self.binarizer = dataset.binarizer
self.max_length = max_length
def __call__(self, batch: List[EmotionExample]) -> Dict[str, torch.Tensor]:
texts = [ex.text for ex in batch]
encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
labels = torch.as_tensor(
self.binarizer.transform([ex.emotions for ex in batch]),
dtype=torch.float32,
)
return {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"labels": labels,
}
class TopicCollator:
"""Prepare batches for single-label topic classification."""
def __init__(
self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None
) -> None:
self.tokenizer = tokenizer
self.encoder = dataset.encoder
self.max_length = max_length
def __call__(self, batch: List[TopicExample]) -> Dict[str, torch.Tensor]:
texts = [ex.text for ex in batch]
encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
labels = torch.as_tensor(
self.encoder.transform([ex.topic for ex in batch]),
dtype=torch.long,
)
return {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"labels": labels,
}
# --------------- Factory Functions ---------------
def build_summarization_dataloader(
dataset: SummarizationDataset,
tokenizer: Tokenizer,
*,
batch_size: int,
shuffle: bool = True,
max_source_length: int | None = None,
max_target_length: int | None = None,
num_workers: int = 0,
pin_memory: bool = False,
) -> DataLoader:
"""Create dataloader for summarization task."""
collator = SummarizationCollator(
tokenizer,
max_source_length=max_source_length,
max_target_length=max_target_length,
)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collator,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=num_workers > 0, # Keep workers alive between epochs
)
def build_emotion_dataloader(
dataset: EmotionDataset,
tokenizer: Tokenizer,
*,
batch_size: int,
shuffle: bool = True,
max_length: int | None = None,
num_workers: int = 0,
pin_memory: bool = False,
) -> DataLoader:
"""Create dataloader for emotion classification task."""
collator = EmotionCollator(tokenizer, dataset, max_length=max_length)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collator,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=num_workers > 0,
)
def build_topic_dataloader(
dataset: TopicDataset,
tokenizer: Tokenizer,
*,
batch_size: int,
shuffle: bool = True,
max_length: int | None = None,
num_workers: int = 0,
pin_memory: bool = False,
) -> DataLoader:
"""Create dataloader for topic classification task."""
collator = TopicCollator(tokenizer, dataset, max_length=max_length)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collator,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=num_workers > 0,
)