Coverage for skema/img2mml/models/encoders/xfmer_encoder.py: 100%
22 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import torch, math
2import torch.nn as nn
3from skema.img2mml.utils import generate_square_subsequent_mask
4from skema.img2mml.models.encoding.positional_encoding_for_xfmer import (
5 PositionalEncoding,
6)
9class Transformer_Encoder(nn.Module):
10 def __init__(
11 self,
12 emb_dim,
13 dec_hid_dim,
14 nheads,
15 dropout,
16 device,
17 max_len,
18 n_xfmer_encoder_layers,
19 dim_feedfwd,
20 len_dim,
21 ):
22 super(Transformer_Encoder, self).__init__()
23 self.dec_hid_dim = dec_hid_dim
24 self.device = device
25 self.change_length = nn.Linear(len_dim, max_len)
26 self.pos = PositionalEncoding(dec_hid_dim, dropout, max_len)
28 """
29 NOTE:
30 nn.TransformerDecoderLayer doesn't have 'batch_first' argument anymore.
31 Therefore, the sequences will be in the shape of (max_len, B)
32 """
33 xfmer_enc_layer = nn.TransformerEncoderLayer(
34 d_model=dec_hid_dim,
35 nhead=nheads,
36 dim_feedforward=dim_feedfwd,
37 dropout=dropout,
38 )
40 self.xfmer_encoder = nn.TransformerEncoder(
41 xfmer_enc_layer, num_layers=n_xfmer_encoder_layers
42 )
44 def forward(self, src_from_cnn):
45 # src_from_cnn: (B, L, dec_hid_dim)
46 # change the L=H*W to max_len
47 src_from_cnn = src_from_cnn.permute(0, 2, 1) # (B, dec_hid_dim, L)
48 src_from_cnn = self.change_length(src_from_cnn) # (B, dec_hid_dim, max_len)
49 src_from_cnn = src_from_cnn.permute(2, 0, 1) # (max_len, B, dec_hid_dim)
51 # embedding + normalization
52 """
53 no need to embed as src from cnn already has dec_hid_dim as the 3rd dim
54 """
55 src_from_cnn *= math.sqrt(self.dec_hid_dim) # (max_len, B, dec_hid_dim)
57 # adding positoinal encoding
58 pos_src = self.pos(src_from_cnn) # (max_len, B, dec_hid_dim)
60 # xfmer encoder
61 mask = generate_square_subsequent_mask(pos_src.shape[0]).to(self.device)
62 xfmer_enc_output = self.xfmer_encoder(
63 src=pos_src, mask=None
64 ) # (max_len, B, dec_hid_dim)
66 return xfmer_enc_output