ClickBaitRaterCZ01 / model.py
LbejchJakub's picture
Upload 4 files
193fd12 verified
# =========================
# model.py
# =========================
import re
import torch
import numpy as np
import pandas as pd
from transformers import (
AutoModelForPreTraining,
AutoTokenizer,
pipeline,
)
import streamlit as st # Přidáme pro cachování
# =========================
# CONFIG (stejné jako u vás)
# =========================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
ELECTRA_MODEL = "Seznam/small-e-czech"
CLF_MODEL = "Stremie/xlm-roberta-base-clickbait"
RTD_CLICKBAIT_TH = 0.20
RTD_BORDERLINE_TH = 0.15
CLF_CLICK_TH = 0.65
CLF_NOT_TH = 0.35
COMB_CLICK_TH = 0.60
COMB_NOT_TH = 0.40
# =========================
# LOAD MODELS (s cachováním)
# =========================
# Použijeme @st.cache_resource, aby se modely načetly jen jednou
@st.cache_resource
def load_models():
"""Načte a vrátí oba modely a tokenizer."""
print("Načítám modely...")
disc = AutoModelForPreTraining.from_pretrained(ELECTRA_MODEL).to(DEVICE).eval()
tok = AutoTokenizer.from_pretrained(ELECTRA_MODEL)
clf = pipeline(
"text-classification",
model=CLF_MODEL,
device=0 if DEVICE == "cuda" else -1
)
# ---- Robust label mapping pro klasifikátor ----
id2label = getattr(clf.model.config, "id2label", {}) or {}
label_values_upper = {str(v).upper() for v in id2label.values()}
if not ({"CLICKBAIT", "NOT"} <= label_values_upper):
clf.model.config.id2label = {0: "NOT", 1: "CLICKBAIT"}
clf.model.config.label2id = {"NOT": 0, "CLICKBAIT": 1}
print("Modely načteny.")
return disc, tok, clf
# Všechny vaše ostatní funkce (rtd_token_scores_batch, classify_supervised, atd.)
# zde zkopírujte BEZE ZMĚN.
# ... (vložte sem zbytek funkcí z vašeho skriptu) ...
@torch.no_grad()
def rtd_token_scores_batch(texts, disc, tok, batch_size=32):
all_scores = []
for i in range(0, len(texts), batch_size):
enc = tok(texts[i:i+batch_size], return_tensors="pt", padding=True, truncation=True).to(DEVICE)
out = disc(**enc)
probs = torch.sigmoid(out.logits).detach().cpu().numpy()
all_scores.extend(probs)
return all_scores
def clickbait_score_rtd_from_probs(probs, k_top: int = 5) -> float:
core = probs[1:-1] if len(probs) >= 2 else probs
if core.size == 0: return 0.0
k = min(k_top, core.size)
topk = np.partition(core, -k)[-k:]
score = float(np.mean(topk))
return max(0.0, min(1.0, score))
def rtd_label_from_score(p: float) -> str:
if p >= RTD_CLICKBAIT_TH: return "CLICKBAIT"
if p >= RTD_BORDERLINE_TH: return "BORDERLINE"
return "NOT"
def _normalize_label_to_index(lbl, LABEL2ID):
if isinstance(lbl, int): return lbl
s = str(lbl)
if s in LABEL2ID: return LABEL2ID[s]
m = re.search(r"(\d+)$", s)
if m: return int(m.group(1))
return None
def classify_supervised(texts, clf):
ID2LABEL = clf.model.config.id2label
LABEL2ID = clf.model.config.label2id
sanitized = [str(t).strip() if pd.notna(t) else "" for t in texts]
outs = clf(sanitized, top_k=None, truncation=True, max_length=256)
results = []
for scores in outs:
prob_click, prob_not = 0.0, 0.0
for s in scores:
idx = _normalize_label_to_index(s["label"], LABEL2ID)
if idx is None: continue
name = ID2LABEL.get(idx, str(s["label"])).upper()
if name == "CLICKBAIT": prob_click = float(s["score"])
elif name == "NOT": prob_not = float(s["score"])
binary_label = "CLICKBAIT" if prob_click >= prob_not else "NOT"
if prob_click >= CLF_CLICK_TH: tri_label = "CLICKBAIT"
elif prob_click <= CLF_NOT_TH: tri_label = "NOT"
else: tri_label = "BORDERLINE"
clf_margin = abs(prob_click - prob_not)
results.append({
"clf_prob_clickbait": prob_click, "clf_prob_not": prob_not,
"clf_label": binary_label, "clf_label_3way": tri_label,
"clf_margin": clf_margin,
})
return results
# =========================
# HLAVNÍ FUNKCE PRO ZPRACOVÁNÍ
# =========================
def process_headlines(headlines: list[str], k_top: int = 5) -> pd.DataFrame:
"""Zpracuje seznam titulků a vrátí DataFrame s výsledky."""
if not headlines or all(s.isspace() for s in headlines):
return pd.DataFrame()
disc, tok, clf = load_models()
df = pd.DataFrame({"Titulek": headlines})
# RTD
rtd_probs_all = rtd_token_scores_batch(headlines, disc, tok, batch_size=32)
rtd_scores = [clickbait_score_rtd_from_probs(p, k_top=k_top) for p in rtd_probs_all]
rtd_labels = [rtd_label_from_score(p) for p in rtd_scores]
# Supervised
sup_rows = classify_supervised(headlines, clf)
df_sup = pd.DataFrame(sup_rows)
# Sestavení výsledků
df_out = df.copy()
df_out["rtd_score"] = rtd_scores
df_out["rtd_label"] = rtd_labels
df_out = pd.concat([df_out, df_sup], axis=1)
df_out["combined_score"] = (0.85 * df_out["clf_prob_clickbait"] + 0.15 * df_out["rtd_score"])
final_labels = []
for s in df_out["combined_score"]:
if s >= COMB_CLICK_TH: final_labels.append("CLICKBAIT")
elif s <= COMB_NOT_TH: final_labels.append("NOT")
else: final_labels.append("BORDERLINE")
df_out["final_label"] = final_labels
# Vybereme a přejmenujeme sloupce pro přehlednost
final_cols = {
"Titulek": "Titulek",
"final_label": "Výsledek",
"combined_score": "Kombinované skóre",
"clf_prob_clickbait": "Pravděpodobnost clickbaitu",
"rtd_score": "RTD skóre",
}
return df_out[final_cols.keys()].rename(columns=final_cols)