Coverage for skema/img2mml/models/decoders/xfmer_decoder.py: 76%
49 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import torch, math
2import torch.nn as nn
3from skema.img2mml.utils import generate_square_subsequent_mask
4from skema.img2mml.models.encoding.positional_encoding_for_xfmer import (
5 PositionalEncoding,
6)
9class Transformer_Decoder(nn.Module):
10 def __init__(
11 self,
12 emb_dim,
13 nheads,
14 dec_hid_dim,
15 output_dim,
16 dropout,
17 max_len,
18 n_xfmer_decoder_layers,
19 dim_feedfwd,
20 device,
21 ):
22 super(Transformer_Decoder, self).__init__()
23 self.device = device
24 self.output_dim = output_dim
25 self.emb_dim = emb_dim
26 self.embed = nn.Embedding(output_dim, emb_dim)
27 self.pos = PositionalEncoding(emb_dim, dropout, max_len)
29 """
30 NOTE:
31 updated nn.TransformerDecoderLayer doesn't have 'batch_first' argument anymore.
32 Therefore, the sequences will be in the shape of (max_len, B)
33 """
34 xfmer_dec_layer = nn.TransformerDecoderLayer(
35 d_model=dec_hid_dim,
36 nhead=nheads,
37 dim_feedforward=dim_feedfwd,
38 dropout=dropout,
39 )
41 self.xfmer_decoder = nn.TransformerDecoder(
42 xfmer_dec_layer, num_layers=n_xfmer_decoder_layers
43 )
45 self.modify_dimension = nn.Linear(emb_dim, dec_hid_dim)
46 self.final_linear = nn.Linear(dec_hid_dim, output_dim)
47 self.init_weights()
49 def init_weights(self):
50 self.modify_dimension.bias.data.zero_()
51 self.modify_dimension.weight.data.uniform_(-0.1, 0.1)
52 self.embed.weight.data.uniform_(-0.1, 0.1)
53 self.final_linear.bias.data.zero_()
54 self.final_linear.weight.data.uniform_(-0.1, 0.1)
56 def create_pad_mask(
57 self, matrix: torch.tensor, pad_token: int
58 ) -> torch.tensor:
59 # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is
60 # [False, False, False, True, True, True]
61 return matrix == pad_token
63 def forward(
64 self,
65 trg,
66 xfmer_enc_output,
67 sos_idx,
68 pad_idx,
69 is_test=False,
70 is_inference=False,
71 ):
72 # xfmer_enc_output: (max_len, B, dec_hid_dim)
73 # trg: (B, max_len)
74 """
75 we provide input: [<sos>, x1, x2, ...]
76 we get output: [x1, x2, ..., <eos>]
77 So we have to add <sos> in the final preds
79 for inference
80 trg: sequnece containing total number of token that has been predicted.
81 xfmer_enc_output: input from encoder
82 """
84 if not is_inference:
85 (B, max_len) = trg.shape
86 _preds = torch.zeros(max_len, B).to(self.device) # (max_len, B)
87 trg = trg.permute(1, 0) # (max_len, B)
88 trg = trg[:-1, :] # (max_len-1, B)
90 sequence_length = trg.shape[0]
91 trg_attn_mask = generate_square_subsequent_mask(sequence_length).to(
92 self.device
93 ) # (max_len-1, max_len-1)
95 # no need of padding for inference
96 if is_inference:
97 trg_padding_mask = None
98 else:
99 trg_padding_mask = self.create_pad_mask(trg, pad_idx).permute(
100 1, 0
101 ) # (B, max_len-1)
103 trg = self.embed(trg) * math.sqrt(
104 self.emb_dim
105 ) # (max_len-1, B, emb_dim)
106 pos_trg = self.pos(trg) # (max_len-1, B, emb_dim)
107 pos_trg = self.modify_dimension(pos_trg) # (max_len-1, B, dec_hid_dim)
109 # outputs: (max_len-1,B, dec_hid_dim)
110 xfmer_dec_outputs = self.xfmer_decoder(
111 tgt=pos_trg,
112 memory=xfmer_enc_output,
113 tgt_mask=trg_attn_mask,
114 tgt_key_padding_mask=trg_padding_mask,
115 )
117 xfmer_dec_outputs = self.final_linear(
118 xfmer_dec_outputs
119 ) # (max_len-1,B, output_dim)
121 if is_inference:
122 return xfmer_dec_outputs # (-1, B, output_dim)
123 else:
124 # preds
125 _preds[0, :] = torch.full(_preds[0, :].shape, sos_idx)
126 if is_test:
127 for i in range(xfmer_dec_outputs.shape[0]):
128 top1 = xfmer_dec_outputs[i, :, :].argmax(1) # (B)
129 _preds[i + 1, :] = top1
131 # xfmer_dec_outputs: (max_len-1, B, output_dim); _preds: (max_len, B)
132 # permute them to make "Batch first"
133 return xfmer_dec_outputs.permute(1, 0, 2), _preds.permute(1, 0)