Spaces:
Running
Running
File size: 5,624 Bytes
590a604 ee1a8a3 1fbc47b 590a604 1fbc47b a18e93d 1fbc47b 590a604 1fbc47b 590a604 1fbc47b a18e93d 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 a18e93d 1ec7405 1fbc47b 590a604 f0493d8 1fbc47b a18e93d 1fbc47b 590a604 1fbc47b 590a604 1fbc47b 590a604 1fbc47b a18e93d 1fbc47b 590a604 1fbc47b a18e93d 590a604 a18e93d 1fbc47b 590a604 1fbc47b b43ba56 1fbc47b 590a604 1fbc47b b43ba56 590a604 b43ba56 1fbc47b b43ba56 1fbc47b 590a604 1fbc47b b43ba56 590a604 b43ba56 1fbc47b b43ba56 1fbc47b 590a604 1fbc47b b43ba56 590a604 b43ba56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
"""
DataLoader builders for LexiMind.
Task-specific collators and factory functions for summarization, emotion, and topic.
Author: Oliver Perrin
Date: December 2025
"""
from __future__ import annotations
from typing import Dict, List
import torch
from torch.utils.data import DataLoader
from .dataset import (
EmotionDataset,
EmotionExample,
SummarizationDataset,
SummarizationExample,
TopicDataset,
TopicExample,
)
from .tokenization import Tokenizer
# --------------- Collators ---------------
class SummarizationCollator:
"""Prepare encoder-decoder batches for seq2seq summarization."""
def __init__(
self,
tokenizer: Tokenizer,
*,
max_source_length: int | None = None,
max_target_length: int | None = None,
) -> None:
self.tokenizer = tokenizer
self.max_source_length = max_source_length
self.max_target_length = max_target_length
def __call__(self, batch: List[SummarizationExample]) -> Dict[str, torch.Tensor]:
sources = [ex.source for ex in batch]
targets = [ex.summary for ex in batch]
src_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
tgt_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)
ids = tgt_enc["input_ids"]
mask = tgt_enc["attention_mask"]
# Create labels for loss: mask padding with -100
labels = ids.clone()
labels[mask == 0] = -100
# Create decoder inputs from original ids (no -100)
# prepare_decoder_inputs shifts right and adds BOS
tgt_ids = self.tokenizer.prepare_decoder_inputs(ids)
return {
"src_ids": src_enc["input_ids"],
"src_mask": src_enc["attention_mask"],
"tgt_ids": tgt_ids,
"labels": labels,
}
class EmotionCollator:
"""Prepare batches for multi-label emotion classification."""
def __init__(
self, tokenizer: Tokenizer, dataset: EmotionDataset, *, max_length: int | None = None
) -> None:
self.tokenizer = tokenizer
self.binarizer = dataset.binarizer
self.max_length = max_length
def __call__(self, batch: List[EmotionExample]) -> Dict[str, torch.Tensor]:
texts = [ex.text for ex in batch]
encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
labels = torch.as_tensor(
self.binarizer.transform([ex.emotions for ex in batch]),
dtype=torch.float32,
)
return {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"labels": labels,
}
class TopicCollator:
"""Prepare batches for single-label topic classification."""
def __init__(
self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None
) -> None:
self.tokenizer = tokenizer
self.encoder = dataset.encoder
self.max_length = max_length
def __call__(self, batch: List[TopicExample]) -> Dict[str, torch.Tensor]:
texts = [ex.text for ex in batch]
encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
labels = torch.as_tensor(
self.encoder.transform([ex.topic for ex in batch]),
dtype=torch.long,
)
return {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"labels": labels,
}
# --------------- Factory Functions ---------------
def build_summarization_dataloader(
dataset: SummarizationDataset,
tokenizer: Tokenizer,
*,
batch_size: int,
shuffle: bool = True,
max_source_length: int | None = None,
max_target_length: int | None = None,
num_workers: int = 0,
pin_memory: bool = False,
) -> DataLoader:
"""Create dataloader for summarization task."""
collator = SummarizationCollator(
tokenizer,
max_source_length=max_source_length,
max_target_length=max_target_length,
)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collator,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=num_workers > 0, # Keep workers alive between epochs
)
def build_emotion_dataloader(
dataset: EmotionDataset,
tokenizer: Tokenizer,
*,
batch_size: int,
shuffle: bool = True,
max_length: int | None = None,
num_workers: int = 0,
pin_memory: bool = False,
) -> DataLoader:
"""Create dataloader for emotion classification task."""
collator = EmotionCollator(tokenizer, dataset, max_length=max_length)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collator,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=num_workers > 0,
)
def build_topic_dataloader(
dataset: TopicDataset,
tokenizer: Tokenizer,
*,
batch_size: int,
shuffle: bool = True,
max_length: int | None = None,
num_workers: int = 0,
pin_memory: bool = False,
) -> DataLoader:
"""Create dataloader for topic classification task."""
collator = TopicCollator(tokenizer, dataset, max_length=max_length)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collator,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=num_workers > 0,
)
|