#!/usr/bin/env python # -*- coding: utf-8 -*- """ Title : maotao_predict.py.py project : minimind_RiboUTR Created by: julse Created on: 2025/10/23 16:02 des: TODO """ import os import time import pandas as pd from inference import inference from model.assemble_fragment import assemble_fragments from model.codon_attr import Codon from model.sliding_windows import process_nucleotide_sequences from model.tools import get_pretraining_args def check_path(dirout,file=False): if file:dirout = dirout.rsplit('/',1)[0] try: if not os.path.exists(dirout): print('make dir -p '+dirout) os.makedirs(dirout) except: print(f'{dirout} have been made by other process') def translate(nucleotide_seq): seq = nucleotide_seq.replace('T','U') amino_acid_seq = ''.join([Codon.CODON_TO_AA.get(seq[x:x+3],'_') for x in range(0,len(seq),3)]) return amino_acid_seq def process_inputs(fin=None,dirout= None,codon_table=None): # codon_table = '/Users/gz_julse/code/minimind_RiboUTR/maotao_file/codon_table/codon_usage_{species}.csv' # fin = '/Users/gz_julse/Data/maotao/2025_bio-os_data/Tests.xlsx' # dirout = f'{WDIR}/predict_web/' check_path(dirout) df = pd.read_excel(fin) data = df[['id', 'RefSeq_aa']] df = data species_list = """mouse,Ec,Sac,Pic,Human""".split(',') print(f'loading {len(data)} AA from {fin}\nprepare inputs for generating to CDS for expression in {species_list}') codon_instance = {species: Codon(codon_table.format(species=species), rna=False) for species in species_list} for i, species in enumerate(species_list): df['species'] = species df['cai_best_nn'] = df.apply(lambda x: codon_instance[x['species']].cai_opt_codon(x['RefSeq_aa']), axis=1) if i == 0: df.to_csv(dirout + '/TS.csv', mode='w', index=False, header=True) else: df.to_csv(dirout + '/TS.csv', mode='a', index=False, header=False) data = pd.read_csv(dirout + '/TS.csv') data['RefSeq_nn'] = data['cai_best_nn'] fragments_list = data.apply( lambda x: process_nucleotide_sequences(x['RefSeq_nn'], max_nn_length=1200, step=300, pad_char='_', meta_dict={'_id': x['id'], 'species': x['species']}), axis=1) expanded_data = pd.DataFrame([item for sublist in fragments_list for item in sublist]) expanded_data['truncated_aa'] = expanded_data['truncated_nn'].apply(lambda x: translate(x)) expanded_data = expanded_data.rename(columns={'truncated_nn': 'cai_best_nn'}) expanded_data.to_csv(dirout + '/TS.csv', mode='w', index=False, header=True) print(f'process {len(expanded_data)} data and saving to {dirout}/TS.csv') def process_result(fin=None,fpred=None,ftruncated=None,fout=None): # fin = '/Users/gz_julse/Data/maotao/2025_bio-os_data/Tests.xlsx' # fpred = '/Users/gz_julse/Data/maotao/2025_bio-os_data/predict_web/TS_pred.csv' # ftruncated = '/Users/gz_julse/Data/maotao/2025_bio-os_data/predict_web/TS.csv' # fout = '/Users/gz_julse/Data/maotao/2025_bio-os_data/predict_web/TS_assemble.xlsx' df = pd.read_excel(fin) df_pred = pd.read_csv(fpred) df_trun = pd.read_csv(ftruncated) df_info = df_pred.merge(df_trun) tmps = [] for (_id, species), data in df_info.groupby(by=['_id', 'species']): # print(_id) # if len(data) <40: continue seq = assemble_fragments(data) # seq = seq.replace('T','U') # aa = ''.join([Codon.CODON_TO_AA[seq[x:x+3]] for x in range(0,len(seq),3)]) # print('seq',seq) tmps.append([_id, species, seq]) df_tmp = pd.DataFrame(tmps, columns=['_id', 'species', 'seq']) df_tmp['species'] = df_tmp['species'].replace({ 'Ec': 'Escherichia coli', 'Human': 'Homo sapiens (Human)', 'Pic': 'Pichia angusta', 'Sac': 'Saccharomyces cerevisiae', 'mouse': 'Mus musculus (Mouse)' }) full_name = ['Homo sapiens (Human)','Mus musculus (Mouse)','Escherichia coli','Saccharomyces cerevisiae','Pichia angusta'] df_wide = df_tmp.pivot(index=['_id'], columns='species', values='seq') df_wide = df_wide.reset_index() # 将索引转回列 df_wide['RefSeq_aa_translate'] = df_wide['Homo sapiens (Human)'].apply( lambda x: ''.join([Codon.CODON_TO_AA[x.replace('T', 'U')[i:i + 3]] for i in range(0, len(x), 3)])) df_wide = df_wide.rename(columns={'_id': 'id'}) df_wide = df[['id', 'RefSeq_aa']].merge(df_wide, on=['id'])[['id', 'RefSeq_aa'] + full_name] # if len(df_wide[df_wide['RefSeq_aa']!=df_wide['RefSeq_aa_translate']]):print('wrongly translated') df_wide.to_excel(fout, index=False, engine='openpyxl') def predict(fin,dirout): '''prepare data''' # codon_table = '/Users/gz_julse/code/minimind_RiboUTR/maotao_file/codon_table/codon_usage_{species}.csv' # fin = '/Users/gz_julse/Data/maotao/2025_bio-os_data/Tests.xlsx' # dirout = f'{WDIR}/predict_web/' parser = get_pretraining_args() args = parser.parse_args() # config parameters # args.downstream_data_path = 'maotao_file/' # args.predict =True # args.out_dir = 'maotao_exp/test' # args.task = 'AA2CDS_data' # args.mlm_pretrained_model_path = args.out_dir + '/AA2CDS.pth' tmp_dir = dirout+'/tmp/' # os.system(f'rm -rf {tmp_dir}') check_path(tmp_dir) args.downstream_data_path = tmp_dir args.predict =True args.out_dir = tmp_dir args.task = 'AA2CDS_data/' args.mlm_pretrained_model_path = 'checkpoint/AA2CDS.pth' WDIR = os.path.join(args.downstream_data_path,args.task) check_path(WDIR) # fin = '/Users/gz_julse/Data/maotao/2025_bio-os_data/Tests.xlsx' fpred = f'{WDIR}/TS_pred.csv' ftruncated = f'{WDIR}/TS.csv' fout = f'{dirout}/Tests.xlsx' '''process inputs''' process_inputs(fin=fin, dirout=os.path.dirname(fpred), codon_table=args.codon_table_path) '''predict''' inference(args) # '''assemble''' process_result(fin=fin,fpred=fpred, ftruncated=ftruncated, fout=fout) if __name__ == '__main__': print('start', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) start = time.time() # fin = '/Users/gz_julse/Data/maotao/2025_bio-os_data/Tests.xlsx' '''round 1''' # fin = 'example/Tests.xlsx' # # dirout = 'maotao_exp/test/' # dirout = os.path.abspath('example/out') # os.system(f'rm -rf {dirout}') # # --limit=320 --batch_size=12 --epoch=2 --out_dir=maotao_exp/test --learning_rate=0.000001 --predict --debug # predict(fin,dirout) # os.system(f'cp {dirout}/Tests.xlsx Tests.xlsx') '''round2 for experiment''' fin = 'example/Tests_round3.xlsx' # dirout = 'maotao_exp/test/' dirout = os.path.abspath('example/out_round3') os.system(f'rm -rf {dirout}') predict(fin,dirout) print('stop', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) print('time', time.time() - start)