#!/usr/bin/env python # -*- coding: utf-8 -*- """ Title : ray_search.py project : minimind_RiboUTR Created by: julse Created on: 2025/9/15 16:58 des: https://docs.ray.io/en/latest/tune/index.html """ import random from collections import defaultdict import sys import os import time import pandas as pd import numpy as np from model.codon_tables import AA_str from model.model_exp import MiniMindLM_Maotao from utils.ernie_rna.dataset import MaotaoDataset import sys import os import pandas as pd from fairseq.data import Dictionary username = os.environ['HOME'] import argparse import time import math import warnings # 获取脚本所在的目录 script_dir = os.path.dirname(os.path.abspath(__file__)) # 切换到脚本所在的目录 os.chdir(script_dir) import torch import torch.nn.functional as F import torch.distributed as dist from torch import optim, nn from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler, random_split from contextlib import nullcontext import torch.nn.functional as F from model.model_ribo import MiniMindLM from model.LMConfig import LMConfig, LMaoTaoConfig from model.tools import EarlyStopping, get_pretraining_args, init_config, ddp_broadcast_early_stopping,compute_metrics_dict from utils.ernie_rna.dataset_dst import RNADataset from src.utils import load_pretrained_ernierna warnings.filterwarnings('ignore') ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? print('Setting running environment') def copy_current_code(path): if not ddp or ddp_local_rank == 0: os.makedirs(path, exist_ok=True) os.system(f'cp -r {script_dir}/*.py {path}') os.system(f'cp -r {script_dir}/model {path}') os.system(f'cp -r {script_dir}/utils {path}') with open(os.path.join(path, 'run.sh'), 'w') as f: gpu_count = torch.cuda.device_count() f.write('#!/bin/bash\n') cuda_visible = os.environ.get('CUDA_VISIBLE_DEVICES', 'NOT_SET') f.write(f'# CUDA_VISIBLE_DEVICES={cuda_visible}\n') f.write(f'# RANK={os.environ.get("RANK", -1)}\n') f.write(f'# LOCAL_RANK={os.environ.get("LOCAL_RANK", -1)}\n') f.write(f'# WORLD_SIZE={os.environ.get("WORLD_SIZE", -1)}\n') f.write('cd '+os.path.abspath(os.path.dirname(os.path.abspath(__file__)))+'\n') f.write(' \\\n'.join([str(sys.executable)]+sys.argv)) def Logger(*content): if not ddp or dist.get_rank() == 0: print(*content) def init_distributed_mode(ddp=True): print("init distributed mode,ddp=",ddp) if not ddp: return global ddp_local_rank, DEVICE dist.init_process_group(backend="nccl") ddp_rank = int(os.environ["RANK"]) ddp_local_rank = int(os.environ["LOCAL_RANK"]) ddp_world_size = int(os.environ["WORLD_SIZE"]) DEVICE = f"cuda:{ddp_local_rank}" torch.cuda.set_device(DEVICE) print('init distributed mode, ddp_rank:', ddp_rank, 'ddp_local_rank:', ddp_local_rank, 'ddp_world_size:', ddp_world_size) return ddp_local_rank,DEVICE def all_gather(tensor, world_size): tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)] torch.distributed.all_gather(tensor_list, tensor) return torch.cat(tensor_list) class DotDefaultDict(defaultdict): def __getattr__(self, name): if name in self: return self[name] # 保留 defaultdict 的默认值生成特性 return super().__getattribute__(name) __setattr__ = defaultdict.__setitem__ __delattr__ = defaultdict.__delitem__ def metric_monitor(all_preds, all_labels, epoch, prefix,start_time,ddp=None,wandb=None,loss_fct=None,Logger=None,cls='regression'): # 合并当前进程的结果 if isinstance(all_preds, list): all_preds = torch.cat(all_preds) all_labels = torch.cat(all_labels) '''数据在dataloader上就处理好了,TR分布到不同卡上,VL和TS在每张卡都跑''' # # 跨进程收集数据 # Logger('before 跨进程收集数据,', len(all_labels),prefix,ddp) # if ddp: # # world_size = dist.get_world_size() # all_preds = all_gather(all_preds, world_size) # all_labels = all_gather(all_labels, world_size) # Logger('after 跨进程收集数据,',len(all_labels)) loss_mask = all_labels != 1 all_preds = all_preds[loss_mask] all_labels = all_labels[loss_mask] epoch_loss = loss_fct(all_preds, all_labels) # print(epoch_loss) all_preds = all_preds.cpu().numpy() all_labels = all_labels.cpu().numpy() epoch_loss = epoch_loss.cpu().item() ans = compute_metrics_dict(all_preds, all_labels,cls=cls) # ans = compute_metrics_dict(np.array(all_preds), np.array(all_labels),cls='binary') wandb_ans = dict(zip([f"{prefix}_{k}" for k in ans.keys()], ans.values())) wandb_ans[f"{prefix}_epoch_loss"] = epoch_loss if wandb: wandb.log(wandb_ans) # Logger(f"Epoch {epoch} - {prefix} Loss: {epoch_loss:.4f} "+', '.join([f"{k}: {v:.4f}" for k,v in ans.items()])) # Logger(f"{prefix} time: {time.time() - start_time:.2f}s") return wandb_ans class maotao(): def step_loss(self,target_nn=None,tgt_te=None, masked_logits_list=None,nn_prob=None, res=None,loss_fct=None,loss_mse=None,args=None): loss_mask = target_nn != 1 tgt_te = tgt_te.view(-1) te_loss = loss_mse(res.te, tgt_te) # res.logits = res.logits + nn_prob + masked_logits # # frame 1,2 is the best cai loss = torch.tensor(0, dtype=torch.float32, device=args.device) for i in range(masked_logits_list.size(1)): masked_logits = masked_logits_list[:, i, ...] loss += self.calculate_loss(res.logits + nn_prob + masked_logits, loss_mask, target_nn, loss_fct) loss += self.calculate_loss(res.logits, loss_mask, target_nn, loss_fct) # res.logits = F.softmax(res.logits, dim=-1) # 数值稳定的,因为它内部使用了数学等价但数值更稳定的实现方式 # cds_start_region_loss, te_start_region_loss, cds_end_region_loss, te_end_region_loss = 0, 0, 0, 0 # loss = sum_loss_model(ans) loss += loss + te_loss + res.aux_loss return loss def forward_step(self,model=None,src_data=None,twod_data=None,aa_idx=None,continuous_features=None,species_features=None,truncated_features=None,target_nn=None,tgt_te=None,masked_logits_list=None,nn_prob=None,loss_fct=None,loss_mse=None,args=None): res = model(input_ids=src_data, twod_tokens=twod_data, aa_idx=aa_idx, continuous_features=continuous_features, species_features=species_features, truncated_features=truncated_features, # targets_nn=target_nn,targets_te=tgt_te ) # find_unused_parameters(model, res.te) res.te = res.te.view(-1) nn_prob = torch.masked_fill(nn_prob, mask=nn_prob == 0, value=float('-inf')) masked_logits = masked_logits_list[:,-1,...] res.logits = res.logits + nn_prob ans = dict() step_loss = self.step_loss(target_nn=target_nn,tgt_te=tgt_te, masked_logits_list=masked_logits_list,nn_prob=nn_prob, res=res,loss_fct=loss_fct,loss_mse=loss_mse,args=args) ans.update({'loss':step_loss}) res.logits = res.logits + masked_logits ans.update({'res':res}) return ans def train_epoch(self,model=None,wandb=None,ddp=None,train_loader=None,optimizer=None, epoch=None,prefix="TR",loss_fct=None,loss_mse=None,args=None,scaler=None,Logger=None,lr_scheduler=None): model.train() epoch_loss = 0 # all_preds, all_labels = [], [] all_preds, all_labels = defaultdict(list), defaultdict(list) start_time = time.time() ans = {} for step,(src_data, twod_data, aa_idx,continuous_features, species_features, truncated_features, \ target_nn, target, masked_logits_list,nn_prob,maotao_id) in enumerate(train_loader): src_data = src_data.to(args.device) twod_data = twod_data.to(args.device) aa_idx = aa_idx.to(args.device) continuous_features = continuous_features.to(args.device) species_features = species_features.to(args.device) truncated_features = truncated_features.to(args.device) masked_logits_list = masked_logits_list.to(args.device) # [12, 1200, 32] nn_prob = nn_prob.to(args.device) # [12, 1200, 32] target_nn = target_nn.to(args.device) tgt_te = target.to(args.device) with torch.cuda.amp.autocast(enabled=scaler.is_enabled()): # torch.cuda.amp 只能在 CUDA 设备上使用[ results = self.forward_step(model=model,src_data=src_data, twod_data=twod_data,aa_idx=aa_idx, continuous_features=continuous_features, species_features=species_features, truncated_features=truncated_features, target_nn=target_nn,tgt_te=tgt_te, masked_logits_list=masked_logits_list, nn_prob=nn_prob,loss_fct=loss_fct, loss_mse=loss_mse,args=args) res, loss = results['res'], results['loss'] scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() # Logger('Epoch:[{}/{}][step:{}/{}] loss:{:.4f} lr:{:.6f}'.format(epoch, args.epochs, step, len(train_loader), loss.item(), lr)) Logger('Epoch:[{}/{}][batch:{}/{}] loss:{:.4f}'.format(epoch, args.epochs, step, len(train_loader), loss.item())) epoch_loss += loss.item() # masked_logits = masked_logits_list[:, 2, ...] # res.logits = res.logits +nn_prob+masked_logits all_preds['logits'].append(res.logits.detach()) all_preds['te'].append(res.te.detach()) all_preds['aux_loss'].append(torch.tensor(res.aux_loss, dtype=torch.float32).reshape(1, 1).to(args.device)) all_labels['logits'].append(target_nn.detach()) all_labels['te'].append(tgt_te.detach()) if step%100==1: ans = {} # ans = metric_monitor(all_preds['logits'], all_labels['logits'], epoch, prefix+'_logits',start_time,ddp=ddp,wandb=wandb,loss_fct=loss_fct,Logger=Logger,cls='binary') ans.update(metric_monitor(all_preds['logits'], all_labels['logits'], epoch, prefix+'_codon',start_time,ddp=ddp,wandb=wandb,loss_fct=loss_fct,Logger=Logger,cls='identity')) # ans.update(metric_monitor(all_preds['te'], all_labels['te'], epoch, prefix+'_cai',start_time,ddp=ddp,wandb=wandb,loss_fct=loss_mse,Logger=Logger,cls='regression')) ans.update({f'{prefix}_loss':epoch_loss}) all_preds, all_labels = defaultdict(list), defaultdict(list) epoch_loss = epoch_loss / len(train_loader) if len(train_loader) > 0 else 0 Logger('\n'+'#'*10+f' {prefix}_loss: {epoch_loss:.4f} '+'#'*10+'\n') return ans def Maotao_DataLoader(self,file_path, args, tokenizer, data_tag='TR', ddp=None, Logger=None, returnid=None): if args.debug: args.limit = 320 train_ds = MaotaoDataset(file_path, tokenizer, args=args, limit=args.limit, seq_len=args.seq_len,returnid=returnid,codon_table_path=args.codon_table_path) train_sampler = DistributedSampler(train_ds) if ddp else None drop_last = True if ddp and data_tag == 'TR' else False if data_tag == 'TR': train_loader = DataLoader( train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=drop_last, # 以避免各卡处理的批次数量不同。 shuffle=False, num_workers=args.num_workers, sampler=train_sampler # 验证集不需要 ) else: train_loader = DataLoader( train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=drop_last, # 以避免各卡处理的批次数量不同。 shuffle=False, num_workers=args.num_workers ) Logger(f'loading data from ', file_path, args.seq_len, args.column, args.label, len(train_loader.dataset)) # 验证集(VL)没有使用 DistributedSampler,导致所有GPU都处理完整的验证集,而不是分布到不同卡上。 return train_loader def load_data(self,args=None,task='sft',tokenizer=None,ddp=None,Logger=None): # csv_path = os.path.join(os.path.dirname(__file__),args.downstream_data_path, task, "{}.csv") csv_path = os.path.join(args.downstream_data_path, task, "{}.csv") check_file_flag = [os.access(csv_path.format(tag),os.F_OK) for tag in ['TR', 'VL', 'TS']] loaders = [] for flag,tag in zip(check_file_flag,['TR', 'VL', 'TS']): if flag: train_loader = self.Maotao_DataLoader(os.path.join(csv_path.format(tag)), args, tokenizer, data_tag=tag, ddp=ddp, Logger=Logger) # --wandb_project=IRES_circle else: print(f'Warning: {tag} data not found, skip it.',csv_path.format(tag)) train_loader = None loaders.append(train_loader) assert check_file_flag[-1],f'请检查数据集路径是否正确,至少需要一个TS数据集,{os.path.abspath(csv_path)},{csv_path.format("TS")}' return loaders def calculate_loss(self,logits, loss_mask, tgt_data, loss_ce): # print('tgt_data',tgt_data.shape) # [5, 30] # print('res.logits',res.logits.shape) # [5, 30, 10] # print('loss_mask',loss_mask.shape) # [5, 30] loss_mask = loss_mask == 1 # print('torch.masked_select(res.logits, loss_mask==1).view(-1, res.logits.size(-1))',torch.masked_select(res.logits, loss_mask==1).view(-1, res.logits.size(-1)).shape) # print('torch.masked_select(tgt_data, loss_mask==1).view(-1)',torch.masked_select(tgt_data, loss_mask==1).view(-1).shape.shape) # logits = torch.softmax(logits,dim=-1) # logits = F.softmax(logits, dim=-1) # 数值稳定的,因为它内部使用了数学等价但数值更稳定的实现方式 loss = loss_ce( torch.masked_select(logits, loss_mask.unsqueeze(-1).repeat(1, 1, logits.size(-1))).view(-1,logits.size( -1)), # [150, 10] torch.masked_select(tgt_data, loss_mask).view(-1) ).mean() # 2.5530 return loss def evaluate(self,model, data_loader,stage='epoch', prefix="VL",args=None,loss_fct=None, loss_mse=None,wandb=None,ddp=None,ctx=None,Logger=None,fpred_out=None,tokenizer=None): start_time = time.time() epoch_loss = 0 model.eval() # all_preds, all_labels = [], [] all_preds, all_labels = defaultdict(list), defaultdict(list) all_num = args.batch_size * len(data_loader) with torch.no_grad(): for idx,(src_data, twod_data, aa_idx, continuous_features, species_features, truncated_features, \ target_nn, target, masked_logits_list, nn_prob, maotao_id) in enumerate(data_loader): src_data = src_data.to(args.device) twod_data = twod_data.to(args.device) aa_idx = aa_idx.to(args.device) continuous_features = continuous_features.to(args.device) species_features = species_features.to(args.device) truncated_features = truncated_features.to(args.device) masked_logits_list = masked_logits_list.to(args.device) # [12, 1200, 32] # masked_logits = masked_logits.to(args.device) # [12, 1200, 32] nn_prob = nn_prob.to(args.device) # [12, 1200, 32] target_nn = target_nn.to(args.device) tgt_te = target.to(args.device) results = self.forward_step(model=model,src_data=src_data, twod_data=twod_data,aa_idx=aa_idx, continuous_features=continuous_features, species_features=species_features, truncated_features=truncated_features, target_nn=target_nn,tgt_te=tgt_te, masked_logits_list=masked_logits_list, nn_prob=nn_prob,loss_fct=loss_fct, loss_mse=loss_mse,args=args) res, loss = results['res'], results['loss'] epoch_loss += loss.item() all_preds['logits'].append(res.logits.detach()) all_preds['te'].append(res.te.reshape(-1).detach()) all_preds['aux_loss'].append(torch.tensor(res.aux_loss, dtype=torch.float32).reshape(1, 1).to(args.device)) all_labels['logits'].append(target_nn) all_labels['te'].append(tgt_te.reshape(-1)) if args.predict and (not ddp or dist.get_rank() == 0): pred_logis = res.logits.detach() pred_te = res.te.reshape(-1).detach() for idj,(_id, logits, nn, te) in enumerate(zip(maotao_id, pred_logis, target_nn, pred_te)): pred_cai = te.item() # logits.argmax(1) # 只有一种结果,改成不同seed不同结果 temperature = 0.8 probs = torch.softmax(logits / temperature, dim=-1) probs = torch.nan_to_num(probs, nan=1e-9) tokens = torch.multinomial(probs, num_samples=1) tokens = tokens.squeeze(-1) tokens_hard = logits.argmax(1) pred_nn = ''.join([tokenizer.symbols[x] for x,y in zip(tokens.cpu().numpy(), tokens_hard.cpu().numpy()) if y]).replace( 'U', 'T') ans = metric_monitor(logits, nn, stage, '' + '', start_time, ddp=ddp, wandb=wandb, loss_fct=loss_fct, Logger=Logger, cls='identity') df = pd.DataFrame([ans]) df['maotao_id'] = _id df['pred_cai'] = pred_cai df['pred_nn'] = pred_nn if os.access(fpred_out, os.F_OK): df.to_csv(fpred_out, mode='a', header=False, index=False) else: df.to_csv(fpred_out, mode='w', header=True, index=False) Logger(f'{args.batch_size*idx+idj}/{data_loader.dataset.__len__()},fpred_out:{fpred_out}') epoch_loss /=len(data_loader) if (wandb is not None) and (not ddp or dist.get_rank() == 0): wandb.log({f'{prefix}_epoch_loss': epoch_loss}) ans = metric_monitor(all_preds['logits'], all_labels['logits'], stage, prefix + '_logits', start_time, ddp=ddp, wandb=wandb, loss_fct=loss_fct, Logger=Logger, cls='binary') ans.update( metric_monitor(all_preds['logits'], all_labels['logits'], stage, prefix + '_codon', start_time, ddp=ddp, wandb=wandb, loss_fct=loss_fct, Logger=Logger, cls='identity')) if not args.predict:ans.update( metric_monitor(all_preds['te'], all_labels['te'], stage, prefix + '_cai', start_time, ddp=ddp, wandb=wandb, loss_fct=loss_mse, Logger=Logger, cls='regression')) ans.update({f'{prefix}_loss': epoch_loss}) wandb_ans = ans if not ddp or dist.get_rank() == 0: if wandb: wandb.log(wandb_ans) # Logger( # f"Metrics - {prefix} " + ', '.join([f"{k}: {v:.4f}" for k, v in wandb_ans.items()])) return wandb_ans # def predict(self,model, data_loader,stage='epoch', prefix="TS",args=None,loss_fct=None,loss_mse=None,wandb=None,ddp=None,ctx=None,Logger=None,fpred_out=None): # start_time = time.time() # model.eval() # with torch.no_grad(): # for src_data, twod_data, aa_idx,continuous_features, species_features, truncated_features, \ # target_nn, target, masked_logits_list,nn_prob,maotao_id in data_loader: # src_data = src_data.to(args.device) # twod_data = twod_data.to(args.device) # aa_idx = aa_idx.to(args.device) # continuous_features = continuous_features.to(args.device) # species_features = species_features.to(args.device) # truncated_features = truncated_features.to(args.device) # # masked_logits_list = masked_logits_list.to(args.device) # [12, 1200, 32] # # masked_logits = masked_logits.to(args.device) # [12, 1200, 32] # nn_prob = nn_prob.to(args.device) # [12, 1200, 32] # # results = self.forward_step(model=model,src_data=src_data, # twod_data=twod_data,aa_idx=aa_idx, # continuous_features=continuous_features, # species_features=species_features, # truncated_features=truncated_features, # target_nn=None,tgt_te=None, # masked_logits_list=masked_logits_list, # nn_prob=nn_prob,loss_fct=loss_fct, # loss_mse=loss_mse,args=args) # # res = results['res'] # # fpred_out = args.out_dir + f'/{prefix}_pred.csv' # if not ddp or dist.get_rank() == 0: # pred_logis = res.logits.detach() # pred_te = res.te.reshape(-1).detach() # for _id,logits,nn,te in zip(maotao_id,pred_logis,target_nn,pred_te): # pred_cai = te.item() # pred_nn = ''.join([tokenizer.symbols[x] for x in logits.argmax(1).cpu().numpy() if x]).replace('U','T') # ans = metric_monitor(logits, nn, stage, '' + '', # start_time, ddp=ddp, # wandb=wandb, loss_fct=loss_fct, Logger=Logger, cls='identity') # df = pd.DataFrame([ans]) # df['maotao_id'] = _id # df['pred_cai'] = pred_cai # df['pred_nn'] = pred_nn # if os.access(fpred_out, os.F_OK): # df.to_csv(fpred_out, mode='a', header=False, index=False) # else: # df.to_csv(fpred_out, mode='w', header=True, index=False) # # Logger(fpred_out) # return pred_results def init_model(args,ckp = f'./out/full_dist_256_epoch.pth',lm_config=None,tokenizer=None,Logger=None,require_ckp=False): if args.debug:lm_config.n_layers = 1 model = MiniMindLM_Maotao(lm_config) print(model) print(lm_config) print('ckp=',ckp) if ckp is not None and os.access(ckp, os.F_OK) and os.path.getsize(ckp) > 0: print('loading model from', ckp) state_dict = torch.load(ckp, map_location=args.device) model.load_state_dict(state_dict, strict=False) print(f'finetune ({args.finetune}) from, {os.path.abspath(ckp)}') else: if require_ckp: print('ckp') exit('not found model'+ckp) # exit('not found model,'+ckp) # print('learning from scratch') if args.finetune: for name, value in model.named_parameters(): if 'layers' in name: value.requires_grad = False print(name, value.numel(), value.requires_grad) else: for name, value in model.named_parameters(): print(name, value.numel(), value.requires_grad) print( f'LLM参数量 训练/总计:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万 / {sum(p.numel() for p in model.parameters()) / 1e6:.3f} 百万') print( f'LLM参数量 训练/总计:{sum(p.numel() for p in model.parameters() if p.requires_grad)} / {sum(p.numel() for p in model.parameters())}') model = model.to(args.device) return model, tokenizer,lm_config def save_metrics(ans, epoch, out_dir, filename='history_metrics.csv'): if not ddp or ddp_local_rank == 0: df = pd.DataFrame([ans]) df['epoch'] = epoch df['path'] = out_dir if epoch == 0: df.to_csv(os.path.join(out_dir, filename), mode='w') elif epoch == 'TS': df = df.T df.to_csv(os.path.join(out_dir, filename), mode='w') else: df.to_csv(os.path.join(out_dir, filename), mode='a', header=False) def sft_process_maotao(max_seq_len=-1,ctx=None,ddp=False,ddp_local_rank=0,args=None,ckp=None,out_ckp=None,lm_config=None, tokenizer=None,Logger=None,task=None,seq_pkl_path=None,sft=None,require_ckp=False): print('3. fine-tune on downstream tasks...') print('start', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) start = time.time() # if args.predict:ckp = out_ckp # model, tokenizer = sft.init_model(args=args,ddp=ddp,ddp_local_rank=ddp_local_rank,Logger=Logger,ckp=ckp,lm_config=lm_config,tokenizer=tokenizer) model, tokenizer, lm_config_student = init_model(args,ckp=ckp,lm_config=lm_config,tokenizer=tokenizer,Logger=Logger,require_ckp=require_ckp) if ddp: model._ddp_params_and_buffers_to_ignore = {"pos_cis"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank],find_unused_parameters=True) train_loader, val_loader, test_loader = sft.load_data(args=args,tokenizer=tokenizer, task=task,ddp=ddp,Logger=Logger) if train_loader: args.wandb_run_name = f"{args.wandb_project}_EP_{args.epochs}_BS_{args.batch_size}_LR_{args.learning_rate}_FT_{args.finetune}_TR_{len(train_loader.dataset)}" else: args.wandb_run_name = f"{args.wandb_project}_EP_{args.epochs}_BS_{args.batch_size}_LR_{args.learning_rate}_FT_{args.finetune}_TR_{0}" if args.use_wandb and (not ddp or ddp_local_rank == 0): import wandb wandb.init(project=args.wandb_project, name=args.wandb_run_name, mode="offline",config=args,anonymous="allow") Logger(f'init wandb with id {wandb.run.id}') else: wandb = None loss_ce = nn.CrossEntropyLoss() loss_mse = nn.MSELoss() if args.predict: fpred_out = args.out_dir + f'/{args.task}/TS_pred.csv' os.system(f'rm -f {fpred_out}') os.makedirs(os.path.dirname(fpred_out), exist_ok=True) Logger(f'predicting {fpred_out} with {os.path.abspath(ckp)}') final_metrics = sft.evaluate(model, test_loader, stage='predict', prefix='TS', args=args, loss_fct=loss_ce,loss_mse=loss_mse, wandb=wandb, ddp=ddp, ctx=ctx, Logger=Logger,fpred_out=fpred_out,tokenizer=tokenizer) return ckp, final_metrics, -1 final_metrics = None if not ddp or ddp_local_rank == 0: Logger('predict and saving to file...',os.path.join(args.out_dir, f"zeroshot_metrics.csv")) for tag,loader in zip(['TR','VL','TS'],[train_loader,val_loader,test_loader]): if tag=='TR':continue if loader is not None: fpred_out = args.out_dir + f'/{prefix}_pred.csv' os.system(f'rm -f {fpred_out}') final_metrics = sft.evaluate(model, loader, stage='predict',prefix=tag,args=args,loss_fct=loss_fct,wandb=wandb,ddp=ddp,ctx=ctx,Logger=Logger,fpred_out=fpred_out,tokenizer=tokenizer) save_metrics(final_metrics,tag,args.out_dir,filename=f"zeroshot_metrics.csv") # final_metrics = [f'{k}_{tag}: {v:.4f}' for k,v in final_metrics.items()] # with open(os.path.join(args.out_dir, f"{args.wandb_run_name}_zeroshot_metrics.csv"), 'a') as f: # f.write(','.join([f'{e:.4f}' for e in final_metrics])+'\n') return ckp, final_metrics, -1 # 最终测试 # Logger('first evaluation on TS') # if not ddp or ddp_local_rank == 0: # df = pd.DataFrame(final_metrics).T # df.to_csv(os.path.join(args.out_dir, f"{args.wandb_run_name}_zeroshot_metrics.csv"),index=False) with open(os.path.join(args.out_dir, f"{args.wandb_run_name}.params"), 'w') as f: f.write(f'#args={args}' + '\n') f.write(f'#model={model}' + '\n') f.write(f'#lm_config={lm_config}' + '\n') f.write(f'#tokenizer={tokenizer.indices}' + '\n') scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) # lr_scheduler = LearningRateScheduler(base_lr=args.learning_rate, total_epochs=args.epochs, total_steps_per_epoch=len(train_loader)) lr_scheduler = None optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) if out_ckp is not None: ckp = out_ckp else: ckp = f'{args.save_dir}/{args.wandb_run_name}.pth' # 新增:初始化Early Stopper(示例:监控验证集MSE,越小越好) early_stopping = EarlyStopping(patience=3, verbose=True, path=ckp) early_stopping.save_model(model, ckp) # 主训练循环 epoch = 0 for epoch in range(args.epochs): if ddp: train_loader.sampler.set_epoch(epoch) if epoch ==0: Logger(f'iterating over {len(train_loader)} batches per epoch, {len(train_loader.dataset)} samples per epoch, batch_size={train_loader.batch_size}, gpu_num={torch.cuda.device_count()}') # 训练一个epoch # current_epoch: int, current_step: int, warmup_epochs: int = 2 ans = sft.train_epoch(model=model,wandb=wandb, train_loader=train_loader, prefix="TR", optimizer=optimizer, ddp=ddp,epoch=epoch, loss_fct=loss_ce,loss_mse=loss_mse, scaler=scaler,args=args,Logger=Logger,lr_scheduler=lr_scheduler) ans_vl = sft.evaluate(model, val_loader, stage=f'{epoch}',prefix="VL",args=args,loss_fct=loss_ce,loss_mse=loss_mse,wandb=wandb,ddp=ddp,ctx=ctx,Logger=Logger,tokenizer=tokenizer) ans.update(ans_vl) val_mse = ans['VL_logits_epoch_loss'] save_metrics(ans,epoch,args.out_dir,filename=f"{args.wandb_run_name}_history_metrics.csv") #val_spr, val_pr, val_mse, val_rmse,val_r2 if ddp: # 分布式训练逻辑 if ddp_local_rank == 0: early_stopping(val_mse, model) # 如果监控的是SPR,直接传入-SPR即可 # 广播 should_stop 的值到其他进程 to_broadcast = torch.tensor([early_stopping.early_stop], dtype=torch.bool, device=args.device) dist.broadcast(to_broadcast, 0) else: # 非主进程等待主进程广播 # print('非主进程等待主进程广播') to_broadcast = torch.tensor([False], dtype=torch.bool, device=args.device) dist.broadcast(to_broadcast, 0) early_stopping.early_stop = bool(to_broadcast.item()) else: # 单机单卡训练逻辑 early_stopping(val_mse, model) # 如果监控的是SPR,直接传入-SPR即可 if early_stopping.early_stop:break # 恢复最佳模型(可选) if os.access(ckp, os.F_OK) and os.path.getsize(ckp) > 0: # epoch ==0 的时候不会保存模型 state_dict = torch.load(ckp, map_location=args.device) model.load_state_dict(state_dict, strict=False) Logger("Loaded best model for final evaluation.") else:early_stopping.save_model(model,ckp) final_metrics = sft.evaluate(model, test_loader,stage='final', prefix="TS",args=args,loss_fct=loss_ce,loss_mse=loss_mse,wandb=wandb,ddp=ddp,ctx=ctx,Logger=Logger,tokenizer=tokenizer) save_metrics(final_metrics, 'TS', args.out_dir, filename=f"{args.wandb_run_name}_final_metrics.csv") # 最终测试 Logger('final evaluation on TS') # if ddp and ddp_local_rank == 0: dist.destroy_process_group() if args.use_wandb and (not ddp or ddp_local_rank == 0):wandb.finish() Logger(f'the end of experiment stf process: time = {(time.time() - start) // 60:.0f} min') return ckp,final_metrics,epoch def trainable(config): # 从config中获取超参数 sft = maotao() # 假设其他参数固定不变或从args中获取 # ctx, ddp, ddp_local_rank, args, in_ckp, out_ckp, lm_config, tokenizer, Logger, task, seq_pkl_path = get_fixed_params() args.batch_size = config["batch_size"] args.learning_rate = config["learning_rate"] max_seq_len = args.max_seq_len # 调用sft_process函数 # ckp, (spr, pr, mse, rmse, r2), _ = sft_process(max_seq_len=max_seq_len, ctx=ctx, ddp=ddp, # ddp_local_rank=ddp_local_rank, # args=args, ckp=in_ckp, out_ckp=out_ckp, # lm_config=lm_config, tokenizer=tokenizer, Logger=Logger, # task=task, seq_pkl_path=seq_pkl_path) ckp, final_metrics, _ = sft_process_maotao(max_seq_len=max_seq_len, ctx=ctx, ddp=ddp, ddp_local_rank=ddp_local_rank, args=args, ckp=in_ckp, out_ckp=out_ckp, lm_config=lm_config, tokenizer=tokenizer, Logger=Logger,task=task, sft=sft) # # 报告结果给Ray Tune # tune.report(accuracy=spr) return {"identity": final_metrics['TS_codon_identity_codon']} # 这里需要你根据实际情况定义get_fixed_params函数,该函数用于获取不通过Ray Tune调优的参数 def test_ray_tune(): Logger('hello word') from ray import tune def objective(config): # ① score = config["a"] ** 2 + config["b"] return {"score": score} search_space = { # ② "a": tune.grid_search([0.001, 0.01, 0.1, 1.0,0.4]), "b": tune.choice([1, 2, 3]), } tuner = tune.Tuner(objective, param_space=search_space, # run_config=ray.air.RunConfig( # storage_path="./ray_results/tune_results", # 存储路径 # name="my_experiment" # 实验名称 # ) ) # ③ results = tuner.fit() Logger(results.get_best_result(metric="score", mode="max").config) def init_config(vocab_path,n_layers,max_seq_len): tokenizer = Dictionary.load(vocab_path) tokenizer.mask_index = tokenizer.add_symbol('') # ['', '', '', '', 'G', 'A', 'U', 'C', 'N', ''] tokenizer.indices['T'] = tokenizer.indices['U'] tokenizer.indices['_'] = tokenizer.pad_index lm_config = LMaoTaoConfig(dim=256, logit_dim=len(tokenizer),n_layers=n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, [tokenizer.add_symbol(word) for word in AA_str] # 10-31 return lm_config,tokenizer def seed_everything(seed=2022): print('seed_everything to ',seed) random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) # 程序每次运行结果一致,但是程序中多次生成随机数每次不一致 # https://blog.csdn.net/qq_42951560/article/details/112174334 torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # minbatch的长度一直在变化,这个优化比较浪费时间 if __name__ == '__main__': print('start', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) start = time.time() ''' test ''' # test_ray_tune() # import train_full_sft as sft sft = maotao() parser = get_pretraining_args() args = parser.parse_args() args.downstream_data_path = 'maotao_file/' args.seed = int(args.seed) # torch.manual_seed(args.seed) # torch.manual_seed(1337) seed_everything(seed=args.seed) if args.predict: task= args.task else: task = 'AA2CDS_data' device_type = "cuda" if "cuda" in args.device else "cpu" ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() ddp_local_rank, DEVICE = 0, "cuda:0" if ddp: print('init distributed mode') ddp_local_rank, DEVICE = init_distributed_mode(ddp=ddp) args.device = torch.device(DEVICE) Logger('args.device:',args.device) Logger('setting args',args) max_seq_len = 1200 args.seq_len = max_seq_len args.save_dir = os.path.join(args.out_dir) # os.system(f"rm -rf {args.save_dir}") # todo os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.out_dir, exist_ok=True) tokens_per_iter = args.batch_size * max_seq_len lm_config,tokenizer = init_config(args.arg_overrides['data'] + '/small_dict.txt', args.n_layers, max_seq_len) lm_config.use_moe = args.use_moe wandb_project = args.wandb_project '''3. benchmark downstream tasks''' prefix = 'TS' # with open(args.save_dir+'/benchmark_result.tsv','w') as f: # f.write('Project\tModel\tTask\tSPR\tPR\tMSE\tRMSE\tR2\tckp\tepoch\n') epochs = args.epochs args.out_dir = os.path.abspath(args.out_dir) os.makedirs(args.out_dir, exist_ok=True) model_dir = args.out_dir # 'exp_log/out_demo4/' model_dir = os.path.abspath(model_dir) # model_dir = 'exp_log/out_demo250810/' data_dir = args.downstream_data_path # 'dataset/downstreamV4/' data_dir = os.path.abspath(data_dir) Logger(f'model_dir:{model_dir}') os.makedirs(model_dir, exist_ok=True) args.save_dir = os.path.abspath(args.save_dir) args.downstream_data_path = os.path.abspath(args.downstream_data_path) # args.codon_table_path = 'maotao_file/codon_table/codon_usage_{species}.csv' Logger('args.downstream_data_path:', args.downstream_data_path) out_ckp = args.save_dir + f'/AA2CDS.pth' out_ckp = os.path.abspath(out_ckp) # os.system(f"rm -rf {out_ckp}") in_ckp = model_dir+'/AA2CDS.pth' # copy_current_code(args.out_dir+'/code/') search_space = { # "max_seq_len": 1200, "batch_size": args.batch_size, "learning_rate": args.learning_rate, } trainable(search_space) print('stop', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) print('time', time.time() - start)