forked from prehistoric-systems/comixify
22 lines
735 B
Python
22 lines
735 B
Python
|
|
import torch.nn as nn
|
||
|
|
from torch.nn import functional as F
|
||
|
|
|
||
|
|
__all__ = ['DSN']
|
||
|
|
|
||
|
|
|
||
|
|
class DSN(nn.Module):
|
||
|
|
"""Deep Summarization Network"""
|
||
|
|
def __init__(self, in_dim=1024, hid_dim=256, num_layers=1, cell='lstm'):
|
||
|
|
super(DSN, self).__init__()
|
||
|
|
assert cell in ['lstm', 'gru'], "cell must be either 'lstm' or 'gru'"
|
||
|
|
if cell == 'lstm':
|
||
|
|
self.rnn = nn.LSTM(in_dim, hid_dim, num_layers=num_layers, bidirectional=True, batch_first=True)
|
||
|
|
else:
|
||
|
|
self.rnn = nn.GRU(in_dim, hid_dim, num_layers=num_layers, bidirectional=True, batch_first=True)
|
||
|
|
self.fc = nn.Linear(hid_dim*2, 1)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
h, _ = self.rnn(x)
|
||
|
|
p = F.sigmoid(self.fc(h))
|
||
|
|
return p
|