Coverage for skema/img2mml/models/encoders/xfmer_encoder.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import torch, math 

2import torch.nn as nn 

3from skema.img2mml.utils import generate_square_subsequent_mask 

4from skema.img2mml.models.encoding.positional_encoding_for_xfmer import ( 

5 PositionalEncoding, 

6) 

7 

8 

9class Transformer_Encoder(nn.Module): 

10 def __init__( 

11 self, 

12 emb_dim, 

13 dec_hid_dim, 

14 nheads, 

15 dropout, 

16 device, 

17 max_len, 

18 n_xfmer_encoder_layers, 

19 dim_feedfwd, 

20 len_dim, 

21 ): 

22 super(Transformer_Encoder, self).__init__() 

23 self.dec_hid_dim = dec_hid_dim 

24 self.device = device 

25 self.change_length = nn.Linear(len_dim, max_len) 

26 self.pos = PositionalEncoding(dec_hid_dim, dropout, max_len) 

27 

28 """ 

29 NOTE: 

30 nn.TransformerDecoderLayer doesn't have 'batch_first' argument anymore. 

31 Therefore, the sequences will be in the shape of (max_len, B) 

32 """ 

33 xfmer_enc_layer = nn.TransformerEncoderLayer( 

34 d_model=dec_hid_dim, 

35 nhead=nheads, 

36 dim_feedforward=dim_feedfwd, 

37 dropout=dropout, 

38 ) 

39 

40 self.xfmer_encoder = nn.TransformerEncoder( 

41 xfmer_enc_layer, num_layers=n_xfmer_encoder_layers 

42 ) 

43 

44 def forward(self, src_from_cnn): 

45 # src_from_cnn: (B, L, dec_hid_dim) 

46 # change the L=H*W to max_len 

47 src_from_cnn = src_from_cnn.permute(0, 2, 1) # (B, dec_hid_dim, L) 

48 src_from_cnn = self.change_length(src_from_cnn) # (B, dec_hid_dim, max_len) 

49 src_from_cnn = src_from_cnn.permute(2, 0, 1) # (max_len, B, dec_hid_dim) 

50 

51 # embedding + normalization 

52 """ 

53 no need to embed as src from cnn already has dec_hid_dim as the 3rd dim 

54 """ 

55 src_from_cnn *= math.sqrt(self.dec_hid_dim) # (max_len, B, dec_hid_dim) 

56 

57 # adding positoinal encoding 

58 pos_src = self.pos(src_from_cnn) # (max_len, B, dec_hid_dim) 

59 

60 # xfmer encoder 

61 mask = generate_square_subsequent_mask(pos_src.shape[0]).to(self.device) 

62 xfmer_enc_output = self.xfmer_encoder( 

63 src=pos_src, mask=None 

64 ) # (max_len, B, dec_hid_dim) 

65 

66 return xfmer_enc_output