Coverage for skema/img2mml/models/image2mml_xfmer.py: 83%

29 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import torch 

2import torch.nn as nn 

3 

4 

5class Image2MathML_Xfmer(nn.Module): 

6 def __init__(self, encoder, decoder, vocab, device): 

7 """ 

8 :param encoder: encoders CNN and XFMER 

9 :param decoder: decoder 

10 :param device: device to use for model: cpu or gpu 

11 """ 

12 super(Image2MathML_Xfmer, self).__init__() 

13 

14 self.cnn_encoder = encoder["CNN"] 

15 self.xfmer_encoder = encoder["XFMER"] 

16 self.xfmer_decoder = decoder 

17 self.vocab = vocab 

18 self.device = device 

19 

20 def forward( 

21 self, 

22 src, 

23 trg, 

24 is_test=False, 

25 is_inference=False, 

26 SOS_token=None, 

27 EOS_token=None, 

28 PAD_token=None, 

29 ): 

30 # run the encoder --> get flattened FV of images 

31 # for inference Batch(B)=1 

32 cnn_enc_output = self.cnn_encoder(src) # (B, L, dec_hid_dim) 

33 xfmer_enc_output = self.xfmer_encoder( 

34 cnn_enc_output 

35 ) # (max_len, B, dec_hid_dim) 

36 

37 if not is_inference: 

38 # normal training and testing part 

39 # we will be using torchtext.vocab object 

40 # while inference, we will provide them 

41 SOS_token = self.vocab.stoi["<sos>"] 

42 EOS_token = self.vocab.stoi["<eos>"] 

43 PAD_token = self.vocab.stoi["<pad>"] 

44 

45 xfmer_dec_outputs, preds = self.xfmer_decoder( 

46 trg, 

47 xfmer_enc_output, 

48 SOS_token, 

49 PAD_token, 

50 is_test=is_test, 

51 ) 

52 

53 return xfmer_dec_outputs, preds 

54 

55 else: 

56 # inference 

57 max_len = xfmer_enc_output.shape[0] 

58 trg = torch.tensor( 

59 [[SOS_token]], dtype=torch.long, device=self.device 

60 ) 

61 for i in range(max_len): 

62 output = self.xfmer_decoder( 

63 trg, 

64 xfmer_enc_output, 

65 SOS_token, 

66 PAD_token, 

67 is_inference=is_inference, 

68 ) 

69 

70 top1a = output[i, :, :].argmax(1) 

71 

72 next_token = torch.tensor([[top1a]], device=self.device) 

73 trg = torch.cat((trg, next_token), dim=0) 

74 

75 # Stop if model predicts end of sentence 

76 if next_token.view(-1).item() == EOS_token: 

77 break 

78 

79 return trg.view(-1).tolist()