File size: 5,624 Bytes
590a604
 
 
 
 
 
 
 
ee1a8a3
1fbc47b
 
590a604
1fbc47b
 
 
 
a18e93d
 
 
 
 
 
 
 
1fbc47b
 
590a604
 
1fbc47b
 
590a604
1fbc47b
a18e93d
 
 
 
 
 
 
1fbc47b
 
 
 
590a604
 
 
1fbc47b
590a604
 
1fbc47b
590a604
 
a18e93d
1ec7405
 
 
 
 
 
 
1fbc47b
 
590a604
 
f0493d8
1fbc47b
 
 
 
 
 
 
a18e93d
 
 
1fbc47b
 
 
 
590a604
 
1fbc47b
590a604
 
 
 
1fbc47b
 
 
 
 
 
 
 
590a604
1fbc47b
a18e93d
 
 
1fbc47b
 
 
 
590a604
 
1fbc47b
a18e93d
590a604
 
a18e93d
1fbc47b
 
 
 
 
 
 
590a604
 
 
1fbc47b
 
 
 
 
 
 
 
b43ba56
 
1fbc47b
590a604
1fbc47b
 
 
 
 
b43ba56
 
 
 
 
 
 
590a604
b43ba56
1fbc47b
 
 
 
 
 
 
 
 
b43ba56
 
1fbc47b
590a604
1fbc47b
b43ba56
 
 
 
 
 
 
590a604
b43ba56
1fbc47b
 
 
 
 
 
 
 
 
b43ba56
 
1fbc47b
590a604
1fbc47b
b43ba56
 
 
 
 
 
 
590a604
b43ba56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
DataLoader builders for LexiMind.

Task-specific collators and factory functions for summarization, emotion, and topic.

Author: Oliver Perrin
Date: December 2025
"""

from __future__ import annotations

from typing import Dict, List

import torch
from torch.utils.data import DataLoader

from .dataset import (
    EmotionDataset,
    EmotionExample,
    SummarizationDataset,
    SummarizationExample,
    TopicDataset,
    TopicExample,
)
from .tokenization import Tokenizer

# --------------- Collators ---------------


class SummarizationCollator:
    """Prepare encoder-decoder batches for seq2seq summarization."""

    def __init__(
        self,
        tokenizer: Tokenizer,
        *,
        max_source_length: int | None = None,
        max_target_length: int | None = None,
    ) -> None:
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length

    def __call__(self, batch: List[SummarizationExample]) -> Dict[str, torch.Tensor]:
        sources = [ex.source for ex in batch]
        targets = [ex.summary for ex in batch]

        src_enc = self.tokenizer.batch_encode(sources, max_length=self.max_source_length)
        tgt_enc = self.tokenizer.batch_encode(targets, max_length=self.max_target_length)

        ids = tgt_enc["input_ids"]
        mask = tgt_enc["attention_mask"]

        # Create labels for loss: mask padding with -100
        labels = ids.clone()
        labels[mask == 0] = -100

        # Create decoder inputs from original ids (no -100)
        # prepare_decoder_inputs shifts right and adds BOS
        tgt_ids = self.tokenizer.prepare_decoder_inputs(ids)

        return {
            "src_ids": src_enc["input_ids"],
            "src_mask": src_enc["attention_mask"],
            "tgt_ids": tgt_ids,
            "labels": labels,
        }


class EmotionCollator:
    """Prepare batches for multi-label emotion classification."""

    def __init__(
        self, tokenizer: Tokenizer, dataset: EmotionDataset, *, max_length: int | None = None
    ) -> None:
        self.tokenizer = tokenizer
        self.binarizer = dataset.binarizer
        self.max_length = max_length

    def __call__(self, batch: List[EmotionExample]) -> Dict[str, torch.Tensor]:
        texts = [ex.text for ex in batch]
        encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
        labels = torch.as_tensor(
            self.binarizer.transform([ex.emotions for ex in batch]),
            dtype=torch.float32,
        )
        return {
            "input_ids": encoded["input_ids"],
            "attention_mask": encoded["attention_mask"],
            "labels": labels,
        }


class TopicCollator:
    """Prepare batches for single-label topic classification."""

    def __init__(
        self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None
    ) -> None:
        self.tokenizer = tokenizer
        self.encoder = dataset.encoder
        self.max_length = max_length

    def __call__(self, batch: List[TopicExample]) -> Dict[str, torch.Tensor]:
        texts = [ex.text for ex in batch]
        encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
        labels = torch.as_tensor(
            self.encoder.transform([ex.topic for ex in batch]),
            dtype=torch.long,
        )
        return {
            "input_ids": encoded["input_ids"],
            "attention_mask": encoded["attention_mask"],
            "labels": labels,
        }


# --------------- Factory Functions ---------------


def build_summarization_dataloader(
    dataset: SummarizationDataset,
    tokenizer: Tokenizer,
    *,
    batch_size: int,
    shuffle: bool = True,
    max_source_length: int | None = None,
    max_target_length: int | None = None,
    num_workers: int = 0,
    pin_memory: bool = False,
) -> DataLoader:
    """Create dataloader for summarization task."""
    collator = SummarizationCollator(
        tokenizer,
        max_source_length=max_source_length,
        max_target_length=max_target_length,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=collator,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=num_workers > 0,  # Keep workers alive between epochs
    )


def build_emotion_dataloader(
    dataset: EmotionDataset,
    tokenizer: Tokenizer,
    *,
    batch_size: int,
    shuffle: bool = True,
    max_length: int | None = None,
    num_workers: int = 0,
    pin_memory: bool = False,
) -> DataLoader:
    """Create dataloader for emotion classification task."""
    collator = EmotionCollator(tokenizer, dataset, max_length=max_length)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=collator,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=num_workers > 0,
    )


def build_topic_dataloader(
    dataset: TopicDataset,
    tokenizer: Tokenizer,
    *,
    batch_size: int,
    shuffle: bool = True,
    max_length: int | None = None,
    num_workers: int = 0,
    pin_memory: bool = False,
) -> DataLoader:
    """Create dataloader for topic classification task."""
    collator = TopicCollator(tokenizer, dataset, max_length=max_length)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=collator,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=num_workers > 0,
    )