Coverage for skema/img2mml/models/encoding/row_encoding.py: 42%

19 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import torch 

2 

3 

4class RowEncoding(torch.nn.Module): 

5 def __init__(self, device, dec_hid_dim, dropout): 

6 super(RowEncoding, self).__init__() 

7 self.device = device 

8 self.emb = torch.nn.Embedding(256, 512) 

9 self.lstm = torch.nn.LSTM( 

10 512, 

11 dec_hid_dim, 

12 num_layers=1, 

13 dropout=dropout, 

14 bidirectional=False, 

15 batch_first=False, 

16 ) 

17 

18 def forward(self, enc_output): 

19 # enc_output: (B, 512, W, H) 

20 # Row encoding 

21 outputs = [] 

22 for wh in range(0, enc_output.shape[2]): 

23 # row => [batch, 512, W] since for each row, 

24 # it becomes a 2d matrix of [512, W] for all batches 

25 row = enc_output[:, :, wh, :] # [batch, 512, W] 

26 row = row.permute(2, 0, 1) # [W, batch, 512(enc_output)] 

27 position_vector = ( 

28 torch.Tensor(row.shape[1]).long().fill_(wh).to(self.device) 

29 ) # [batch] 

30 # self.emb(pos) ==> [batch, 512] 

31 lstm_input = torch.cat( 

32 (self.emb(position_vector).unsqueeze(0), row), dim=0 

33 ) # [W+1, batch, 512] 

34 lstm_output, (hidden, cell) = self.lstm(lstm_input) 

35 # output = [W+1, batch, hid_dimx2] 

36 # hidden/cell = [2x1, batch, hid_dim] 

37 # we want the fwd and bckwd directional final layer 

38 

39 outputs.append(lstm_output.unsqueeze(0)) 

40 

41 final_output = torch.cat( 

42 outputs, dim=0 

43 ) # [H, W+1, BATCH, dec_hid_dim] 

44 # modifying it to [H*W+1, batch, dec_hid_dim] 

45 final_output = final_output.view( 

46 final_output.shape[0] * final_output.shape[1], 

47 final_output.shape[2], 

48 final_output.shape[3], 

49 ) 

50 

51 # O:[B, L, dec_hid_dim] H:[1, B, dec_hid_dim] 

52 return final_output.permute(1, 0, 2), hidden, cell