Coverage for skema/img2mml/models/image2mml_xfmer.py: 83%
29 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import torch
2import torch.nn as nn
5class Image2MathML_Xfmer(nn.Module):
6 def __init__(self, encoder, decoder, vocab, device):
7 """
8 :param encoder: encoders CNN and XFMER
9 :param decoder: decoder
10 :param device: device to use for model: cpu or gpu
11 """
12 super(Image2MathML_Xfmer, self).__init__()
14 self.cnn_encoder = encoder["CNN"]
15 self.xfmer_encoder = encoder["XFMER"]
16 self.xfmer_decoder = decoder
17 self.vocab = vocab
18 self.device = device
20 def forward(
21 self,
22 src,
23 trg,
24 is_test=False,
25 is_inference=False,
26 SOS_token=None,
27 EOS_token=None,
28 PAD_token=None,
29 ):
30 # run the encoder --> get flattened FV of images
31 # for inference Batch(B)=1
32 cnn_enc_output = self.cnn_encoder(src) # (B, L, dec_hid_dim)
33 xfmer_enc_output = self.xfmer_encoder(
34 cnn_enc_output
35 ) # (max_len, B, dec_hid_dim)
37 if not is_inference:
38 # normal training and testing part
39 # we will be using torchtext.vocab object
40 # while inference, we will provide them
41 SOS_token = self.vocab.stoi["<sos>"]
42 EOS_token = self.vocab.stoi["<eos>"]
43 PAD_token = self.vocab.stoi["<pad>"]
45 xfmer_dec_outputs, preds = self.xfmer_decoder(
46 trg,
47 xfmer_enc_output,
48 SOS_token,
49 PAD_token,
50 is_test=is_test,
51 )
53 return xfmer_dec_outputs, preds
55 else:
56 # inference
57 max_len = xfmer_enc_output.shape[0]
58 trg = torch.tensor(
59 [[SOS_token]], dtype=torch.long, device=self.device
60 )
61 for i in range(max_len):
62 output = self.xfmer_decoder(
63 trg,
64 xfmer_enc_output,
65 SOS_token,
66 PAD_token,
67 is_inference=is_inference,
68 )
70 top1a = output[i, :, :].argmax(1)
72 next_token = torch.tensor([[top1a]], device=self.device)
73 trg = torch.cat((trg, next_token), dim=0)
75 # Stop if model predicts end of sentence
76 if next_token.view(-1).item() == EOS_token:
77 break
79 return trg.view(-1).tolist()