#!/usr/bin/env python # -*- coding: utf-8 -*- """ Title : tokenizer.py project : minimind_RiboUTR Created by: julse Created on: 2025/2/12 16:40 des: TODO """ from typing import List import argparse import os import pickle import random import re import time from collections import defaultdict from itertools import chain from random import shuffle import math import numpy as np import pandas as pd import torch from torch.utils.data import Dataset, DataLoader import torch.nn.functional as F import transformers from copy import copy, deepcopy from model.codon_attr import Codon # for debug only os.chdir('../../') # print(__file__,os.getcwd()) import sys from utils.ernie_rna.dictionary import Dictionary from utils.ernie_rna.position_prob_mask import calculate_mask_prob from transformers import DebertaTokenizerFast from model.codon_tables import CODON_TO_AA, AA_str, AA_TO_CODONS, reverse_dictionary, create_codon_mask # from utils.esm3.tokenizer import EsmSequenceTokenizer base_range_lst = [1] lamda_lst = [0.8] import torch from torch.utils.data import Dataset import numpy as np import pandas as pd class BaseDataset(Dataset): """公共基类,包含共享属性和方法""" def __init__( self, tokenizer, region: int = 300, limit: int = -1, return_masked_tokens: bool = False, seed: int = 1, mask_prob: float = 0.15, leave_unmasked_prob: float = 0.1, random_token_prob: float = 0.1, freq_weighted_replacement: bool = False, two_dim_score: bool = False, two_dim_mask: int = -1, mask_whole_words: torch.Tensor = None, ): # 参数校验 assert 0.0 < mask_prob < 1.0 assert 0.0 <= random_token_prob <= 1.0 assert 0.0 <= leave_unmasked_prob <= 1.0 assert random_token_prob + leave_unmasked_prob <= 1.0 # 初始化公共属性 self.tokenizer = tokenizer self.pad_idx = tokenizer.pad_index self.mask_idx = tokenizer.mask_index self.return_masked_tokens = return_masked_tokens self.seed = seed self.mask_prob = mask_prob self.leave_unmasked_prob = leave_unmasked_prob self.random_token_prob = random_token_prob self.two_dim_score = two_dim_score self.two_dim_mask = two_dim_mask self.mask_whole_words = mask_whole_words self.region = region self.limit = limit # 初始化权重(如果需要) if random_token_prob > 0.0: weights = np.array(tokenizer.count) if freq_weighted_replacement else np.ones(len(tokenizer)) weights[: tokenizer.nspecial] = 0 self.weights = weights / weights.sum() self.tokenizer.indices['T']=self.tokenizer.indices['U'] self.amino_acid_to_codons = {} for aa, codons in AA_TO_CODONS.items(): codons_num = [] for codon in codons: codon_num = [] for base in codon: codon_num.append(self.tokenizer.indices[base]) # 如果碱基不在映射中,使用4表示未知 codons_num.append(codon_num) self.amino_acid_to_codons[self.tokenizer.indices[aa.lower()]] = codons_num # region === 公共方法 === # @staticmethod # def prepare_input_for_ernierna(index, seq_len): # shorten_index = index[:seq_len + 2] # 截断到seq_len+2 # one_d = torch.from_numpy(shorten_index).long().reshape(1, -1) # two_d = np.zeros((1, seq_len + 2, seq_len + 2)) # two_d[0, :, :] = creatmat(shorten_index.astype(int), base_range=1, lamda=0.8) # # two_d[:, :, :] = creatmat(shorten_index.astype(int), base_range=1, lamda=0.8) # two_d = two_d.transpose(1, 2, 0) # two_d = torch.from_numpy(two_d).reshape(1, seq_len + 2, seq_len + 2, 1) # return one_d, two_d @staticmethod def translate(nucleotide_seq,repeate=3): amino_acid_list = [] for i in range(0, len(nucleotide_seq), 3): codon = nucleotide_seq[i:i + 3] amino_acid_list.append(CODON_TO_AA.get(codon, '-')*repeate) amino_acid_seq = ''.join(amino_acid_list) return amino_acid_seq @staticmethod def prepare_input_for_ernierna(index, seq_len): # (1, 1205), 1205 if index.ndim == 2: index = np.squeeze(index) shorten_index = index[:seq_len] # 截断到seq_len one_d = torch.from_numpy(shorten_index).long().reshape(1, -1) two_d = np.zeros((1, seq_len, seq_len)) two_d[0, :, :] = creatmat(shorten_index.astype(int), base_range=1, lamda=0.8) # new_matrix = creatmat(item.numpy(), base_range, lamda) # [1205] two_d = two_d.transpose(1, 2, 0) two_d = torch.from_numpy(two_d).reshape(1, seq_len, seq_len, 1) return one_d, two_d def generate_inputs(self,x): region = self.region # utr5 = x["UTR5"] if 'UTR5' in x else UTR5 # utr3 = x["UTR3"] if 'UTR3' in x else UTR3 # cds = x["CDS"] if 'CDS' in x else CDS utr5 = x["UTR5"] utr3 = x["UTR3"] cds = x["CDS"] seq = utr5 + cds + utr3 cds_start = len(utr5) cds_stop = len(utr5) + len(cds) # utr5 = seq[:cds_start] # cds = seq[cds_start:cds_stop] # utr3 = seq[cds_stop:] utr5_limit = 300 if region > 300 else region # seq = self.process_sequence(seq, cds_start, cds_stop, region, '_', '<', '>', 'N', utr5_limit) seq = self.process_sequence(seq, cds_start, cds_stop, region, '_', '<', '>', 'N', utr5_limit) return seq def process_sequence(self,seq, cds_start, cds_stop, region, pad_mark, bos, eos,link,utr5_limit): utr5 = seq[:cds_start] cds = seq[cds_start:cds_stop] utr3 = seq[cds_stop:] # utr5 = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # cds_h = self.process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # cds_t = self.process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # seq = utr5 + cds_h + cds_t + utr3 # seq = seq[:region*2+1]+link*3+seq[-region*2-1:] utr5 = self.process_utr(utr5, utr5_limit, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) cds_h = self.process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) cds_t = self.process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) seq = utr5 + cds_h + cds_t + utr3 seq = seq[:utr5_limit+region+1]+link*3+seq[-region*2-1:] # c1 = seq[cds_start:] # c2 = seq[:cds_stop] # # utr5 = self.process_utr(utr5, utr5_limit, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # cds_h = self.process_utr(c1, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # cds_t = self.process_utr(c2, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # 这样会导致CDS和UTR混在一起,后面不太好mask猜AA # utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # seq = utr5 + cds_h + cds_t + utr3 # seq = seq[:utr5_limit+region+1]+link*3+seq[-region*2-1:] # utr5 = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # # pre_processed = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # behind_processed = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # # seq = pre_processed + cds_part + behind_processed # # seq = seq[:region*2+1]+link*3+seq[-region*2-1:] if isinstance(seq,list): seq = np.array(seq) return seq @staticmethod def process_utr(utr, input_len, pad_method, pad_mark='_',bos='<',eos='>'): if len(utr) < input_len: if pad_method == 'pre': padded_utr = pad_mark * (input_len - len(utr)) + bos + utr elif pad_method == 'behind': padded_utr = utr+eos + pad_mark * (input_len - len(utr)) else: if pad_method == 'pre': padded_utr = bos+utr[-input_len:] elif pad_method == 'behind': padded_utr = utr[:input_len]+eos return padded_utr # self.process_utr = process_utr # @staticmethod # def seqs_to_index(sequences, pad_idx=1, unk_idx=3): # ''' # input: # sequences: list of string (difference length) # # return: # rna_index: numpy matrix, shape like: [len(sequences), max_seq_len+2] # rna_len_lst: list of length # # # examples: # rna_index, rna_len_lst = seq_to_index(sequences) # for i, (index, seq_len) in enumerate(zip(rna_index, rna_len_lst)): # one_d, two_d = prepare_input_for_ernierna(index, seq_len) # one_d = one_d.to(device) # two_d = two_d.to(device) # # output = my_model(one_d, two_d, layer_idx=layer_idx).cpu().detach().numpy() # ''' # # rna_len_lst = [len(ss) for ss in sequences] # max_len = max(rna_len_lst) # # assert max_len <= 1022 # seq_nums = len(rna_len_lst) # rna_index = np.ones((seq_nums, max_len + 2)) # for i in range(seq_nums): # for j in range(rna_len_lst[i]): # 4,5,6,7 --->GATC # if sequences[i][j] in set("Aa"): # rna_index[i][j + 1] = 5 # elif sequences[i][j] in set("Cc"): # rna_index[i][j + 1] = 7 # elif sequences[i][j] in set("Gg"): # rna_index[i][j + 1] = 4 # elif sequences[i][j] in set('TUtu'): # rna_index[i][j + 1] = 6 # elif sequences[i][j] in set('_'): # rna_index[i][j + 1] = pad_idx # else: # rna_index[i][j + 1] = unk_idx # rna_index[i][rna_len_lst[i] + 1] = 2 # add 'eos' token # rna_index[:, 0] = 0 # add 'cls' token # return rna_index, rna_len_lst # @staticmethod # def seq_to_rnaindex(seq,pad_idx=1, unk_idx=3): # l = len(seq) # X = np.ones((1, l + 2)) # for j in range(l): # if seq[j] in set('Aa'): # X[0, j + 1] = 5 # elif seq[j] in set('UuTt'): # X[0, j + 1] = 6 # elif seq[j] in set('Cc'): # X[0, j + 1] = 7 # elif seq[j] in set('Gg'): # X[0, j + 1] = 4 # elif seq[j] in set('_'): # X[0,j + 1] = pad_idx # else: # X[0,j + 1] = unk_idx # # X[0, l + 1] = 2 # X[0, 0] = 0 # return X @staticmethod def seq_to_rnaindex(seq,pad_idx=1, unk_idx=3): # rna_alphabet_list:str="""GAUC_""", # '': 0, '': 1, '': 2, '': 3, # 'G': 4, 'A': 5, 'U': 6, 'C': 7, 'N': 8, '': 9, # 'a': 10, 'y': 29, '*': 30, '-': 31, 'T': 6 l = len(seq) X = np.ones((1, l)) for j in range(l): if seq[j] in set('Aa'): X[0, j] = 5 elif seq[j] in set('UuTt'): X[0, j] = 6 elif seq[j] in set('Cc'): X[0, j] = 7 elif seq[j] in set('Gg'): X[0, j] = 4 elif seq[j] in set('_'): X[0,j] = pad_idx elif seq[j] in set('<'): X[0,j] = 0 elif seq[j] in set('>'): X[0,j] = 2 else: X[0,j] = unk_idx # X[0, l + 1] = 2 # X[0, 0] = 0 return X # def generate_mask(self, X,seq_len): # one_d, twod_d = self.prepare_input_for_ernierna(X, seq_len) # # return one_d, twod_data # [1,L+2],[1,L+2,L+2,1],[1,L,4] # '''generate src_data, tgt_data, twod_data ''' # item = one_d.view(-1) # # assert ( # self.mask_idx not in item # ), "Dataset contains mask_idx (={}), this is not expected!".format( # self.mask_idx, # ) # # if self.mask_whole_words is not None: # todo: check when need # word_begins_mask = self.mask_whole_words.gather(0, item) # word_begins_idx = word_begins_mask.nonzero().view(-1) # sz = len(word_begins_idx) # words = np.split(word_begins_mask, word_begins_idx)[1:] # assert len(words) == sz # word_lens = list(map(len, words)) # # sz = len(item) # # decide elements to mask # mask = np.full(sz, False) # # # 找出非 padding 的位置 # non_pad_indices = np.where( # (item != self.tokenizer.pad_index) & # (item != self.tokenizer.unk_index) # )[0] # # 计算需要掩码的数量 # num_non_pad = len(non_pad_indices) # num_mask = int( # self.mask_prob * num_non_pad + np.random.rand() # ) # # # 在非 padding 的位置中随机选择要掩码的元素,根据position prob 进行mask # target_positions = [self.region + 1, self.region *3 + 4] # sigma = 90 # 控制概率衰减的速度,数值越小,衰减越快 # probabilities = np.array([calculate_mask_prob(i, target_positions, sigma) for i in range(sz)]) # non_pad_probabilities = probabilities[non_pad_indices] # non_pad_probabilities = non_pad_probabilities/non_pad_probabilities.sum() # if num_non_pad >= 1: # mask[non_pad_indices[np.random.choice(num_non_pad, num_mask, replace=False,p=non_pad_probabilities)]] = True # # twod_data # two_dim_matrix =torch.squeeze(twod_d, dim=-1).numpy() # # item_len = len(item.numpy()) # # two_dim_matrix = np.zeros((len(base_range_lst) * len(lamda_lst), item_len, item_len)) # 只有0和-1 # padding_dim = 0 # for base_range in base_range_lst: # for lamda in lamda_lst: # new_matrix = creatmat(item.numpy(), base_range, lamda) # new_matrix[mask, :] = -1 # new_matrix[:, mask] = -1 # two_dim_matrix[padding_dim, :, :] = new_matrix # padding_dim += 1 # # use -1 represent mask # # matrix[mask,:] = self.two_dim_mask # # matrix[:,mask] = self.two_dim_mask # # print(two_dim_matrix.shape) # twod_data = torch.from_numpy(two_dim_matrix)#.unsqueeze(-1) # [1, L+2, L+2, 1] # # # if self.mask_whole_words is not None: # mask = np.repeat(mask, word_lens) # # new_item = np.full(len(mask), self.pad_idx) # new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] # tgt_data = torch.from_numpy(new_item)#.unsqueeze(0) # [L,1] # # # decide unmasking and random replacement # rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob # if rand_or_unmask_prob > 0.0: # rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) # if self.random_token_prob == 0.0: # unmask = rand_or_unmask # rand_mask = None # elif self.leave_unmasked_prob == 0.0: # unmask = None # rand_mask = rand_or_unmask # else: # unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob # decision = np.random.rand(sz) < unmask_prob # unmask = rand_or_unmask & decision # rand_mask = rand_or_unmask & (~decision) # else: # unmask = rand_mask = None # # if unmask is not None: # mask = mask ^ unmask # # if self.mask_whole_words is not None: # mask = np.repeat(mask, word_lens) # # new_item = np.copy(item) # new_item[mask] = self.mask_idx # if rand_mask is not None: # num_rand = rand_mask.sum() # if num_rand > 0: # if self.mask_whole_words is not None: # rand_mask = np.repeat(rand_mask, word_lens) # num_rand = rand_mask.sum() # # 以概率突变 # new_item[rand_mask] = np.random.choice( # len(self.tokenizer), # num_rand, # p=self.weights, # ) # src_data = torch.from_numpy(new_item)#.unsqueeze(0) # loss_mask = torch.tensor(mask, dtype=torch.long)#.unsqueeze(0) # # return src_data,tgt_data,twod_data,loss_mask def generate_mask(self, X,seq_len,mask=None,input_mask=True): """ :param X: seuqnce in number index :param seq_len: :param mask: 1d mask array, if None, generate by dual center mask :param input_mask: true: use 2d mask :return: """ one_d, twod_d = self.prepare_input_for_ernierna(X, seq_len) # X: [1,1205] # return one_d, twod_data # [1,L+2],[1,L+2,L+2,1],[1,L,4] '''generate src_data, tgt_data, twod_data ''' item = one_d.view(-1) assert ( self.mask_idx not in item ), "Dataset contains mask_idx (={}), this is not expected!".format( self.mask_idx, ) if self.mask_whole_words is not None: # todo: check when need word_begins_mask = self.mask_whole_words.gather(0, item) word_begins_idx = word_begins_mask.nonzero().view(-1) sz = len(word_begins_idx) words = np.split(word_begins_mask, word_begins_idx)[1:] assert len(words) == sz word_lens = list(map(len, words)) sz = len(item) # decide elements to mask if mask is None: # 1D mask mask = np.full(sz, False) # 找出非 padding 的位置 non_pad_indices = np.where( (item != self.tokenizer.pad_index) & (item != self.tokenizer.unk_index) )[0] # # 计算需要掩码的数量 num_non_pad = len(non_pad_indices) num_mask = int( self.mask_prob * num_non_pad + np.random.rand() ) # 在非 padding 的位置中随机选择要掩码的元素,根据position prob 进行mask target_positions = [self.region + 1, self.region *3 + 4] # ATG seq[301:301+3], TGG seq[901:901+3] sigma = 90 # 控制概率衰减的速度,数值越小,衰减越快 probabilities = np.array([calculate_mask_prob(i, target_positions, sigma) for i in range(sz)]) non_pad_probabilities = probabilities[non_pad_indices] non_pad_probabilities = non_pad_probabilities/non_pad_probabilities.sum() if num_non_pad >= 1: mask[non_pad_indices[np.random.choice(num_non_pad, num_mask, replace=False,p=non_pad_probabilities)]] = True mask[target_positions[0]:target_positions[0]+3]=False # (301, 304) ATG mask[target_positions[1]-3:target_positions[1]]=False # (901, 904) TAA mask[target_positions[0]+300:target_positions[0]+303]=False # (601, 604) NNN # decide unmasking and random replacement rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob if rand_or_unmask_prob > 0.0: rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) if self.random_token_prob == 0.0: unmask = rand_or_unmask rand_mask = None elif self.leave_unmasked_prob == 0.0: unmask = None rand_mask = rand_or_unmask else: unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob decision = np.random.rand(sz) < unmask_prob unmask = rand_or_unmask & decision rand_mask = rand_or_unmask & (~decision) else: unmask = rand_mask = None if unmask is not None: mask = mask ^ unmask # twod_data if input_mask: twod_data = self.get_twod_data(item,twod_d.detach(),mask) # mask = [1, 1205, 1205, 1]# else: twod_data = self.get_twod_data(item,twod_d.detach(),np.zeros_like(mask)) # mask = [1, 1205, 1205, 1]# if self.mask_whole_words is not None: mask = np.repeat(mask, word_lens) # new_item = np.full(len(mask), self.pad_idx) # new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] # tgt_data = torch.from_numpy(new_item)#.unsqueeze(0) # [L,1] tgt_data = item#.unsqueeze(0) # [L,1] if self.mask_whole_words is not None: mask = np.repeat(mask, word_lens) new_item = np.copy(item) new_item[mask] = self.mask_idx # if rand_mask is not None: # num_rand = rand_mask.sum() # if num_rand > 0: # if self.mask_whole_words is not None: # rand_mask = np.repeat(rand_mask, word_lens) # num_rand = rand_mask.sum() # # 以概率突变 # new_item[rand_mask] = np.random.choice( # 9,#len(self.tokenizer) # num_rand, # p=self.weights, # ) src_data = torch.from_numpy(new_item)#.unsqueeze(0) # loss_mask = torch.tensor(mask, dtype=torch.long)#.unsqueeze(0) return src_data,tgt_data,twod_data,mask def get_twod_data(self,item,twod_d,mask): two_dim_matrix =torch.squeeze(twod_d, dim=-1).numpy() # item_len = len(item.numpy()) # two_dim_matrix = np.zeros((len(base_range_lst) * len(lamda_lst), item_len, item_len)) # 只有0和-1 padding_dim = 0 for base_range in base_range_lst: for lamda in lamda_lst: new_matrix = creatmat(item.numpy(), base_range, lamda) new_matrix[mask==1, :] = -1 new_matrix[:, mask==1] = -1 two_dim_matrix[padding_dim, :, :] = new_matrix padding_dim += 1 # use -1 represent mask # matrix[mask,:] = self.two_dim_mask # matrix[:,mask] = self.two_dim_mask # print(two_dim_matrix.shape) twod_data = torch.from_numpy(two_dim_matrix)#.unsqueeze(-1) # [1, L+2, L+2, 1] return twod_data @staticmethod def read_text_file(file_path): try: with open(file_path, 'r') as file: return [line.strip() for line in file] except FileNotFoundError: print(f"Error: File '{file_path}' not found.") return [] @staticmethod def create_base_prob(target_protein,ith_nn_prob,rna_alphabet,tokenizer): mask_nn_logits = torch.full(size=(len(target_protein)*3,len(tokenizer)),fill_value=float("-inf")) for i,a in enumerate(target_protein): if a not in ith_nn_prob[0]: continue for j in range(3): for n in rna_alphabet: mask_nn_logits[i*3+j,tokenizer.index(n)] = ith_nn_prob[j][a][n] return mask_nn_logits @staticmethod def create_codon_mask(target_protein, backbone_cds, amino_acid_to_codons, tokenizer): # logits = torch.full() # batch_size, seq_length, vocab_size = logits.shape seq_length = len(backbone_cds) vocab_size = len(tokenizer) # mask = torch.full_like(logits, float("-inf")) mask = torch.full(size=(seq_length,vocab_size),fill_value=float("-inf")) for i, amino_acid in enumerate(target_protein): codon_start = i * 3 # 每个氨基酸对应 3 个碱基 codon_end = codon_start + 3 if codon_end > seq_length: continue # 超出序列长度,跳过 possible_codons = amino_acid_to_codons.get(amino_acid.item(), []) # filter_codons = [] for pos in range(codon_start, codon_end): base_pos = pos % 3 # 当前碱基在密码子中的位置(0, 1, 2) for codon in possible_codons: flag = True for j, nt in enumerate(backbone_cds[codon_start:codon_end]): nt = nt.item() if tokenizer.mask_index == nt: continue if codon[j] != nt: flag = False # filter_codons.append(codon) if flag: base_idx = codon[base_pos] mask[pos, base_idx] = 0 # a = mask.numpy() return mask # endregion # region === 需要子类实现的方法 === def load_data(self, path, **kwargs): raise NotImplementedError("Subclasses must implement load_data") def __getitem__(self, idx): raise NotImplementedError("Subclasses must implement __getitem__") # endregion class RNADataset(BaseDataset): """处理RNA序列的Dataset""" def __init__( self, path, tokenizer, region: int = 300, limit: int = -1, return_masked_tokens: bool = False, seed: int = 1, mask_prob: float = 0.15, leave_unmasked_prob: float = 0.1, random_token_prob: float = 0.1, freq_weighted_replacement: bool = False, two_dim_score: bool = False, two_dim_mask: int = -1, mask_whole_words: torch.Tensor = None, ): # 调用父类初始化 super().__init__( tokenizer=tokenizer, region=region, limit=limit, return_masked_tokens=return_masked_tokens, seed=seed, mask_prob=mask_prob, leave_unmasked_prob=leave_unmasked_prob, random_token_prob=random_token_prob, freq_weighted_replacement=freq_weighted_replacement, two_dim_score=two_dim_score, two_dim_mask=two_dim_mask, mask_whole_words=mask_whole_words, ) # 加载数据 self.samples = self.load_data(path, region=self.region, limit=limit) def load_data(self, path, region=300, limit=-1): return self.read_fasta_file(path, region=region, limit=limit) @staticmethod def read_fasta_file(file_path, region=300, cds_min=100, limit=-1): ''' input: file_path: str, fasta file path of input seqs return: seqs_dict: dict[str], dict of seqs { 'ENST00000231420.11': { # 转录本的标识符 'cds_start': 57, # CDS的起始位置(基于0的索引) 'cds_stop': 1599, # CDS的终止位置(不包括该位置,基于0的索引) 'full': 'AGTTAGAGCCCGGCCTCCAATCTGCTTCCATGGGGTTGGCTTTCTGAGTGGGAGAAATGACTCTAATCTGGAGACA...', # 完整的mRNA序列 'start_context': '___GAAATGTCT', # CDS起始位置前的序列上下文, padding left _,essential 'stop_context': 'AAGTAAGGG___' # CDS终止位置后的序列上下文, padding right _, essential } } ''' # region = getattr(args, 'region', region) # limit = getattr(args, 'limit', limit) try: with open(file_path) as fa: seqs_dicts = [] cds_start = 0 cds_stop = 0 count = 0 seq_name = '' # for line in fa.read().splitlines(): for line in fa: line = line.replace('\n', '') if line.startswith('>'): transcript_id, gene_id, cds_start, cds_stop = line[1:].split( ' ') # # ENST00000332160.5 ENSG00000185432.12 24 756 cds_start = int(cds_start) cds_stop = int(cds_stop) if cds_stop - cds_start < cds_min: continue seq_name = transcript_id # seqs_dict[seq_name] = {} # seqs_dict[seq_name]['cds_start'] = cds_start # seqs_dict[seq_name]['cds_stop'] = cds_stop else: expand_mRNA = '_' * region + line + '_' * region cds_start += region cds_stop += region # seqs_dict[seq_name]['full'] = line start_context = expand_mRNA[cds_start - region:cds_start + region] stop_context = expand_mRNA[cds_stop - region:cds_stop + region] seqs_dicts.append( {'_id': seq_name, 'start_context': start_context, 'stop_context': stop_context}) count += 1 if count > limit and limit != -1: break return seqs_dicts except FileNotFoundError: print(f"Error: File '{file_path}' not found.") return [] def __len__(self): return len(self.samples) def __getitem__(self, idx): # to check ''' GAUC 4567 unk 3 :param idx: :return: ''' sample = self.samples[idx] seq = sample['start_context'] + 'NNN' + sample['stop_context'] X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) if '_' in sample['start_context']: X[:, sample['start_context'].count('_')] = self.tokenizer.bos_index if '_' in sample['stop_context']: X[:, -sample['stop_context'].count('_')-1] = self.tokenizer.eos_index '''generate src_data, tgt_data, twod_data ''' src_data,tgt_data,twod_data,loss_mask = self.generate_mask(X,len(seq)) return src_data,tgt_data,twod_data,loss_mask class RiboDataPipeline(): """ 处理预训练任务的Dataset,生成mRNA.fa,加载ribosome_density, ribo_counts, rna_counts,划分TR,VL,TS Loading from origin bw """ def __init__( self, path, ribo_experiment,rna_experiment, seq_only=False, region: int = 300, cds_min: int = -1, # -1,不检查cds 长度 limit: int = -1, env : int = 0, norm = True ): self.seq_only = seq_only self.cds_min = cds_min self.env = env # self.reference_transcript_dict = {'ENST00000303577.7': 'PCBP1', # IRES chr2 # 'ENST00000309311.7': 'EEF2'} # cap dependent # chr19 # self.reference_transcript_dict = { # 'ENST00000309311.7': 'EEF2'} # cap dependent # chr19 self.reference_transcript_dict = {} # cap dependent # chr19 # 加载数据 self.samples = self.load_data(path, ribo_experiment=ribo_experiment,rna_experiment=rna_experiment,region=region, limit=limit,norm=norm) # self.ref_norm = np.mean([self.samples[key][4].sum() for key in self.reference_transcript_dict.keys()]) if self.reference_transcript_dict else 1 # RNA_counts def load_data(self, path, ribo_experiment=None,rna_experiment=None,region=300, limit=-1,norm=True): ''' 读取数据 1. 根据ribo_experiment,从meta中查询species,avg_len,total counts 等 2. 查询mRNA.fa文件是否存在 不存在: 查询mRNA.tsv是否存在 不存在: genome.gtf 生成 mRNA.gtf文件 (只含有mRNA相关的行和列,scale the size of gtf) 根据genome.fa 和mRNA.gtf 生成mRNA.fa文件 (包括start or stop codon positions) 3. 读取track文件,生成ribosome_density, ribo_counts, rna_counts {'ENST00000303577.7': 'PCBP1', # IRES 'ENST00000309311.7': 'EEF2'} # cap dependent :param path: :param reference_path: :param region: :param limit: :return: samples path = ./dataset/pretraining/ ''' """1. input ribo_experiment, meta""" seq_only = self.seq_only cds_min = self.cds_min # print('load_data in Pipeline') reference_path = os.path.join(path,'reference') meta = self.read_meta_file(os.path.join(reference_path, 'experiment_meta.tsv'),ribo_experiment,rna_experiment,seq_only = seq_only) # todo totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA,readsLength_RPF,species = meta fribo_track = os.path.join(path, 'track', f'{ribo_experiment}.bw') frna_track = os.path.join(path, 'track', f'{rna_experiment}.bw') if not seq_only: if os.access(fribo_track,os.F_OK) and os.access(frna_track,os.F_OK): print(f'load {ribo_experiment} and {rna_experiment} tracks') else: print(f'Error: {fribo_track} or {frna_track} not found.') return None """2. check mRNA.fa, .pkl""" # # 读取 chromosomes reference # print('load_data in Pipeline',meta) mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa') if region != -1: mrna_fa_path = mrna_fa_path.replace('.fa', f'_{region}.fa') mrna_fa_path = mrna_fa_path.replace('.fa', f'.pkl') self.mrna_region_pkl_path = mrna_fa_path # sequence if seq_only and os.access(mrna_fa_path, os.F_OK): with open(mrna_fa_path, 'rb') as f: sample_dict = pickle.load(f) limited_sample_dict = {} for key in sample_dict.keys(): # if limit != -1: # limited_sample_dict[key] = [[transcript_id]+list(sample_dict[key][transcript_id]) for transcript_id in list(sample_dict[key].keys())[:limit]] # else :limited_sample_dict[key] = [[transcript_id]+list(sample_dict[key][transcript_id]) for transcript_id in list(sample_dict[key].keys())] if limit != -1: limited_sample_dict[key] = sample_dict[:limit] else: limited_sample_dict[key] = sample_dict[key] return limited_sample_dict # gtf mrna_tsv_path = os.path.join(reference_path,species, 'mRNA.tsv') if not os.access(mrna_tsv_path, os.F_OK): genome_gtf_path = os.path.join(reference_path,species, 'genome.gtf') genome_fa_path = os.path.join(reference_path,species, 'genome.fa') mrna_tsv = self.generate_mRNA_tsv(genome_gtf_path,genome_fa_path,mrna_tsv_path) else: mrna_tsv = pd.read_table(mrna_tsv_path)#.iloc[:100] # 读取已经生成的mRNA.tsv文件 # print('load_data in Pipeline',mrna_tsv.shape) # 1459048, 11 """3. read track files""" # if limit<10: # debug mode # debug_ids = 'ENST00000332831.5,ENST00000000233.10'.split(',') # mrna_tsv = mrna_tsv[mrna_tsv.transcript_id.isin(debug_ids)].reindex() # region = 6 print(f'filter limit={limit},region={region}') print('load_data in Pipeline,before filter',mrna_tsv.shape) # 1459048, 11 reference_transcript_ids = list(self.reference_transcript_dict.keys()) keeping_transcript_ids = mrna_tsv[mrna_tsv['seqname'].isin(['chr10','chr15'])].transcript_id.unique().tolist() print('keeping transcript_ids',len(reference_transcript_ids+keeping_transcript_ids)) # if args.debug: # region = 6 # limit = 2000 if limit!=-1: other_transcript_ids = mrna_tsv[~mrna_tsv['transcript_id'].isin(reference_transcript_ids+keeping_transcript_ids)].transcript_id.unique().tolist() shuffle(other_transcript_ids) shuffle(keeping_transcript_ids) mrna_tsv = mrna_tsv[mrna_tsv.transcript_id.isin(reference_transcript_ids+keeping_transcript_ids[:limit]+other_transcript_ids[:limit])] print('load_data in Pipeline, after filter',mrna_tsv.shape) # 1459048, 11 if not seq_only: import pyBigWig ribo_bw,rna_bw = [pyBigWig.open(fribo_track), pyBigWig.open(frna_track)] print(f'meta of {ribo_experiment} and {rna_experiment} tracks loaded\n{ribo_bw.header()}\n{rna_bw.header()}') def iterfunc(x,bw): chrom, start, end = x['seqname'], x['start'], x['end'] if chrom in bw.chroms(): return np.array(bw.values(chrom, start - 1, end)) else: return np.zeros(end - start) # if x['seqname'] in bw or x['end']={ribo_bw.chroms(x['seqname'])},{x['end'], region=6 ribo_counts, rna_counts, ribosome_density, te, self.env, cds_len, mRNA_len,junction_counts) = ans ref_norm.append((sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA)) if len(ref_norm)==0 and norm: print(f'Error: no qualified reference transcript (housekeeping when norm=True)') return None ref_norm = np.mean(ref_norm,axis=0) if norm and len(ref_norm)>0 else None print('ref_norm',ref_norm,'sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA)') '''generate by norm''' for transcript_id, data in mrna_tsv.groupby('transcript_id'): tag = data['dataset'].iloc[0] ans = self.merge_transcript_level(data,total_counts_info=total_counts_info,seq_only=seq_only,cds_min=cds_min,region=region,ref_norm=ref_norm) if ans is None:continue sample_dict[tag].append([transcript_id] + ans) count += 1 if limit == count: break # datasplit = 'TR_VL_TS.tsv' # # ~/Data/RNAdesign/Raw_data/_0_reference/GRCh38.p14/mRNA/dataset_split $ head TR_VL_TS.tsv # # transcript_id dataset seqname gene_id self.samples = sample_dict if seq_only: mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa') if region !=-1: mrna_fa_path = mrna_fa_path.replace('.fa',f'_{region}.fa') if not (os.access(mrna_fa_path,os.F_OK) and os.path.getsize(mrna_fa_path)>0): print(f'generate {sum([len(a) for a in sample_dict.values()])} sequences to {mrna_fa_path} {os.path.abspath(mrna_fa_path)}') self.generate_mRNA_fa(mrna_fa_path,sample_dict,force_regenerate=True) # ,force_regenerate=True mrna_fa_path = mrna_fa_path.replace('.fa',f'.pkl') if not os.access(mrna_fa_path, os.F_OK): with open(mrna_fa_path, 'wb') as f: pickle.dump(sample_dict, f) self.mrna_region_pkl_path = mrna_fa_path return sample_dict # transcript_id 作为 groupby key def utr5_limit(self,args,x,region): utr5_limit = 300 if args.region>300 else args.region seq = list( x[region - utr5_limit:region + 1 + args.region] \ + 'NNN' + x[3 * region + 4 - args.region:3 * region + 4 + args.region+1]) if seq[-1] not in {'_','>'}:seq[-1]='>' if seq[0] not in {'_','<'}:seq[0]='<' return seq def merge_transcript_level(self,data,total_counts_info=None,seq_only=False,cds_min=-1,region=300,ref_norm=None): # [1,1,1,1] # print(transcript_id) ans = self.qualified_samples(data, seq_only=seq_only, cds_min=cds_min) junction_counts = len(data[data['feature'] == 'CDS']) if ans is not None: seq, cds_start, cds_stop, ribo_counts, rna_counts, anno, metadict = ans cds_len = cds_stop - cds_start mRNA_len = len(seq) if region!=-1: utr5_limit = 300 if region > 300 else region seq = self.process_sequence(seq, cds_start, cds_stop, region, '_', '<', '>', 'N',utr5_limit) # anno = self.process_sequence(anno, cds_start, cds_stop, region, '_', '<', '>', 'N') if metadict is not None: # seq_only = False totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF = total_counts_info if metadict['ribo_recovery'] > 0.9 and metadict['rna_recovery'] > 0.9: # high quality samples for TE te = self.calculate_ribosome_density(metadict['ribo_avg_count'], metadict['rna_avg_count'], totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF) te = float(te) else: te = -1 # low quality samples for TE # pad_or_truncate_utr # seq = self.process_sequence(seq, cds_start, cds_stop, region, '_', '<', '>', 'N') anno = self.process_sequence(anno, cds_start, cds_stop, region, '_', '<', '>', 'N',utr5_limit) ribo_counts\ = self.process_sequence(ribo_counts, cds_start, cds_stop, region, [-1], [-1], [-1], [-1],utr5_limit) rna_counts = self.process_sequence(rna_counts, cds_start, cds_stop, region, [-1], [-1], [-1], [-1],utr5_limit) if sum(ribo_counts[ribo_counts != -1]) <= 100 or sum( rna_counts[rna_counts != -1]) <= 100: # 这里有padding的-1,不应该放入计算中, 质量控制 # print(f"No reads for {data['transcript_id'].iloc[0]}") return None ''' normalized by total counts https://rcxqhxlmkf.feishu.cn/docx/MdEvd008poMIaexhX9Xc7EAEnth#share-SNGtdmaQ2oATE0xax1Nc6m3jnbp ''' ribo_counts += 1 # max 1130 rna_counts += 1 # max 3628 ribosome_density = deepcopy(ribo_counts) ribosome_density[ribosome_density != 0] = self.calculate_ribosome_density( ribo_counts[ribo_counts != 0], rna_counts[rna_counts != 0], totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF) # 2.38 if ref_norm is not None: ribo_counts, rna_counts = ribo_counts / (ref_norm[0] * readsLength_RPF), rna_counts / ( ref_norm[1] * readsLength_RNA) # demo4 # print([(max(a), min(a)) for a in [ribo_counts, rna_counts]]) cds_start, cds_stop = anno.index('|'), anno.rindex('|', 1) + 4 return [seq, cds_start, cds_stop, # CDS_region = seq[cds_start:cds_stop] , region=6 ribo_counts, rna_counts, ribosome_density, te, self.env,cds_len,mRNA_len,junction_counts] return [seq, cds_start, cds_stop,cds_len,mRNA_len,junction_counts] # def process_sequence(self,seq, cds_start, cds_stop, region, pad_mark, bos, eos,link,utr5_limit): # utr5 = seq[:cds_start] # cds = seq[cds_start:cds_stop] # utr3 = seq[cds_stop:] # # # utr5 = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # # cds_h = self.process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # # cds_t = self.process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # # utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # # seq = utr5 + cds_h + cds_t + utr3 # # seq = seq[:region*2+1]+link*3+seq[-region*2-1:] # # utr5 = self.process_utr(utr5, utr5_limit, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # cds_h = self.process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # cds_t = self.process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # seq = utr5 + cds_h + cds_t + utr3 # seq = seq[:utr5_limit+region+1]+link*3+seq[-region*2-1:] # # # c1 = seq[cds_start:] # # c2 = seq[:cds_stop] # # # # utr5 = self.process_utr(utr5, utr5_limit, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # # cds_h = self.process_utr(c1, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # # cds_t = self.process_utr(c2, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # 这样会导致CDS和UTR混在一起,后面不太好mask猜AA # # utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # # seq = utr5 + cds_h + cds_t + utr3 # # seq = seq[:utr5_limit+region+1]+link*3+seq[-region*2-1:] # # # utr5 = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # # utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # # # # pre_processed = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # # behind_processed = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) # # # seq = pre_processed + cds_part + behind_processed # # # seq = seq[:region*2+1]+link*3+seq[-region*2-1:] # if isinstance(seq,list): # seq = np.array(seq) # return seq # # # # @staticmethod # def process_utr(utr, input_len, pad_method, pad_mark='_',bos='<',eos='>'): # if len(utr) < input_len: # if pad_method == 'pre': # padded_utr = pad_mark * (input_len - len(utr)) + bos + utr # elif pad_method == 'behind': # padded_utr = utr+eos + pad_mark * (input_len - len(utr)) # else: # if pad_method == 'pre': # padded_utr = bos+utr[-input_len:] # elif pad_method == 'behind': # padded_utr = utr[:input_len]+eos # return padded_utr # # self.process_utr = process_utr def generate_mRNA_fa(self,mrna_fa_path,sample_dict,force_regenerate=False): '''for pretrain''' if force_regenerate: '''generate mRNA.fa''' print('generate mRNA.fa to',mrna_fa_path) with open(mrna_fa_path, 'w') as f: # for tag, data in self.samples.items(): for tag, data in sample_dict.items(): for transcript_id, seq, cds_start, cds_stop, cds_len,mRNA_len,*_ in data: # seq = seq[1:-1] if '<' ==seq[0] else seq f.write(f">{transcript_id}|cds_start={cds_start}|cds_stop={cds_stop}|cds_len={cds_len}|mRNA_len={mRNA_len}|dataset={tag}\n{re.sub(r'[^ACGT]', 'N', seq.replace('U','T'))}\n") # print(f'>{transcript_id}|{cds_start}|{cds_stop}|{tag}\n{seq}\n') # # print(seq[cds_start:cds_stop]) @staticmethod def qualified_samples(data,seq_only=False,cds_min=-1): """ 过滤掉不合格的样本 :param df_total_counts: :return: """ """load elements""" strand = data['strand'].iloc[0] num_start = data[data.feature == 'start_codon'].shape[0] num_stop = data[data.feature == 'stop_codon'].shape[0] if num_start == 0 or num_stop == 0: # print(f"No start or stop codon for {data['transcript_id'].iloc[0]}") return None # 没有标记起始密码子或者终止密码子 data = data[(data.feature!='start_codon') & (data.feature!='stop_codon')] seq = ''.join(list(chain(*data['seq']))) anno = ''.join(list(chain(*data['anno']))) # - or | represent UTR or CDS if not seq_only: ribo_counts = list(chain(*data['ribo_counts'])) rna_counts = list(chain(*data['rna_counts'])) if sum(ribo_counts) == 0 or sum(rna_counts) == 0: # print(f"No reads for {data['transcript_id'].iloc[0]}") return None # ribosome_density = list(chain(*data['ribosome_density'])) if strand == '-': from pyfaidx import complement seq = complement(seq[::-1]) anno = anno[::-1] if not seq_only: ribo_counts = ribo_counts[::-1] rna_counts = rna_counts[::-1] # ribosome_density = ribosome_density[::-1] cds_start = anno.index('|') cds_stop = anno.rindex('|') + 4 if cds_min!=-1: if cds_stop - cds_start < cds_min: # print(f"CDS length is less than {cds_min} for {data['transcript_id'].iloc[0]}") return None # CDS长度太短 trible = anno.count('|') % 3 if trible != 0: return None if not seq_only: # CDS recovery, RNA-seq, Ribo-seq, CDS length, metadict = dict() counts = np.array([ribo_counts, rna_counts]) t = counts[:, cds_start:cds_stop] > 0 metadict['cds_len'] = cds_stop - cds_start metadict['ribo_recovery'], metadict['rna_recovery'] = t.sum(axis=1) / metadict['cds_len'] metadict['ribo_avg_count'], metadict['rna_avg_count'] = counts.sum(axis=1) / metadict['cds_len'] if seq_only: return seq, cds_start, cds_stop,None,None,anno,None # metadict['5utr_len'] = cds_start # metadict['3utr_len'] = len(anno) - cds_stop return seq,cds_start,cds_stop,ribo_counts,rna_counts,anno,metadict @staticmethod def generate_mRNA_tsv(genome_gtf_path,genome_fa_path,mrna_tsv_path): # returns GTF with essential columns such as "feature", "seqname", "start", "end" # alongside the names of any optional keys which appeared in the attribute column from gtfparse import read_gtf from pyfaidx import Fasta import polars as pl gtf = read_gtf(genome_gtf_path) # 先读这个文件,并简化这个文件,再读fasta,不然内存溢出 # gtf format described in 'https://www.gencodegenes.org/pages/data_format.html' features_to_keep = 'CDS,UTR,start_codon,stop_codon,five_prime_utr,three_prime_utr'.split(',') # start_codon 在CDS中, stop_codon 在UTR中 # "five_prime_utr", "three_prime_utr" 部分版本是这个 columns_to_keep = ['seqname','gene_id','transcript_id','protein_id','transcript_type','start', 'end', 'feature','strand'] gtf = gtf.filter(pl.col("feature").is_in(features_to_keep)) gtf = gtf.select(columns_to_keep) gtf = gtf.to_pandas() gtf = gtf.sort_values(by=['seqname', 'start','end']) # gtf.feature.unique() # ['UTR', 'start_codon', 'CDS', 'stop_codon'] # ['gene', 'transcript', 'exon', 'CDS', 'start_codon', 'stop_codon', 'UTR'] genome_fa = Fasta(genome_fa_path) gtf['seq'] = gtf.apply(lambda x: genome_fa[x['seqname']][x['start']-1:x['end']].seq, axis=1) gtf['anno'] = gtf.apply(lambda x: '-'* (x['end'] - x['start']+1) if x['feature'] in ['UTR','five_prime_utr','three_prime_utr'] else '|'*(x['end']-x['start']+1) , axis=1) # 所以start codon 和stop codon会被标记为 | gtf.to_csv(mrna_tsv_path,index=None,sep='\t') # # 比原来的文件缩小十倍,只保留了mRNA相关的内容,保留了mRNA序列,reverse和互补配对之前的序列 del genome_fa print(f"generate mRNA.tsv file: {mrna_tsv_path}\n{gtf.shape}\t{gtf[['seq','anno']].head()}") return gtf @staticmethod def calculate_ribosome_density(numReads_RPF, numReads_RNA, totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF): ''' 计算ribosome_density :param numReads_RNA: :param totalNumReads_RNA: :param totalNumReads_RPF: :param readsLength_RNA: :param readsLength_RPF: :return: example: # 示例值 numReads_RPF = 1000 numReads_RNA = 2000 totalNumReads_RNA = 5000000 totalNumReads_RPF = 3000000 readsLength_RNA = 150 readsLength_RPF = 100 result = calculate_ribosome_density(numReads_RPF, numReads_RNA, totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF) print("Ribosome Density:", result) ''' # Riboseq数据在处理时只保留了20-40nt长度的 reads readsLength_RPF = np.where(readsLength_RPF > 40, 30, readsLength_RPF) ratio_numReads = numReads_RPF / numReads_RNA ratio_totalNumReads = totalNumReads_RNA / totalNumReads_RPF ratio_readsLength = readsLength_RNA / readsLength_RPF ribosome_density = np.log2(ratio_numReads * ratio_totalNumReads * ratio_readsLength + 1) ribosome_density = np.where(numReads_RNA==-1, -1, ribosome_density) return ribosome_density def read_meta_file(self, file_path, ribo_experiment, rna_experiment, seq_only=False): df = pd.read_table(file_path) if seq_only: if ribo_experiment: species = df[df['ribo_experiment'] == ribo_experiment]['Ref'].iloc[0] elif rna_experiment: species = df[df['rna_experiment'] == ribo_experiment]['Ref'].iloc[0] else: raise ValueError("ribo_experiment or rna_experiment should be provided") return None,None,None,None,species row = df[(df['ribo_experiment'] == ribo_experiment) & (df['rna_experiment'] == rna_experiment)].iloc[0] totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF,species = row['totalNumReads_RNA'], row['totalNumReads_RPF'], row['readsLength_RNA'], row['readsLength_RPF'],row['Ref'] return totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA,readsLength_RPF,species class RiboBwDataPipeline(RiboDataPipeline): def __init__(self, data_path, ribo_experiment, rna_experiment, seq_only=False, limit=-1, ): super().__init__(data_path, ribo_experiment, rna_experiment, seq_only, limit) def load_data(self, path, ribo_experiment=None,rna_experiment=None,region=300, limit=-1,norm=True): ''' 读取数据 1. 根据ribo_experiment,从meta中查询species,avg_len,total counts 等 2. 查询mRNA.fa文件是否存在 不存在: 查询mRNA.tsv是否存在 不存在: genome.gtf 生成 mRNA.gtf文件 (只含有mRNA相关的行和列,scale the size of gtf) 根据genome.fa 和mRNA.gtf 生成mRNA.fa文件 (包括start or stop codon positions) 3. 读取track文件,生成ribosome_density, ribo_counts, rna_counts {'ENST00000303577.7': 'PCBP1', # IRES 'ENST00000309311.7': 'EEF2'} # cap dependent :param path: :param reference_path: :param region: :param limit: :return: samples path = ./dataset/pretraining/ ''' """1. input ribo_experiment, meta""" seq_only = self.seq_only cds_min = self.cds_min # print('load_data in Pipeline') reference_path = os.path.join(path,'reference') meta = self.read_meta_file(os.path.join(reference_path, 'experiment_meta.tsv'),ribo_experiment,rna_experiment,seq_only = seq_only) # todo totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA,readsLength_RPF,species = meta """2. check mRNA.fa""" # # 读取 chromosomes reference # print('load_data in Pipeline',meta) mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa') if region != -1: mrna_fa_path = mrna_fa_path.replace('.fa', f'_{region}.fa') mrna_fa_path = mrna_fa_path.replace('.fa', f'.pkl') self.mrna_region_pkl_path = mrna_fa_path if seq_only and os.access(mrna_fa_path, os.F_OK): with open(mrna_fa_path, 'rb') as f: sample_dict = pickle.load(f) if limit!=-1: limited_sample_dict = {} for key in sample_dict.keys(): limited_sample_dict[key] = {transcript_id:sample_dict[key][transcript_id] for transcript_id in list(sample_dict[key].keys())[:limit]} return limited_sample_dict mrna_tsv_path = os.path.join(reference_path,species, 'mRNA.tsv') if not os.access(mrna_tsv_path, os.F_OK): genome_gtf_path = os.path.join(reference_path,species, 'genome.gtf') genome_fa_path = os.path.join(reference_path,species, 'genome.fa') mrna_tsv = self.generate_mRNA_tsv(genome_gtf_path,genome_fa_path,mrna_tsv_path) else: mrna_tsv = pd.read_table(mrna_tsv_path)#.iloc[:100] # 读取已经生成的mRNA.tsv文件 # print('load_data in Pipeline',mrna_tsv.shape) # 1459048, 11 """3. read track files""" # if limit<10: # debug mode # debug_ids = 'ENST00000332831.5,ENST00000000233.10'.split(',') # mrna_tsv = mrna_tsv[mrna_tsv.transcript_id.isin(debug_ids)].reindex() # region = 6 print(f'filter limit={limit},region={region}') print('load_data in Pipeline,before filter',mrna_tsv.shape) # 1459048, 11 reference_transcript_ids = list(self.reference_transcript_dict.keys()) # if args.debug: # region = 6 # limit = 2000 if limit!=-1: mrna_tsv = mrna_tsv[mrna_tsv.transcript_id.isin(reference_transcript_ids+list(mrna_tsv.transcript_id.unique()[:limit]))] print('load_data in Pipeline, after filter',mrna_tsv.shape) # 1459048, 11 if not seq_only: import pyBigWig fribo_track = os.path.join(path,'track', f'{ribo_experiment}.bw') frna_track = os.path.join(path,'track', f'{rna_experiment}.bw') if os.access(fribo_track,os.F_OK) and os.access(frna_track,os.F_OK): print(f'load {ribo_experiment} and {rna_experiment} tracks') else: print(f'Error: {fribo_track} or {frna_track} not found.') return None ribo_bw,rna_bw = [pyBigWig.open(fribo_track), pyBigWig.open(frna_track)] print(f'meta of {ribo_experiment} and {rna_experiment} tracks loaded\n{ribo_bw.header()}\n{rna_bw.header()}') def iterfunc(x,bw): chrom, start, end = x['seqname'], x['start'], x['end'] if chrom in bw.chroms(): return np.array(bw.values(chrom, start - 1, end)) else: return np.zeros(end - start) # if x['seqname'] in bw or x['end']={ribo_bw.chroms(x['seqname'])},{x['end'], region=6 ribo_counts, rna_counts, ribosome_density, te, self.env, cds_len, mRNA_len,junction_counts) = ans ref_norm.append((sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA)) if len(ref_norm)==0 and norm: print(f'Error: no qualified reference transcript (housekeeping when norm=True)') return None ref_norm = np.mean(ref_norm,axis=0) if norm and len(ref_norm)>0 else None print('ref_norm',ref_norm,'sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA)') '''generate by norm''' for transcript_id, data in mrna_tsv.groupby('transcript_id'): tag = data['dataset'].iloc[0] ans = self.merge_transcript_level(data,total_counts_info=total_counts_info,seq_only=seq_only,cds_min=cds_min,region=region,ref_norm=ref_norm) if ans is None:continue sample_dict[tag].append([transcript_id] + ans) count += 1 if limit == count: break # datasplit = 'TR_VL_TS.tsv' # # ~/Data/RNAdesign/Raw_data/_0_reference/GRCh38.p14/mRNA/dataset_split $ head TR_VL_TS.tsv # # transcript_id dataset seqname gene_id self.samples = sample_dict if seq_only: mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa') if region !=-1: mrna_fa_path = mrna_fa_path.replace('.fa',f'_{region}.fa') if not os.access(mrna_fa_path,os.F_OK) or os.path.getsize(mrna_fa_path)==0: print(f'generate {sum([len(a.keys()) for a in sample_dict.values()])} sequences to {mrna_fa_path} {os.path.abspath(mrna_fa_path)}') self.generate_mRNA_fa(mrna_fa_path,force_regenerate=True) # ,force_regenerate=True mrna_fa_path = mrna_fa_path.replace('.fa',f'.pkl') if not os.access(mrna_fa_path, os.F_OK) or os.path.getsize(mrna_fa_path)==0: with open(mrna_fa_path, 'wb') as f: pickle.dump(sample_dict, f) self.mrna_region_pkl_path = mrna_fa_path return sample_dict # transcript_id 作为 groupby key class RegionDataset(BaseDataset): """DST""" def __init__( self, samples, tokenizer, args, region: int = 300, limit: int = -1, return_masked_tokens: bool = False, seed: int = 1, mask_prob: float = 0.15, leave_unmasked_prob: float = 0.1, random_token_prob: float = 0.1, freq_weighted_replacement: bool = False, two_dim_score: bool = False, two_dim_mask: int = -1, mask_whole_words: torch.Tensor = None, ): # 调用父类初始化 super().__init__( tokenizer=tokenizer, region=region, limit=limit, return_masked_tokens=return_masked_tokens, seed=seed, mask_prob=mask_prob, leave_unmasked_prob=leave_unmasked_prob, random_token_prob=random_token_prob, freq_weighted_replacement=freq_weighted_replacement, two_dim_score=two_dim_score, two_dim_mask=two_dim_mask, mask_whole_words=mask_whole_words, ) # 加载数据 self.args = args self.samples = samples if limit!=-1: self.samples = self.samples[:limit] self.teacher_tokenizer = DebertaTokenizerFast.from_pretrained("./src/mRNA2vec/tokenizer", use_fast=True) self.teacher_tokenizer.padding_side = "left" def __len__(self): return len(self.samples) def __getitem__(self, idx): _id, seq, cds_start, cds_stop,*_ = self.samples[idx] # if len( self.samples[idx]) == 7: # _id,seq, cds_start, cds_stop, cds_len, mRNA_len,junction_counts = self.samples[idx] # else: # _id,seq, cds_start, cds_stop, cds_len, mRNA_len = self.samples[idx] aa_seq = '-'+self.translate(re.sub(r'[^ACGT]', 'N', seq[1:-1].replace('U','T'))).lower()+'-' aa_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in aa_seq]),dtype=torch.long) # aa20 = torch.tensor(np.array([self.tokenizer.indices.get(aa.lower()) for aa in AA_str]),dtype=torch.long) # nt12 = torch.tensor(np.array([self.seq_to_rnaindex(nn) for aa in AA_str[:-1].upper() for nn in AA_TO_CODONS[aa]]),dtype=torch.long) # nt12 = defaultdict(list) # [nt12[self.tokenizer.indices.get(aa.lower())].append(self.seq_to_rnaindex(nn)[0]) for aa in AA_str[:-1].upper() for nn in AA_TO_CODONS[aa]] # # 准备1D和2D输入数据 X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) '''generate src_data, tgt_data, twod_data ''' src_data, tgt_data, twod_data, loss_mask = self.generate_mask(X, len(seq)) if "ernierna" in self.args.mlm_pretrained_model_path or 'teacher' in self.args.mlm_pretrained_model_path: teacher_input_ids = src_data elif "mrna2vec" in self.args.mlm_pretrained_model_path: teacher_encoder = self.teacher_tokenizer(seq[1:-1], padding='max_length', max_length=403, truncation=True, add_special_tokens=True, return_tensors="pt", ) teacher_input_ids = teacher_encoder['input_ids'].squeeze(0) # src_data = torch.where(torch.from_numpy(loss_mask),aa_idx,src_data) return (src_data,teacher_input_ids, tgt_data, twod_data,aa_idx, loss_mask) @staticmethod def seq_to_rnaindex(seq,pad_idx=1, unk_idx=3): seq = seq.upper() if seq.count('<') > 1 or seq.count('/') > 0: seq = seq.replace('', '_').replace('', 'V').replace('', '^').replace('/','NNN') l = len(seq) X = np.ones((1, l)) for j in range(l): if seq[j] in set('Aa'): X[0, j] = 5 elif seq[j] in set('UuTt'): X[0, j] = 6 elif seq[j] in set('Cc'): X[0, j] = 7 elif seq[j] in set('Gg'): X[0, j] = 4 elif seq[j] in set('_'): X[0,j] = pad_idx elif seq[j] in set('^'): X[0,j] = 2 # eos else: X[0,j] = unk_idx # linker return X '''generate''' class BackBoneDataset(RegionDataset): '''for distillation using ribo dataset''' def __init__( self, samples, tokenizer, args, region: int = 300, limit: int = -1, return_masked_tokens: bool = False, seed: int = 1, mask_prob: float = 0.15, leave_unmasked_prob: float = 0.1, random_token_prob: float = 0.1, freq_weighted_replacement: bool = False, two_dim_score: bool = False, two_dim_mask: int = -1, mask_whole_words: torch.Tensor = None, input_mask = True,Kozak_GS6H_Stop3='GCCACC,GGGAGCCACCACCACCATCACCAC,TGATAATAG' ): # 调用父类初始化 super().__init__( samples=samples, tokenizer=tokenizer, args=args, region=region, limit=limit, return_masked_tokens=return_masked_tokens, seed=seed, mask_prob=mask_prob, leave_unmasked_prob=leave_unmasked_prob, random_token_prob=random_token_prob, freq_weighted_replacement=freq_weighted_replacement, two_dim_score=two_dim_score, two_dim_mask=two_dim_mask, mask_whole_words=mask_whole_words, ) self.input_mask = input_mask self.Kozak_GS6H_Stop3 = Kozak_GS6H_Stop3.upper() def __getitem__(self, idx): # (_id,seq,cds_start, cds_stop, # ribo_counts,rna_counts, # ribosome_density,te,env,*_) = self.samples[idx] #cds_len,mRNA_len,junction_counts data = self.samples.iloc[idx] _id = data['_id'] seq = data['sequence'] seq = seq.replace('U','T') # for translate start,stop = self.region + 1, self.region * 3 + 4 # ATG seq[301:301+3], TGG seq[901:901+3] Kozak, GS6H, Stop3 = self.Kozak_GS6H_Stop3.split(',') if ',' in self.Kozak_GS6H_Stop3 else '','','' # Kozak,GS6H,Stop3 = 'GCCACC,GGGAGCCACCACCACCATCACCAC,TGATAATAG'.split(',') '''fix nt, not opt''' seq = seq[:start-len(Kozak)].replace('ATG','ATC') + Kozak + seq[start:stop-len(GS6H)-len(Stop3)] + GS6H+ Stop3 +seq[stop:] # Kozak = 'GCCACC' # GCCACCATGGCG # seq = seq[:start-6].replace('ATG','ATC') + Kozak + seq[start:stop-3] +'TAATAATAA'+seq[stop:-6] # https://www.nature.com/articles/s41467-024-48387-x#Fig1 # seq = seq[:start-6].replace('ATG','ATC') + Kozak + seq[start:stop-3] + 'TGATAATAG' +seq[stop:-6] # # mus frequent:TGATAATAG;3 stop codon, GS6H # seq = seq[:start-6].replace('ATG','ATC') + Kozak + seq[start:stop-3] + 'TGATAATAG' +seq[stop+6:] # # mus frequent:TGATAATAG;3 stop codon, GS6H '''whole mask''' aa_seq = '-'+self.translate(re.sub(r'[^ACGTU]', 'N', seq[1:-1])).lower()+'-' aa_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in aa_seq]),dtype=torch.long) # 准备1D和2D输入数据 # 对src,tgt,twod_data,loss_mask进行 mask 处理 X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) _,_,_,mask = self.generate_mask(X,len(seq)) # [1205] # (1, 1205) seq_length = len(seq) vocab_size = len(self.tokenizer) # mask = torch.full_like(logits, float("-inf")) masked_logits = torch.full(size=(seq_length,vocab_size),fill_value=float("-inf")) # [1205, 32] masked_logits[np.arange(X.shape[1]),X.reshape(-1)]=0 # target 的位置不mask '''CDS mask''' X_CDS,masked_logits_CDS = self.CDS_mask(seq, start, stop- len(GS6H) - len(Stop3)) # (603,),torch.Size([1, 603, 32]) mask[start:stop- len(GS6H) - len(Stop3)] = X_CDS==self.tokenizer.mask_index special = self.seq_to_rnaindex('ACGT').reshape(-1) masked_logits[:start-len(Kozak),special]=0 # UTR5 masked_logits[stop:,special]=0 # UTR3 masked_logits[start:stop- len(GS6H) - len(Stop3)] = masked_logits_CDS for token in ['', '', '', '']: masked_logits[X.reshape(-1) == self.tokenizer.indices.get(token), :] = float("-inf") masked_logits[X.reshape(-1)==self.tokenizer.indices.get(token),self.tokenizer.indices.get(token)] = 0 # masked_logits = masked_logits.unsqueeze(0) mask[start-len(Kozak):start+3] = False mask[stop-len(GS6H)-len(Stop3):stop] = False src_data, tgt_data, twod_data, loss_mask = self.generate_mask(X, len(seq),mask=mask,input_mask=self.input_mask) loss_mask = torch.tensor(mask, dtype=torch.bool)#.unsqueeze(0) src_data = torch.where(loss_mask,aa_idx,src_data) src_env = torch.tensor(self.args.env_id, dtype=torch.long) # test = torch.concat([torch.from_numpy(exp_one_d).float(),src_exp_data,loss_mask.unsqueeze(1).float(),src_exp_mask,tgt_exp_data],dim=1).detach().cpu().numpy() return (_id,src_data,tgt_data,twod_data,loss_mask,masked_logits,src_env) # tgt_exp_data,tgt_data 传送被mask的部分会更节约资源 def CDS_mask(self,seq,start,stop): # target_positions = [self.region + 1,self.region *3 + 4] # ATG seq[301:301+3], TGG seq[901:901+3] # backbone_cds = re.sub(r'[^ACGT]', 'N', seq[target_positions[0]:target_positions[1]]) backbone_cds = re.sub(r'[^ACGT]', 'N', seq[start:stop]) # 目标氨基酸序列 # target_protein = ['M', 'A', 'L'] target_protein = self.translate(backbone_cds,repeate=1).upper() target_protein_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in target_protein.lower()]),dtype=torch.long) X = self.seq_to_rnaindex(backbone_cds, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index).reshape(-1) num_rows, num_cols = len(target_protein),3 cds_mask = np.zeros([num_rows, num_cols],dtype=int) # 计算要掩码的行数(20%) rows_to_mask = int(num_rows * self.mask_prob *2) # 随机选择要掩码的行 masked_rows = random.sample(range(num_rows), rows_to_mask) # 为每个掩码行随机选择一列 masked_cols = np.random.randint(0, num_cols, size=rows_to_mask) # 使用 NumPy 的向量化操作更新数据框 cds_mask[masked_rows, masked_cols] = 1 cds_mask = cds_mask.reshape(-1) X[cds_mask==1]=self.tokenizer.mask_index # mask # 假设的 logits 输出,形状为 (batch_size, seq_length, vocab_size) # 这里假设 batch_size=1,seq_length=9(即 3 个密码子),vocab_size=4(A, U, C, G) # logits = torch.zeros(len(backbone_cds), len(self.tokenizer)) # 创建掩码 # backbone_cds = 'AT_G_C_TC' # base_map = {0: 'A', 1: 'T', 2: 'C', 3: 'G'} # reverse_dictionary(base_map) masked_logits = self.create_codon_mask(target_protein_idx.numpy(), X, self.amino_acid_to_codons,self.tokenizer) # joint_mask = create_codon_mask(logits, target_protein,backbone_cds, AA_TO_CODONS) # 应用掩码 # masked_logits = mask + logits return X,masked_logits class RiboDataset(RegionDataset): '''for distillation using ribo dataset''' def __init__( self, samples, tokenizer, args, region: int = 300, limit: int = -1, return_masked_tokens: bool = False, seed: int = 1, mask_prob: float = 0.15, leave_unmasked_prob: float = 0.1, random_token_prob: float = 0.1, freq_weighted_replacement: bool = False, two_dim_score: bool = False, two_dim_mask: int = -1, mask_whole_words: torch.Tensor = None, ): # 调用父类初始化 super().__init__( samples=samples, tokenizer=tokenizer, args=args, region=region, limit=limit, return_masked_tokens=return_masked_tokens, seed=seed, mask_prob=mask_prob, leave_unmasked_prob=leave_unmasked_prob, random_token_prob=random_token_prob, freq_weighted_replacement=freq_weighted_replacement, two_dim_score=two_dim_score, two_dim_mask=two_dim_mask, mask_whole_words=mask_whole_words, ) def __getitem__(self, idx): # (_id,seq,cds_start, cds_stop, # ribo_counts,rna_counts, # ribosome_density,te,env,*_) = self.samples[idx] #cds_len,mRNA_len,junction_counts (_id,seq,cds_start, cds_stop, ribo_counts,rna_counts, ribosome_density,te,env,cds_len,mRNA_len,junction_counts) = self.samples[idx] #cds_len,mRNA_len,junction_counts aa_seq = '-'+self.translate(re.sub(r'[^ACGT]', 'N', seq[1:-1])).lower()+'-' aa_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in aa_seq]),dtype=torch.long) # 准备1D和2D输入数据 # 对src,tgt,twod_data,loss_mask进行 mask 处理 X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) src_data,tgt_data,twod_data,mask = self.generate_mask(X,len(seq)) loss_mask = torch.tensor(mask, dtype=torch.bool)#.unsqueeze(0) src_data = torch.where(loss_mask,aa_idx,src_data) window = 31 exp_one_d = np.stack([ribo_counts,rna_counts,ribosome_density],axis=1) tgt_exp_data = torch.from_numpy(exp_one_d).float() tgt_exp_data = tgt_exp_data.permute(1,0) tgt_exp_data = F.avg_pool1d(tgt_exp_data,kernel_size=window,padding=window//2,stride=1) tgt_exp_data = tgt_exp_data.permute(1,0) tgt_exp_data[~loss_mask,:] = -1 # [1205, 3] src_exp_data = torch.from_numpy(exp_one_d).float() src_exp_mask = F.max_pool1d(loss_mask.unsqueeze(0).repeat(3,1).float(),kernel_size=window,padding=window//2,stride=1).permute(1,0) # 形状变为 (1, L).permute(1,0) src_exp_data = torch.where(src_exp_mask.bool(),torch.zeros_like(src_exp_mask),src_exp_data) # src_exp_data = torch.zeros_like(tgt_exp_data) # src_exp_data = [] # zero or ones are tried #src_exp_data = None TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found src_data = torch.where(loss_mask,aa_idx,src_data) src_env = torch.tensor(env, dtype=torch.long) src_feature = np.array([cds_len,mRNA_len,junction_counts]) src_feature = torch.from_numpy(src_feature).float() # .float() ==torch.float32 src_feature = torch.log(src_feature+1) # 取对数 tgt_te = torch.tensor(te, dtype=torch.float32) # test = torch.concat([torch.from_numpy(exp_one_d).float(),src_exp_data,loss_mask.unsqueeze(1).float(),src_exp_mask,tgt_exp_data],dim=1).detach().cpu().numpy() return src_data,src_exp_data,src_env,src_feature,tgt_data,tgt_exp_data,tgt_te,twod_data,loss_mask # tgt_exp_data,tgt_data 传送被mask的部分会更节约资源 class DownstreamDataset(RegionDataset): def __init__( self, samples, tokenizer, args, region: int = 300, limit: int = -1, return_masked_tokens: bool = False, seed: int = 1, mask_prob: float = 0.15, leave_unmasked_prob: float = 0.1, random_token_prob: float = 0.1, freq_weighted_replacement: bool = False, two_dim_score: bool = False, two_dim_mask: int = -1, mask_whole_words: torch.Tensor = None, seq_len: int = 174, pad_method: str = "pre", column: str = "sequence", cds_len:str='cds_len', mRNA_len:str='mRNA_len', label: str = "IRES_Activity", ): # 调用父类初始化 super().__init__( samples=samples, tokenizer=tokenizer, args=args, region=region, limit=limit, return_masked_tokens=return_masked_tokens, seed=seed, mask_prob=mask_prob, leave_unmasked_prob=leave_unmasked_prob, random_token_prob=random_token_prob, freq_weighted_replacement=freq_weighted_replacement, two_dim_score=two_dim_score, two_dim_mask=two_dim_mask, mask_whole_words=mask_whole_words, ) # 特有属性 self.label = label self.column = column self.seq_len = seq_len self.cds_len = cds_len self.mRNA_len = mRNA_len self.pad_method = pad_method if limit!=-1: self.samples = self.samples.iloc[:limit] ''' eGFP https://www.ncbi.nlm.nih.gov/nuccore/L29345.1 >L29345.1 Aequorea victoria green-fluorescent protein (GFP) mRNA, complete cds| 26..742 TACACACGAATAAAAGATAACAAAGATGAGTAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGCGATGTTAATGGGCAAAAATTCTCTGTCAGTGGAGAGGGTGAAGGTGATGCAACATACGGAAAACTTACCCTTAAATTTATTTGCACTACTGGGAAGCTACCTGTTCCATGGCCAACACTTGTCACTACTTTCTCTTATGGTGTTCAATGCTTTTCAAGATACCCAGATCATATGAAACAGCATGACTTTTTCAAGAGTGCCATGCCCGAAGGTTATGTACAGGAAAGAACTATATTTTACAAAGATGACGGGAACTACAAGACACGTGCTGAAGTCAAGTTTGAAGGTGATACCCTTGTTAATAGAATCGAGTTAAAAGGTATTGATTTTAAAGAAGATGGAAACATTCTTGGACACAAAATGGAATACAACTATAACTCACATAATGTATACATCATGGCAGACAAACCAAAGAATGGAATCAAAGTTAACTTCAAAATTAGACACAACATTAAAGATGGAAGCGTTCAATTAGCAGACCATTATCAACAAAATACTCCAATTGGCGATGGCCCTGTCCTTTTACCAGACAACCATTACCTGTCCACACAATCTGCCCTTTCCAAAGATCCCAACGAAAAGAGAGATCACATGATCCTTCTTGAGTTTGTAACAGCTGCTGGGATTACACATGGCATGGATGAACTATACAAATAAATGTCCAGACTTCCAATTGACACTAAAGTGTCCGAACAATTACTAAATTCTCAGGGTTCCTGGTTAAATTCAGGCTGAGACTTTATTTATATATTTATAGATTCATTAAAATTTTATGAATAATTTATTGATGTTATTAATAGGGGCTATTTTCTTATTAAATAGGCTACTGGAGTGTAT ''' def __len__(self): return len(self.samples) def __getitem__(self, idx): data = self.samples.iloc[idx] seq = data[self.column] target = data[self.label] X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) one_d, twod_d = self.prepare_input_for_ernierna(X, len(seq)) one_d = one_d.view(-1) # 转换为PyTorch张量 if not torch.is_tensor(one_d): src_data = torch.from_numpy(one_d) # 假设one_d是你想要的1D输入特征 else: src_data = one_d if not torch.is_tensor(twod_d): twod_data = torch.from_numpy(twod_d.squeeze(dim=-1)) # 二维输入特征 else: twod_data = twod_d.squeeze(dim=-1) src_env = torch.tensor(self.args.env_id, dtype=torch.long) cds_len = data[self.cds_len] if self.cds_len in data else 742-26+1 mRNA_len = data[self.mRNA_len] if self.mRNA_len in data else 922 # cds_len = 742-26+1 # mRNA_len = 922 junction_counts= 0 src_feature = np.array([cds_len,mRNA_len,junction_counts]) src_feature = torch.from_numpy(src_feature).float() # .float() ==torch.float32 # src_feature = torch.log(src_feature+1) # 取对数 # 获取回归任务的目标值 target = torch.tensor(target, dtype=torch.float32) # 假设每个样本都有一个'target'字段表示其回归目标 return src_data, twod_data,src_env, src_feature, target # 返回所有必要的输入和目标值 class RegressionDataset(BaseDataset): """处理回归任务的Dataset""" def __init__( self, path, tokenizer, args, region: int = 300, limit: int = -1, return_masked_tokens: bool = False, seed: int = 1, mask_prob: float = 0.15, leave_unmasked_prob: float = 0.1, random_token_prob: float = 0.1, freq_weighted_replacement: bool = False, two_dim_score: bool = False, two_dim_mask: int = -1, mask_whole_words: torch.Tensor = None, seq_len: int = 174, pad_method: str = "pre", column: str = "sequence", label: str = "IRES_Activity", returnid=None ): # 调用父类初始化 super().__init__( tokenizer=tokenizer, region=region, limit=limit, return_masked_tokens=return_masked_tokens, seed=seed, mask_prob=mask_prob, leave_unmasked_prob=leave_unmasked_prob, random_token_prob=random_token_prob, freq_weighted_replacement=freq_weighted_replacement, two_dim_score=two_dim_score, two_dim_mask=two_dim_mask, mask_whole_words=mask_whole_words, ) # 特有属性 self.label = label self.column = column self.seq_len = seq_len self.pad_method = pad_method self.args = args self.returnid = returnid # 加载数据 self.samples = self.load_data( path, seq_len=seq_len, column=column, pad_method=pad_method ) if limit!=-1: self.samples = self.samples.iloc[:limit] def load_data(self, path, **kwargs): return self.read_csv_file( path, seq_len=kwargs['seq_len'], column=kwargs['column'], pad_method=kwargs['pad_method'] ) def read_csv_file(self,file_path, **kwargs): # 保持原有CSV读取逻辑 try: column = kwargs['column'] data = pd.read_csv(file_path) if column not in data.columns: data[column] = data.apply(self.generate_inputs, axis=1) # 预处理数据 return pad_or_truncate_utr( data, pad_method=kwargs['pad_method'], column=kwargs['column'], input_len=kwargs['seq_len'] ) except FileNotFoundError: print(f"Error: File '{file_path}' not found.") return [] def __len__(self): return len(self.samples) def __getitem__(self,idx): data = self.samples.iloc[idx] seq = data[self.column] target = data[self.label] # X, data_seq = self.seq_to_rnaindex_and_onehot(seq) X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) # 准备1D和2D输入数据 one_d, twod_d = self.prepare_input_for_ernierna(X, len(seq)) one_d = one_d.view(-1) # 转换为PyTorch张量 if not torch.is_tensor(one_d): src_data = torch.from_numpy(one_d) # 假设one_d是你想要的1D输入特征 else: src_data = one_d if not torch.is_tensor(twod_d): twod_data = torch.from_numpy(twod_d.squeeze(dim=-1)) # 二维输入特征 else: twod_data = twod_d.squeeze(dim=-1) src_env = torch.tensor(self.args.env_id, dtype=torch.long) cds_len = 742-26+1 mRNA_len = 922 junction_counts= 0 src_feature = np.array([cds_len,mRNA_len,junction_counts]) src_feature = torch.from_numpy(src_feature).float() # .float() ==torch.float32 src_feature = torch.log(src_feature+1) # 取对数 # 获取回归任务的目标值 target = torch.tensor(target, dtype=torch.float32) # 假设每个样本都有一个'target'字段表示其回归目标 if self.returnid is None:return src_data, twod_data,src_env, src_feature, target # 返回所有必要的输入和目标值 else: return data[self.returnid],src_data, twod_data,src_env, src_feature, target class MaotaoDataset(BaseDataset): """处理回归任务的Dataset""" def __init__( self, path, tokenizer, args, region: int = 300, limit: int = -1, return_masked_tokens: bool = False, seed: int = 1, mask_prob: float = 0.15, leave_unmasked_prob: float = 0.1, random_token_prob: float = 0.1, freq_weighted_replacement: bool = False, two_dim_score: bool = False, two_dim_mask: int = -1, mask_whole_words: torch.Tensor = None, seq_len: int = 1200, column: str = 'off_start,off_end,full_len,type,_id,species,maotao_id,truncated_aa,cai_best_nn', label: str = "truncated_nn,cai_nature", codon_table_path: str='maotao_file/codon_table/codon_usage_{species}.csv', species_list:str="""mouse,Ec,Sac,Pic,Human""", type_list:str="""full,head,tail,boundary,middle""", # protein_alphabet_list:str="""_ACDEFGHIKLMNPQRSTVWY*""", # padding 被写死为1了 # # 10-31 rna_alphabet_list:str="""GAUC""",# 用网络自带的编码 returnid = None ): # 调用父类初始化 super().__init__( tokenizer=tokenizer, region=region, limit=limit, return_masked_tokens=return_masked_tokens, seed=seed, mask_prob=mask_prob, leave_unmasked_prob=leave_unmasked_prob, random_token_prob=random_token_prob, freq_weighted_replacement=freq_weighted_replacement, two_dim_score=two_dim_score, two_dim_mask=two_dim_mask, mask_whole_words=mask_whole_words, ) # 特有属性 self.species = {k:v for v,k in enumerate(species_list.split(','))} self.species.update({v:v for v,k in enumerate(species_list.split(','))}) self.seq_types = {k:v for v,k in enumerate(type_list.split(','))} self.seq_types.update({v:v for v,k in enumerate(type_list.split(','))}) # self.protein_alphabet = {k:v for v,k in enumerate(protein_alphabet_list)} self.rna_alphabet = {k:v+4 for v,k in enumerate(rna_alphabet_list)} self.label = label.split(',') self.column = column.split(',') self.seq_len = seq_len self.args = args # 加载数据 self.samples = self.load_data(path) # 加载codontable self.codon_instance_rna = {self.species[species]: Codon(codon_table_path.format(species=species), rna=True) for species in species_list.split(',')} if limit!=-1: self.samples = self.samples.iloc[:limit] def load_data(self, path, **kwargs): if os.access(path.replace('.csv','_processed.pickle'), os.R_OK): df = pd.read_pickle(path.replace('.csv','_processed.pickle')) else: df = pd.read_csv(path) df['truncated_aa'] = df['truncated_aa'].apply(lambda x: re.sub(r'[^acdefghiklmnpqrstvwy*_]', '_', x.lower())) df['cai_best_nn'] = df['cai_best_nn'].apply(lambda x: x.upper().replace('T','U')) df['species'] = df['species'].apply(lambda x: self.species[x]) df['type'] = df['type'].apply(lambda x: self.seq_types[x]) df.to_csv(path.replace('.csv','_processed.csv'),index=False) with open(path.replace('.csv','_processed.pickle'), 'wb') as f: pickle.dump(df,f) return df def __len__(self): return len(self.samples) def __getitem__(self,idx): data = self.samples.iloc[idx] maotao_id = data['maotao_id'] aa_index = np.array([self.tokenizer.index(x) for x in data['truncated_aa']]) # input idx aa_idx = torch.from_numpy(aa_index).long() seq = data['cai_best_nn'] '''prepare 1D and 2D input data''' # X, data_seq = self.seq_to_rnaindex_and_onehot(seq) X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) # 准备1D和2D输入数据 one_d, twod_d = self.prepare_input_for_ernierna(X, len(seq)) one_d = one_d.view(-1) # 转换为PyTorch张量 if not torch.is_tensor(one_d): src_data = torch.from_numpy(one_d) # 假设one_d是你想要的1D输入特征 else: src_data = one_d if not torch.is_tensor(twod_d): twod_data = torch.from_numpy(twod_d.squeeze(dim=-1)) # 二维输入特征 else: twod_data = twod_d.squeeze(dim=-1) continuous_features = np.array([data['off_start'],data['off_end'],data['full_len']]) continuous_features = np.log(np.maximum(continuous_features+3,0)+1) continuous_features = torch.from_numpy(continuous_features).float() # .float() ==torch.float32 # continuous_features = torch.log(torch.max(torch.tensor(continuous_features+3),torch.tensor(0))+1) # 取对数 species_features = torch.tensor(data['species'],dtype=torch.long) truncated_features = torch.tensor(data['type'],dtype=torch.long) ith_nn_prob = self.codon_instance_rna[data['species']].frame_ith_aa_base_fraction nn_prob = self.create_base_prob(data['truncated_aa'],ith_nn_prob,self.rna_alphabet,self.tokenizer) '''output''' if 'truncated_nn' in data: target_nn = self.seq_to_rnaindex(data['truncated_nn'], pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index).reshape(-1) # 获取回归任务的目标值 target = torch.tensor(data['cai_nature'], dtype=torch.float32) # 假设每个样本都有一个'target'字段表示其回归目标 else: target_nn = self.seq_to_rnaindex(data['cai_best_nn'], pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index).reshape(-1) target = torch.tensor(0, dtype=torch.float32) # 假设每个样本都有一个'target'字段表示其回归目标 target_nn = torch.from_numpy(target_nn).long() frames = [1, 2, 3] backbone_cds_list = self.modify_codon_by_frames(target_nn, frames=frames, masked_token=self.tokenizer.mask_index) # backbone_cds_list = self.modify_codon_by_frames(src_data, frames = frames,masked_token=self.tokenizer.mask_index) masked_logits_list = [] for backbone_cds, frame in zip(backbone_cds_list, frames): masked_logits = self.create_codon_mask(aa_idx, backbone_cds, self.amino_acid_to_codons, self.tokenizer) masked_logits_list.append(masked_logits.unsqueeze(0)) # 'UUCACCCAGGCCACGCGGAGUACGAUCGAGUGUACAGUGAA' # test = masked_logits.numpy() masked_logits_list = torch.cat(masked_logits_list, dim=0) return src_data, twod_data, aa_idx,continuous_features, species_features, truncated_features, target_nn, target,masked_logits_list[...,:10],nn_prob[...,:10], maotao_id # [(a.shape, a.dtype) for a in [src_data, twod_data, aa_idx,continuous_features, species_features, truncated_features, target_nn, target]] # Out[10]: # [(torch.Size([1200]), torch.int64), # (torch.Size([1, 1200, 1200]), torch.float64), # (torch.Size([400]), torch.int64), # (torch.Size([3]), torch.float32), # (torch.Size([]), torch.int64), # (torch.Size([]), torch.int64), # (torch.Size([1200]), torch.int64), # (torch.Size([]), torch.float32)] @staticmethod def modify_codon_by_frames(sequence, frames=[1, 2, 3], masked_token='_'): """ 高级版本:支持自定义修改函数 参数: sequence (str): 输入序列 frame (int): 要修改的密码子位置 (1, 2, 3) modify_func (callable): 修改函数,接收原帧字符串,返回修改后的字符串 返回: str: 修改后的重建序列 """ # 清理序列 # seq = sequence.upper().replace(' ', '').replace('\n', '') seq = sequence # seq = seq[:len(seq) - len(seq) % 3] # 使用切片提取帧 frames_seq = [seq[0::3], seq[1::3], seq[2::3]] reconstructed_list = [] # 应用修改函数 for ith,frame in enumerate(frames_seq): if ith+1 in frames: tmp_seq = deepcopy(frames_seq) tmp_seq[ith] = [masked_token] * len(frames_seq[ith]) # 重建序列 # reconstructed = None if isinstance(seq,str): reconstructed = ''.join( tmp_seq[0][i] + tmp_seq[1][i] + tmp_seq[2][i] for i in range(len(tmp_seq[0])) ) elif isinstance(seq,torch.Tensor): tmp_seq[ith] = torch.from_numpy(np.array(tmp_seq[ith])) reconstructed = torch.stack(tmp_seq, dim=1).reshape(-1) elif isinstance(seq,np.ndarray): tmp_seq[ith] = np.array(tmp_seq[ith]) reconstructed = np.stack(tmp_seq, axis=1).reshape(-1) else: raise ValueError(type(seq)) # reconstructed = torch.cat([tmp_seq[0][i],tmp_seq[1][i],tmp_seq[2][i] # for i in range(len(tmp_seq[0]))]) # reconstructed = [tmp_seq[0][i] + tmp_seq[1][i] + tmp_seq[2][i] # for i in range(len(tmp_seq[0]))] reconstructed_list.append(deepcopy(reconstructed)) return reconstructed_list def gaussian(x): return math.exp(-0.5*(x*x)) def paired(x,y,lamda=0.8): if x == 5 and y == 6: return 2 elif x == 4 and y == 7: return 3 elif x == 4 and y == 6: return lamda elif x == 6 and y == 5: return 2 elif x == 7 and y == 4: return 3 elif x == 6 and y == 4: return lamda else: return 0 def pad_or_truncate_utr(data, input_len, pad_method,column='utr',pad_mark='_'): def process_utr(utr): if len(utr) < input_len: if pad_method == 'pre': padded_utr = pad_mark * (input_len - len(utr)) + utr elif pad_method == 'behind': padded_utr = utr + pad_mark * (input_len - len(utr)) else: padded_utr = utr[-input_len:] return padded_utr data[column] = data[column].apply(process_utr) return data # def do_createmat(data, base_range=30, lamda=0.8): # paird_map = np.array([[paired(i, j, lamda) for i in range(30)] for j in range(30)]) # token # data_index = np.arange(0, len(data)) # # np.indices((2,2))    # coefficient = np.zeros([len(data), len(data)]) # # mat = np.zeros((len(data),len(data))) # score_mask = np.full((len(data), len(data)), True) # for add in range(base_range): # data_index_x = data_index - add # data_index_y = data_index + add # score_mask = ((data_index_x >= 0)[:, None] & (data_index_y < len(data))[None, :]) & score_mask # data_index_x, data_index_y = np.meshgrid(data_index_x.clip(0, len(data) - 1), # data_index_y.clip(0, len(data) - 1), indexing='ij') # score = paird_map[data[data_index_x], data[data_index_y]] # score_mask = score_mask & (score != 0) # # coefficient = coefficient + score * score_mask * gaussian(add) # if ~(score_mask.any()): # break # score_mask = coefficient > 0 # for add in range(1, base_range): # data_index_x = data_index + add # data_index_y = data_index - add # score_mask = ((data_index_x < len(data))[:, None] & (data_index_y >= 0)[None, :]) & score_mask # data_index_x, data_index_y = np.meshgrid(data_index_x.clip(0, len(data) - 1), # data_index_y.clip(0, len(data) - 1), indexing='ij') # score = paird_map[data[data_index_x], data[data_index_y]] # score_mask = score_mask & (score != 0) # coefficient = coefficient + score * score_mask * gaussian(add) # if ~(score_mask.any()): # break # return coefficient def do_createmat(data, base_range=30, lamda=0.8): paird_map = np.array([[paired(i, j, lamda) for i in range(30)] for j in range(30)]) # token data_index = np.arange(0, len(data)) # np.indices((2,2))    coefficient = np.zeros([len(data), len(data)]) # mat = np.zeros((len(data),len(data))) score_mask = np.full((len(data), len(data)), True) for add in [0,300]: data_index_x = data_index - add data_index_y = data_index + add score_mask = ((data_index_x >= 0)[:, None] & (data_index_y < len(data))[None, :]) & score_mask data_index_x, data_index_y = np.meshgrid(data_index_x.clip(0, len(data) - 1), data_index_y.clip(0, len(data) - 1), indexing='ij') score = paird_map[data[data_index_x], data[data_index_y]] score_mask = score_mask & (score != 0) coefficient = coefficient + score * score_mask * gaussian(add) if ~(score_mask.any()): break return coefficient def creatmat(data, base_range=30, lamda=0.8): return do_createmat(data, base_range=base_range, lamda=lamda) # if len(data.shape)==1:return do_createmat(data,base_range=base_range,lamda =lamda) # else: # coefficient = np.zeros((data.shape[0],data.shape[1],data.shape[1])) # for i in range(data.shape[0]): # coefficient[i,:,:] = do_createmat(data[i:i+1,:], base_range=base_range, lamda=lamda) # return coefficient import argparse if __name__ == '__main__': print('start generating') # # 获取 pretraining 和 dataset 的 args # from model.tools import get_dataset_args, get_pretraining_args # pretraining_parser = get_pretraining_args() # dataset_parser = get_dataset_args() # # # 合并 args # parser = argparse.ArgumentParser(parents=[pretraining_parser, dataset_parser], add_help=False,conflict_handler='resolve') # # dataset_parser = get_dataset_args() # ## 合并 args # # parser = argparse.ArgumentParser(parents=[ dataset_parser], add_help=False, # # conflict_handler='resolve') # # args = parser.parse_args() # args.batch_size = 5 # # args.ffasta = '/public/home/jiang_jiuhong/Data/RNAdesign/Raw_data/_0_reference/GRCh38.p14/mRNA/full.fa' # # ans = RNADataset.read_fasta_file(args.ffasta) # # args.device = 'cpu' # vocab_path = args.arg_overrides['data'] + '/dict.txt' # tokenizer = Dictionary.load(vocab_path) # tokenizer.mask_index = tokenizer.add_symbol('') # # train_ds = RNADataset(args.ffasta,max_length=args.region * 2,tokenizer=tokenizer) # # train_loader = DataLoader( # # train_ds, # # batch_size=args.batch_size, # # pin_memory=True, # # drop_last=False, # # shuffle=False, # # num_workers=args.num_workers, # # sampler=None # # ) # # # # for step, (src_data,tgt_data,twod_data,loss_mask) in enumerate(train_loader): # # print(step, [a.shape for a in [src_data, tgt_data, twod_data, # # loss_mask]]) # [torch.Size([1, 1203]), torch.Size([1, 1203]), torch.Size([1, 1203, 1203, 1])] # # a = [a.numpy()[0] for a in [src_data, tgt_data, twod_data, # # loss_mask]] # # # # train_ds = RegressionDataset(args.downstream_data_path+'/IRES_linear/TS.csv', tokenizer, seq_len=args.seq_len, column=args.column, label=args.label, # # pad_method=args.pad_method) # # train_loader = DataLoader( # # train_ds, # # batch_size=args.batch_size, # # pin_memory=True, # # drop_last=False, # # shuffle=False, # # ) # # for step, data in enumerate(train_loader): # # print(step, [a.shape for a in data]) # [torch.Size([1, 1203]), torch.Size([1, 1203]), torch.Size([1, 1203, 1203, 1])] # # '''RiboDataPipeline''' # ribo_experiment, rna_experiment = 'SRX5164421','SRX5164417' #TAIR10# # ribo_experiment, rna_experiment = 'SRX12763793','SRX12763783' # human # ribo_experiment, rna_experiment = 'SRX9444526','SRX9444530' # mouse # print(os.path.abspath(args.exp_pretrain_data_path)) # # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=-1,limit=-1,cds_min=100) # generate mRNA.fa # # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=300,limit=-1,cds_min=100) # generate mRNA_region.fa # # RDP = RiboDataPipeline(args.exp_pretrain_data_path, ribo_experiment, rna_experiment, seq_only=False, region=6,limit=300,cds_min=100) # loading ribosome counts data # # TR,VL,TS = RDP.samples['TR'],RDP.samples['VL'],RDP.samples['TS'] # # with open('./dataset/experiment/nature/reference/GRCh38.p14/mRNA_300.pkl','wb') as f: # # pickle.dump((TR,VL,TS),f) # # # # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=300,limit=300,cds_min=100,norm=True) # generate mRNA_region.fa # # # TR,VL,TS = RDP.samples['TR'],RDP.samples['VL'],RDP.samples['TS'] # # with open('./dataset/experiment/nature/reference/GRCh38.p14/mRNA_300.pkl','wb') as f: # # pickle.dump(RDP.samples,f) # # # '''why 300''' # # region = 1000 # # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=False, region=1000,limit=-1,cds_min=100,norm=False) # generate mRNA_region.fa # # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=2000,limit=-1,cds_min=100,norm=False) # generate mRNA_region.fa # # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=2000,limit=-1,cds_min=100,norm=False) # generate mRNA_region.fa # # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=-1,limit=-1,cds_min=100,norm=False) # generate mRNA_region.fa # # # TR,VL,TS = RDP.samples['TR'],RDP.samples['VL'],RDP.samples['TS'] # # ftrack = f'./dataset/experiment/nature/track_{ribo_experiment}_{rna_experiment}_{region}_counts_not_norm.pkl' # # ref_norm = [690.98991075,2214.2488917] # # # ref_norm = [1,1] # # # # with open(ftrack,'wb') as f: # # data = RDP.samples # # TS = deepcopy(data['TR']) # # df = pd.DataFrame(TS).T # # df.columns = 'seq,cds_start,cds_stop,ribo_counts,rna_counts,ribosome_density,te,env,cds_len,mRNA_len'.split(',') # # df['RPF_counts'] = df['ribo_counts'].apply(lambda x: sum(x[1001:3004])) # # df['RNA_counts'] = df['rna_counts'].apply(lambda x: sum(x[1001:3004])) # # # # df = df[df['RPF_counts']>100] # # df = df[df['RNA_counts']>100] # # df = df[df['cds_len']>2000] # # df = df[df['te']!=-1] # # RPF = np.array(df['ribo_counts'].tolist())*ref_norm[0] # # RNA = np.array(df['rna_counts'].tolist())*ref_norm[1] # # density = np.array(df['ribosome_density'].tolist()) # # pickle.dump((RPF,RNA,density),f) # # # fname = 'track_SRX12763793_SRX12763783_1000.pkl' # # with open(os.path.join(WDIR, fname), 'rb') as f: # # data = pickle.load(f) # # # # train_ds = RiboDataset(TR, tokenizer) # # train_loader = DataLoader( # # train_ds, # # batch_size=args.batch_size, # # pin_memory=True, # # drop_last=False, # # shuffle=False, # # ) # # for step, data in enumerate(train_loader): # # print(step, [a.shape for a in data]) # [torch.Size([1, 1203]), torch.Size([1, 1203]), torch.Size([1, 1203, 1203, 1])] # # # '''submit task to get a lot of pretraining data''' # ribo_experiment, rna_experiment = 'SRX5164421','SRX5164417' #TAIR10# # ribo_experiment, rna_experiment = 'SRX12763793','SRX12763783' # human # # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=-1,limit=-1,cds_min=100) # generate mRNA.fa # # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=300,limit=-1,cds_min=100) # generate mRNA_region.fa # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=1998,limit=-1,cds_min=100) # generate mRNA_region.fa # # RDP = RiboDataPipeline(args.exp_pretrain_data_path, ribo_experiment, rna_experiment, seq_only=False, region=6,limit=300,cds_min=100) # loading ribosome counts data # # TR,VL,TS = RDP.samples['TR'],RDP.samples['VL'],RDP.samples['TS']