Coverage for skema/img2mml/models/encoding/row_encoding.py: 42%
19 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import torch
4class RowEncoding(torch.nn.Module):
5 def __init__(self, device, dec_hid_dim, dropout):
6 super(RowEncoding, self).__init__()
7 self.device = device
8 self.emb = torch.nn.Embedding(256, 512)
9 self.lstm = torch.nn.LSTM(
10 512,
11 dec_hid_dim,
12 num_layers=1,
13 dropout=dropout,
14 bidirectional=False,
15 batch_first=False,
16 )
18 def forward(self, enc_output):
19 # enc_output: (B, 512, W, H)
20 # Row encoding
21 outputs = []
22 for wh in range(0, enc_output.shape[2]):
23 # row => [batch, 512, W] since for each row,
24 # it becomes a 2d matrix of [512, W] for all batches
25 row = enc_output[:, :, wh, :] # [batch, 512, W]
26 row = row.permute(2, 0, 1) # [W, batch, 512(enc_output)]
27 position_vector = (
28 torch.Tensor(row.shape[1]).long().fill_(wh).to(self.device)
29 ) # [batch]
30 # self.emb(pos) ==> [batch, 512]
31 lstm_input = torch.cat(
32 (self.emb(position_vector).unsqueeze(0), row), dim=0
33 ) # [W+1, batch, 512]
34 lstm_output, (hidden, cell) = self.lstm(lstm_input)
35 # output = [W+1, batch, hid_dimx2]
36 # hidden/cell = [2x1, batch, hid_dim]
37 # we want the fwd and bckwd directional final layer
39 outputs.append(lstm_output.unsqueeze(0))
41 final_output = torch.cat(
42 outputs, dim=0
43 ) # [H, W+1, BATCH, dec_hid_dim]
44 # modifying it to [H*W+1, batch, dec_hid_dim]
45 final_output = final_output.view(
46 final_output.shape[0] * final_output.shape[1],
47 final_output.shape[2],
48 final_output.shape[3],
49 )
51 # O:[B, L, dec_hid_dim] H:[1, B, dec_hid_dim]
52 return final_output.permute(1, 0, 2), hidden, cell