Skip to content

Commit 18e3057

Browse files
committed
Update onmt library
1 parent a14b072 commit 18e3057

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+11628
-0
lines changed

python/onmt/decoders/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Module defining decoders."""
2+
from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, \
3+
StdRNNDecoder
4+
from onmt.decoders.transformer import TransformerDecoder
5+
from onmt.decoders.cnn_decoder import CNNDecoder
6+
7+
8+
str2dec = {"rnn": StdRNNDecoder, "ifrnn": InputFeedRNNDecoder,
9+
"cnn": CNNDecoder, "transformer": TransformerDecoder}
10+
11+
__all__ = ["DecoderBase", "TransformerDecoder", "StdRNNDecoder", "CNNDecoder",
12+
"InputFeedRNNDecoder", "str2dec"]

python/onmt/decoders/cnn_decoder.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""Implementation of the CNN Decoder part of
2+
"Convolutional Sequence to Sequence Learning"
3+
"""
4+
import torch
5+
import torch.nn as nn
6+
7+
from onmt.modules import ConvMultiStepAttention, GlobalAttention
8+
from onmt.utils.cnn_factory import shape_transform, GatedConv
9+
from onmt.decoders.decoder import DecoderBase
10+
11+
SCALE_WEIGHT = 0.5 ** 0.5
12+
13+
14+
class CNNDecoder(DecoderBase):
15+
"""Decoder based on "Convolutional Sequence to Sequence Learning"
16+
:cite:`DBLP:journals/corr/GehringAGYD17`.
17+
18+
Consists of residual convolutional layers, with ConvMultiStepAttention.
19+
"""
20+
21+
def __init__(self, num_layers, hidden_size, attn_type,
22+
copy_attn, cnn_kernel_width, dropout, embeddings,
23+
copy_attn_type):
24+
super(CNNDecoder, self).__init__()
25+
26+
self.cnn_kernel_width = cnn_kernel_width
27+
self.embeddings = embeddings
28+
29+
# Decoder State
30+
self.state = {}
31+
32+
input_size = self.embeddings.embedding_size
33+
self.linear = nn.Linear(input_size, hidden_size)
34+
self.conv_layers = nn.ModuleList(
35+
[GatedConv(hidden_size, cnn_kernel_width, dropout, True)
36+
for i in range(num_layers)]
37+
)
38+
self.attn_layers = nn.ModuleList(
39+
[ConvMultiStepAttention(hidden_size) for i in range(num_layers)]
40+
)
41+
42+
# CNNDecoder has its own attention mechanism.
43+
# Set up a separate copy attention layer if needed.
44+
assert not copy_attn, "Copy mechanism not yet tested in conv2conv"
45+
if copy_attn:
46+
self.copy_attn = GlobalAttention(
47+
hidden_size, attn_type=copy_attn_type)
48+
else:
49+
self.copy_attn = None
50+
51+
@classmethod
52+
def from_opt(cls, opt, embeddings):
53+
"""Alternate constructor."""
54+
return cls(
55+
opt.dec_layers,
56+
opt.dec_rnn_size,
57+
opt.global_attention,
58+
opt.copy_attn,
59+
opt.cnn_kernel_width,
60+
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
61+
embeddings,
62+
opt.copy_attn_type)
63+
64+
def init_state(self, _, memory_bank, enc_hidden):
65+
"""Init decoder state."""
66+
self.state["src"] = (memory_bank + enc_hidden) * SCALE_WEIGHT
67+
self.state["previous_input"] = None
68+
69+
def map_state(self, fn):
70+
self.state["src"] = fn(self.state["src"], 1)
71+
if self.state["previous_input"] is not None:
72+
self.state["previous_input"] = fn(self.state["previous_input"], 1)
73+
74+
def detach_state(self):
75+
self.state["previous_input"] = self.state["previous_input"].detach()
76+
77+
def forward(self, tgt, memory_bank, step=None, **kwargs):
78+
""" See :obj:`onmt.modules.RNNDecoderBase.forward()`"""
79+
80+
if self.state["previous_input"] is not None:
81+
tgt = torch.cat([self.state["previous_input"], tgt], 0)
82+
83+
dec_outs = []
84+
attns = {"std": []}
85+
if self.copy_attn is not None:
86+
attns["copy"] = []
87+
88+
emb = self.embeddings(tgt)
89+
assert emb.dim() == 3 # len x batch x embedding_dim
90+
91+
tgt_emb = emb.transpose(0, 1).contiguous()
92+
# The output of CNNEncoder.
93+
src_memory_bank_t = memory_bank.transpose(0, 1).contiguous()
94+
# The combination of output of CNNEncoder and source embeddings.
95+
src_memory_bank_c = self.state["src"].transpose(0, 1).contiguous()
96+
97+
emb_reshape = tgt_emb.contiguous().view(
98+
tgt_emb.size(0) * tgt_emb.size(1), -1)
99+
linear_out = self.linear(emb_reshape)
100+
x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1)
101+
x = shape_transform(x)
102+
103+
pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1)
104+
105+
pad = pad.type_as(x)
106+
base_target_emb = x
107+
108+
for conv, attention in zip(self.conv_layers, self.attn_layers):
109+
new_target_input = torch.cat([pad, x], 2)
110+
out = conv(new_target_input)
111+
c, attn = attention(base_target_emb, out,
112+
src_memory_bank_t, src_memory_bank_c)
113+
x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT
114+
output = x.squeeze(3).transpose(1, 2)
115+
116+
# Process the result and update the attentions.
117+
dec_outs = output.transpose(0, 1).contiguous()
118+
if self.state["previous_input"] is not None:
119+
dec_outs = dec_outs[self.state["previous_input"].size(0):]
120+
attn = attn[:, self.state["previous_input"].size(0):].squeeze()
121+
attn = torch.stack([attn])
122+
attns["std"] = attn
123+
if self.copy_attn is not None:
124+
attns["copy"] = attn
125+
126+
# Update the state.
127+
self.state["previous_input"] = tgt
128+
# TODO change the way attns is returned dict => list or tuple (onnx)
129+
return dec_outs, attns
130+
131+
def update_dropout(self, dropout):
132+
for layer in self.conv_layers:
133+
layer.dropout.p = dropout

0 commit comments

Comments
 (0)