#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Title : tokenizer.py
project : minimind_RiboUTR
Created by: julse
Created on: 2025/2/12 16:40
des: TODO
"""
from typing import List
import argparse
import os
import pickle
import random
import re
import time
from collections import defaultdict
from itertools import chain
from random import shuffle
import math
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import transformers
from copy import copy, deepcopy
from model.codon_attr import Codon
# for debug only
os.chdir('../../')
# print(__file__,os.getcwd())
import sys
from utils.ernie_rna.dictionary import Dictionary
from utils.ernie_rna.position_prob_mask import calculate_mask_prob
from transformers import DebertaTokenizerFast
from model.codon_tables import CODON_TO_AA, AA_str, AA_TO_CODONS, reverse_dictionary, create_codon_mask
# from utils.esm3.tokenizer import EsmSequenceTokenizer
base_range_lst = [1]
lamda_lst = [0.8]
import torch
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
class BaseDataset(Dataset):
"""公共基类,包含共享属性和方法"""
def __init__(
self,
tokenizer,
region: int = 300,
limit: int = -1,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
freq_weighted_replacement: bool = False,
two_dim_score: bool = False,
two_dim_mask: int = -1,
mask_whole_words: torch.Tensor = None,
):
# 参数校验
assert 0.0 < mask_prob < 1.0
assert 0.0 <= random_token_prob <= 1.0
assert 0.0 <= leave_unmasked_prob <= 1.0
assert random_token_prob + leave_unmasked_prob <= 1.0
# 初始化公共属性
self.tokenizer = tokenizer
self.pad_idx = tokenizer.pad_index
self.mask_idx = tokenizer.mask_index
self.return_masked_tokens = return_masked_tokens
self.seed = seed
self.mask_prob = mask_prob
self.leave_unmasked_prob = leave_unmasked_prob
self.random_token_prob = random_token_prob
self.two_dim_score = two_dim_score
self.two_dim_mask = two_dim_mask
self.mask_whole_words = mask_whole_words
self.region = region
self.limit = limit
# 初始化权重(如果需要)
if random_token_prob > 0.0:
weights = np.array(tokenizer.count) if freq_weighted_replacement else np.ones(len(tokenizer))
weights[: tokenizer.nspecial] = 0
self.weights = weights / weights.sum()
self.tokenizer.indices['T']=self.tokenizer.indices['U']
self.amino_acid_to_codons = {}
for aa, codons in AA_TO_CODONS.items():
codons_num = []
for codon in codons:
codon_num = []
for base in codon:
codon_num.append(self.tokenizer.indices[base]) # 如果碱基不在映射中,使用4表示未知
codons_num.append(codon_num)
self.amino_acid_to_codons[self.tokenizer.indices[aa.lower()]] = codons_num
# region === 公共方法 ===
# @staticmethod
# def prepare_input_for_ernierna(index, seq_len):
# shorten_index = index[:seq_len + 2] # 截断到seq_len+2
# one_d = torch.from_numpy(shorten_index).long().reshape(1, -1)
# two_d = np.zeros((1, seq_len + 2, seq_len + 2))
# two_d[0, :, :] = creatmat(shorten_index.astype(int), base_range=1, lamda=0.8)
# # two_d[:, :, :] = creatmat(shorten_index.astype(int), base_range=1, lamda=0.8)
# two_d = two_d.transpose(1, 2, 0)
# two_d = torch.from_numpy(two_d).reshape(1, seq_len + 2, seq_len + 2, 1)
# return one_d, two_d
@staticmethod
def translate(nucleotide_seq,repeate=3):
amino_acid_list = []
for i in range(0, len(nucleotide_seq), 3):
codon = nucleotide_seq[i:i + 3]
amino_acid_list.append(CODON_TO_AA.get(codon, '-')*repeate)
amino_acid_seq = ''.join(amino_acid_list)
return amino_acid_seq
@staticmethod
def prepare_input_for_ernierna(index, seq_len): # (1, 1205), 1205
if index.ndim == 2:
index = np.squeeze(index)
shorten_index = index[:seq_len] # 截断到seq_len
one_d = torch.from_numpy(shorten_index).long().reshape(1, -1)
two_d = np.zeros((1, seq_len, seq_len))
two_d[0, :, :] = creatmat(shorten_index.astype(int), base_range=1, lamda=0.8)
# new_matrix = creatmat(item.numpy(), base_range, lamda) # [1205]
two_d = two_d.transpose(1, 2, 0)
two_d = torch.from_numpy(two_d).reshape(1, seq_len, seq_len, 1)
return one_d, two_d
def generate_inputs(self,x):
region = self.region
# utr5 = x["UTR5"] if 'UTR5' in x else UTR5
# utr3 = x["UTR3"] if 'UTR3' in x else UTR3
# cds = x["CDS"] if 'CDS' in x else CDS
utr5 = x["UTR5"]
utr3 = x["UTR3"]
cds = x["CDS"]
seq = utr5 + cds + utr3
cds_start = len(utr5)
cds_stop = len(utr5) + len(cds)
# utr5 = seq[:cds_start]
# cds = seq[cds_start:cds_stop]
# utr3 = seq[cds_stop:]
utr5_limit = 300 if region > 300 else region
# seq = self.process_sequence(seq, cds_start, cds_stop, region, '_', '<', '>', 'N', utr5_limit)
seq = self.process_sequence(seq, cds_start, cds_stop, region, '_', '<', '>', 'N', utr5_limit)
return seq
def process_sequence(self,seq, cds_start, cds_stop, region, pad_mark, bos, eos,link,utr5_limit):
utr5 = seq[:cds_start]
cds = seq[cds_start:cds_stop]
utr3 = seq[cds_stop:]
# utr5 = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# cds_h = self.process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# cds_t = self.process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# seq = utr5 + cds_h + cds_t + utr3
# seq = seq[:region*2+1]+link*3+seq[-region*2-1:]
utr5 = self.process_utr(utr5, utr5_limit, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
cds_h = self.process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
cds_t = self.process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
seq = utr5 + cds_h + cds_t + utr3
seq = seq[:utr5_limit+region+1]+link*3+seq[-region*2-1:]
# c1 = seq[cds_start:]
# c2 = seq[:cds_stop]
#
# utr5 = self.process_utr(utr5, utr5_limit, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# cds_h = self.process_utr(c1, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# cds_t = self.process_utr(c2, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # 这样会导致CDS和UTR混在一起,后面不太好mask猜AA
# utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# seq = utr5 + cds_h + cds_t + utr3
# seq = seq[:utr5_limit+region+1]+link*3+seq[-region*2-1:]
# utr5 = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
#
# pre_processed = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# behind_processed = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# # seq = pre_processed + cds_part + behind_processed
# # seq = seq[:region*2+1]+link*3+seq[-region*2-1:]
if isinstance(seq,list):
seq = np.array(seq)
return seq
@staticmethod
def process_utr(utr, input_len, pad_method, pad_mark='_',bos='<',eos='>'):
if len(utr) < input_len:
if pad_method == 'pre':
padded_utr = pad_mark * (input_len - len(utr)) + bos + utr
elif pad_method == 'behind':
padded_utr = utr+eos + pad_mark * (input_len - len(utr))
else:
if pad_method == 'pre':
padded_utr = bos+utr[-input_len:]
elif pad_method == 'behind':
padded_utr = utr[:input_len]+eos
return padded_utr
# self.process_utr = process_utr
# @staticmethod
# def seqs_to_index(sequences, pad_idx=1, unk_idx=3):
# '''
# input:
# sequences: list of string (difference length)
#
# return:
# rna_index: numpy matrix, shape like: [len(sequences), max_seq_len+2]
# rna_len_lst: list of length
#
#
# examples:
# rna_index, rna_len_lst = seq_to_index(sequences)
# for i, (index, seq_len) in enumerate(zip(rna_index, rna_len_lst)):
# one_d, two_d = prepare_input_for_ernierna(index, seq_len)
# one_d = one_d.to(device)
# two_d = two_d.to(device)
#
# output = my_model(one_d, two_d, layer_idx=layer_idx).cpu().detach().numpy()
# '''
#
# rna_len_lst = [len(ss) for ss in sequences]
# max_len = max(rna_len_lst)
# # assert max_len <= 1022
# seq_nums = len(rna_len_lst)
# rna_index = np.ones((seq_nums, max_len + 2))
# for i in range(seq_nums):
# for j in range(rna_len_lst[i]): # 4,5,6,7 --->GATC
# if sequences[i][j] in set("Aa"):
# rna_index[i][j + 1] = 5
# elif sequences[i][j] in set("Cc"):
# rna_index[i][j + 1] = 7
# elif sequences[i][j] in set("Gg"):
# rna_index[i][j + 1] = 4
# elif sequences[i][j] in set('TUtu'):
# rna_index[i][j + 1] = 6
# elif sequences[i][j] in set('_'):
# rna_index[i][j + 1] = pad_idx
# else:
# rna_index[i][j + 1] = unk_idx
# rna_index[i][rna_len_lst[i] + 1] = 2 # add 'eos' token
# rna_index[:, 0] = 0 # add 'cls' token
# return rna_index, rna_len_lst
# @staticmethod
# def seq_to_rnaindex(seq,pad_idx=1, unk_idx=3):
# l = len(seq)
# X = np.ones((1, l + 2))
# for j in range(l):
# if seq[j] in set('Aa'):
# X[0, j + 1] = 5
# elif seq[j] in set('UuTt'):
# X[0, j + 1] = 6
# elif seq[j] in set('Cc'):
# X[0, j + 1] = 7
# elif seq[j] in set('Gg'):
# X[0, j + 1] = 4
# elif seq[j] in set('_'):
# X[0,j + 1] = pad_idx
# else:
# X[0,j + 1] = unk_idx
#
# X[0, l + 1] = 2
# X[0, 0] = 0
# return X
@staticmethod
def seq_to_rnaindex(seq,pad_idx=1, unk_idx=3):
# rna_alphabet_list:str="""GAUC_""",
# '': 0, '': 1, '': 2, '': 3,
# 'G': 4, 'A': 5, 'U': 6, 'C': 7, 'N': 8, '': 9,
# 'a': 10, 'y': 29, '*': 30, '-': 31, 'T': 6
l = len(seq)
X = np.ones((1, l))
for j in range(l):
if seq[j] in set('Aa'):
X[0, j] = 5
elif seq[j] in set('UuTt'):
X[0, j] = 6
elif seq[j] in set('Cc'):
X[0, j] = 7
elif seq[j] in set('Gg'):
X[0, j] = 4
elif seq[j] in set('_'):
X[0,j] = pad_idx
elif seq[j] in set('<'):
X[0,j] = 0
elif seq[j] in set('>'):
X[0,j] = 2
else:
X[0,j] = unk_idx
# X[0, l + 1] = 2
# X[0, 0] = 0
return X
# def generate_mask(self, X,seq_len):
# one_d, twod_d = self.prepare_input_for_ernierna(X, seq_len)
# # return one_d, twod_data # [1,L+2],[1,L+2,L+2,1],[1,L,4]
# '''generate src_data, tgt_data, twod_data '''
# item = one_d.view(-1)
#
# assert (
# self.mask_idx not in item
# ), "Dataset contains mask_idx (={}), this is not expected!".format(
# self.mask_idx,
# )
#
# if self.mask_whole_words is not None: # todo: check when need
# word_begins_mask = self.mask_whole_words.gather(0, item)
# word_begins_idx = word_begins_mask.nonzero().view(-1)
# sz = len(word_begins_idx)
# words = np.split(word_begins_mask, word_begins_idx)[1:]
# assert len(words) == sz
# word_lens = list(map(len, words))
#
# sz = len(item)
# # decide elements to mask
# mask = np.full(sz, False)
#
# # 找出非 padding 的位置
# non_pad_indices = np.where(
# (item != self.tokenizer.pad_index) &
# (item != self.tokenizer.unk_index)
# )[0]
# # 计算需要掩码的数量
# num_non_pad = len(non_pad_indices)
# num_mask = int(
# self.mask_prob * num_non_pad + np.random.rand()
# )
#
# # 在非 padding 的位置中随机选择要掩码的元素,根据position prob 进行mask
# target_positions = [self.region + 1, self.region *3 + 4]
# sigma = 90 # 控制概率衰减的速度,数值越小,衰减越快
# probabilities = np.array([calculate_mask_prob(i, target_positions, sigma) for i in range(sz)])
# non_pad_probabilities = probabilities[non_pad_indices]
# non_pad_probabilities = non_pad_probabilities/non_pad_probabilities.sum()
# if num_non_pad >= 1:
# mask[non_pad_indices[np.random.choice(num_non_pad, num_mask, replace=False,p=non_pad_probabilities)]] = True
# # twod_data
# two_dim_matrix =torch.squeeze(twod_d, dim=-1).numpy()
# # item_len = len(item.numpy())
# # two_dim_matrix = np.zeros((len(base_range_lst) * len(lamda_lst), item_len, item_len)) # 只有0和-1
# padding_dim = 0
# for base_range in base_range_lst:
# for lamda in lamda_lst:
# new_matrix = creatmat(item.numpy(), base_range, lamda)
# new_matrix[mask, :] = -1
# new_matrix[:, mask] = -1
# two_dim_matrix[padding_dim, :, :] = new_matrix
# padding_dim += 1
# # use -1 represent mask
# # matrix[mask,:] = self.two_dim_mask
# # matrix[:,mask] = self.two_dim_mask
# # print(two_dim_matrix.shape)
# twod_data = torch.from_numpy(two_dim_matrix)#.unsqueeze(-1) # [1, L+2, L+2, 1]
#
#
# if self.mask_whole_words is not None:
# mask = np.repeat(mask, word_lens)
#
# new_item = np.full(len(mask), self.pad_idx)
# new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
# tgt_data = torch.from_numpy(new_item)#.unsqueeze(0) # [L,1]
#
# # decide unmasking and random replacement
# rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
# if rand_or_unmask_prob > 0.0:
# rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
# if self.random_token_prob == 0.0:
# unmask = rand_or_unmask
# rand_mask = None
# elif self.leave_unmasked_prob == 0.0:
# unmask = None
# rand_mask = rand_or_unmask
# else:
# unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
# decision = np.random.rand(sz) < unmask_prob
# unmask = rand_or_unmask & decision
# rand_mask = rand_or_unmask & (~decision)
# else:
# unmask = rand_mask = None
#
# if unmask is not None:
# mask = mask ^ unmask
#
# if self.mask_whole_words is not None:
# mask = np.repeat(mask, word_lens)
#
# new_item = np.copy(item)
# new_item[mask] = self.mask_idx
# if rand_mask is not None:
# num_rand = rand_mask.sum()
# if num_rand > 0:
# if self.mask_whole_words is not None:
# rand_mask = np.repeat(rand_mask, word_lens)
# num_rand = rand_mask.sum()
# # 以概率突变
# new_item[rand_mask] = np.random.choice(
# len(self.tokenizer),
# num_rand,
# p=self.weights,
# )
# src_data = torch.from_numpy(new_item)#.unsqueeze(0)
# loss_mask = torch.tensor(mask, dtype=torch.long)#.unsqueeze(0)
#
# return src_data,tgt_data,twod_data,loss_mask
def generate_mask(self, X,seq_len,mask=None,input_mask=True):
"""
:param X: seuqnce in number index
:param seq_len:
:param mask: 1d mask array, if None, generate by dual center mask
:param input_mask: true: use 2d mask
:return:
"""
one_d, twod_d = self.prepare_input_for_ernierna(X, seq_len) # X: [1,1205]
# return one_d, twod_data # [1,L+2],[1,L+2,L+2,1],[1,L,4]
'''generate src_data, tgt_data, twod_data '''
item = one_d.view(-1)
assert (
self.mask_idx not in item
), "Dataset contains mask_idx (={}), this is not expected!".format(
self.mask_idx,
)
if self.mask_whole_words is not None: # todo: check when need
word_begins_mask = self.mask_whole_words.gather(0, item)
word_begins_idx = word_begins_mask.nonzero().view(-1)
sz = len(word_begins_idx)
words = np.split(word_begins_mask, word_begins_idx)[1:]
assert len(words) == sz
word_lens = list(map(len, words))
sz = len(item)
# decide elements to mask
if mask is None: # 1D mask
mask = np.full(sz, False)
# 找出非 padding 的位置
non_pad_indices = np.where(
(item != self.tokenizer.pad_index) &
(item != self.tokenizer.unk_index)
)[0] #
# 计算需要掩码的数量
num_non_pad = len(non_pad_indices)
num_mask = int(
self.mask_prob * num_non_pad + np.random.rand()
)
# 在非 padding 的位置中随机选择要掩码的元素,根据position prob 进行mask
target_positions = [self.region + 1, self.region *3 + 4] # ATG seq[301:301+3], TGG seq[901:901+3]
sigma = 90 # 控制概率衰减的速度,数值越小,衰减越快
probabilities = np.array([calculate_mask_prob(i, target_positions, sigma) for i in range(sz)])
non_pad_probabilities = probabilities[non_pad_indices]
non_pad_probabilities = non_pad_probabilities/non_pad_probabilities.sum()
if num_non_pad >= 1:
mask[non_pad_indices[np.random.choice(num_non_pad, num_mask, replace=False,p=non_pad_probabilities)]] = True
mask[target_positions[0]:target_positions[0]+3]=False # (301, 304) ATG
mask[target_positions[1]-3:target_positions[1]]=False # (901, 904) TAA
mask[target_positions[0]+300:target_positions[0]+303]=False # (601, 604) NNN
# decide unmasking and random replacement
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
if rand_or_unmask_prob > 0.0:
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
if self.random_token_prob == 0.0:
unmask = rand_or_unmask
rand_mask = None
elif self.leave_unmasked_prob == 0.0:
unmask = None
rand_mask = rand_or_unmask
else:
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
decision = np.random.rand(sz) < unmask_prob
unmask = rand_or_unmask & decision
rand_mask = rand_or_unmask & (~decision)
else:
unmask = rand_mask = None
if unmask is not None:
mask = mask ^ unmask
# twod_data
if input_mask:
twod_data = self.get_twod_data(item,twod_d.detach(),mask) # mask = [1, 1205, 1205, 1]#
else:
twod_data = self.get_twod_data(item,twod_d.detach(),np.zeros_like(mask)) # mask = [1, 1205, 1205, 1]#
if self.mask_whole_words is not None:
mask = np.repeat(mask, word_lens)
# new_item = np.full(len(mask), self.pad_idx)
# new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
# tgt_data = torch.from_numpy(new_item)#.unsqueeze(0) # [L,1]
tgt_data = item#.unsqueeze(0) # [L,1]
if self.mask_whole_words is not None:
mask = np.repeat(mask, word_lens)
new_item = np.copy(item)
new_item[mask] = self.mask_idx
# if rand_mask is not None:
# num_rand = rand_mask.sum()
# if num_rand > 0:
# if self.mask_whole_words is not None:
# rand_mask = np.repeat(rand_mask, word_lens)
# num_rand = rand_mask.sum()
# # 以概率突变
# new_item[rand_mask] = np.random.choice(
# 9,#len(self.tokenizer)
# num_rand,
# p=self.weights,
# )
src_data = torch.from_numpy(new_item)#.unsqueeze(0)
# loss_mask = torch.tensor(mask, dtype=torch.long)#.unsqueeze(0)
return src_data,tgt_data,twod_data,mask
def get_twod_data(self,item,twod_d,mask):
two_dim_matrix =torch.squeeze(twod_d, dim=-1).numpy()
# item_len = len(item.numpy())
# two_dim_matrix = np.zeros((len(base_range_lst) * len(lamda_lst), item_len, item_len)) # 只有0和-1
padding_dim = 0
for base_range in base_range_lst:
for lamda in lamda_lst:
new_matrix = creatmat(item.numpy(), base_range, lamda)
new_matrix[mask==1, :] = -1
new_matrix[:, mask==1] = -1
two_dim_matrix[padding_dim, :, :] = new_matrix
padding_dim += 1
# use -1 represent mask
# matrix[mask,:] = self.two_dim_mask
# matrix[:,mask] = self.two_dim_mask
# print(two_dim_matrix.shape)
twod_data = torch.from_numpy(two_dim_matrix)#.unsqueeze(-1) # [1, L+2, L+2, 1]
return twod_data
@staticmethod
def read_text_file(file_path):
try:
with open(file_path, 'r') as file:
return [line.strip() for line in file]
except FileNotFoundError:
print(f"Error: File '{file_path}' not found.")
return []
@staticmethod
def create_base_prob(target_protein,ith_nn_prob,rna_alphabet,tokenizer):
mask_nn_logits = torch.full(size=(len(target_protein)*3,len(tokenizer)),fill_value=float("-inf"))
for i,a in enumerate(target_protein):
if a not in ith_nn_prob[0]: continue
for j in range(3):
for n in rna_alphabet:
mask_nn_logits[i*3+j,tokenizer.index(n)] = ith_nn_prob[j][a][n]
return mask_nn_logits
@staticmethod
def create_codon_mask(target_protein, backbone_cds, amino_acid_to_codons,
tokenizer):
# logits = torch.full()
# batch_size, seq_length, vocab_size = logits.shape
seq_length = len(backbone_cds)
vocab_size = len(tokenizer)
# mask = torch.full_like(logits, float("-inf"))
mask = torch.full(size=(seq_length,vocab_size),fill_value=float("-inf"))
for i, amino_acid in enumerate(target_protein):
codon_start = i * 3 # 每个氨基酸对应 3 个碱基
codon_end = codon_start + 3
if codon_end > seq_length:
continue # 超出序列长度,跳过
possible_codons = amino_acid_to_codons.get(amino_acid.item(), [])
# filter_codons = []
for pos in range(codon_start, codon_end):
base_pos = pos % 3 # 当前碱基在密码子中的位置(0, 1, 2)
for codon in possible_codons:
flag = True
for j, nt in enumerate(backbone_cds[codon_start:codon_end]):
nt = nt.item()
if tokenizer.mask_index == nt: continue
if codon[j] != nt:
flag = False
# filter_codons.append(codon)
if flag:
base_idx = codon[base_pos]
mask[pos, base_idx] = 0
# a = mask.numpy()
return mask
# endregion
# region === 需要子类实现的方法 ===
def load_data(self, path, **kwargs):
raise NotImplementedError("Subclasses must implement load_data")
def __getitem__(self, idx):
raise NotImplementedError("Subclasses must implement __getitem__")
# endregion
class RNADataset(BaseDataset):
"""处理RNA序列的Dataset"""
def __init__(
self,
path,
tokenizer,
region: int = 300,
limit: int = -1,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
freq_weighted_replacement: bool = False,
two_dim_score: bool = False,
two_dim_mask: int = -1,
mask_whole_words: torch.Tensor = None,
):
# 调用父类初始化
super().__init__(
tokenizer=tokenizer,
region=region,
limit=limit,
return_masked_tokens=return_masked_tokens,
seed=seed,
mask_prob=mask_prob,
leave_unmasked_prob=leave_unmasked_prob,
random_token_prob=random_token_prob,
freq_weighted_replacement=freq_weighted_replacement,
two_dim_score=two_dim_score,
two_dim_mask=two_dim_mask,
mask_whole_words=mask_whole_words,
)
# 加载数据
self.samples = self.load_data(path, region=self.region, limit=limit)
def load_data(self, path, region=300, limit=-1):
return self.read_fasta_file(path, region=region, limit=limit)
@staticmethod
def read_fasta_file(file_path, region=300, cds_min=100, limit=-1):
'''
input:
file_path: str, fasta file path of input seqs
return:
seqs_dict: dict[str], dict of seqs
{
'ENST00000231420.11': { # 转录本的标识符
'cds_start': 57, # CDS的起始位置(基于0的索引)
'cds_stop': 1599, # CDS的终止位置(不包括该位置,基于0的索引)
'full': 'AGTTAGAGCCCGGCCTCCAATCTGCTTCCATGGGGTTGGCTTTCTGAGTGGGAGAAATGACTCTAATCTGGAGACA...', # 完整的mRNA序列
'start_context': '___GAAATGTCT', # CDS起始位置前的序列上下文, padding left _,essential
'stop_context': 'AAGTAAGGG___' # CDS终止位置后的序列上下文, padding right _, essential
}
}
'''
# region = getattr(args, 'region', region)
# limit = getattr(args, 'limit', limit)
try:
with open(file_path) as fa:
seqs_dicts = []
cds_start = 0
cds_stop = 0
count = 0
seq_name = ''
# for line in fa.read().splitlines():
for line in fa:
line = line.replace('\n', '')
if line.startswith('>'):
transcript_id, gene_id, cds_start, cds_stop = line[1:].split(
' ') # # ENST00000332160.5 ENSG00000185432.12 24 756
cds_start = int(cds_start)
cds_stop = int(cds_stop)
if cds_stop - cds_start < cds_min: continue
seq_name = transcript_id
# seqs_dict[seq_name] = {}
# seqs_dict[seq_name]['cds_start'] = cds_start
# seqs_dict[seq_name]['cds_stop'] = cds_stop
else:
expand_mRNA = '_' * region + line + '_' * region
cds_start += region
cds_stop += region
# seqs_dict[seq_name]['full'] = line
start_context = expand_mRNA[cds_start - region:cds_start + region]
stop_context = expand_mRNA[cds_stop - region:cds_stop + region]
seqs_dicts.append(
{'_id': seq_name, 'start_context': start_context, 'stop_context': stop_context})
count += 1
if count > limit and limit != -1: break
return seqs_dicts
except FileNotFoundError:
print(f"Error: File '{file_path}' not found.")
return []
def __len__(self):
return len(self.samples)
def __getitem__(self, idx): # to check
'''
GAUC 4567
unk 3
:param idx:
:return:
'''
sample = self.samples[idx]
seq = sample['start_context'] + 'NNN' + sample['stop_context']
X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index)
if '_' in sample['start_context']:
X[:, sample['start_context'].count('_')] = self.tokenizer.bos_index
if '_' in sample['stop_context']:
X[:, -sample['stop_context'].count('_')-1] = self.tokenizer.eos_index
'''generate src_data, tgt_data, twod_data '''
src_data,tgt_data,twod_data,loss_mask = self.generate_mask(X,len(seq))
return src_data,tgt_data,twod_data,loss_mask
class RiboDataPipeline():
"""
处理预训练任务的Dataset,生成mRNA.fa,加载ribosome_density, ribo_counts, rna_counts,划分TR,VL,TS
Loading from origin bw
"""
def __init__(
self,
path,
ribo_experiment,rna_experiment,
seq_only=False,
region: int = 300,
cds_min: int = -1, # -1,不检查cds 长度
limit: int = -1,
env : int = 0,
norm = True
):
self.seq_only = seq_only
self.cds_min = cds_min
self.env = env
# self.reference_transcript_dict = {'ENST00000303577.7': 'PCBP1', # IRES chr2
# 'ENST00000309311.7': 'EEF2'} # cap dependent # chr19
# self.reference_transcript_dict = {
# 'ENST00000309311.7': 'EEF2'} # cap dependent # chr19
self.reference_transcript_dict = {} # cap dependent # chr19
# 加载数据
self.samples = self.load_data(path, ribo_experiment=ribo_experiment,rna_experiment=rna_experiment,region=region, limit=limit,norm=norm)
# self.ref_norm = np.mean([self.samples[key][4].sum() for key in self.reference_transcript_dict.keys()]) if self.reference_transcript_dict else 1 # RNA_counts
def load_data(self, path, ribo_experiment=None,rna_experiment=None,region=300, limit=-1,norm=True):
'''
读取数据
1. 根据ribo_experiment,从meta中查询species,avg_len,total counts 等
2. 查询mRNA.fa文件是否存在
不存在:
查询mRNA.tsv是否存在
不存在:
genome.gtf 生成 mRNA.gtf文件 (只含有mRNA相关的行和列,scale the size of gtf)
根据genome.fa 和mRNA.gtf 生成mRNA.fa文件 (包括start or stop codon positions)
3. 读取track文件,生成ribosome_density, ribo_counts, rna_counts
{'ENST00000303577.7': 'PCBP1', # IRES
'ENST00000309311.7': 'EEF2'} # cap dependent
:param path:
:param reference_path:
:param region:
:param limit:
:return: samples
path = ./dataset/pretraining/
'''
"""1. input ribo_experiment, meta"""
seq_only = self.seq_only
cds_min = self.cds_min
# print('load_data in Pipeline')
reference_path = os.path.join(path,'reference')
meta = self.read_meta_file(os.path.join(reference_path, 'experiment_meta.tsv'),ribo_experiment,rna_experiment,seq_only = seq_only) # todo
totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA,readsLength_RPF,species = meta
fribo_track = os.path.join(path, 'track', f'{ribo_experiment}.bw')
frna_track = os.path.join(path, 'track', f'{rna_experiment}.bw')
if not seq_only:
if os.access(fribo_track,os.F_OK) and os.access(frna_track,os.F_OK):
print(f'load {ribo_experiment} and {rna_experiment} tracks')
else:
print(f'Error: {fribo_track} or {frna_track} not found.')
return None
"""2. check mRNA.fa, .pkl"""
# # 读取 chromosomes reference
# print('load_data in Pipeline',meta)
mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa')
if region != -1: mrna_fa_path = mrna_fa_path.replace('.fa', f'_{region}.fa')
mrna_fa_path = mrna_fa_path.replace('.fa', f'.pkl')
self.mrna_region_pkl_path = mrna_fa_path
# sequence
if seq_only and os.access(mrna_fa_path, os.F_OK):
with open(mrna_fa_path, 'rb') as f:
sample_dict = pickle.load(f)
limited_sample_dict = {}
for key in sample_dict.keys():
# if limit != -1:
# limited_sample_dict[key] = [[transcript_id]+list(sample_dict[key][transcript_id]) for transcript_id in list(sample_dict[key].keys())[:limit]]
# else :limited_sample_dict[key] = [[transcript_id]+list(sample_dict[key][transcript_id]) for transcript_id in list(sample_dict[key].keys())]
if limit != -1:
limited_sample_dict[key] = sample_dict[:limit]
else:
limited_sample_dict[key] = sample_dict[key]
return limited_sample_dict
# gtf
mrna_tsv_path = os.path.join(reference_path,species, 'mRNA.tsv')
if not os.access(mrna_tsv_path, os.F_OK):
genome_gtf_path = os.path.join(reference_path,species, 'genome.gtf')
genome_fa_path = os.path.join(reference_path,species, 'genome.fa')
mrna_tsv = self.generate_mRNA_tsv(genome_gtf_path,genome_fa_path,mrna_tsv_path)
else:
mrna_tsv = pd.read_table(mrna_tsv_path)#.iloc[:100] # 读取已经生成的mRNA.tsv文件
# print('load_data in Pipeline',mrna_tsv.shape) # 1459048, 11
"""3. read track files"""
# if limit<10: # debug mode
# debug_ids = 'ENST00000332831.5,ENST00000000233.10'.split(',')
# mrna_tsv = mrna_tsv[mrna_tsv.transcript_id.isin(debug_ids)].reindex()
# region = 6
print(f'filter limit={limit},region={region}')
print('load_data in Pipeline,before filter',mrna_tsv.shape) # 1459048, 11
reference_transcript_ids = list(self.reference_transcript_dict.keys())
keeping_transcript_ids = mrna_tsv[mrna_tsv['seqname'].isin(['chr10','chr15'])].transcript_id.unique().tolist()
print('keeping transcript_ids',len(reference_transcript_ids+keeping_transcript_ids))
# if args.debug:
# region = 6
# limit = 2000
if limit!=-1:
other_transcript_ids = mrna_tsv[~mrna_tsv['transcript_id'].isin(reference_transcript_ids+keeping_transcript_ids)].transcript_id.unique().tolist()
shuffle(other_transcript_ids)
shuffle(keeping_transcript_ids)
mrna_tsv = mrna_tsv[mrna_tsv.transcript_id.isin(reference_transcript_ids+keeping_transcript_ids[:limit]+other_transcript_ids[:limit])]
print('load_data in Pipeline, after filter',mrna_tsv.shape) # 1459048, 11
if not seq_only:
import pyBigWig
ribo_bw,rna_bw = [pyBigWig.open(fribo_track), pyBigWig.open(frna_track)]
print(f'meta of {ribo_experiment} and {rna_experiment} tracks loaded\n{ribo_bw.header()}\n{rna_bw.header()}')
def iterfunc(x,bw):
chrom, start, end = x['seqname'], x['start'], x['end']
if chrom in bw.chroms():
return np.array(bw.values(chrom, start - 1, end))
else:
return np.zeros(end - start)
# if x['seqname'] in bw or x['end']={ribo_bw.chroms(x['seqname'])},{x['end'], region=6
ribo_counts,
rna_counts,
ribosome_density,
te, self.env, cds_len, mRNA_len,junction_counts) = ans
ref_norm.append((sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA))
if len(ref_norm)==0 and norm:
print(f'Error: no qualified reference transcript (housekeeping when norm=True)')
return None
ref_norm = np.mean(ref_norm,axis=0) if norm and len(ref_norm)>0 else None
print('ref_norm',ref_norm,'sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA)')
'''generate by norm'''
for transcript_id, data in mrna_tsv.groupby('transcript_id'):
tag = data['dataset'].iloc[0]
ans = self.merge_transcript_level(data,total_counts_info=total_counts_info,seq_only=seq_only,cds_min=cds_min,region=region,ref_norm=ref_norm)
if ans is None:continue
sample_dict[tag].append([transcript_id] + ans)
count += 1
if limit == count: break
# datasplit = 'TR_VL_TS.tsv'
# # ~/Data/RNAdesign/Raw_data/_0_reference/GRCh38.p14/mRNA/dataset_split $ head TR_VL_TS.tsv
# # transcript_id dataset seqname gene_id
self.samples = sample_dict
if seq_only:
mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa')
if region !=-1:
mrna_fa_path = mrna_fa_path.replace('.fa',f'_{region}.fa')
if not (os.access(mrna_fa_path,os.F_OK) and os.path.getsize(mrna_fa_path)>0):
print(f'generate {sum([len(a) for a in sample_dict.values()])} sequences to {mrna_fa_path} {os.path.abspath(mrna_fa_path)}')
self.generate_mRNA_fa(mrna_fa_path,sample_dict,force_regenerate=True) # ,force_regenerate=True
mrna_fa_path = mrna_fa_path.replace('.fa',f'.pkl')
if not os.access(mrna_fa_path, os.F_OK):
with open(mrna_fa_path, 'wb') as f:
pickle.dump(sample_dict, f)
self.mrna_region_pkl_path = mrna_fa_path
return sample_dict # transcript_id 作为 groupby key
def utr5_limit(self,args,x,region):
utr5_limit = 300 if args.region>300 else args.region
seq = list( x[region - utr5_limit:region + 1 + args.region] \
+ 'NNN' + x[3 * region + 4 - args.region:3 * region + 4 + args.region+1])
if seq[-1] not in {'_','>'}:seq[-1]='>'
if seq[0] not in {'_','<'}:seq[0]='<'
return seq
def merge_transcript_level(self,data,total_counts_info=None,seq_only=False,cds_min=-1,region=300,ref_norm=None): # [1,1,1,1]
# print(transcript_id)
ans = self.qualified_samples(data, seq_only=seq_only, cds_min=cds_min)
junction_counts = len(data[data['feature'] == 'CDS'])
if ans is not None:
seq, cds_start, cds_stop, ribo_counts, rna_counts, anno, metadict = ans
cds_len = cds_stop - cds_start
mRNA_len = len(seq)
if region!=-1:
utr5_limit = 300 if region > 300 else region
seq = self.process_sequence(seq, cds_start, cds_stop, region, '_', '<', '>', 'N',utr5_limit)
# anno = self.process_sequence(anno, cds_start, cds_stop, region, '_', '<', '>', 'N')
if metadict is not None: # seq_only = False
totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF = total_counts_info
if metadict['ribo_recovery'] > 0.9 and metadict['rna_recovery'] > 0.9: # high quality samples for TE
te = self.calculate_ribosome_density(metadict['ribo_avg_count'], metadict['rna_avg_count'],
totalNumReads_RNA, totalNumReads_RPF,
readsLength_RNA, readsLength_RPF)
te = float(te)
else:
te = -1 # low quality samples for TE
# pad_or_truncate_utr
# seq = self.process_sequence(seq, cds_start, cds_stop, region, '_', '<', '>', 'N')
anno = self.process_sequence(anno, cds_start, cds_stop, region, '_', '<', '>', 'N',utr5_limit)
ribo_counts\
= self.process_sequence(ribo_counts, cds_start, cds_stop, region, [-1], [-1], [-1], [-1],utr5_limit)
rna_counts = self.process_sequence(rna_counts, cds_start, cds_stop, region, [-1], [-1], [-1], [-1],utr5_limit)
if sum(ribo_counts[ribo_counts != -1]) <= 100 or sum(
rna_counts[rna_counts != -1]) <= 100: # 这里有padding的-1,不应该放入计算中, 质量控制
# print(f"No reads for {data['transcript_id'].iloc[0]}")
return None
'''
normalized by total counts
https://rcxqhxlmkf.feishu.cn/docx/MdEvd008poMIaexhX9Xc7EAEnth#share-SNGtdmaQ2oATE0xax1Nc6m3jnbp
'''
ribo_counts += 1 # max 1130
rna_counts += 1 # max 3628
ribosome_density = deepcopy(ribo_counts)
ribosome_density[ribosome_density != 0] = self.calculate_ribosome_density(
ribo_counts[ribo_counts != 0], rna_counts[rna_counts != 0], totalNumReads_RNA, totalNumReads_RPF,
readsLength_RNA, readsLength_RPF) # 2.38
if ref_norm is not None:
ribo_counts, rna_counts = ribo_counts / (ref_norm[0] * readsLength_RPF), rna_counts / (
ref_norm[1] * readsLength_RNA) # demo4
# print([(max(a), min(a)) for a in [ribo_counts, rna_counts]])
cds_start, cds_stop = anno.index('|'), anno.rindex('|', 1) + 4
return [seq,
cds_start, cds_stop,
# CDS_region = seq[cds_start:cds_stop] , region=6
ribo_counts,
rna_counts,
ribosome_density,
te, self.env,cds_len,mRNA_len,junction_counts]
return [seq, cds_start, cds_stop,cds_len,mRNA_len,junction_counts]
# def process_sequence(self,seq, cds_start, cds_stop, region, pad_mark, bos, eos,link,utr5_limit):
# utr5 = seq[:cds_start]
# cds = seq[cds_start:cds_stop]
# utr3 = seq[cds_stop:]
#
# # utr5 = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# # cds_h = self.process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# # cds_t = self.process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# # utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# # seq = utr5 + cds_h + cds_t + utr3
# # seq = seq[:region*2+1]+link*3+seq[-region*2-1:]
#
# utr5 = self.process_utr(utr5, utr5_limit, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# cds_h = self.process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# cds_t = self.process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# seq = utr5 + cds_h + cds_t + utr3
# seq = seq[:utr5_limit+region+1]+link*3+seq[-region*2-1:]
#
# # c1 = seq[cds_start:]
# # c2 = seq[:cds_stop]
# #
# # utr5 = self.process_utr(utr5, utr5_limit, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# # cds_h = self.process_utr(c1, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# # cds_t = self.process_utr(c2, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) # 这样会导致CDS和UTR混在一起,后面不太好mask猜AA
# # utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# # seq = utr5 + cds_h + cds_t + utr3
# # seq = seq[:utr5_limit+region+1]+link*3+seq[-region*2-1:]
#
# # utr5 = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# # utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# #
# # pre_processed = self.process_utr(utr5, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos)
# # behind_processed = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos)
# # # seq = pre_processed + cds_part + behind_processed
# # # seq = seq[:region*2+1]+link*3+seq[-region*2-1:]
# if isinstance(seq,list):
# seq = np.array(seq)
# return seq
#
#
#
# @staticmethod
# def process_utr(utr, input_len, pad_method, pad_mark='_',bos='<',eos='>'):
# if len(utr) < input_len:
# if pad_method == 'pre':
# padded_utr = pad_mark * (input_len - len(utr)) + bos + utr
# elif pad_method == 'behind':
# padded_utr = utr+eos + pad_mark * (input_len - len(utr))
# else:
# if pad_method == 'pre':
# padded_utr = bos+utr[-input_len:]
# elif pad_method == 'behind':
# padded_utr = utr[:input_len]+eos
# return padded_utr
# # self.process_utr = process_utr
def generate_mRNA_fa(self,mrna_fa_path,sample_dict,force_regenerate=False):
'''for pretrain'''
if force_regenerate:
'''generate mRNA.fa'''
print('generate mRNA.fa to',mrna_fa_path)
with open(mrna_fa_path, 'w') as f:
# for tag, data in self.samples.items():
for tag, data in sample_dict.items():
for transcript_id, seq, cds_start, cds_stop, cds_len,mRNA_len,*_ in data:
# seq = seq[1:-1] if '<' ==seq[0] else seq
f.write(f">{transcript_id}|cds_start={cds_start}|cds_stop={cds_stop}|cds_len={cds_len}|mRNA_len={mRNA_len}|dataset={tag}\n{re.sub(r'[^ACGT]', 'N', seq.replace('U','T'))}\n")
# print(f'>{transcript_id}|{cds_start}|{cds_stop}|{tag}\n{seq}\n') #
# print(seq[cds_start:cds_stop])
@staticmethod
def qualified_samples(data,seq_only=False,cds_min=-1):
"""
过滤掉不合格的样本
:param df_total_counts:
:return:
"""
"""load elements"""
strand = data['strand'].iloc[0]
num_start = data[data.feature == 'start_codon'].shape[0]
num_stop = data[data.feature == 'stop_codon'].shape[0]
if num_start == 0 or num_stop == 0:
# print(f"No start or stop codon for {data['transcript_id'].iloc[0]}")
return None # 没有标记起始密码子或者终止密码子
data = data[(data.feature!='start_codon') & (data.feature!='stop_codon')]
seq = ''.join(list(chain(*data['seq'])))
anno = ''.join(list(chain(*data['anno']))) # - or | represent UTR or CDS
if not seq_only:
ribo_counts = list(chain(*data['ribo_counts']))
rna_counts = list(chain(*data['rna_counts']))
if sum(ribo_counts) == 0 or sum(rna_counts) == 0:
# print(f"No reads for {data['transcript_id'].iloc[0]}")
return None
# ribosome_density = list(chain(*data['ribosome_density']))
if strand == '-':
from pyfaidx import complement
seq = complement(seq[::-1])
anno = anno[::-1]
if not seq_only:
ribo_counts = ribo_counts[::-1]
rna_counts = rna_counts[::-1]
# ribosome_density = ribosome_density[::-1]
cds_start = anno.index('|')
cds_stop = anno.rindex('|') + 4
if cds_min!=-1:
if cds_stop - cds_start < cds_min:
# print(f"CDS length is less than {cds_min} for {data['transcript_id'].iloc[0]}")
return None # CDS长度太短
trible = anno.count('|') % 3
if trible != 0: return None
if not seq_only:
# CDS recovery, RNA-seq, Ribo-seq, CDS length,
metadict = dict()
counts = np.array([ribo_counts, rna_counts])
t = counts[:, cds_start:cds_stop] > 0
metadict['cds_len'] = cds_stop - cds_start
metadict['ribo_recovery'], metadict['rna_recovery'] = t.sum(axis=1) / metadict['cds_len']
metadict['ribo_avg_count'], metadict['rna_avg_count'] = counts.sum(axis=1) / metadict['cds_len']
if seq_only:
return seq, cds_start, cds_stop,None,None,anno,None
# metadict['5utr_len'] = cds_start
# metadict['3utr_len'] = len(anno) - cds_stop
return seq,cds_start,cds_stop,ribo_counts,rna_counts,anno,metadict
@staticmethod
def generate_mRNA_tsv(genome_gtf_path,genome_fa_path,mrna_tsv_path):
# returns GTF with essential columns such as "feature", "seqname", "start", "end"
# alongside the names of any optional keys which appeared in the attribute column
from gtfparse import read_gtf
from pyfaidx import Fasta
import polars as pl
gtf = read_gtf(genome_gtf_path) # 先读这个文件,并简化这个文件,再读fasta,不然内存溢出
# gtf format described in 'https://www.gencodegenes.org/pages/data_format.html'
features_to_keep = 'CDS,UTR,start_codon,stop_codon,five_prime_utr,three_prime_utr'.split(',') # start_codon 在CDS中, stop_codon 在UTR中 # "five_prime_utr", "three_prime_utr" 部分版本是这个
columns_to_keep = ['seqname','gene_id','transcript_id','protein_id','transcript_type','start', 'end', 'feature','strand']
gtf = gtf.filter(pl.col("feature").is_in(features_to_keep))
gtf = gtf.select(columns_to_keep)
gtf = gtf.to_pandas()
gtf = gtf.sort_values(by=['seqname', 'start','end'])
# gtf.feature.unique() # ['UTR', 'start_codon', 'CDS', 'stop_codon'] # ['gene', 'transcript', 'exon', 'CDS', 'start_codon', 'stop_codon', 'UTR']
genome_fa = Fasta(genome_fa_path)
gtf['seq'] = gtf.apply(lambda x: genome_fa[x['seqname']][x['start']-1:x['end']].seq, axis=1)
gtf['anno'] = gtf.apply(lambda x: '-'* (x['end'] - x['start']+1) if x['feature'] in ['UTR','five_prime_utr','three_prime_utr'] else '|'*(x['end']-x['start']+1) , axis=1) # 所以start codon 和stop codon会被标记为 |
gtf.to_csv(mrna_tsv_path,index=None,sep='\t') # # 比原来的文件缩小十倍,只保留了mRNA相关的内容,保留了mRNA序列,reverse和互补配对之前的序列
del genome_fa
print(f"generate mRNA.tsv file: {mrna_tsv_path}\n{gtf.shape}\t{gtf[['seq','anno']].head()}")
return gtf
@staticmethod
def calculate_ribosome_density(numReads_RPF, numReads_RNA, totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA,
readsLength_RPF):
'''
计算ribosome_density
:param numReads_RNA:
:param totalNumReads_RNA:
:param totalNumReads_RPF:
:param readsLength_RNA:
:param readsLength_RPF:
:return:
example:
# 示例值
numReads_RPF = 1000
numReads_RNA = 2000
totalNumReads_RNA = 5000000
totalNumReads_RPF = 3000000
readsLength_RNA = 150
readsLength_RPF = 100
result = calculate_ribosome_density(numReads_RPF, numReads_RNA, totalNumReads_RNA, totalNumReads_RPF,
readsLength_RNA, readsLength_RPF)
print("Ribosome Density:", result)
'''
# Riboseq数据在处理时只保留了20-40nt长度的 reads
readsLength_RPF = np.where(readsLength_RPF > 40, 30, readsLength_RPF)
ratio_numReads = numReads_RPF / numReads_RNA
ratio_totalNumReads = totalNumReads_RNA / totalNumReads_RPF
ratio_readsLength = readsLength_RNA / readsLength_RPF
ribosome_density = np.log2(ratio_numReads * ratio_totalNumReads * ratio_readsLength + 1)
ribosome_density = np.where(numReads_RNA==-1, -1, ribosome_density)
return ribosome_density
def read_meta_file(self, file_path, ribo_experiment, rna_experiment, seq_only=False):
df = pd.read_table(file_path)
if seq_only:
if ribo_experiment:
species = df[df['ribo_experiment'] == ribo_experiment]['Ref'].iloc[0]
elif rna_experiment:
species = df[df['rna_experiment'] == ribo_experiment]['Ref'].iloc[0]
else:
raise ValueError("ribo_experiment or rna_experiment should be provided")
return None,None,None,None,species
row = df[(df['ribo_experiment'] == ribo_experiment) & (df['rna_experiment'] == rna_experiment)].iloc[0]
totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF,species = row['totalNumReads_RNA'], row['totalNumReads_RPF'], row['readsLength_RNA'], row['readsLength_RPF'],row['Ref']
return totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA,readsLength_RPF,species
class RiboBwDataPipeline(RiboDataPipeline):
def __init__(self,
data_path,
ribo_experiment,
rna_experiment,
seq_only=False,
limit=-1,
):
super().__init__(data_path, ribo_experiment, rna_experiment, seq_only, limit)
def load_data(self, path, ribo_experiment=None,rna_experiment=None,region=300, limit=-1,norm=True):
'''
读取数据
1. 根据ribo_experiment,从meta中查询species,avg_len,total counts 等
2. 查询mRNA.fa文件是否存在
不存在:
查询mRNA.tsv是否存在
不存在:
genome.gtf 生成 mRNA.gtf文件 (只含有mRNA相关的行和列,scale the size of gtf)
根据genome.fa 和mRNA.gtf 生成mRNA.fa文件 (包括start or stop codon positions)
3. 读取track文件,生成ribosome_density, ribo_counts, rna_counts
{'ENST00000303577.7': 'PCBP1', # IRES
'ENST00000309311.7': 'EEF2'} # cap dependent
:param path:
:param reference_path:
:param region:
:param limit:
:return: samples
path = ./dataset/pretraining/
'''
"""1. input ribo_experiment, meta"""
seq_only = self.seq_only
cds_min = self.cds_min
# print('load_data in Pipeline')
reference_path = os.path.join(path,'reference')
meta = self.read_meta_file(os.path.join(reference_path, 'experiment_meta.tsv'),ribo_experiment,rna_experiment,seq_only = seq_only) # todo
totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA,readsLength_RPF,species = meta
"""2. check mRNA.fa"""
# # 读取 chromosomes reference
# print('load_data in Pipeline',meta)
mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa')
if region != -1: mrna_fa_path = mrna_fa_path.replace('.fa', f'_{region}.fa')
mrna_fa_path = mrna_fa_path.replace('.fa', f'.pkl')
self.mrna_region_pkl_path = mrna_fa_path
if seq_only and os.access(mrna_fa_path, os.F_OK):
with open(mrna_fa_path, 'rb') as f:
sample_dict = pickle.load(f)
if limit!=-1:
limited_sample_dict = {}
for key in sample_dict.keys():
limited_sample_dict[key] = {transcript_id:sample_dict[key][transcript_id] for transcript_id in list(sample_dict[key].keys())[:limit]}
return limited_sample_dict
mrna_tsv_path = os.path.join(reference_path,species, 'mRNA.tsv')
if not os.access(mrna_tsv_path, os.F_OK):
genome_gtf_path = os.path.join(reference_path,species, 'genome.gtf')
genome_fa_path = os.path.join(reference_path,species, 'genome.fa')
mrna_tsv = self.generate_mRNA_tsv(genome_gtf_path,genome_fa_path,mrna_tsv_path)
else:
mrna_tsv = pd.read_table(mrna_tsv_path)#.iloc[:100] # 读取已经生成的mRNA.tsv文件
# print('load_data in Pipeline',mrna_tsv.shape) # 1459048, 11
"""3. read track files"""
# if limit<10: # debug mode
# debug_ids = 'ENST00000332831.5,ENST00000000233.10'.split(',')
# mrna_tsv = mrna_tsv[mrna_tsv.transcript_id.isin(debug_ids)].reindex()
# region = 6
print(f'filter limit={limit},region={region}')
print('load_data in Pipeline,before filter',mrna_tsv.shape) # 1459048, 11
reference_transcript_ids = list(self.reference_transcript_dict.keys())
# if args.debug:
# region = 6
# limit = 2000
if limit!=-1:
mrna_tsv = mrna_tsv[mrna_tsv.transcript_id.isin(reference_transcript_ids+list(mrna_tsv.transcript_id.unique()[:limit]))]
print('load_data in Pipeline, after filter',mrna_tsv.shape) # 1459048, 11
if not seq_only:
import pyBigWig
fribo_track = os.path.join(path,'track', f'{ribo_experiment}.bw')
frna_track = os.path.join(path,'track', f'{rna_experiment}.bw')
if os.access(fribo_track,os.F_OK) and os.access(frna_track,os.F_OK):
print(f'load {ribo_experiment} and {rna_experiment} tracks')
else:
print(f'Error: {fribo_track} or {frna_track} not found.')
return None
ribo_bw,rna_bw = [pyBigWig.open(fribo_track), pyBigWig.open(frna_track)]
print(f'meta of {ribo_experiment} and {rna_experiment} tracks loaded\n{ribo_bw.header()}\n{rna_bw.header()}')
def iterfunc(x,bw):
chrom, start, end = x['seqname'], x['start'], x['end']
if chrom in bw.chroms():
return np.array(bw.values(chrom, start - 1, end))
else:
return np.zeros(end - start)
# if x['seqname'] in bw or x['end']={ribo_bw.chroms(x['seqname'])},{x['end'], region=6
ribo_counts,
rna_counts,
ribosome_density,
te, self.env, cds_len, mRNA_len,junction_counts) = ans
ref_norm.append((sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA))
if len(ref_norm)==0 and norm:
print(f'Error: no qualified reference transcript (housekeeping when norm=True)')
return None
ref_norm = np.mean(ref_norm,axis=0) if norm and len(ref_norm)>0 else None
print('ref_norm',ref_norm,'sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA)')
'''generate by norm'''
for transcript_id, data in mrna_tsv.groupby('transcript_id'):
tag = data['dataset'].iloc[0]
ans = self.merge_transcript_level(data,total_counts_info=total_counts_info,seq_only=seq_only,cds_min=cds_min,region=region,ref_norm=ref_norm)
if ans is None:continue
sample_dict[tag].append([transcript_id] + ans)
count += 1
if limit == count: break
# datasplit = 'TR_VL_TS.tsv'
# # ~/Data/RNAdesign/Raw_data/_0_reference/GRCh38.p14/mRNA/dataset_split $ head TR_VL_TS.tsv
# # transcript_id dataset seqname gene_id
self.samples = sample_dict
if seq_only:
mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa')
if region !=-1:
mrna_fa_path = mrna_fa_path.replace('.fa',f'_{region}.fa')
if not os.access(mrna_fa_path,os.F_OK) or os.path.getsize(mrna_fa_path)==0:
print(f'generate {sum([len(a.keys()) for a in sample_dict.values()])} sequences to {mrna_fa_path} {os.path.abspath(mrna_fa_path)}')
self.generate_mRNA_fa(mrna_fa_path,force_regenerate=True) # ,force_regenerate=True
mrna_fa_path = mrna_fa_path.replace('.fa',f'.pkl')
if not os.access(mrna_fa_path, os.F_OK) or os.path.getsize(mrna_fa_path)==0:
with open(mrna_fa_path, 'wb') as f:
pickle.dump(sample_dict, f)
self.mrna_region_pkl_path = mrna_fa_path
return sample_dict # transcript_id 作为 groupby key
class RegionDataset(BaseDataset):
"""DST"""
def __init__(
self,
samples,
tokenizer,
args,
region: int = 300,
limit: int = -1,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
freq_weighted_replacement: bool = False,
two_dim_score: bool = False,
two_dim_mask: int = -1,
mask_whole_words: torch.Tensor = None,
):
# 调用父类初始化
super().__init__(
tokenizer=tokenizer,
region=region,
limit=limit,
return_masked_tokens=return_masked_tokens,
seed=seed,
mask_prob=mask_prob,
leave_unmasked_prob=leave_unmasked_prob,
random_token_prob=random_token_prob,
freq_weighted_replacement=freq_weighted_replacement,
two_dim_score=two_dim_score,
two_dim_mask=two_dim_mask,
mask_whole_words=mask_whole_words,
)
# 加载数据
self.args = args
self.samples = samples
if limit!=-1:
self.samples = self.samples[:limit]
self.teacher_tokenizer = DebertaTokenizerFast.from_pretrained("./src/mRNA2vec/tokenizer", use_fast=True)
self.teacher_tokenizer.padding_side = "left"
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
_id, seq, cds_start, cds_stop,*_ = self.samples[idx]
# if len( self.samples[idx]) == 7:
# _id,seq, cds_start, cds_stop, cds_len, mRNA_len,junction_counts = self.samples[idx]
# else:
# _id,seq, cds_start, cds_stop, cds_len, mRNA_len = self.samples[idx]
aa_seq = '-'+self.translate(re.sub(r'[^ACGT]', 'N', seq[1:-1].replace('U','T'))).lower()+'-'
aa_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in aa_seq]),dtype=torch.long)
# aa20 = torch.tensor(np.array([self.tokenizer.indices.get(aa.lower()) for aa in AA_str]),dtype=torch.long)
# nt12 = torch.tensor(np.array([self.seq_to_rnaindex(nn) for aa in AA_str[:-1].upper() for nn in AA_TO_CODONS[aa]]),dtype=torch.long)
# nt12 = defaultdict(list)
# [nt12[self.tokenizer.indices.get(aa.lower())].append(self.seq_to_rnaindex(nn)[0]) for aa in AA_str[:-1].upper() for nn in AA_TO_CODONS[aa]]
#
# 准备1D和2D输入数据
X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index)
'''generate src_data, tgt_data, twod_data '''
src_data, tgt_data, twod_data, loss_mask = self.generate_mask(X, len(seq))
if "ernierna" in self.args.mlm_pretrained_model_path or 'teacher' in self.args.mlm_pretrained_model_path:
teacher_input_ids = src_data
elif "mrna2vec" in self.args.mlm_pretrained_model_path:
teacher_encoder = self.teacher_tokenizer(seq[1:-1],
padding='max_length',
max_length=403,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
teacher_input_ids = teacher_encoder['input_ids'].squeeze(0)
# src_data = torch.where(torch.from_numpy(loss_mask),aa_idx,src_data)
return (src_data,teacher_input_ids, tgt_data, twod_data,aa_idx, loss_mask)
@staticmethod
def seq_to_rnaindex(seq,pad_idx=1, unk_idx=3):
seq = seq.upper()
if seq.count('<') > 1 or seq.count('/') > 0:
seq = seq.replace('', '_').replace('', 'V').replace('', '^').replace('/','NNN')
l = len(seq)
X = np.ones((1, l))
for j in range(l):
if seq[j] in set('Aa'):
X[0, j] = 5
elif seq[j] in set('UuTt'):
X[0, j] = 6
elif seq[j] in set('Cc'):
X[0, j] = 7
elif seq[j] in set('Gg'):
X[0, j] = 4
elif seq[j] in set('_'):
X[0,j] = pad_idx
elif seq[j] in set('^'):
X[0,j] = 2 # eos
else:
X[0,j] = unk_idx # linker
return X
'''generate'''
class BackBoneDataset(RegionDataset):
'''for distillation using ribo dataset'''
def __init__(
self,
samples,
tokenizer,
args,
region: int = 300,
limit: int = -1,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
freq_weighted_replacement: bool = False,
two_dim_score: bool = False,
two_dim_mask: int = -1,
mask_whole_words: torch.Tensor = None,
input_mask = True,Kozak_GS6H_Stop3='GCCACC,GGGAGCCACCACCACCATCACCAC,TGATAATAG'
):
# 调用父类初始化
super().__init__(
samples=samples,
tokenizer=tokenizer,
args=args,
region=region,
limit=limit,
return_masked_tokens=return_masked_tokens,
seed=seed,
mask_prob=mask_prob,
leave_unmasked_prob=leave_unmasked_prob,
random_token_prob=random_token_prob,
freq_weighted_replacement=freq_weighted_replacement,
two_dim_score=two_dim_score,
two_dim_mask=two_dim_mask,
mask_whole_words=mask_whole_words,
)
self.input_mask = input_mask
self.Kozak_GS6H_Stop3 = Kozak_GS6H_Stop3.upper()
def __getitem__(self, idx):
# (_id,seq,cds_start, cds_stop,
# ribo_counts,rna_counts,
# ribosome_density,te,env,*_) = self.samples[idx] #cds_len,mRNA_len,junction_counts
data = self.samples.iloc[idx]
_id = data['_id']
seq = data['sequence']
seq = seq.replace('U','T') # for translate
start,stop = self.region + 1, self.region * 3 + 4 # ATG seq[301:301+3], TGG seq[901:901+3]
Kozak, GS6H, Stop3 = self.Kozak_GS6H_Stop3.split(',') if ',' in self.Kozak_GS6H_Stop3 else '','',''
# Kozak,GS6H,Stop3 = 'GCCACC,GGGAGCCACCACCACCATCACCAC,TGATAATAG'.split(',')
'''fix nt, not opt'''
seq = seq[:start-len(Kozak)].replace('ATG','ATC') + Kozak + seq[start:stop-len(GS6H)-len(Stop3)] + GS6H+ Stop3 +seq[stop:]
# Kozak = 'GCCACC' # GCCACCATGGCG
# seq = seq[:start-6].replace('ATG','ATC') + Kozak + seq[start:stop-3] +'TAATAATAA'+seq[stop:-6] # https://www.nature.com/articles/s41467-024-48387-x#Fig1
# seq = seq[:start-6].replace('ATG','ATC') + Kozak + seq[start:stop-3] + 'TGATAATAG' +seq[stop:-6] # # mus frequent:TGATAATAG;3 stop codon, GS6H
# seq = seq[:start-6].replace('ATG','ATC') + Kozak + seq[start:stop-3] + 'TGATAATAG' +seq[stop+6:] # # mus frequent:TGATAATAG;3 stop codon, GS6H
'''whole mask'''
aa_seq = '-'+self.translate(re.sub(r'[^ACGTU]', 'N', seq[1:-1])).lower()+'-'
aa_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in aa_seq]),dtype=torch.long)
# 准备1D和2D输入数据
# 对src,tgt,twod_data,loss_mask进行 mask 处理
X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index)
_,_,_,mask = self.generate_mask(X,len(seq)) # [1205] # (1, 1205)
seq_length = len(seq)
vocab_size = len(self.tokenizer)
# mask = torch.full_like(logits, float("-inf"))
masked_logits = torch.full(size=(seq_length,vocab_size),fill_value=float("-inf")) # [1205, 32]
masked_logits[np.arange(X.shape[1]),X.reshape(-1)]=0 # target 的位置不mask
'''CDS mask'''
X_CDS,masked_logits_CDS = self.CDS_mask(seq, start, stop- len(GS6H) - len(Stop3)) # (603,),torch.Size([1, 603, 32])
mask[start:stop- len(GS6H) - len(Stop3)] = X_CDS==self.tokenizer.mask_index
special = self.seq_to_rnaindex('ACGT').reshape(-1)
masked_logits[:start-len(Kozak),special]=0 # UTR5
masked_logits[stop:,special]=0 # UTR3
masked_logits[start:stop- len(GS6H) - len(Stop3)] = masked_logits_CDS
for token in ['', '', '', '']:
masked_logits[X.reshape(-1) == self.tokenizer.indices.get(token), :] = float("-inf")
masked_logits[X.reshape(-1)==self.tokenizer.indices.get(token),self.tokenizer.indices.get(token)] = 0
# masked_logits = masked_logits.unsqueeze(0)
mask[start-len(Kozak):start+3] = False
mask[stop-len(GS6H)-len(Stop3):stop] = False
src_data, tgt_data, twod_data, loss_mask = self.generate_mask(X, len(seq),mask=mask,input_mask=self.input_mask)
loss_mask = torch.tensor(mask, dtype=torch.bool)#.unsqueeze(0)
src_data = torch.where(loss_mask,aa_idx,src_data)
src_env = torch.tensor(self.args.env_id, dtype=torch.long)
# test = torch.concat([torch.from_numpy(exp_one_d).float(),src_exp_data,loss_mask.unsqueeze(1).float(),src_exp_mask,tgt_exp_data],dim=1).detach().cpu().numpy()
return (_id,src_data,tgt_data,twod_data,loss_mask,masked_logits,src_env) # tgt_exp_data,tgt_data 传送被mask的部分会更节约资源
def CDS_mask(self,seq,start,stop):
# target_positions = [self.region + 1,self.region *3 + 4] # ATG seq[301:301+3], TGG seq[901:901+3]
# backbone_cds = re.sub(r'[^ACGT]', 'N', seq[target_positions[0]:target_positions[1]])
backbone_cds = re.sub(r'[^ACGT]', 'N', seq[start:stop])
# 目标氨基酸序列
# target_protein = ['M', 'A', 'L']
target_protein = self.translate(backbone_cds,repeate=1).upper()
target_protein_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in target_protein.lower()]),dtype=torch.long)
X = self.seq_to_rnaindex(backbone_cds, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index).reshape(-1)
num_rows, num_cols = len(target_protein),3
cds_mask = np.zeros([num_rows, num_cols],dtype=int)
# 计算要掩码的行数(20%)
rows_to_mask = int(num_rows * self.mask_prob *2)
# 随机选择要掩码的行
masked_rows = random.sample(range(num_rows), rows_to_mask)
# 为每个掩码行随机选择一列
masked_cols = np.random.randint(0, num_cols, size=rows_to_mask)
# 使用 NumPy 的向量化操作更新数据框
cds_mask[masked_rows, masked_cols] = 1
cds_mask = cds_mask.reshape(-1)
X[cds_mask==1]=self.tokenizer.mask_index # mask
# 假设的 logits 输出,形状为 (batch_size, seq_length, vocab_size)
# 这里假设 batch_size=1,seq_length=9(即 3 个密码子),vocab_size=4(A, U, C, G)
# logits = torch.zeros(len(backbone_cds), len(self.tokenizer))
# 创建掩码
# backbone_cds = 'AT_G_C_TC'
# base_map = {0: 'A', 1: 'T', 2: 'C', 3: 'G'}
# reverse_dictionary(base_map)
masked_logits = self.create_codon_mask(target_protein_idx.numpy(), X, self.amino_acid_to_codons,self.tokenizer)
# joint_mask = create_codon_mask(logits, target_protein,backbone_cds, AA_TO_CODONS)
# 应用掩码
# masked_logits = mask + logits
return X,masked_logits
class RiboDataset(RegionDataset):
'''for distillation using ribo dataset'''
def __init__(
self,
samples,
tokenizer,
args,
region: int = 300,
limit: int = -1,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
freq_weighted_replacement: bool = False,
two_dim_score: bool = False,
two_dim_mask: int = -1,
mask_whole_words: torch.Tensor = None,
):
# 调用父类初始化
super().__init__(
samples=samples,
tokenizer=tokenizer,
args=args,
region=region,
limit=limit,
return_masked_tokens=return_masked_tokens,
seed=seed,
mask_prob=mask_prob,
leave_unmasked_prob=leave_unmasked_prob,
random_token_prob=random_token_prob,
freq_weighted_replacement=freq_weighted_replacement,
two_dim_score=two_dim_score,
two_dim_mask=two_dim_mask,
mask_whole_words=mask_whole_words,
)
def __getitem__(self, idx):
# (_id,seq,cds_start, cds_stop,
# ribo_counts,rna_counts,
# ribosome_density,te,env,*_) = self.samples[idx] #cds_len,mRNA_len,junction_counts
(_id,seq,cds_start, cds_stop,
ribo_counts,rna_counts,
ribosome_density,te,env,cds_len,mRNA_len,junction_counts) = self.samples[idx] #cds_len,mRNA_len,junction_counts
aa_seq = '-'+self.translate(re.sub(r'[^ACGT]', 'N', seq[1:-1])).lower()+'-'
aa_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in aa_seq]),dtype=torch.long)
# 准备1D和2D输入数据
# 对src,tgt,twod_data,loss_mask进行 mask 处理
X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index)
src_data,tgt_data,twod_data,mask = self.generate_mask(X,len(seq))
loss_mask = torch.tensor(mask, dtype=torch.bool)#.unsqueeze(0)
src_data = torch.where(loss_mask,aa_idx,src_data)
window = 31
exp_one_d = np.stack([ribo_counts,rna_counts,ribosome_density],axis=1)
tgt_exp_data = torch.from_numpy(exp_one_d).float()
tgt_exp_data = tgt_exp_data.permute(1,0)
tgt_exp_data = F.avg_pool1d(tgt_exp_data,kernel_size=window,padding=window//2,stride=1)
tgt_exp_data = tgt_exp_data.permute(1,0)
tgt_exp_data[~loss_mask,:] = -1 # [1205, 3]
src_exp_data = torch.from_numpy(exp_one_d).float()
src_exp_mask = F.max_pool1d(loss_mask.unsqueeze(0).repeat(3,1).float(),kernel_size=window,padding=window//2,stride=1).permute(1,0) # 形状变为 (1, L).permute(1,0)
src_exp_data = torch.where(src_exp_mask.bool(),torch.zeros_like(src_exp_mask),src_exp_data)
# src_exp_data = torch.zeros_like(tgt_exp_data)
# src_exp_data = [] # zero or ones are tried
#src_exp_data = None TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found
src_data = torch.where(loss_mask,aa_idx,src_data)
src_env = torch.tensor(env, dtype=torch.long)
src_feature = np.array([cds_len,mRNA_len,junction_counts])
src_feature = torch.from_numpy(src_feature).float() # .float() ==torch.float32
src_feature = torch.log(src_feature+1) # 取对数
tgt_te = torch.tensor(te, dtype=torch.float32)
# test = torch.concat([torch.from_numpy(exp_one_d).float(),src_exp_data,loss_mask.unsqueeze(1).float(),src_exp_mask,tgt_exp_data],dim=1).detach().cpu().numpy()
return src_data,src_exp_data,src_env,src_feature,tgt_data,tgt_exp_data,tgt_te,twod_data,loss_mask # tgt_exp_data,tgt_data 传送被mask的部分会更节约资源
class DownstreamDataset(RegionDataset):
def __init__(
self,
samples,
tokenizer,
args,
region: int = 300,
limit: int = -1,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
freq_weighted_replacement: bool = False,
two_dim_score: bool = False,
two_dim_mask: int = -1,
mask_whole_words: torch.Tensor = None,
seq_len: int = 174,
pad_method: str = "pre",
column: str = "sequence",
cds_len:str='cds_len',
mRNA_len:str='mRNA_len',
label: str = "IRES_Activity",
):
# 调用父类初始化
super().__init__(
samples=samples,
tokenizer=tokenizer,
args=args,
region=region,
limit=limit,
return_masked_tokens=return_masked_tokens,
seed=seed,
mask_prob=mask_prob,
leave_unmasked_prob=leave_unmasked_prob,
random_token_prob=random_token_prob,
freq_weighted_replacement=freq_weighted_replacement,
two_dim_score=two_dim_score,
two_dim_mask=two_dim_mask,
mask_whole_words=mask_whole_words,
)
# 特有属性
self.label = label
self.column = column
self.seq_len = seq_len
self.cds_len = cds_len
self.mRNA_len = mRNA_len
self.pad_method = pad_method
if limit!=-1:
self.samples = self.samples.iloc[:limit]
'''
eGFP
https://www.ncbi.nlm.nih.gov/nuccore/L29345.1
>L29345.1 Aequorea victoria green-fluorescent protein (GFP) mRNA, complete cds| 26..742
TACACACGAATAAAAGATAACAAAGATGAGTAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGCGATGTTAATGGGCAAAAATTCTCTGTCAGTGGAGAGGGTGAAGGTGATGCAACATACGGAAAACTTACCCTTAAATTTATTTGCACTACTGGGAAGCTACCTGTTCCATGGCCAACACTTGTCACTACTTTCTCTTATGGTGTTCAATGCTTTTCAAGATACCCAGATCATATGAAACAGCATGACTTTTTCAAGAGTGCCATGCCCGAAGGTTATGTACAGGAAAGAACTATATTTTACAAAGATGACGGGAACTACAAGACACGTGCTGAAGTCAAGTTTGAAGGTGATACCCTTGTTAATAGAATCGAGTTAAAAGGTATTGATTTTAAAGAAGATGGAAACATTCTTGGACACAAAATGGAATACAACTATAACTCACATAATGTATACATCATGGCAGACAAACCAAAGAATGGAATCAAAGTTAACTTCAAAATTAGACACAACATTAAAGATGGAAGCGTTCAATTAGCAGACCATTATCAACAAAATACTCCAATTGGCGATGGCCCTGTCCTTTTACCAGACAACCATTACCTGTCCACACAATCTGCCCTTTCCAAAGATCCCAACGAAAAGAGAGATCACATGATCCTTCTTGAGTTTGTAACAGCTGCTGGGATTACACATGGCATGGATGAACTATACAAATAAATGTCCAGACTTCCAATTGACACTAAAGTGTCCGAACAATTACTAAATTCTCAGGGTTCCTGGTTAAATTCAGGCTGAGACTTTATTTATATATTTATAGATTCATTAAAATTTTATGAATAATTTATTGATGTTATTAATAGGGGCTATTTTCTTATTAAATAGGCTACTGGAGTGTAT
'''
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
data = self.samples.iloc[idx]
seq = data[self.column]
target = data[self.label]
X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index)
one_d, twod_d = self.prepare_input_for_ernierna(X, len(seq))
one_d = one_d.view(-1)
# 转换为PyTorch张量
if not torch.is_tensor(one_d):
src_data = torch.from_numpy(one_d) # 假设one_d是你想要的1D输入特征
else:
src_data = one_d
if not torch.is_tensor(twod_d):
twod_data = torch.from_numpy(twod_d.squeeze(dim=-1)) # 二维输入特征
else:
twod_data = twod_d.squeeze(dim=-1)
src_env = torch.tensor(self.args.env_id, dtype=torch.long)
cds_len = data[self.cds_len] if self.cds_len in data else 742-26+1
mRNA_len = data[self.mRNA_len] if self.mRNA_len in data else 922
# cds_len = 742-26+1
# mRNA_len = 922
junction_counts= 0
src_feature = np.array([cds_len,mRNA_len,junction_counts])
src_feature = torch.from_numpy(src_feature).float() # .float() ==torch.float32
# src_feature = torch.log(src_feature+1) # 取对数
# 获取回归任务的目标值
target = torch.tensor(target, dtype=torch.float32) # 假设每个样本都有一个'target'字段表示其回归目标
return src_data, twod_data,src_env, src_feature, target # 返回所有必要的输入和目标值
class RegressionDataset(BaseDataset):
"""处理回归任务的Dataset"""
def __init__(
self,
path,
tokenizer,
args,
region: int = 300,
limit: int = -1,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
freq_weighted_replacement: bool = False,
two_dim_score: bool = False,
two_dim_mask: int = -1,
mask_whole_words: torch.Tensor = None,
seq_len: int = 174,
pad_method: str = "pre",
column: str = "sequence",
label: str = "IRES_Activity",
returnid=None
):
# 调用父类初始化
super().__init__(
tokenizer=tokenizer,
region=region,
limit=limit,
return_masked_tokens=return_masked_tokens,
seed=seed,
mask_prob=mask_prob,
leave_unmasked_prob=leave_unmasked_prob,
random_token_prob=random_token_prob,
freq_weighted_replacement=freq_weighted_replacement,
two_dim_score=two_dim_score,
two_dim_mask=two_dim_mask,
mask_whole_words=mask_whole_words,
)
# 特有属性
self.label = label
self.column = column
self.seq_len = seq_len
self.pad_method = pad_method
self.args = args
self.returnid = returnid
# 加载数据
self.samples = self.load_data(
path,
seq_len=seq_len,
column=column,
pad_method=pad_method
)
if limit!=-1:
self.samples = self.samples.iloc[:limit]
def load_data(self, path, **kwargs):
return self.read_csv_file(
path,
seq_len=kwargs['seq_len'],
column=kwargs['column'],
pad_method=kwargs['pad_method']
)
def read_csv_file(self,file_path, **kwargs):
# 保持原有CSV读取逻辑
try:
column = kwargs['column']
data = pd.read_csv(file_path)
if column not in data.columns:
data[column] = data.apply(self.generate_inputs, axis=1) # 预处理数据
return pad_or_truncate_utr(
data,
pad_method=kwargs['pad_method'],
column=kwargs['column'],
input_len=kwargs['seq_len']
)
except FileNotFoundError:
print(f"Error: File '{file_path}' not found.")
return []
def __len__(self):
return len(self.samples)
def __getitem__(self,idx):
data = self.samples.iloc[idx]
seq = data[self.column]
target = data[self.label]
# X, data_seq = self.seq_to_rnaindex_and_onehot(seq)
X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index)
# 准备1D和2D输入数据
one_d, twod_d = self.prepare_input_for_ernierna(X, len(seq))
one_d = one_d.view(-1)
# 转换为PyTorch张量
if not torch.is_tensor(one_d):
src_data = torch.from_numpy(one_d) # 假设one_d是你想要的1D输入特征
else:
src_data = one_d
if not torch.is_tensor(twod_d):
twod_data = torch.from_numpy(twod_d.squeeze(dim=-1)) # 二维输入特征
else:
twod_data = twod_d.squeeze(dim=-1)
src_env = torch.tensor(self.args.env_id, dtype=torch.long)
cds_len = 742-26+1
mRNA_len = 922
junction_counts= 0
src_feature = np.array([cds_len,mRNA_len,junction_counts])
src_feature = torch.from_numpy(src_feature).float() # .float() ==torch.float32
src_feature = torch.log(src_feature+1) # 取对数
# 获取回归任务的目标值
target = torch.tensor(target, dtype=torch.float32) # 假设每个样本都有一个'target'字段表示其回归目标
if self.returnid is None:return src_data, twod_data,src_env, src_feature, target # 返回所有必要的输入和目标值
else: return data[self.returnid],src_data, twod_data,src_env, src_feature, target
class MaotaoDataset(BaseDataset):
"""处理回归任务的Dataset"""
def __init__(
self,
path,
tokenizer,
args,
region: int = 300,
limit: int = -1,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
freq_weighted_replacement: bool = False,
two_dim_score: bool = False,
two_dim_mask: int = -1,
mask_whole_words: torch.Tensor = None,
seq_len: int = 1200,
column: str = 'off_start,off_end,full_len,type,_id,species,maotao_id,truncated_aa,cai_best_nn',
label: str = "truncated_nn,cai_nature",
codon_table_path: str='maotao_file/codon_table/codon_usage_{species}.csv',
species_list:str="""mouse,Ec,Sac,Pic,Human""",
type_list:str="""full,head,tail,boundary,middle""",
# protein_alphabet_list:str="""_ACDEFGHIKLMNPQRSTVWY*""", # padding 被写死为1了 # # 10-31
rna_alphabet_list:str="""GAUC""",# 用网络自带的编码
returnid = None
):
# 调用父类初始化
super().__init__(
tokenizer=tokenizer,
region=region,
limit=limit,
return_masked_tokens=return_masked_tokens,
seed=seed,
mask_prob=mask_prob,
leave_unmasked_prob=leave_unmasked_prob,
random_token_prob=random_token_prob,
freq_weighted_replacement=freq_weighted_replacement,
two_dim_score=two_dim_score,
two_dim_mask=two_dim_mask,
mask_whole_words=mask_whole_words,
)
# 特有属性
self.species = {k:v for v,k in enumerate(species_list.split(','))}
self.species.update({v:v for v,k in enumerate(species_list.split(','))})
self.seq_types = {k:v for v,k in enumerate(type_list.split(','))}
self.seq_types.update({v:v for v,k in enumerate(type_list.split(','))})
# self.protein_alphabet = {k:v for v,k in enumerate(protein_alphabet_list)}
self.rna_alphabet = {k:v+4 for v,k in enumerate(rna_alphabet_list)}
self.label = label.split(',')
self.column = column.split(',')
self.seq_len = seq_len
self.args = args
# 加载数据
self.samples = self.load_data(path)
# 加载codontable
self.codon_instance_rna = {self.species[species]: Codon(codon_table_path.format(species=species), rna=True) for species in
species_list.split(',')}
if limit!=-1:
self.samples = self.samples.iloc[:limit]
def load_data(self, path, **kwargs):
if os.access(path.replace('.csv','_processed.pickle'), os.R_OK):
df = pd.read_pickle(path.replace('.csv','_processed.pickle'))
else:
df = pd.read_csv(path)
df['truncated_aa'] = df['truncated_aa'].apply(lambda x: re.sub(r'[^acdefghiklmnpqrstvwy*_]', '_', x.lower()))
df['cai_best_nn'] = df['cai_best_nn'].apply(lambda x: x.upper().replace('T','U'))
df['species'] = df['species'].apply(lambda x: self.species[x])
df['type'] = df['type'].apply(lambda x: self.seq_types[x])
df.to_csv(path.replace('.csv','_processed.csv'),index=False)
with open(path.replace('.csv','_processed.pickle'), 'wb') as f:
pickle.dump(df,f)
return df
def __len__(self):
return len(self.samples)
def __getitem__(self,idx):
data = self.samples.iloc[idx]
maotao_id = data['maotao_id']
aa_index = np.array([self.tokenizer.index(x) for x in data['truncated_aa']])
# input idx
aa_idx = torch.from_numpy(aa_index).long()
seq = data['cai_best_nn']
'''prepare 1D and 2D input data'''
# X, data_seq = self.seq_to_rnaindex_and_onehot(seq)
X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index)
# 准备1D和2D输入数据
one_d, twod_d = self.prepare_input_for_ernierna(X, len(seq))
one_d = one_d.view(-1)
# 转换为PyTorch张量
if not torch.is_tensor(one_d):
src_data = torch.from_numpy(one_d) # 假设one_d是你想要的1D输入特征
else:
src_data = one_d
if not torch.is_tensor(twod_d):
twod_data = torch.from_numpy(twod_d.squeeze(dim=-1)) # 二维输入特征
else:
twod_data = twod_d.squeeze(dim=-1)
continuous_features = np.array([data['off_start'],data['off_end'],data['full_len']])
continuous_features = np.log(np.maximum(continuous_features+3,0)+1)
continuous_features = torch.from_numpy(continuous_features).float() # .float() ==torch.float32
# continuous_features = torch.log(torch.max(torch.tensor(continuous_features+3),torch.tensor(0))+1) # 取对数
species_features = torch.tensor(data['species'],dtype=torch.long)
truncated_features = torch.tensor(data['type'],dtype=torch.long)
ith_nn_prob = self.codon_instance_rna[data['species']].frame_ith_aa_base_fraction
nn_prob = self.create_base_prob(data['truncated_aa'],ith_nn_prob,self.rna_alphabet,self.tokenizer)
'''output'''
if 'truncated_nn' in data:
target_nn = self.seq_to_rnaindex(data['truncated_nn'], pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index).reshape(-1)
# 获取回归任务的目标值
target = torch.tensor(data['cai_nature'], dtype=torch.float32) # 假设每个样本都有一个'target'字段表示其回归目标
else:
target_nn = self.seq_to_rnaindex(data['cai_best_nn'], pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index).reshape(-1)
target = torch.tensor(0, dtype=torch.float32) # 假设每个样本都有一个'target'字段表示其回归目标
target_nn = torch.from_numpy(target_nn).long()
frames = [1, 2, 3]
backbone_cds_list = self.modify_codon_by_frames(target_nn, frames=frames,
masked_token=self.tokenizer.mask_index)
# backbone_cds_list = self.modify_codon_by_frames(src_data, frames = frames,masked_token=self.tokenizer.mask_index)
masked_logits_list = []
for backbone_cds, frame in zip(backbone_cds_list, frames):
masked_logits = self.create_codon_mask(aa_idx, backbone_cds, self.amino_acid_to_codons,
self.tokenizer)
masked_logits_list.append(masked_logits.unsqueeze(0))
# 'UUCACCCAGGCCACGCGGAGUACGAUCGAGUGUACAGUGAA'
# test = masked_logits.numpy()
masked_logits_list = torch.cat(masked_logits_list, dim=0)
return src_data, twod_data, aa_idx,continuous_features, species_features, truncated_features, target_nn, target,masked_logits_list[...,:10],nn_prob[...,:10], maotao_id
# [(a.shape, a.dtype) for a in [src_data, twod_data, aa_idx,continuous_features, species_features, truncated_features, target_nn, target]]
# Out[10]:
# [(torch.Size([1200]), torch.int64),
# (torch.Size([1, 1200, 1200]), torch.float64),
# (torch.Size([400]), torch.int64),
# (torch.Size([3]), torch.float32),
# (torch.Size([]), torch.int64),
# (torch.Size([]), torch.int64),
# (torch.Size([1200]), torch.int64),
# (torch.Size([]), torch.float32)]
@staticmethod
def modify_codon_by_frames(sequence, frames=[1, 2, 3], masked_token='_'):
"""
高级版本:支持自定义修改函数
参数:
sequence (str): 输入序列
frame (int): 要修改的密码子位置 (1, 2, 3)
modify_func (callable): 修改函数,接收原帧字符串,返回修改后的字符串
返回:
str: 修改后的重建序列
"""
# 清理序列
# seq = sequence.upper().replace(' ', '').replace('\n', '')
seq = sequence
# seq = seq[:len(seq) - len(seq) % 3]
# 使用切片提取帧
frames_seq = [seq[0::3], seq[1::3], seq[2::3]]
reconstructed_list = []
# 应用修改函数
for ith,frame in enumerate(frames_seq):
if ith+1 in frames:
tmp_seq = deepcopy(frames_seq)
tmp_seq[ith] = [masked_token] * len(frames_seq[ith])
# 重建序列
# reconstructed = None
if isinstance(seq,str):
reconstructed = ''.join(
tmp_seq[0][i] + tmp_seq[1][i] + tmp_seq[2][i]
for i in range(len(tmp_seq[0]))
)
elif isinstance(seq,torch.Tensor):
tmp_seq[ith] = torch.from_numpy(np.array(tmp_seq[ith]))
reconstructed = torch.stack(tmp_seq, dim=1).reshape(-1)
elif isinstance(seq,np.ndarray):
tmp_seq[ith] = np.array(tmp_seq[ith])
reconstructed = np.stack(tmp_seq, axis=1).reshape(-1)
else:
raise ValueError(type(seq))
# reconstructed = torch.cat([tmp_seq[0][i],tmp_seq[1][i],tmp_seq[2][i]
# for i in range(len(tmp_seq[0]))])
# reconstructed = [tmp_seq[0][i] + tmp_seq[1][i] + tmp_seq[2][i]
# for i in range(len(tmp_seq[0]))]
reconstructed_list.append(deepcopy(reconstructed))
return reconstructed_list
def gaussian(x):
return math.exp(-0.5*(x*x))
def paired(x,y,lamda=0.8):
if x == 5 and y == 6:
return 2
elif x == 4 and y == 7:
return 3
elif x == 4 and y == 6:
return lamda
elif x == 6 and y == 5:
return 2
elif x == 7 and y == 4:
return 3
elif x == 6 and y == 4:
return lamda
else:
return 0
def pad_or_truncate_utr(data, input_len, pad_method,column='utr',pad_mark='_'):
def process_utr(utr):
if len(utr) < input_len:
if pad_method == 'pre':
padded_utr = pad_mark * (input_len - len(utr)) + utr
elif pad_method == 'behind':
padded_utr = utr + pad_mark * (input_len - len(utr))
else:
padded_utr = utr[-input_len:]
return padded_utr
data[column] = data[column].apply(process_utr)
return data
# def do_createmat(data, base_range=30, lamda=0.8):
# paird_map = np.array([[paired(i, j, lamda) for i in range(30)] for j in range(30)]) # token
# data_index = np.arange(0, len(data))
# # np.indices((2,2))
# coefficient = np.zeros([len(data), len(data)])
# # mat = np.zeros((len(data),len(data)))
# score_mask = np.full((len(data), len(data)), True)
# for add in range(base_range):
# data_index_x = data_index - add
# data_index_y = data_index + add
# score_mask = ((data_index_x >= 0)[:, None] & (data_index_y < len(data))[None, :]) & score_mask
# data_index_x, data_index_y = np.meshgrid(data_index_x.clip(0, len(data) - 1),
# data_index_y.clip(0, len(data) - 1), indexing='ij')
# score = paird_map[data[data_index_x], data[data_index_y]]
# score_mask = score_mask & (score != 0)
#
# coefficient = coefficient + score * score_mask * gaussian(add)
# if ~(score_mask.any()):
# break
# score_mask = coefficient > 0
# for add in range(1, base_range):
# data_index_x = data_index + add
# data_index_y = data_index - add
# score_mask = ((data_index_x < len(data))[:, None] & (data_index_y >= 0)[None, :]) & score_mask
# data_index_x, data_index_y = np.meshgrid(data_index_x.clip(0, len(data) - 1),
# data_index_y.clip(0, len(data) - 1), indexing='ij')
# score = paird_map[data[data_index_x], data[data_index_y]]
# score_mask = score_mask & (score != 0)
# coefficient = coefficient + score * score_mask * gaussian(add)
# if ~(score_mask.any()):
# break
# return coefficient
def do_createmat(data, base_range=30, lamda=0.8):
paird_map = np.array([[paired(i, j, lamda) for i in range(30)] for j in range(30)]) # token
data_index = np.arange(0, len(data))
# np.indices((2,2))
coefficient = np.zeros([len(data), len(data)])
# mat = np.zeros((len(data),len(data)))
score_mask = np.full((len(data), len(data)), True)
for add in [0,300]:
data_index_x = data_index - add
data_index_y = data_index + add
score_mask = ((data_index_x >= 0)[:, None] & (data_index_y < len(data))[None, :]) & score_mask
data_index_x, data_index_y = np.meshgrid(data_index_x.clip(0, len(data) - 1),
data_index_y.clip(0, len(data) - 1), indexing='ij')
score = paird_map[data[data_index_x], data[data_index_y]]
score_mask = score_mask & (score != 0)
coefficient = coefficient + score * score_mask * gaussian(add)
if ~(score_mask.any()):
break
return coefficient
def creatmat(data, base_range=30, lamda=0.8):
return do_createmat(data, base_range=base_range, lamda=lamda)
# if len(data.shape)==1:return do_createmat(data,base_range=base_range,lamda =lamda)
# else:
# coefficient = np.zeros((data.shape[0],data.shape[1],data.shape[1]))
# for i in range(data.shape[0]):
# coefficient[i,:,:] = do_createmat(data[i:i+1,:], base_range=base_range, lamda=lamda)
# return coefficient
import argparse
if __name__ == '__main__':
print('start generating')
# # 获取 pretraining 和 dataset 的 args
# from model.tools import get_dataset_args, get_pretraining_args
# pretraining_parser = get_pretraining_args()
# dataset_parser = get_dataset_args()
#
# # 合并 args
# parser = argparse.ArgumentParser(parents=[pretraining_parser, dataset_parser], add_help=False,conflict_handler='resolve')
# # dataset_parser = get_dataset_args()
# ## 合并 args
# # parser = argparse.ArgumentParser(parents=[ dataset_parser], add_help=False,
# # conflict_handler='resolve')
#
# args = parser.parse_args()
# args.batch_size = 5
#
# args.ffasta = '/public/home/jiang_jiuhong/Data/RNAdesign/Raw_data/_0_reference/GRCh38.p14/mRNA/full.fa'
# # ans = RNADataset.read_fasta_file(args.ffasta)
#
# args.device = 'cpu'
# vocab_path = args.arg_overrides['data'] + '/dict.txt'
# tokenizer = Dictionary.load(vocab_path)
# tokenizer.mask_index = tokenizer.add_symbol('')
# # train_ds = RNADataset(args.ffasta,max_length=args.region * 2,tokenizer=tokenizer)
# # train_loader = DataLoader(
# # train_ds,
# # batch_size=args.batch_size,
# # pin_memory=True,
# # drop_last=False,
# # shuffle=False,
# # num_workers=args.num_workers,
# # sampler=None
# # )
# #
# # for step, (src_data,tgt_data,twod_data,loss_mask) in enumerate(train_loader):
# # print(step, [a.shape for a in [src_data, tgt_data, twod_data,
# # loss_mask]]) # [torch.Size([1, 1203]), torch.Size([1, 1203]), torch.Size([1, 1203, 1203, 1])]
# # a = [a.numpy()[0] for a in [src_data, tgt_data, twod_data,
# # loss_mask]]
#
#
# # train_ds = RegressionDataset(args.downstream_data_path+'/IRES_linear/TS.csv', tokenizer, seq_len=args.seq_len, column=args.column, label=args.label,
# # pad_method=args.pad_method)
# # train_loader = DataLoader(
# # train_ds,
# # batch_size=args.batch_size,
# # pin_memory=True,
# # drop_last=False,
# # shuffle=False,
# # )
# # for step, data in enumerate(train_loader):
# # print(step, [a.shape for a in data]) # [torch.Size([1, 1203]), torch.Size([1, 1203]), torch.Size([1, 1203, 1203, 1])]
#
# '''RiboDataPipeline'''
# ribo_experiment, rna_experiment = 'SRX5164421','SRX5164417' #TAIR10#
# ribo_experiment, rna_experiment = 'SRX12763793','SRX12763783' # human
# ribo_experiment, rna_experiment = 'SRX9444526','SRX9444530' # mouse
# print(os.path.abspath(args.exp_pretrain_data_path))
# # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=-1,limit=-1,cds_min=100) # generate mRNA.fa
# # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=300,limit=-1,cds_min=100) # generate mRNA_region.fa
# # RDP = RiboDataPipeline(args.exp_pretrain_data_path, ribo_experiment, rna_experiment, seq_only=False, region=6,limit=300,cds_min=100) # loading ribosome counts data
# # TR,VL,TS = RDP.samples['TR'],RDP.samples['VL'],RDP.samples['TS']
# # with open('./dataset/experiment/nature/reference/GRCh38.p14/mRNA_300.pkl','wb') as f:
# # pickle.dump((TR,VL,TS),f)
# #
# # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=300,limit=300,cds_min=100,norm=True) # generate mRNA_region.fa
# # # TR,VL,TS = RDP.samples['TR'],RDP.samples['VL'],RDP.samples['TS']
# # with open('./dataset/experiment/nature/reference/GRCh38.p14/mRNA_300.pkl','wb') as f:
# # pickle.dump(RDP.samples,f)
#
#
# '''why 300'''
# # region = 1000
# # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=False, region=1000,limit=-1,cds_min=100,norm=False) # generate mRNA_region.fa
# # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=2000,limit=-1,cds_min=100,norm=False) # generate mRNA_region.fa
# # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=2000,limit=-1,cds_min=100,norm=False) # generate mRNA_region.fa
# # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=-1,limit=-1,cds_min=100,norm=False) # generate mRNA_region.fa
# # # TR,VL,TS = RDP.samples['TR'],RDP.samples['VL'],RDP.samples['TS']
# # ftrack = f'./dataset/experiment/nature/track_{ribo_experiment}_{rna_experiment}_{region}_counts_not_norm.pkl'
# # ref_norm = [690.98991075,2214.2488917]
# # # ref_norm = [1,1]
# #
# # with open(ftrack,'wb') as f:
# # data = RDP.samples
# # TS = deepcopy(data['TR'])
# # df = pd.DataFrame(TS).T
# # df.columns = 'seq,cds_start,cds_stop,ribo_counts,rna_counts,ribosome_density,te,env,cds_len,mRNA_len'.split(',')
# # df['RPF_counts'] = df['ribo_counts'].apply(lambda x: sum(x[1001:3004]))
# # df['RNA_counts'] = df['rna_counts'].apply(lambda x: sum(x[1001:3004]))
# #
# # df = df[df['RPF_counts']>100]
# # df = df[df['RNA_counts']>100]
# # df = df[df['cds_len']>2000]
# # df = df[df['te']!=-1]
# # RPF = np.array(df['ribo_counts'].tolist())*ref_norm[0]
# # RNA = np.array(df['rna_counts'].tolist())*ref_norm[1]
# # density = np.array(df['ribosome_density'].tolist())
# # pickle.dump((RPF,RNA,density),f)
#
# # fname = 'track_SRX12763793_SRX12763783_1000.pkl'
# # with open(os.path.join(WDIR, fname), 'rb') as f:
# # data = pickle.load(f)
#
#
# # train_ds = RiboDataset(TR, tokenizer)
# # train_loader = DataLoader(
# # train_ds,
# # batch_size=args.batch_size,
# # pin_memory=True,
# # drop_last=False,
# # shuffle=False,
# # )
# # for step, data in enumerate(train_loader):
# # print(step, [a.shape for a in data]) # [torch.Size([1, 1203]), torch.Size([1, 1203]), torch.Size([1, 1203, 1203, 1])]
#
#
# '''submit task to get a lot of pretraining data'''
# ribo_experiment, rna_experiment = 'SRX5164421','SRX5164417' #TAIR10#
# ribo_experiment, rna_experiment = 'SRX12763793','SRX12763783' # human
# # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=-1,limit=-1,cds_min=100) # generate mRNA.fa
# # RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=300,limit=-1,cds_min=100) # generate mRNA_region.fa
# RDP = RiboDataPipeline(args.exp_pretrain_data_path,ribo_experiment, rna_experiment, seq_only=True, region=1998,limit=-1,cds_min=100) # generate mRNA_region.fa
# # RDP = RiboDataPipeline(args.exp_pretrain_data_path, ribo_experiment, rna_experiment, seq_only=False, region=6,limit=300,cds_min=100) # loading ribosome counts data
# # TR,VL,TS = RDP.samples['TR'],RDP.samples['VL'],RDP.samples['TS']