Coverage for skema/img2mml/models/decoders/xfmer_decoder.py: 76%

49 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import torch, math 

2import torch.nn as nn 

3from skema.img2mml.utils import generate_square_subsequent_mask 

4from skema.img2mml.models.encoding.positional_encoding_for_xfmer import ( 

5 PositionalEncoding, 

6) 

7 

8 

9class Transformer_Decoder(nn.Module): 

10 def __init__( 

11 self, 

12 emb_dim, 

13 nheads, 

14 dec_hid_dim, 

15 output_dim, 

16 dropout, 

17 max_len, 

18 n_xfmer_decoder_layers, 

19 dim_feedfwd, 

20 device, 

21 ): 

22 super(Transformer_Decoder, self).__init__() 

23 self.device = device 

24 self.output_dim = output_dim 

25 self.emb_dim = emb_dim 

26 self.embed = nn.Embedding(output_dim, emb_dim) 

27 self.pos = PositionalEncoding(emb_dim, dropout, max_len) 

28 

29 """ 

30 NOTE: 

31 updated nn.TransformerDecoderLayer doesn't have 'batch_first' argument anymore. 

32 Therefore, the sequences will be in the shape of (max_len, B) 

33 """ 

34 xfmer_dec_layer = nn.TransformerDecoderLayer( 

35 d_model=dec_hid_dim, 

36 nhead=nheads, 

37 dim_feedforward=dim_feedfwd, 

38 dropout=dropout, 

39 ) 

40 

41 self.xfmer_decoder = nn.TransformerDecoder( 

42 xfmer_dec_layer, num_layers=n_xfmer_decoder_layers 

43 ) 

44 

45 self.modify_dimension = nn.Linear(emb_dim, dec_hid_dim) 

46 self.final_linear = nn.Linear(dec_hid_dim, output_dim) 

47 self.init_weights() 

48 

49 def init_weights(self): 

50 self.modify_dimension.bias.data.zero_() 

51 self.modify_dimension.weight.data.uniform_(-0.1, 0.1) 

52 self.embed.weight.data.uniform_(-0.1, 0.1) 

53 self.final_linear.bias.data.zero_() 

54 self.final_linear.weight.data.uniform_(-0.1, 0.1) 

55 

56 def create_pad_mask( 

57 self, matrix: torch.tensor, pad_token: int 

58 ) -> torch.tensor: 

59 # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is 

60 # [False, False, False, True, True, True] 

61 return matrix == pad_token 

62 

63 def forward( 

64 self, 

65 trg, 

66 xfmer_enc_output, 

67 sos_idx, 

68 pad_idx, 

69 is_test=False, 

70 is_inference=False, 

71 ): 

72 # xfmer_enc_output: (max_len, B, dec_hid_dim) 

73 # trg: (B, max_len) 

74 """ 

75 we provide input: [<sos>, x1, x2, ...] 

76 we get output: [x1, x2, ..., <eos>] 

77 So we have to add <sos> in the final preds 

78 

79 for inference 

80 trg: sequnece containing total number of token that has been predicted. 

81 xfmer_enc_output: input from encoder 

82 """ 

83 

84 if not is_inference: 

85 (B, max_len) = trg.shape 

86 _preds = torch.zeros(max_len, B).to(self.device) # (max_len, B) 

87 trg = trg.permute(1, 0) # (max_len, B) 

88 trg = trg[:-1, :] # (max_len-1, B) 

89 

90 sequence_length = trg.shape[0] 

91 trg_attn_mask = generate_square_subsequent_mask(sequence_length).to( 

92 self.device 

93 ) # (max_len-1, max_len-1) 

94 

95 # no need of padding for inference 

96 if is_inference: 

97 trg_padding_mask = None 

98 else: 

99 trg_padding_mask = self.create_pad_mask(trg, pad_idx).permute( 

100 1, 0 

101 ) # (B, max_len-1) 

102 

103 trg = self.embed(trg) * math.sqrt( 

104 self.emb_dim 

105 ) # (max_len-1, B, emb_dim) 

106 pos_trg = self.pos(trg) # (max_len-1, B, emb_dim) 

107 pos_trg = self.modify_dimension(pos_trg) # (max_len-1, B, dec_hid_dim) 

108 

109 # outputs: (max_len-1,B, dec_hid_dim) 

110 xfmer_dec_outputs = self.xfmer_decoder( 

111 tgt=pos_trg, 

112 memory=xfmer_enc_output, 

113 tgt_mask=trg_attn_mask, 

114 tgt_key_padding_mask=trg_padding_mask, 

115 ) 

116 

117 xfmer_dec_outputs = self.final_linear( 

118 xfmer_dec_outputs 

119 ) # (max_len-1,B, output_dim) 

120 

121 if is_inference: 

122 return xfmer_dec_outputs # (-1, B, output_dim) 

123 else: 

124 # preds 

125 _preds[0, :] = torch.full(_preds[0, :].shape, sos_idx) 

126 if is_test: 

127 for i in range(xfmer_dec_outputs.shape[0]): 

128 top1 = xfmer_dec_outputs[i, :, :].argmax(1) # (B) 

129 _preds[i + 1, :] = top1 

130 

131 # xfmer_dec_outputs: (max_len-1, B, output_dim); _preds: (max_len, B) 

132 # permute them to make "Batch first" 

133 return xfmer_dec_outputs.permute(1, 0, 2), _preds.permute(1, 0)