Spaces:
Running on CPU Upgrade

whisky-wheel / lib /bert_regressor.py
ziem-io's picture
New: Implement cleanup functionality
9b9fd34
raw
history blame
5.8 kB
import torch
import torch.nn as nn
from transformers import AutoModel
###################################################################################
# Erweiterte Regressorklasse: Ein gemeinsamer Encoder, aber mehrere unabhängige Köpfe
class BertMultiHeadRegressor(nn.Module):
"""
Mehrkopf-Regression auf einem beliebigen HF-Encoder (BERT/RoBERTa/DeBERTa/ModernBERT).
- Gemeinsamer Encoder
- n unabhängige Regressionsköpfe (je 1 Wert)
- Robustes Pooling (Pooler wenn vorhanden, sonst maskiertes Mean)
- Partielles Unfreezen ab `unfreeze_from`
"""
def __init__(self, pretrained_model_name: str,
n_heads: int = 8,
unfreeze_from: int = 8,
dropout: float = 0.1):
super().__init__()
# Beliebigen Encoder laden
self.encoder = AutoModel.from_pretrained(
pretrained_model_name,
low_cpu_mem_usage=False # vermeidet accelerate-Abhängigkeit zur Init
)
hidden_size = self.encoder.config.hidden_size
# Erst alles einfrieren …
for p in self.encoder.parameters():
p.requires_grad = False
# … dann Layer ab `unfreeze_from` freigeben (falls vorhanden)
# Die meisten Encoder haben `.encoder.layer`
encoder_block = getattr(self.encoder, "encoder", None)
layers = getattr(encoder_block, "layer", None)
if layers is not None:
for layer in layers[unfreeze_from:]:
for p in layer.parameters():
p.requires_grad = True
else:
# Fallback: wenn kein klassisches Lagen-Array existiert, nichts tun
pass
self.dropout = nn.Dropout(dropout)
self.heads = nn.ModuleList([nn.Linear(hidden_size, 1) for _ in range(n_heads)])
def _pool(self, outputs, attention_mask):
"""
Robustes Pooling:
- Wenn pooler_output vorhanden: nutzen (BERT/RoBERTa)
- Sonst: maskiertes Mean-Pooling über last_hidden_state (z. B. DeBERTaV3)
"""
pooler = getattr(outputs, "pooler_output", None)
if pooler is not None:
return pooler # [B, H]
last_hidden = outputs.last_hidden_state # [B, T, H]
mask = attention_mask.unsqueeze(-1).float() # [B, T, 1]
summed = (last_hidden * mask).sum(dim=1) # [B, H]
denom = mask.sum(dim=1).clamp(min=1e-6) # [B, 1]
return summed / denom
def forward(self, input_ids, attention_mask, token_type_ids=None):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids if token_type_ids is not None else None,
return_dict=True
)
pooled = self._pool(outputs, attention_mask) # [B, H]
pooled = self.dropout(pooled)
preds = [head(pooled) for head in self.heads] # n × [B, 1]
return torch.cat(preds, dim=1) # [B, n_heads]
###################################################################################
class BertBinaryClassifier(nn.Module):
def __init__(self, pretrained_model_name='distilbert-base-uncased', unfreeze_from=4, dropout=0.3):
super(BertBinaryClassifier, self).__init__()
# Modell laden (funktioniert für BERT und DistilBERT)
self.bert = AutoModel.from_pretrained(pretrained_model_name)
# Speichern, ob es DistilBERT ist (wichtig für forward pass und layer-Zugriff)
# Wir prüfen einfach, ob 'distilbert' im Namen vorkommt
self.is_distilbert = 'distilbert' in pretrained_model_name.lower()
# Alle Layer zunächst einfrieren
for param in self.bert.parameters():
param.requires_grad = False
# --- SMART UNFREEZE LOGIK ---
if self.is_distilbert:
# DistilBERT Struktur: transformer.layer
# Hinweis: DistilBERT hat nur 6 Layer. unfreeze_from sollte also z.B. 3 oder 4 sein.
layers_to_unfreeze = self.bert.transformer.layer[unfreeze_from:]
else:
# Standard BERT Struktur: encoder.layer
layers_to_unfreeze = self.bert.encoder.layer[unfreeze_from:]
# Die ausgewählten Layer wieder "auftauen"
for layer in layers_to_unfreeze:
for param in layer.parameters():
param.requires_grad = True
# Dropout
self.dropout = nn.Dropout(dropout)
# Klassifikationskopf
self.classifier = nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
# Forward pass durch das Base Model
# Bei DistilBERT darf man keine token_type_ids übergeben, AutoModel regelt das meist,
# aber wir übergeben hier explizit nur ids und mask.
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
# --- SMART POOLING LOGIK ---
if self.is_distilbert:
# DistilBERT hat keinen pooler_output.
# Wir nehmen den [CLS] Token (Index 0) vom last_hidden_state
# Shape: [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size]
pooled_output = outputs.last_hidden_state[:, 0]
else:
# Standard BERT hat einen pooler_output (bereits durch Tanh aktiviert)
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
pooled_output = outputs.pooler_output
else:
# Fallback, falls ein BERT-Modell ohne Pooler genutzt wird
pooled_output = outputs.last_hidden_state[:, 0]
# Dropout
dropped = self.dropout(pooled_output)
# Klassifikation
logits = self.classifier(dropped)
return logits