Source code for TSARC.tsarp

#!/home/zhusitao/anaconda3/bin/python
# -*- coding:utf-8 -*-
'''
filename: models.py
date: 2021/12/26 上午9:42
author: Sitao Zhu
mail: zhusitao1990@163.com
'''

import torch
from torch import nn, einsum
from torch.autograd import Variable
from torch.nn import Sequential
import torch.nn.functional as F
from einops import rearrange, repeat, reduce
from inspect import isfunction


# padding same: padding = dilation * (kernel -1) / 2)

##### classification ######
# CNN + GRU
[docs]class CnnGru(nn.Module): """ TSS prediction model """ def __init__(self, l1,l2): super(CnnGru, self).__init__() self.layer1 = Sequential( nn.Conv1d(4, 128, # input channel, output channel kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding =1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer2 = Sequential( nn.Conv1d(128, 256, kernel_size=5, stride=1, dilation=1, groups=1, bias=True, padding=2), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer3 = Sequential( nn.Conv1d(256, 256, kernel_size=7, stride=1, dilation=1, groups=1, bias=True, padding=3), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer4 = Sequential( nn.Conv1d(256, 256, kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding=1), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.gru = nn.GRU(input_size=256, hidden_size=128, num_layers=2, batch_first=True, bidirectional=True) self.fc1 = nn.Linear(256*32, l1) self.fc2 = nn.Linear(l1, l2) self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # gru x = torch.transpose(x, 2, 1) x ,x_hidden = self.gru(x) # x = torch.cat((x1,x2,x3),dim=1) # concat channel # flat for Linear layers x = x.reshape(x.size(0), -1) x = self.fc1(x) x = torch.relu(x) x = F.dropout(x, p=0.2) x = self.fc2(x) x = torch.relu(x) x = self.fc3(x) return x
[docs]class CnnGruCat(nn.Module): """ TSS prediction model """ def __init__(self, l1,l2): super(CnnGruCat, self).__init__() self.layer1 = Sequential( nn.Conv1d(4, 128, # input channel, output channel kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding =1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer2 = Sequential( nn.Conv1d(128, 256, kernel_size=5, stride=1, dilation=1, groups=1, bias=True, padding=2), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer3 = Sequential( nn.Conv1d(4, 256, kernel_size=7, stride=1, dilation=1, groups=1, bias=True, padding=3), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1), ) self.gru = nn.GRU(input_size=256, hidden_size=128, num_layers=1, batch_first=True, bidirectional=True) self.fc1 = nn.Linear(256*128, l1) self.fc2 = nn.Linear(l1, l2) self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x): x = self.layer1(x) x = self.layer2(x) # x3 = self.layer3(x) # x4 = torch.cat((x1,x2,x3),dim=1) # cat channel # gru x = torch.transpose(x, 2, 1) # N,C,L --> N,L,C x ,x_hidden = self.gru(x) # p2 = torch.transpose(p2,2,1) # N,L,C --> N,C,L # x = torch.cat((x,p2),dim=1) # concat channel, dim2=1 # flat for Linear layers x = x.reshape(x.size(0), -1) x = self.fc1(x) x = torch.relu(x) x = F.dropout(x, p=0.2) x = self.fc2(x) x = torch.relu(x) x = self.fc3(x) return x
# CNN + LSTM
[docs]class CnnLstm(nn.Module): """ TSS prediction model """ def __init__(self, l1,l2): super(CnnLstm, self).__init__() self.layer1 = Sequential( nn.Conv1d(4, 128, # input channel, output channel kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding =1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer2 = Sequential( nn.Conv1d(128, 256, kernel_size=5, stride=1, dilation=1, groups=1, bias=True, padding=2), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer3 = Sequential( nn.Conv1d(256, 256, kernel_size=7, stride=1, dilation=1, groups=1, bias=True, padding=3), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer4 = Sequential( nn.Conv1d(256, 256, kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding=1), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.lstm = nn.LSTM(input_size=256, hidden_size=128, num_layers=2, batch_first=True, bidirectional=True) self.fc1 = nn.Linear(256*32, l1) self.fc2 = nn.Linear(l1, l2) self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # gru x = torch.transpose(x, 2, 1) x ,(hn, cn) = self.lstm(x) # x = torch.cat((x1,x2,x3),dim=1) # concat channel # flat for Linear layers x = x.reshape(x.size(0), -1) x = self.fc1(x) x = torch.relu(x) x = F.dropout(x, p=0.2) x = self.fc2(x) x = torch.relu(x) x = self.fc3(x) return x
# CNN + Attention
[docs]class CnnAttention(nn.Module): """ cnn + attention """ def __init__(self): super(CnnAttention, self).__init__() self.layer1 = Sequential( nn.Conv1d(4, 128, kernel_size=7, stride=1, dilation=1, groups=1, bias=True, padding=3), nn.BatchNorm1d(128,momentum=0.01), nn.ReLU(), nn.Dropout(p=0.25) ) self.layer2 = Sequential( nn.Conv1d(128, 256, kernel_size=7, stride=1, dilation=1, groups=1, bias=True, padding=3), nn.BatchNorm1d(256, momentum=0.01), nn.ReLU(), nn.Dropout(p=0.25) ) self.head_att = nn.Sequential( nn.Linear(256, 256), nn.Tanh(), nn.Linear(256, 1), nn.Softmax(dim=1) ) self.MLP = nn.Sequential( nn.Linear(256 * 1, 64), nn.BatchNorm1d(64, momentum=0.01), nn.PReLU(64), nn.Dropout(p=0.25), nn.Linear(64, 1) )
[docs] def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = x.view(x.size(0), 256, 512) x = x.permute(0, 2, 1) # Channel , seqLength exchange att_x = self.head_att(x) # N,512,1 x = x.permute(0, 2, 1).bmm(att_x) # N,128,1 x = x.view(x.size(0), -1) x = self.MLP(x) return x
# CNN
[docs]class Cnn(nn.Module): """ CNN model """ def __init__(self, l1=128,l2=64): super(Cnn, self).__init__() self.layer1 = Sequential( nn.Conv1d(4, 128, # input channel, output channel kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding =1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer2 = Sequential( nn.Conv1d(128, 256, kernel_size=5, stride=1, dilation=1, groups=1, bias=True, padding=2), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer3 = Sequential( nn.Conv1d(256, 256, kernel_size=7, stride=1, dilation=1, groups=1, bias=True, padding=3), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer4 = Sequential( nn.Conv1d(256, 256, kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding=1), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.fc1 = nn.Linear(256*32, l1) self.fc2 = nn.Linear(l1, l2) self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # x = torch.cat((x1,x2,x3),dim=1) # concat channel # flat for Linear layers x = x.reshape(x.size(0), -1) x = self.fc1(x) x = torch.relu(x) x = F.dropout(x, p=0.2) x = self.fc2(x) x = torch.relu(x) x = F.dropout(x, p=0.2) x = self.fc3(x) return x
[docs]class ResBlock(nn.Module): """ ResBlock module """ def __init__(self,channel_in,channel_out,stride,kernel_size,padding_size,shortcut=None,pool=False): super(ResBlock,self).__init__() self.left = nn.Sequential( nn.Conv1d(channel_in,channel_out, kernel_size=kernel_size, padding=padding_size, bias=True), nn.BatchNorm1d(channel_out), nn.ReLU(inplace=True), nn.Dropout(p=0.3), nn.Conv1d(channel_out,channel_out, kernel_size=kernel_size, padding=padding_size, bias=True), nn.BatchNorm1d(channel_out) ) self.right = shortcut self.pool = pool self.MaxPool = nn.MaxPool1d(kernel_size=2, stride=2)
[docs] def forward(self,x): out = self.left(x) residual = x if self.right is None else self.right(x) out += residual out = F.relu(out) if self.pool: out = self.MaxPool(out) return out
[docs]class ResNet(nn.Module): def __init__(self): super(ResNet, self).__init__() self.layer1 = self.make_layer(channel_in=4, channel_out=128, block_num=1, kernel_size=3, padding_size=1, stride=1) self.layer2 = self.make_layer(channel_in=128, channel_out=128, block_num=2, kernel_size=5, padding_size=2, stride=1) self.layer3 = self.make_layer(channel_in=128, channel_out=128, block_num=2, kernel_size=9, padding_size=4, stride=1) self.dropout = nn.Dropout(0.5) self.flatten = torch.nn.Flatten() self.ff1 = nn.Linear(2048, 32, bias=True) self.ff2 = nn.Linear(32, 1, bias=True)
[docs] def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.flatten(x) x = self.ff1(x) x = F.relu(x) # x = t.tanh(x) x = self.ff2(x) return x
[docs] def make_layer(self,channel_in,channel_out,block_num,kernel_size,padding_size,stride=1): ''' build many resnet block :return: residual block ''' shortcut = conv_layer(channel_in ,channel_out, kernel_size=9, padding_size=4, dilation=1) layers = [] # 第一层在不同resblock 之间传递,改变block之间channel不一致的情况 layers.append(ResBlock(channel_in,channel_out,stride,kernel_size,padding_size,shortcut,pool=True)) # resnet core for i in range(1,block_num): layers.append(ResBlock(channel_out,channel_out,stride,kernel_size,padding_size,pool=True)) return nn.Sequential(*layers)
[docs]def conv_layer(chann_in, chann_out, kernel_size, padding_size, dilation): layer = nn.Sequential( nn.Conv1d(chann_in, chann_out, kernel_size=kernel_size, padding=padding_size, stride=1, dilation=dilation, groups=1, bias=True), nn.BatchNorm1d(chann_out), nn.ReLU(), nn.Dropout(p=0.3)) return layer
##### regression ###### # CNN
[docs]class TssRegression(nn.Module): """ TSS prediction model """ def __init__(self, l1, l2): super(TssRegression, self).__init__() self.layer1 = Sequential( nn.Conv1d(4, 128, # input channel, output channel kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding =1), nn.BatchNorm1d(128), nn.ReLU(), # nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.2) ) self.layer2 = Sequential( nn.Conv1d(128, 128, kernel_size=5, stride=1, dilation=1, groups=1, bias=True, padding=2), nn.BatchNorm1d(128), nn.ReLU(), # nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.2) ) self.layer3 = Sequential( nn.Conv1d(128, 128, kernel_size=5, stride=1, dilation=1, groups=1, bias=True, padding=2), nn.BatchNorm1d(128), nn.ReLU(), # nn.MaxPool1d(kernel_size=2, stride=2, dilation=1, padding=0), nn.Dropout(p=0.2) ) self.fc1 = nn.Linear(128*512, l1) self.fc2 = nn.Linear(l1, l2) self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) # flat for Linear layers x = x.view(x.size(0), -1) x = self.fc1(x) x = torch.relu(x) x = F.dropout(x, p=0.2) x = self.fc2(x) x = torch.relu(x) x = self.fc3(x) return x
# CNN + GRU
[docs]class TssRegressionGru(nn.Module): """ TSS prediction model """ def __init__(self, l1, l2): super(TssRegressionGru, self).__init__() self.layer1 = Sequential( nn.Conv1d(4, 128, # input channel, output channel kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer2 = Sequential( nn.Conv1d(128, 256, kernel_size=5, stride=1, dilation=1, groups=1, bias=True, padding=2), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer3 = Sequential( nn.Conv1d(256, 256, kernel_size=7, stride=1, dilation=1, groups=1, bias=True, padding=3), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.layer4 = Sequential( nn.Conv1d(256, 256, kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding=1), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.1) ) self.gru = nn.GRU(input_size=256, hidden_size=128, num_layers=2, batch_first=True, bidirectional=True) self.fc1 = nn.Linear(256 * 32, l1) self.fc2 = nn.Linear(l1, l2) self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # gru x = torch.transpose(x, 2, 1) x, x_hidden = self.gru(x) # x = torch.cat((x1,x2,x3),dim=1) # concat channel # flat for Linear layers x = x.reshape(x.size(0), -1) x = self.fc1(x) x = torch.relu(x) # activation function relu x = F.dropout(x, p=0.2) x = self.fc2(x) x = torch.relu(x) # relu x = F.dropout(x, p=0.2) x = self.fc3(x) return x
# CNN + Attention
[docs]class RegressAttention(nn.Module): def __init__(self): super(RegressAttention, self).__init__() self.layer1 = Sequential( nn.Conv1d(4, 128, # input channel, output channel kernel_size=3, stride=1, dilation=1, groups=1, bias=True, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(p=0.2) ) self.attn = Attention(dim=128, heads=8) self.fc1 = nn.Linear(256*128, 64) self.fc2 = nn.Linear(64, 32) self.fc3 = nn.Linear(32, 1)
[docs] def forward(self, x): x = self.layer1(x) x = torch.transpose(x, 2, 1) x = self.attn(x) x = x.reshape(x.size(0), -1) x = self.fc1(x) x = torch.relu(x) x = self.fc2(x) x = torch.relu(x) x = self.fc3(x) return x
[docs]def init_zero_(layer): nn.init.constant_(layer.weight, 0.) if exists(layer.bias): nn.init.constant_(layer.bias, 0.)
[docs]def exists(val): return val is not None
[docs]def default(val, d): if exists(val): return val return d() if isfunction(d) else d
[docs]class Attention(nn.Module): def __init__( self, dim, seq_len = None, heads = 8, dim_head = 64, dropout = 0., gating = True ): super().__init__() inner_dim = dim_head * heads self.seq_len = seq_len self.heads= heads self.scale = dim_head ** -0.5 self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim) self.gating = nn.Linear(dim, inner_dim) nn.init.constant_(self.gating.weight, 0.) nn.init.constant_(self.gating.bias, 1.) self.dropout = nn.Dropout(dropout) # init_zero_(self.to_out)
[docs] def forward(self, x, mask = None, attn_bias = None, context = None, context_mask = None, tie_dim = None): device, orig_shape, h, has_context = x.device, x.shape, self.heads, exists(context) context = default(context, x) q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) i, j = q.shape[-2], k.shape[-2] q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # scale q = q * self.scale # query / key similarities if exists(tie_dim): # as in the paper, for the extra MSAs # they average the queries along the rows of the MSAs # they named this particular module MSAColumnGlobalAttention q, k = map(lambda t: rearrange(t, '(b r) ... -> b r ...', r = tie_dim), (q, k)) q = q.mean(dim = 1) dots = einsum('b h i d, b r h j d -> b r h i j', q, k) dots = rearrange(dots, 'b r ... -> (b r) ...') else: dots = einsum('b h i d, b h j d -> b h i j', q, k) # add attention bias, if supplied (for pairwise to msa attention communication) if exists(attn_bias): dots = dots + attn_bias # masking if exists(mask): mask = default(mask, lambda: torch.ones(1, i, device = device).bool()) context_mask = mask if not has_context else default(context_mask, lambda: torch.ones(1, k.shape[-2], device = device).bool()) mask_value = -torch.finfo(dots.dtype).max mask = mask[:, None, :, None] * context_mask[:, None, None, :] dots = dots.masked_fill(~mask, mask_value) # attention dots = dots - dots.max(dim = -1, keepdims = True).values attn = dots.softmax(dim = -1) attn = self.dropout(attn) # aggregate out = einsum('b h i j, b h j d -> b h i d', attn, v) # merge heads out = rearrange(out, 'b h n d -> b n (h d)') # gating gates = self.gating(x) out = out * gates.sigmoid() # combine to out out = self.to_out(out) return out
[docs]class LogisticRegression(torch.nn.Module): def __init__(self, input_dim=4*512, output_dim=1): super(LogisticRegression, self).__init__() self.linear = torch.nn.Linear(input_dim, output_dim)
[docs] def forward(self, x): # outputs = torch.sigmoid(self.linear(x)) x = x.reshape(-1, 4*512) outputs = self.linear(x) # sigmoid function add in loss function return outputs
if __name__=='__main__': from torchsummary import summary # model = Cnn(128, 64) # model = CnnGru(256, 128) # model = CnnLstm(256, 128) # model = ResNet() # model = CnnGruCat(32,10) model = TssRegressionGru(256, 128) # model = RegressAttention() summary(model)