Spaces:
Running on CPU Upgrade

whisky-wheel / lib /bert_regressor_utils.py
ziem-io's picture
Update: Refactor code
9a2984d
raw
history blame
9.58 kB
import torch
from transformers import AutoTokenizer
from torch.utils.data import Dataset
import numpy as np
from .bert_regressor import BertMultiHeadRegressor
###################################################################################
# Konstante Liste der acht Aromen-Kategorien für Whisky-Tasting-Notes.
# Diese wird von Modellen und Evaluierungsfunktionen verwendet.
TARGET_COLUMNS = [
"grainy",
"grassy",
"fragrant",
"fruity",
"peated",
"woody",
"winey",
"off-notes"
]
###################################################################################
COLORS = {
"grainy": "#FFF3B0",
"grassy": "#C4F0C5",
"fragrant": "#F3C4FB",
"fruity": "#FFD6B0",
"peated": "#CFCFCF",
"woody": "#EAD6C7",
"winey": "#F7B7A3",
"off-notes": "#D6E4F0",
"quantifiers": "#ff8083"
}
ICONS = {
"grainy": "🌾",
"grassy": "🌿",
"fragrant": "🌸",
"fruity": "🍋",
"peated": "🔥",
"woody": "🌲",
"winey": "🍷",
"off-notes": "☠️"
}
###################################################################################
class WhiskyDataset(Dataset):
def __init__(self, texts, targets, tokenizer, max_len):
self.texts = texts
self.targets = targets
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, item):
text = str(self.texts[item])
target = self.targets[item]
# Einheitliche Tokenisierung über Hilfsfunktion
encoding = tokenize_input(text, self.tokenizer)
return {
'input_ids': encoding['input_ids'].squeeze(),
'attention_mask': encoding['attention_mask'].squeeze(),
'targets': torch.tensor(target, dtype=torch.float)
}
###################################################################################
def get_device(prefer_mps=True, verbose=True):
"""
Gibt das beste verfügbare Torch-Device zurück (MPS, CUDA oder CPU).
Args:
prefer_mps (bool): Ob bei Apple-Geräten 'mps' (Metal Performance Shaders) bevorzugt werden soll.
verbose (bool): Ob das erkannte Device ausgegeben werden soll.
Returns:
torch.device: Das beste verfügbare Gerät für das Training.
"""
if prefer_mps and torch.backends.mps.is_available():
device = torch.device("mps")
name = "Apple GPU (MPS)"
elif torch.cuda.is_available():
device = torch.device("cuda")
name = torch.cuda.get_device_name(device)
else:
device = torch.device("cpu")
name = "CPU"
if verbose:
print(f"✅ Verwendetes Gerät: {name} ({device})")
return device
###################################################################################
def tokenize_input(texts, tokenizer, max_len=256):
"""
Einheitliche Tokenisierung für Training und Inferenz.
Args:
texts (str or List[str]): Eingabetext(e).
tokenizer (PreTrainedTokenizer): z. B. BertTokenizer.
Returns:
dict: Dictionary mit PyTorch-Tensoren (input_ids, attention_mask).
"""
return tokenizer(
texts,
truncation=True,
padding='max_length',
max_length=max_len,
return_tensors='pt'
)
###################################################################################
def load_model_and_tokenizer(model_name, model_path):
"""
Ladefunktion für BertMultiHeadRegressor.
Args:
model_name (str): Name des vortrainierten BERT-Modells (z. B. 'bert-base-uncased').
model_path (str): Pfad zur gespeicherten Modellzustandsdatei (.pt).
Returns:
model (nn.Module): Geladenes Modell im Eval-Modus.
tokenizer (BertTokenizer): Passender Tokenizer.
device (torch.device): Verwendetes Rechengerät (CPU oder GPU).
"""
# Gerät automatisch ermitteln (GPU/CPU)
device = get_device()
# Modellzustand und Konfiguration laden
checkpoint = torch.load(model_path, map_location=device)
config = checkpoint["model_config"]
# Modell initialisieren
model = BertMultiHeadRegressor(
pretrained_model_name=config["pretrained_model_name"],
n_heads=config["n_heads"],
unfreeze_from=config["unfreeze_from"],
dropout=config["dropout"]
)
# Gewichtungen laden und Modell auf Gerät verschieben
model.to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval() # Wechselt in den Inferenzmodus
# Lädt den passenden Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer, device
###################################################################################
def predict_flavours(review_text, model, tokenizer, device, max_len=256):
# Modell in den Evaluierungsmodus setzen (kein Dropout etc.)
model.eval()
# Eingabetext tokenisieren und als Tensoren zurückgeben
encoding = tokenize_input(
review_text,
tokenizer
)
# Tokens auf das richtige Device verschieben
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
# Inferenz ohne Gradientenberechnung (Effizienz)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask) # shape: [1, 8]
prediction = outputs.cpu().numpy().flatten() # [8] – flach machen
prediction = np.clip(prediction, 0.0, 4.0)
# In ein Dictionary umwandeln (z. B. {"fruity": 2.1, "peated": 3.8, ...})
result = {
flavour: round(float(score), 2)
for flavour, score in zip(TARGET_COLUMNS, prediction)
}
return result
###################################################################################
def predict_is_review(review_text, model, tokenizer, device, max_len=256, threshold=0.5):
# Modell in den Evaluierungsmodus setzen (kein Dropout etc.)
model.eval()
# Eingabetext tokenisieren und als Tensoren zurückgeben
encoding = tokenize_input(
review_text,
tokenizer
)
# Tokens auf das richtige Device verschieben
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
print(outputs.cpu().numpy()) # <--- Zeigt die rohen Logits
probs = torch.sigmoid(outputs) # [1, 1]
prob = float(probs.squeeze().cpu().numpy()) # Skalar
return {
"is_review": prob >= threshold,
"probability": round(prob, 4)
}
###################################################################################
# Globale Variable, initial auf None
_nlp = None
def _load_spacy_model():
# Lädt das Modell nur, wenn es noch nicht da ist (Singleton Pattern).
global _nlp
if _nlp is not None:
return _nlp
import spacy
from spacy.language import Language
model = spacy.load("en_core_web_sm")
# Custom Component registrieren
if "custom_boundaries" not in model.pipe_names:
# Die Funktion muss lokal definiert oder global verfügbar sein.
# Da @Language.component global registriert, ist es sicherer,
# den Decorator-Namen zu prüfen.
if not Language.has_factory("custom_boundaries"):
@Language.component("custom_boundaries")
def set_custom_boundaries(doc):
for token in doc[:-1]:
if token.text in [";", ":"]:
doc[token.i + 1].is_sent_start = True
return doc
model.add_pipe("custom_boundaries", before="parser")
_nlp = model
return _nlp
###################################################################################
def text_to_sentences(text):
# Öffentliche Funktion. Lädt Spacy automatisch beim ersten Aufruf.
if not text:
return []
# Hier passiert die Magie: Laden erst bei Bedarf
nlp = _load_spacy_model()
doc = nlp(text)
sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]
return sentences
###################################################################################
def cleanup_tasting_note(text, model, tokenizer, device, threshold=0.5):
good_sentences = []
scored_sentences = []
has_review = False
has_noise = False
sentences = text_to_sentences(text)
for sentence in sentences:
if not sentence:
continue
result = predict_is_review(sentence, model, tokenizer, device)
score = round(result["probability"], 3)
is_note = score > threshold
scored_sentences.append({
"is_note": is_note,
"score": score,
"sentence": sentence
})
if is_note:
good_sentences.append(sentence)
has_review = True
else:
has_noise = True
new_text = " ".join(good_sentences)
# ✅ Status bestimmen
if has_review and has_noise:
review_status = "mixed"
elif has_review:
review_status = "review_only"
elif has_noise:
review_status = "noise_only"
else:
review_status = "noise_only" # leerer Text → effektiv kein Review
return new_text, scored_sentences, review_status
###################################################################################