#!/home/zhusitao/anaconda3/bin/python
# -*- coding:utf-8 -*-
'''
filename: models.py
date: 2021/12/26 上午9:42
author: Sitao Zhu
mail: zhusitao1990@163.com
'''
import torch
from torch import nn, einsum
from torch.autograd import Variable
from torch.nn import Sequential
import torch.nn.functional as F
from einops import rearrange, repeat, reduce
from inspect import isfunction
# padding same: padding = dilation * (kernel -1) / 2)
##### classification ######
# CNN + GRU
[docs]class CnnGru(nn.Module):
"""
TSS prediction model
"""
def __init__(self, l1,l2):
super(CnnGru, self).__init__()
self.layer1 = Sequential(
nn.Conv1d(4, 128, # input channel, output channel
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
padding =1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer2 = Sequential(
nn.Conv1d(128, 256,
kernel_size=5,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=2),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer3 = Sequential(
nn.Conv1d(256, 256,
kernel_size=7,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=3),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer4 = Sequential(
nn.Conv1d(256, 256,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.gru = nn.GRU(input_size=256, hidden_size=128, num_layers=2,
batch_first=True, bidirectional=True)
self.fc1 = nn.Linear(256*32, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# gru
x = torch.transpose(x, 2, 1)
x ,x_hidden = self.gru(x)
# x = torch.cat((x1,x2,x3),dim=1) # concat channel
# flat for Linear layers
x = x.reshape(x.size(0), -1)
x = self.fc1(x)
x = torch.relu(x)
x = F.dropout(x, p=0.2)
x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
return x
[docs]class CnnGruCat(nn.Module):
"""
TSS prediction model
"""
def __init__(self, l1,l2):
super(CnnGruCat, self).__init__()
self.layer1 = Sequential(
nn.Conv1d(4, 128, # input channel, output channel
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
padding =1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer2 = Sequential(
nn.Conv1d(128, 256,
kernel_size=5,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=2),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer3 = Sequential(
nn.Conv1d(4, 256,
kernel_size=7,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=3),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1),
)
self.gru = nn.GRU(input_size=256, hidden_size=128, num_layers=1,
batch_first=True, bidirectional=True)
self.fc1 = nn.Linear(256*128, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
# x3 = self.layer3(x)
# x4 = torch.cat((x1,x2,x3),dim=1) # cat channel
# gru
x = torch.transpose(x, 2, 1) # N,C,L --> N,L,C
x ,x_hidden = self.gru(x)
# p2 = torch.transpose(p2,2,1) # N,L,C --> N,C,L
# x = torch.cat((x,p2),dim=1) # concat channel, dim2=1
# flat for Linear layers
x = x.reshape(x.size(0), -1)
x = self.fc1(x)
x = torch.relu(x)
x = F.dropout(x, p=0.2)
x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
return x
# CNN + LSTM
[docs]class CnnLstm(nn.Module):
"""
TSS prediction model
"""
def __init__(self, l1,l2):
super(CnnLstm, self).__init__()
self.layer1 = Sequential(
nn.Conv1d(4, 128, # input channel, output channel
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
padding =1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer2 = Sequential(
nn.Conv1d(128, 256,
kernel_size=5,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=2),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer3 = Sequential(
nn.Conv1d(256, 256,
kernel_size=7,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=3),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer4 = Sequential(
nn.Conv1d(256, 256,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.lstm = nn.LSTM(input_size=256, hidden_size=128, num_layers=2,
batch_first=True, bidirectional=True)
self.fc1 = nn.Linear(256*32, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# gru
x = torch.transpose(x, 2, 1)
x ,(hn, cn) = self.lstm(x)
# x = torch.cat((x1,x2,x3),dim=1) # concat channel
# flat for Linear layers
x = x.reshape(x.size(0), -1)
x = self.fc1(x)
x = torch.relu(x)
x = F.dropout(x, p=0.2)
x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
return x
# CNN + Attention
[docs]class CnnAttention(nn.Module):
"""
cnn + attention
"""
def __init__(self):
super(CnnAttention, self).__init__()
self.layer1 = Sequential(
nn.Conv1d(4, 128,
kernel_size=7,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=3),
nn.BatchNorm1d(128,momentum=0.01),
nn.ReLU(),
nn.Dropout(p=0.25)
)
self.layer2 = Sequential(
nn.Conv1d(128, 256,
kernel_size=7,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=3),
nn.BatchNorm1d(256, momentum=0.01),
nn.ReLU(),
nn.Dropout(p=0.25)
)
self.head_att = nn.Sequential(
nn.Linear(256, 256),
nn.Tanh(),
nn.Linear(256, 1),
nn.Softmax(dim=1)
)
self.MLP = nn.Sequential(
nn.Linear(256 * 1, 64),
nn.BatchNorm1d(64, momentum=0.01),
nn.PReLU(64),
nn.Dropout(p=0.25),
nn.Linear(64, 1)
)
[docs] def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = x.view(x.size(0), 256, 512)
x = x.permute(0, 2, 1) # Channel , seqLength exchange
att_x = self.head_att(x) # N,512,1
x = x.permute(0, 2, 1).bmm(att_x) # N,128,1
x = x.view(x.size(0), -1)
x = self.MLP(x)
return x
# CNN
[docs]class Cnn(nn.Module):
"""
CNN model
"""
def __init__(self, l1=128,l2=64):
super(Cnn, self).__init__()
self.layer1 = Sequential(
nn.Conv1d(4, 128, # input channel, output channel
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
padding =1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer2 = Sequential(
nn.Conv1d(128, 256,
kernel_size=5,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=2),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer3 = Sequential(
nn.Conv1d(256, 256,
kernel_size=7,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=3),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer4 = Sequential(
nn.Conv1d(256, 256,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.fc1 = nn.Linear(256*32, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# x = torch.cat((x1,x2,x3),dim=1) # concat channel
# flat for Linear layers
x = x.reshape(x.size(0), -1)
x = self.fc1(x)
x = torch.relu(x)
x = F.dropout(x, p=0.2)
x = self.fc2(x)
x = torch.relu(x)
x = F.dropout(x, p=0.2)
x = self.fc3(x)
return x
[docs]class ResBlock(nn.Module):
"""
ResBlock module
"""
def __init__(self,channel_in,channel_out,stride,kernel_size,padding_size,shortcut=None,pool=False):
super(ResBlock,self).__init__()
self.left = nn.Sequential(
nn.Conv1d(channel_in,channel_out,
kernel_size=kernel_size,
padding=padding_size,
bias=True),
nn.BatchNorm1d(channel_out),
nn.ReLU(inplace=True),
nn.Dropout(p=0.3),
nn.Conv1d(channel_out,channel_out,
kernel_size=kernel_size,
padding=padding_size,
bias=True),
nn.BatchNorm1d(channel_out)
)
self.right = shortcut
self.pool = pool
self.MaxPool = nn.MaxPool1d(kernel_size=2, stride=2)
[docs] def forward(self,x):
out = self.left(x)
residual = x if self.right is None else self.right(x)
out += residual
out = F.relu(out)
if self.pool:
out = self.MaxPool(out)
return out
[docs]class ResNet(nn.Module):
def __init__(self):
super(ResNet, self).__init__()
self.layer1 = self.make_layer(channel_in=4, channel_out=128, block_num=1, kernel_size=3, padding_size=1,
stride=1)
self.layer2 = self.make_layer(channel_in=128, channel_out=128, block_num=2, kernel_size=5, padding_size=2,
stride=1)
self.layer3 = self.make_layer(channel_in=128, channel_out=128, block_num=2, kernel_size=9, padding_size=4,
stride=1)
self.dropout = nn.Dropout(0.5)
self.flatten = torch.nn.Flatten()
self.ff1 = nn.Linear(2048, 32, bias=True)
self.ff2 = nn.Linear(32, 1, bias=True)
[docs] def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.flatten(x)
x = self.ff1(x)
x = F.relu(x) # x = t.tanh(x)
x = self.ff2(x)
return x
[docs] def make_layer(self,channel_in,channel_out,block_num,kernel_size,padding_size,stride=1):
'''
build many resnet block
:return: residual block
'''
shortcut = conv_layer(channel_in ,channel_out, kernel_size=9, padding_size=4, dilation=1)
layers = []
# 第一层在不同resblock 之间传递,改变block之间channel不一致的情况
layers.append(ResBlock(channel_in,channel_out,stride,kernel_size,padding_size,shortcut,pool=True))
# resnet core
for i in range(1,block_num):
layers.append(ResBlock(channel_out,channel_out,stride,kernel_size,padding_size,pool=True))
return nn.Sequential(*layers)
[docs]def conv_layer(chann_in, chann_out, kernel_size, padding_size, dilation):
layer = nn.Sequential(
nn.Conv1d(chann_in, chann_out,
kernel_size=kernel_size,
padding=padding_size,
stride=1,
dilation=dilation,
groups=1,
bias=True),
nn.BatchNorm1d(chann_out),
nn.ReLU(),
nn.Dropout(p=0.3))
return layer
##### regression ######
# CNN
[docs]class TssRegression(nn.Module):
"""
TSS prediction model
"""
def __init__(self, l1, l2):
super(TssRegression, self).__init__()
self.layer1 = Sequential(
nn.Conv1d(4, 128, # input channel, output channel
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
padding =1),
nn.BatchNorm1d(128),
nn.ReLU(),
# nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.2)
)
self.layer2 = Sequential(
nn.Conv1d(128, 128,
kernel_size=5,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=2),
nn.BatchNorm1d(128),
nn.ReLU(),
# nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.2)
)
self.layer3 = Sequential(
nn.Conv1d(128, 128,
kernel_size=5,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=2),
nn.BatchNorm1d(128),
nn.ReLU(),
# nn.MaxPool1d(kernel_size=2, stride=2, dilation=1, padding=0),
nn.Dropout(p=0.2)
)
self.fc1 = nn.Linear(128*512, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
# flat for Linear layers
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = torch.relu(x)
x = F.dropout(x, p=0.2)
x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
return x
# CNN + GRU
[docs]class TssRegressionGru(nn.Module):
"""
TSS prediction model
"""
def __init__(self, l1, l2):
super(TssRegressionGru, self).__init__()
self.layer1 = Sequential(
nn.Conv1d(4, 128, # input channel, output channel
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer2 = Sequential(
nn.Conv1d(128, 256,
kernel_size=5,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=2),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer3 = Sequential(
nn.Conv1d(256, 256,
kernel_size=7,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=3),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.layer4 = Sequential(
nn.Conv1d(256, 256,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.1)
)
self.gru = nn.GRU(input_size=256, hidden_size=128, num_layers=2,
batch_first=True, bidirectional=True)
self.fc1 = nn.Linear(256 * 32, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 1)
[docs] def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# gru
x = torch.transpose(x, 2, 1)
x, x_hidden = self.gru(x)
# x = torch.cat((x1,x2,x3),dim=1) # concat channel
# flat for Linear layers
x = x.reshape(x.size(0), -1)
x = self.fc1(x)
x = torch.relu(x) # activation function relu
x = F.dropout(x, p=0.2)
x = self.fc2(x)
x = torch.relu(x) # relu
x = F.dropout(x, p=0.2)
x = self.fc3(x)
return x
# CNN + Attention
[docs]class RegressAttention(nn.Module):
def __init__(self):
super(RegressAttention, self).__init__()
self.layer1 = Sequential(
nn.Conv1d(4, 128, # input channel, output channel
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=True,
padding=1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(p=0.2)
)
self.attn = Attention(dim=128, heads=8)
self.fc1 = nn.Linear(256*128, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 1)
[docs] def forward(self, x):
x = self.layer1(x)
x = torch.transpose(x, 2, 1)
x = self.attn(x)
x = x.reshape(x.size(0), -1)
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
return x
[docs]def init_zero_(layer):
nn.init.constant_(layer.weight, 0.)
if exists(layer.bias):
nn.init.constant_(layer.bias, 0.)
[docs]def exists(val):
return val is not None
[docs]def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
[docs]class Attention(nn.Module):
def __init__(
self,
dim,
seq_len = None,
heads = 8,
dim_head = 64,
dropout = 0.,
gating = True
):
super().__init__()
inner_dim = dim_head * heads
self.seq_len = seq_len
self.heads= heads
self.scale = dim_head ** -0.5
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.gating = nn.Linear(dim, inner_dim)
nn.init.constant_(self.gating.weight, 0.)
nn.init.constant_(self.gating.bias, 1.)
self.dropout = nn.Dropout(dropout)
# init_zero_(self.to_out)
[docs] def forward(self, x, mask = None, attn_bias = None, context = None, context_mask = None, tie_dim = None):
device, orig_shape, h, has_context = x.device, x.shape, self.heads, exists(context)
context = default(context, x)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
i, j = q.shape[-2], k.shape[-2]
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# scale
q = q * self.scale
# query / key similarities
if exists(tie_dim):
# as in the paper, for the extra MSAs
# they average the queries along the rows of the MSAs
# they named this particular module MSAColumnGlobalAttention
q, k = map(lambda t: rearrange(t, '(b r) ... -> b r ...', r = tie_dim), (q, k))
q = q.mean(dim = 1)
dots = einsum('b h i d, b r h j d -> b r h i j', q, k)
dots = rearrange(dots, 'b r ... -> (b r) ...')
else:
dots = einsum('b h i d, b h j d -> b h i j', q, k)
# add attention bias, if supplied (for pairwise to msa attention communication)
if exists(attn_bias):
dots = dots + attn_bias
# masking
if exists(mask):
mask = default(mask, lambda: torch.ones(1, i, device = device).bool())
context_mask = mask if not has_context else default(context_mask, lambda: torch.ones(1, k.shape[-2], device = device).bool())
mask_value = -torch.finfo(dots.dtype).max
mask = mask[:, None, :, None] * context_mask[:, None, None, :]
dots = dots.masked_fill(~mask, mask_value)
# attention
dots = dots - dots.max(dim = -1, keepdims = True).values
attn = dots.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h n d -> b n (h d)')
# gating
gates = self.gating(x)
out = out * gates.sigmoid()
# combine to out
out = self.to_out(out)
return out
[docs]class LogisticRegression(torch.nn.Module):
def __init__(self, input_dim=4*512, output_dim=1):
super(LogisticRegression, self).__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)
[docs] def forward(self, x):
# outputs = torch.sigmoid(self.linear(x))
x = x.reshape(-1, 4*512)
outputs = self.linear(x) # sigmoid function add in loss function
return outputs
if __name__=='__main__':
from torchsummary import summary
# model = Cnn(128, 64)
# model = CnnGru(256, 128)
# model = CnnLstm(256, 128)
# model = ResNet()
# model = CnnGruCat(32,10)
model = TssRegressionGru(256, 128)
# model = RegressAttention()
summary(model)