|
|
import os |
|
|
import math |
|
|
from pathlib import Path |
|
|
import sys |
|
|
from contextlib import contextmanager |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence |
|
|
from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM |
|
|
from lightning.pytorch import seed_everything |
|
|
seed_everything(1986) |
|
|
|
|
|
CSV_PATH = Path("./Classifier_Weight/training_data_cleaned/binding_affinity/c-binding_with_openfold_scores.csv") |
|
|
|
|
|
OUT_ROOT = Path( |
|
|
"./Classifier_Weight/training_data_cleaned/binding_affinity" |
|
|
) |
|
|
|
|
|
|
|
|
WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D" |
|
|
WT_MAX_LEN = 1022 |
|
|
WT_BATCH = 32 |
|
|
|
|
|
|
|
|
SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all" |
|
|
TOKENIZER_VOCAB = "./Classifier_Weight/tokenizer/new_vocab.txt" |
|
|
TOKENIZER_SPLITS = "./Classifier_Weight/tokenizer/new_splits.txt" |
|
|
SMI_MAX_LEN = 768 |
|
|
SMI_BATCH = 128 |
|
|
|
|
|
|
|
|
TRAIN_FRAC = 0.80 |
|
|
RANDOM_SEED = 1986 |
|
|
AFFINITY_Q_BINS = 30 |
|
|
|
|
|
COL_SEQ1 = "seq1" |
|
|
COL_SEQ2 = "seq2" |
|
|
COL_AFF = "affinity" |
|
|
COL_F2S = "Fasta2SMILES" |
|
|
COL_REACT = "REACT_SMILES" |
|
|
COL_WT_IPTM = "wt_iptm_score" |
|
|
COL_SMI_IPTM = "smiles_iptm_score" |
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
QUIET = True |
|
|
USE_TQDM = False |
|
|
LOG_FILE = None |
|
|
|
|
|
def log(msg: str): |
|
|
if LOG_FILE is not None: |
|
|
Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True) |
|
|
with open(LOG_FILE, "a") as f: |
|
|
f.write(msg.rstrip() + "\n") |
|
|
if not QUIET: |
|
|
print(msg) |
|
|
|
|
|
def pbar(it, **kwargs): |
|
|
return tqdm(it, **kwargs) if USE_TQDM else it |
|
|
|
|
|
@contextmanager |
|
|
def section(title: str): |
|
|
log(f"\n=== {title} ===") |
|
|
yield |
|
|
log(f"=== done: {title} ===") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def has_uaa(seq: str) -> bool: |
|
|
return "X" in str(seq).upper() |
|
|
|
|
|
def affinity_to_class(a: float) -> str: |
|
|
|
|
|
if a >= 9.0: |
|
|
return "High" |
|
|
elif a >= 7.0: |
|
|
return "Moderate" |
|
|
else: |
|
|
return "Low" |
|
|
|
|
|
def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame: |
|
|
df = df.copy() |
|
|
|
|
|
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") |
|
|
df = df.dropna(subset=[COL_AFF]).reset_index(drop=True) |
|
|
|
|
|
df["affinity_class"] = df[COL_AFF].apply(affinity_to_class) |
|
|
|
|
|
try: |
|
|
df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop") |
|
|
strat_col = "aff_bin" |
|
|
except Exception: |
|
|
df["aff_bin"] = df["affinity_class"] |
|
|
strat_col = "aff_bin" |
|
|
|
|
|
rng = np.random.RandomState(RANDOM_SEED) |
|
|
|
|
|
df["split"] = None |
|
|
for _, g in df.groupby(strat_col, observed=True): |
|
|
idx = g.index.to_numpy() |
|
|
rng.shuffle(idx) |
|
|
n_train = int(math.floor(len(idx) * TRAIN_FRAC)) |
|
|
df.loc[idx[:n_train], "split"] = "train" |
|
|
df.loc[idx[n_train:], "split"] = "val" |
|
|
|
|
|
df["split"] = df["split"].fillna("train") |
|
|
return df |
|
|
|
|
|
def _summ(x): |
|
|
x = np.asarray(x, dtype=float) |
|
|
x = x[~np.isnan(x)] |
|
|
if len(x) == 0: |
|
|
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan} |
|
|
return { |
|
|
"n": int(len(x)), |
|
|
"mean": float(np.mean(x)), |
|
|
"std": float(np.std(x)), |
|
|
"p50": float(np.quantile(x, 0.50)), |
|
|
"p95": float(np.quantile(x, 0.95)), |
|
|
} |
|
|
|
|
|
def _len_stats(seqs): |
|
|
lens = np.asarray([len(str(s)) for s in seqs], dtype=float) |
|
|
if len(lens) == 0: |
|
|
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan} |
|
|
return { |
|
|
"n": int(len(lens)), |
|
|
"mean": float(lens.mean()), |
|
|
"std": float(lens.std()), |
|
|
"p50": float(np.quantile(lens, 0.50)), |
|
|
"p95": float(np.quantile(lens, 0.95)), |
|
|
} |
|
|
|
|
|
def verify_split_before_embedding( |
|
|
df2: pd.DataFrame, |
|
|
affinity_col: str, |
|
|
split_col: str, |
|
|
seq_col: str, |
|
|
iptm_col: str, |
|
|
aff_class_col: str = "affinity_class", |
|
|
aff_bins: int = 30, |
|
|
save_report_prefix: str | None = None, |
|
|
verbose: bool = False, |
|
|
): |
|
|
df2 = df2.copy() |
|
|
df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce") |
|
|
df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce") |
|
|
|
|
|
assert split_col in df2.columns, f"Missing split col: {split_col}" |
|
|
assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}" |
|
|
assert df2[affinity_col].notna().any(), "No valid affinity values after coercion." |
|
|
|
|
|
try: |
|
|
df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop") |
|
|
except Exception: |
|
|
df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str) |
|
|
|
|
|
tr = df2[df2[split_col] == "train"].reset_index(drop=True) |
|
|
va = df2[df2[split_col] == "val"].reset_index(drop=True) |
|
|
|
|
|
tr_aff = _summ(tr[affinity_col].to_numpy()) |
|
|
va_aff = _summ(va[affinity_col].to_numpy()) |
|
|
tr_len = _len_stats(tr[seq_col].tolist()) |
|
|
va_len = _len_stats(va[seq_col].tolist()) |
|
|
|
|
|
|
|
|
bin_ct = ( |
|
|
df2.groupby([split_col, "_aff_bin_dbg"]) |
|
|
.size() |
|
|
.groupby(level=0) |
|
|
.apply(lambda s: s / s.sum()) |
|
|
) |
|
|
tr_bins = bin_ct.loc["train"] |
|
|
va_bins = bin_ct.loc["val"] |
|
|
all_bins = tr_bins.index.union(va_bins.index) |
|
|
tr_bins = tr_bins.reindex(all_bins, fill_value=0.0) |
|
|
va_bins = va_bins.reindex(all_bins, fill_value=0.0) |
|
|
max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values))) |
|
|
|
|
|
msg = ( |
|
|
f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | " |
|
|
f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | " |
|
|
f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | " |
|
|
f"max_bin_diff={max_bin_diff:.4f}" |
|
|
) |
|
|
log(msg) |
|
|
|
|
|
if verbose and (not QUIET): |
|
|
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0) |
|
|
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0) |
|
|
print("\n[verbose] affinity_class counts:\n", class_ct) |
|
|
print("\n[verbose] affinity_class proportions:\n", class_prop.round(4)) |
|
|
|
|
|
if save_report_prefix is not None: |
|
|
out = Path(save_report_prefix) |
|
|
out.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
stats_df = pd.DataFrame([ |
|
|
{"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}}, |
|
|
{"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}}, |
|
|
]) |
|
|
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0) |
|
|
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index() |
|
|
|
|
|
stats_df.to_csv(out.with_suffix(".stats.csv"), index=False) |
|
|
class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022): |
|
|
embs = [] |
|
|
for i in pbar(range(0, len(seqs), batch_size)): |
|
|
batch = seqs[i:i + batch_size] |
|
|
inputs = tokenizer( |
|
|
batch, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
return_tensors="pt", |
|
|
) |
|
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
|
out = model(**inputs) |
|
|
h = out.last_hidden_state |
|
|
|
|
|
attn = inputs["attention_mask"].unsqueeze(-1) |
|
|
summed = (h * attn).sum(dim=1) |
|
|
denom = attn.sum(dim=1).clamp(min=1e-9) |
|
|
pooled = (summed / denom).detach().cpu().numpy() |
|
|
embs.append(pooled) |
|
|
|
|
|
return np.vstack(embs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022): |
|
|
tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt") |
|
|
tok = {k: v.to(DEVICE) for k, v in tok.items()} |
|
|
out = model(**tok) |
|
|
h = out.last_hidden_state[0] |
|
|
attn = tok["attention_mask"][0].bool() |
|
|
ids = tok["input_ids"][0] |
|
|
|
|
|
keep = attn.clone() |
|
|
if cls_id is not None: |
|
|
keep &= (ids != cls_id) |
|
|
if eos_id is not None: |
|
|
keep &= (ids != eos_id) |
|
|
|
|
|
return h[keep].detach().cpu().to(torch.float16).numpy() |
|
|
|
|
|
def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model): |
|
|
""" |
|
|
Expects df_split to have: |
|
|
- target_sequence (seq1) |
|
|
- sequence (binder seq2; WT binder) |
|
|
- label, affinity_class, COL_AFF, COL_WT_IPTM |
|
|
Saves a dataset where each row contains BOTH: |
|
|
- target_embedding (Lt,H), target_attention_mask, target_length |
|
|
- binder_embedding (Lb,H), binder_attention_mask, binder_length |
|
|
""" |
|
|
cls_id = tokenizer.cls_token_id |
|
|
eos_id = tokenizer.eos_token_id |
|
|
H = model.config.hidden_size |
|
|
|
|
|
features = Features({ |
|
|
"target_sequence": Value("string"), |
|
|
"sequence": Value("string"), |
|
|
"label": Value("float32"), |
|
|
"affinity": Value("float32"), |
|
|
"affinity_class": Value("string"), |
|
|
|
|
|
"target_embedding": HFSequence(HFSequence(Value("float16"), length=H)), |
|
|
"target_attention_mask": HFSequence(Value("int8")), |
|
|
"target_length": Value("int64"), |
|
|
|
|
|
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)), |
|
|
"binder_attention_mask": HFSequence(Value("int8")), |
|
|
"binder_length": Value("int64"), |
|
|
|
|
|
COL_WT_IPTM: Value("float32"), |
|
|
COL_AFF: Value("float32"), |
|
|
}) |
|
|
|
|
|
def gen_rows(df: pd.DataFrame): |
|
|
for r in pbar(df.itertuples(index=False), total=len(df)): |
|
|
tgt = str(getattr(r, "target_sequence")).strip() |
|
|
bnd = str(getattr(r, "sequence")).strip() |
|
|
|
|
|
y = float(getattr(r, "label")) |
|
|
aff = float(getattr(r, COL_AFF)) |
|
|
acls = str(getattr(r, "affinity_class")) |
|
|
|
|
|
iptm = getattr(r, COL_WT_IPTM) |
|
|
iptm = float(iptm) if pd.notna(iptm) else np.nan |
|
|
|
|
|
|
|
|
t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) |
|
|
b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) |
|
|
|
|
|
t_list = t_emb.tolist() |
|
|
b_list = b_emb.tolist() |
|
|
Lt = len(t_list) |
|
|
Lb = len(b_list) |
|
|
|
|
|
yield { |
|
|
"target_sequence": tgt, |
|
|
"sequence": bnd, |
|
|
"label": np.float32(y), |
|
|
"affinity": np.float32(aff), |
|
|
"affinity_class": acls, |
|
|
|
|
|
"target_embedding": t_list, |
|
|
"target_attention_mask": [1] * Lt, |
|
|
"target_length": int(Lt), |
|
|
|
|
|
"binder_embedding": b_list, |
|
|
"binder_attention_mask": [1] * Lb, |
|
|
"binder_length": int(Lb), |
|
|
|
|
|
COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan), |
|
|
COL_AFF: np.float32(aff), |
|
|
} |
|
|
|
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
|
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features) |
|
|
ds.save_to_disk(str(out_dir), max_shard_size="1GB") |
|
|
return ds |
|
|
|
|
|
def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled, |
|
|
smi_tok, smi_roformer): |
|
|
""" |
|
|
df_split must have: |
|
|
- target_sequence (seq1) |
|
|
- sequence (binder smiles string) |
|
|
- label, affinity_class, COL_AFF, COL_SMI_IPTM |
|
|
Saves rows with: |
|
|
target_embedding (Lt,Ht) from ESM |
|
|
binder_embedding (Lb,Hb) from PeptideCLM |
|
|
""" |
|
|
cls_id = wt_tokenizer.cls_token_id |
|
|
eos_id = wt_tokenizer.eos_token_id |
|
|
Ht = wt_model_unpooled.config.hidden_size |
|
|
|
|
|
Hb = getattr(smi_roformer.config, "hidden_size", None) |
|
|
if Hb is None: |
|
|
Hb = getattr(smi_roformer.config, "dim", None) |
|
|
if Hb is None: |
|
|
raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.") |
|
|
|
|
|
features = Features({ |
|
|
"target_sequence": Value("string"), |
|
|
"sequence": Value("string"), |
|
|
"label": Value("float32"), |
|
|
"affinity": Value("float32"), |
|
|
"affinity_class": Value("string"), |
|
|
|
|
|
"target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)), |
|
|
"target_attention_mask": HFSequence(Value("int8")), |
|
|
"target_length": Value("int64"), |
|
|
|
|
|
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)), |
|
|
"binder_attention_mask": HFSequence(Value("int8")), |
|
|
"binder_length": Value("int64"), |
|
|
|
|
|
COL_SMI_IPTM: Value("float32"), |
|
|
COL_AFF: Value("float32"), |
|
|
}) |
|
|
|
|
|
def gen_rows(df: pd.DataFrame): |
|
|
for r in pbar(df.itertuples(index=False), total=len(df)): |
|
|
tgt = str(getattr(r, "target_sequence")).strip() |
|
|
bnd = str(getattr(r, "sequence")).strip() |
|
|
|
|
|
y = float(getattr(r, "label")) |
|
|
aff = float(getattr(r, COL_AFF)) |
|
|
acls = str(getattr(r, "affinity_class")) |
|
|
|
|
|
iptm = getattr(r, COL_SMI_IPTM) |
|
|
iptm = float(iptm) if pd.notna(iptm) else np.nan |
|
|
|
|
|
|
|
|
t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN) |
|
|
t_list = t_emb.tolist() |
|
|
Lt = len(t_list) |
|
|
|
|
|
|
|
|
_, tok_list, mask_list, lengths = smiles_embed_batch_return_both( |
|
|
[bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN |
|
|
) |
|
|
b_emb = tok_list[0] |
|
|
b_list = b_emb.tolist() |
|
|
Lb = int(lengths[0]) |
|
|
b_mask = mask_list[0].astype(np.int8).tolist() |
|
|
|
|
|
yield { |
|
|
"target_sequence": tgt, |
|
|
"sequence": bnd, |
|
|
"label": np.float32(y), |
|
|
"affinity": np.float32(aff), |
|
|
"affinity_class": acls, |
|
|
|
|
|
"target_embedding": t_list, |
|
|
"target_attention_mask": [1] * Lt, |
|
|
"target_length": int(Lt), |
|
|
|
|
|
"binder_embedding": b_list, |
|
|
"binder_attention_mask": [int(x) for x in b_mask], |
|
|
"binder_length": int(Lb), |
|
|
|
|
|
COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan), |
|
|
COL_AFF: np.float32(aff), |
|
|
} |
|
|
|
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
|
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features) |
|
|
ds.save_to_disk(str(out_dir), max_shard_size="1GB") |
|
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_special_ids(tokenizer_obj): |
|
|
cand = [ |
|
|
getattr(tokenizer_obj, "pad_token_id", None), |
|
|
getattr(tokenizer_obj, "cls_token_id", None), |
|
|
getattr(tokenizer_obj, "sep_token_id", None), |
|
|
getattr(tokenizer_obj, "bos_token_id", None), |
|
|
getattr(tokenizer_obj, "eos_token_id", None), |
|
|
getattr(tokenizer_obj, "mask_token_id", None), |
|
|
] |
|
|
return sorted({x for x in cand if x is not None}) |
|
|
|
|
|
@torch.no_grad() |
|
|
def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length): |
|
|
tok = tokenizer_obj( |
|
|
batch_sequences, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
) |
|
|
input_ids = tok["input_ids"].to(DEVICE) |
|
|
attention_mask = tok["attention_mask"].to(DEVICE) |
|
|
|
|
|
outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask) |
|
|
last_hidden = outputs.last_hidden_state |
|
|
|
|
|
special_ids = get_special_ids(tokenizer_obj) |
|
|
valid = attention_mask.bool() |
|
|
if len(special_ids) > 0: |
|
|
sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long) |
|
|
if hasattr(torch, "isin"): |
|
|
valid = valid & (~torch.isin(input_ids, sid)) |
|
|
else: |
|
|
m = torch.zeros_like(valid) |
|
|
for s in special_ids: |
|
|
m |= (input_ids == s) |
|
|
valid = valid & (~m) |
|
|
|
|
|
valid_f = valid.unsqueeze(-1).float() |
|
|
summed = torch.sum(last_hidden * valid_f, dim=1) |
|
|
denom = torch.clamp(valid_f.sum(dim=1), min=1e-9) |
|
|
pooled = (summed / denom).detach().cpu().numpy() |
|
|
|
|
|
token_emb_list, mask_list, lengths = [], [], [] |
|
|
for b in range(last_hidden.shape[0]): |
|
|
emb = last_hidden[b, valid[b]] |
|
|
token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy()) |
|
|
li = emb.shape[0] |
|
|
lengths.append(int(li)) |
|
|
mask_list.append(np.ones((li,), dtype=np.int8)) |
|
|
|
|
|
return pooled, token_emb_list, mask_list, lengths |
|
|
|
|
|
def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length): |
|
|
pooled_all = [] |
|
|
token_emb_all = [] |
|
|
mask_all = [] |
|
|
lengths_all = [] |
|
|
|
|
|
for i in pbar(range(0, len(seqs), batch_size)): |
|
|
batch = seqs[i:i + batch_size] |
|
|
pooled, tok_list, m_list, lens = smiles_embed_batch_return_both( |
|
|
batch, tokenizer_obj, model_roformer, max_length |
|
|
) |
|
|
pooled_all.append(pooled) |
|
|
token_emb_all.extend(tok_list) |
|
|
mask_all.extend(m_list) |
|
|
lengths_all.extend(lens) |
|
|
|
|
|
return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all |
|
|
|
|
|
def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame): |
|
|
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME) |
|
|
wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval() |
|
|
|
|
|
|
|
|
tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist() |
|
|
tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist() |
|
|
|
|
|
wt_train_tgt_emb = wt_pooled_embeddings( |
|
|
tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN |
|
|
) |
|
|
wt_val_tgt_emb = wt_pooled_embeddings( |
|
|
tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN |
|
|
) |
|
|
|
|
|
|
|
|
train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)} |
|
|
val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)} |
|
|
return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
log(f"[INFO] DEVICE: {DEVICE}") |
|
|
OUT_ROOT.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with section("load csv + dedup"): |
|
|
df = pd.read_csv(CSV_PATH) |
|
|
for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]: |
|
|
if c in df.columns: |
|
|
df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x) |
|
|
|
|
|
|
|
|
DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT] |
|
|
df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True) |
|
|
|
|
|
print("Rows after dedup on", DEDUP_COLS, ":", len(df)) |
|
|
|
|
|
need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM] |
|
|
missing = [c for c in need if c not in df.columns] |
|
|
if missing: |
|
|
raise ValueError(f"Missing required columns: {missing}") |
|
|
|
|
|
|
|
|
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") |
|
|
|
|
|
|
|
|
with section("prepare wt/smiles subsets"): |
|
|
|
|
|
df_wt = df.copy() |
|
|
df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip() |
|
|
df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True) |
|
|
df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")] |
|
|
df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True) |
|
|
|
|
|
|
|
|
df_smi = df.copy() |
|
|
df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True) |
|
|
df_smi = df_smi[ |
|
|
pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna() |
|
|
].reset_index(drop=True) |
|
|
|
|
|
is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False) |
|
|
df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S]) |
|
|
df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip() |
|
|
df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")] |
|
|
df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True) |
|
|
|
|
|
log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)") |
|
|
|
|
|
|
|
|
with section("split wt and smiles separately"): |
|
|
df_wt2 = make_distribution_matched_split(df_wt) |
|
|
df_smi2 = make_distribution_matched_split(df_smi) |
|
|
|
|
|
|
|
|
wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv" |
|
|
smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv" |
|
|
df_wt2.to_csv(wt_split_csv, index=False) |
|
|
df_smi2.to_csv(smi_split_csv, index=False) |
|
|
log(f"Saved WT split meta: {wt_split_csv}") |
|
|
log(f"Saved SMILES split meta: {smi_split_csv}") |
|
|
|
|
|
verify_split_before_embedding( |
|
|
df2=df_wt2, |
|
|
affinity_col=COL_AFF, |
|
|
split_col="split", |
|
|
seq_col="wt_sequence", |
|
|
iptm_col=COL_WT_IPTM, |
|
|
aff_class_col="affinity_class", |
|
|
aff_bins=AFFINITY_Q_BINS, |
|
|
save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"), |
|
|
verbose=False, |
|
|
) |
|
|
verify_split_before_embedding( |
|
|
df2=df_smi2, |
|
|
affinity_col=COL_AFF, |
|
|
split_col="split", |
|
|
seq_col="smiles_sequence", |
|
|
iptm_col=COL_SMI_IPTM, |
|
|
aff_class_col="affinity_class", |
|
|
aff_bins=AFFINITY_Q_BINS, |
|
|
save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"), |
|
|
verbose=False, |
|
|
) |
|
|
|
|
|
|
|
|
def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame: |
|
|
out = df_in.copy() |
|
|
out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() |
|
|
out["sequence"] = out[binder_seq_col].astype(str).str.strip() |
|
|
out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce") |
|
|
out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce") |
|
|
out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce") |
|
|
out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True) |
|
|
return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]] |
|
|
|
|
|
wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM) |
|
|
smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True) |
|
|
wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True) |
|
|
smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True) |
|
|
smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"): |
|
|
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME) |
|
|
wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval() |
|
|
|
|
|
|
|
|
wt_train_tgt_emb = wt_pooled_embeddings( |
|
|
wt_train["target_sequence"].astype(str).str.strip().tolist(), |
|
|
wt_tok, wt_esm, |
|
|
batch_size=WT_BATCH, |
|
|
max_length=WT_MAX_LEN, |
|
|
).astype(np.float32) |
|
|
|
|
|
wt_val_tgt_emb = wt_pooled_embeddings( |
|
|
wt_val["target_sequence"].astype(str).str.strip().tolist(), |
|
|
wt_tok, wt_esm, |
|
|
batch_size=WT_BATCH, |
|
|
max_length=WT_MAX_LEN, |
|
|
).astype(np.float32) |
|
|
|
|
|
|
|
|
smi_train_tgt_emb = wt_pooled_embeddings( |
|
|
smi_train["target_sequence"].astype(str).str.strip().tolist(), |
|
|
wt_tok, wt_esm, |
|
|
batch_size=WT_BATCH, |
|
|
max_length=WT_MAX_LEN, |
|
|
).astype(np.float32) |
|
|
|
|
|
smi_val_tgt_emb = wt_pooled_embeddings( |
|
|
smi_val["target_sequence"].astype(str).str.strip().tolist(), |
|
|
wt_tok, wt_esm, |
|
|
batch_size=WT_BATCH, |
|
|
max_length=WT_MAX_LEN, |
|
|
).astype(np.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with section("WT pooled binder embeddings + save"): |
|
|
wt_train_emb = wt_pooled_embeddings( |
|
|
wt_train["sequence"].astype(str).str.strip().tolist(), |
|
|
wt_tok, wt_esm, |
|
|
batch_size=WT_BATCH, |
|
|
max_length=WT_MAX_LEN, |
|
|
).astype(np.float32) |
|
|
|
|
|
wt_val_emb = wt_pooled_embeddings( |
|
|
wt_val["sequence"].astype(str).str.strip().tolist(), |
|
|
wt_tok, wt_esm, |
|
|
batch_size=WT_BATCH, |
|
|
max_length=WT_MAX_LEN, |
|
|
).astype(np.float32) |
|
|
|
|
|
wt_train_ds = Dataset.from_dict({ |
|
|
"target_sequence": wt_train["target_sequence"].tolist(), |
|
|
"sequence": wt_train["sequence"].tolist(), |
|
|
"label": wt_train["label"].astype(float).tolist(), |
|
|
"target_embedding": wt_train_tgt_emb, |
|
|
"embedding": wt_train_emb, |
|
|
COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(), |
|
|
COL_AFF: wt_train[COL_AFF].astype(float).tolist(), |
|
|
"affinity_class": wt_train["affinity_class"].tolist(), |
|
|
}) |
|
|
|
|
|
wt_val_ds = Dataset.from_dict({ |
|
|
"target_sequence": wt_val["target_sequence"].tolist(), |
|
|
"sequence": wt_val["sequence"].tolist(), |
|
|
"label": wt_val["label"].astype(float).tolist(), |
|
|
"target_embedding": wt_val_tgt_emb, |
|
|
"embedding": wt_val_emb, |
|
|
COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(), |
|
|
COL_AFF: wt_val[COL_AFF].astype(float).tolist(), |
|
|
"affinity_class": wt_val["affinity_class"].tolist(), |
|
|
}) |
|
|
|
|
|
wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds}) |
|
|
wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled" |
|
|
wt_pooled_dd.save_to_disk(str(wt_pooled_out)) |
|
|
log(f"Saved WT pooled -> {wt_pooled_out}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with section("SMILES pooled binder embeddings + save"): |
|
|
smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS) |
|
|
smi_roformer = ( |
|
|
AutoModelForMaskedLM |
|
|
.from_pretrained(SMI_MODEL_NAME) |
|
|
.roformer |
|
|
.to(DEVICE) |
|
|
.eval() |
|
|
) |
|
|
|
|
|
smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both( |
|
|
smi_train["sequence"].astype(str).str.strip().tolist(), |
|
|
smi_tok, smi_roformer, |
|
|
batch_size=SMI_BATCH, |
|
|
max_length=SMI_MAX_LEN, |
|
|
) |
|
|
|
|
|
smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both( |
|
|
smi_val["sequence"].astype(str).str.strip().tolist(), |
|
|
smi_tok, smi_roformer, |
|
|
batch_size=SMI_BATCH, |
|
|
max_length=SMI_MAX_LEN, |
|
|
) |
|
|
|
|
|
smi_train_ds = Dataset.from_dict({ |
|
|
"target_sequence": smi_train["target_sequence"].tolist(), |
|
|
"sequence": smi_train["sequence"].tolist(), |
|
|
"label": smi_train["label"].astype(float).tolist(), |
|
|
"target_embedding": smi_train_tgt_emb, |
|
|
"embedding": smi_train_pooled.astype(np.float32), |
|
|
COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(), |
|
|
COL_AFF: smi_train[COL_AFF].astype(float).tolist(), |
|
|
"affinity_class": smi_train["affinity_class"].tolist(), |
|
|
}) |
|
|
|
|
|
smi_val_ds = Dataset.from_dict({ |
|
|
"target_sequence": smi_val["target_sequence"].tolist(), |
|
|
"sequence": smi_val["sequence"].tolist(), |
|
|
"label": smi_val["label"].astype(float).tolist(), |
|
|
"target_embedding": smi_val_tgt_emb, |
|
|
"embedding": smi_val_pooled.astype(np.float32), |
|
|
COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(), |
|
|
COL_AFF: smi_val[COL_AFF].astype(float).tolist(), |
|
|
"affinity_class": smi_val["affinity_class"].tolist(), |
|
|
}) |
|
|
|
|
|
smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds}) |
|
|
smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled" |
|
|
smi_pooled_dd.save_to_disk(str(smi_pooled_out)) |
|
|
log(f"Saved SMILES pooled -> {smi_pooled_out}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with section("WT unpooled paired embeddings + save"): |
|
|
wt_tok_unpooled = wt_tok |
|
|
wt_esm_unpooled = wt_esm |
|
|
|
|
|
wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled" |
|
|
wt_unpooled_dd = DatasetDict({ |
|
|
"train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train", |
|
|
wt_tok_unpooled, wt_esm_unpooled), |
|
|
"val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val", |
|
|
wt_tok_unpooled, wt_esm_unpooled), |
|
|
}) |
|
|
wt_unpooled_dd.save_to_disk(str(wt_unpooled_out)) |
|
|
log(f"Saved WT unpooled -> {wt_unpooled_out}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with section("SMILES unpooled paired embeddings + save"): |
|
|
smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled" |
|
|
smi_unpooled_dd = DatasetDict({ |
|
|
"train": build_smiles_unpooled_paired_dataset( |
|
|
smi_train, smi_unpooled_out / "train", |
|
|
wt_tok, wt_esm, |
|
|
smi_tok, smi_roformer |
|
|
), |
|
|
"val": build_smiles_unpooled_paired_dataset( |
|
|
smi_val, smi_unpooled_out / "val", |
|
|
wt_tok, wt_esm, |
|
|
smi_tok, smi_roformer |
|
|
), |
|
|
}) |
|
|
smi_unpooled_dd.save_to_disk(str(smi_unpooled_out)) |
|
|
log(f"Saved SMILES unpooled -> {smi_unpooled_out}") |
|
|
|
|
|
log(f"\n[DONE] All datasets saved under: {OUT_ROOT}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|