#!/usr/bin/env python # -*- coding: utf-8 -*- """ Title : maotao_inference.py.py project : minimind_RiboUTR Created by: julse Created on: 2025/10/23 16:49 des: TODO """ import sys import os import time import pandas as pd import numpy as np import sys import os import time import torch import torch.distributed as dist from model.tools import get_pretraining_args, find_unused_parameters from contextlib import nullcontext from train import sft_process_maotao, init_config, maotao ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? print('Setting running environment') def Logger(*content): if not ddp or dist.get_rank() == 0: print(*content) def init_distributed_mode(ddp=True): print("init distributed mode,ddp=",ddp) if not ddp: return global ddp_local_rank, DEVICE dist.init_process_group(backend="nccl") ddp_rank = int(os.environ["RANK"]) ddp_local_rank = int(os.environ["LOCAL_RANK"]) ddp_world_size = int(os.environ["WORLD_SIZE"]) DEVICE = f"cuda:{ddp_local_rank}" torch.cuda.set_device(DEVICE) print('init distributed mode, ddp_rank:', ddp_rank, 'ddp_local_rank:', ddp_local_rank, 'ddp_world_size:', ddp_world_size) return ddp_local_rank,DEVICE def inference(args): sft = maotao() if args.predict: task = args.task else: task = 'AA2CDS_data' device_type = "cuda" if "cuda" in args.device else "cpu" ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() ddp_local_rank, DEVICE = 0, "cuda:0" if ddp: print('init distributed mode') ddp_local_rank, DEVICE = init_distributed_mode(ddp=ddp) args.device = torch.device(DEVICE) Logger('args.device:', args.device) Logger('setting args', args) max_seq_len = 1200 args.seq_len = max_seq_len args.save_dir = os.path.join(args.out_dir) # os.system(f"rm -rf {args.save_dir}") # todo os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.out_dir, exist_ok=True) tokens_per_iter = args.batch_size * max_seq_len lm_config, tokenizer = init_config(args.arg_overrides['data'] + '/small_dict.txt', args.n_layers, max_seq_len) lm_config.use_moe = args.use_moe wandb_project = args.wandb_project '''3. benchmark downstream tasks''' prefix = 'TS' # with open(args.save_dir+'/benchmark_result.tsv','w') as f: # f.write('Project\tModel\tTask\tSPR\tPR\tMSE\tRMSE\tR2\tckp\tepoch\n') epochs = args.epochs args.out_dir = os.path.abspath(args.out_dir) os.makedirs(args.out_dir, exist_ok=True) model_dir = args.out_dir # 'exp_log/out_demo4/' model_dir = os.path.abspath(model_dir) data_dir = args.downstream_data_path # 'dataset/downstreamV4/' Logger(f'model_dir:{model_dir}') os.makedirs(model_dir, exist_ok=True) args.save_dir = os.path.abspath(args.save_dir) args.downstream_data_path = os.path.abspath(args.downstream_data_path) # args.codon_table_path = 'maotao_file/codon_table/codon_usage_{species}.csv' Logger('args.downstream_data_path:', args.downstream_data_path) out_ckp = args.save_dir + f'/AA2CDS.pth' out_ckp = os.path.abspath(out_ckp) # os.system(f"rm -rf {out_ckp}") # in_ckp = model_dir + '/AA2CDS.pth' in_ckp = args.mlm_pretrained_model_path ckp, final_metrics, _ = sft_process_maotao(max_seq_len=max_seq_len, ctx=ctx, ddp=ddp, ddp_local_rank=ddp_local_rank, args=args, ckp=in_ckp, out_ckp=out_ckp, lm_config=lm_config, tokenizer=tokenizer, Logger=Logger,task=task, sft=sft,require_ckp=True) if __name__ == '__main__': print('start', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) start = time.time() parser = get_pretraining_args() args = parser.parse_args() # args.downstream_data_path = 'example/out/tmp/AA2CDS_data/' # TS.csv # args.downstream_data_path = 'maotao_file/' # TS.csv args.task='AA2CDS_data' args.predict =True args.mlm_pretrained_model_path = 'checkpoint/AA2CDS.pth' args.out_dir = 'example/out_TR_TS' inference(args) print('stop', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) print('time', time.time() - start)