Spaces:
Running on CPU Upgrade

File size: 5,802 Bytes
bb3d05e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbe7bbd
 
 
 
bb3d05e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b9fd34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn as nn
from transformers import AutoModel

###################################################################################

# Erweiterte Regressorklasse: Ein gemeinsamer Encoder, aber mehrere unabhängige Köpfe
class BertMultiHeadRegressor(nn.Module):
    """
    Mehrkopf-Regression auf einem beliebigen HF-Encoder (BERT/RoBERTa/DeBERTa/ModernBERT).
    - Gemeinsamer Encoder
    - n unabhängige Regressionsköpfe (je 1 Wert)
    - Robustes Pooling (Pooler wenn vorhanden, sonst maskiertes Mean)
    - Partielles Unfreezen ab `unfreeze_from`
    """
    def __init__(self, pretrained_model_name: str,
                 n_heads: int = 8,
                 unfreeze_from: int = 8,
                 dropout: float = 0.1):
        super().__init__()

        # Beliebigen Encoder laden
        self.encoder = AutoModel.from_pretrained(
            pretrained_model_name,
            low_cpu_mem_usage=False  # vermeidet accelerate-Abhängigkeit zur Init
        )
        hidden_size = self.encoder.config.hidden_size

        # Erst alles einfrieren …
        for p in self.encoder.parameters():
            p.requires_grad = False

        # … dann Layer ab `unfreeze_from` freigeben (falls vorhanden)
        # Die meisten Encoder haben `.encoder.layer`
        encoder_block = getattr(self.encoder, "encoder", None)
        layers = getattr(encoder_block, "layer", None)
        if layers is not None:
            for layer in layers[unfreeze_from:]:
                for p in layer.parameters():
                    p.requires_grad = True
        else:
            # Fallback: wenn kein klassisches Lagen-Array existiert, nichts tun
            pass

        self.dropout = nn.Dropout(dropout)
        self.heads = nn.ModuleList([nn.Linear(hidden_size, 1) for _ in range(n_heads)])

    def _pool(self, outputs, attention_mask):
        """
        Robustes Pooling:
        - Wenn pooler_output vorhanden: nutzen (BERT/RoBERTa)
        - Sonst: maskiertes Mean-Pooling über last_hidden_state (z. B. DeBERTaV3)
        """
        pooler = getattr(outputs, "pooler_output", None)
        if pooler is not None:
            return pooler  # [B, H]

        last_hidden = outputs.last_hidden_state  # [B, T, H]
        mask = attention_mask.unsqueeze(-1).float()  # [B, T, 1]
        summed = (last_hidden * mask).sum(dim=1)     # [B, H]
        denom = mask.sum(dim=1).clamp(min=1e-6)      # [B, 1]
        return summed / denom

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids if token_type_ids is not None else None,
            return_dict=True
        )
        pooled = self._pool(outputs, attention_mask)    # [B, H]
        pooled = self.dropout(pooled)
        preds = [head(pooled) for head in self.heads]   # n × [B, 1]
        return torch.cat(preds, dim=1)                  # [B, n_heads]

###################################################################################

class BertBinaryClassifier(nn.Module):
    def __init__(self, pretrained_model_name='distilbert-base-uncased', unfreeze_from=4, dropout=0.3):
        super(BertBinaryClassifier, self).__init__()

        # Modell laden (funktioniert für BERT und DistilBERT)
        self.bert = AutoModel.from_pretrained(pretrained_model_name)
        
        # Speichern, ob es DistilBERT ist (wichtig für forward pass und layer-Zugriff)
        # Wir prüfen einfach, ob 'distilbert' im Namen vorkommt
        self.is_distilbert = 'distilbert' in pretrained_model_name.lower()

        # Alle Layer zunächst einfrieren
        for param in self.bert.parameters():
            param.requires_grad = False

        # --- SMART UNFREEZE LOGIK ---
        if self.is_distilbert:
            # DistilBERT Struktur: transformer.layer
            # Hinweis: DistilBERT hat nur 6 Layer. unfreeze_from sollte also z.B. 3 oder 4 sein.
            layers_to_unfreeze = self.bert.transformer.layer[unfreeze_from:]
        else:
            # Standard BERT Struktur: encoder.layer
            layers_to_unfreeze = self.bert.encoder.layer[unfreeze_from:]

        # Die ausgewählten Layer wieder "auftauen"
        for layer in layers_to_unfreeze:
            for param in layer.parameters():
                param.requires_grad = True

        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Klassifikationskopf
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        # Forward pass durch das Base Model
        # Bei DistilBERT darf man keine token_type_ids übergeben, AutoModel regelt das meist,
        # aber wir übergeben hier explizit nur ids und mask.
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        # --- SMART POOLING LOGIK ---
        if self.is_distilbert:
            # DistilBERT hat keinen pooler_output.
            # Wir nehmen den [CLS] Token (Index 0) vom last_hidden_state
            # Shape: [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size]
            pooled_output = outputs.last_hidden_state[:, 0]
        else:
            # Standard BERT hat einen pooler_output (bereits durch Tanh aktiviert)
            if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
                pooled_output = outputs.pooler_output
            else:
                # Fallback, falls ein BERT-Modell ohne Pooler genutzt wird
                pooled_output = outputs.last_hidden_state[:, 0]

        # Dropout
        dropped = self.dropout(pooled_output)

        # Klassifikation
        logits = self.classifier(dropped)

        return logits