Coverage for skema/img2mml/utils.py: 24%

62 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import torch 

2from collections import Counter 

3 

4 

5class CreateVocab(object): 

6 """ 

7 building vocab for the dataset 

8 stoi: string to index dictionary 

9 itos: index to string dictionary 

10 """ 

11 

12 def __init__(self, train, special_tokens=None, min_freq=1): 

13 self.counter = Counter() 

14 for line in train["EQUATION"]: 

15 self.counter.update(line.split()) 

16 

17 self.min_freq = min_freq 

18 self.tok2ind = dict() 

19 self.ind2tok = dict() 

20 

21 # appending special_tokens 

22 self.s_count = 0 

23 if special_tokens is not None: 

24 for i in special_tokens: 

25 self.tok2ind[i] = self.s_count 

26 self.ind2tok[self.s_count] = i 

27 self.s_count += 1 

28 

29 # appending rest of the vocab 

30 self.tok_array = [ 

31 tok for (tok, freq) in self.counter.items() if freq >= min_freq 

32 ] 

33 self.stoi, self.itos = self.vocab() 

34 

35 def vocab(self): 

36 _count = self.s_count 

37 for t in self.tok_array: 

38 self.tok2ind[t] = _count 

39 self.ind2tok[_count] = t 

40 _count += 1 

41 return self.tok2ind, self.ind2tok 

42 

43 def __getitem__(self): 

44 return self.stoi, self.itos 

45 

46 def __len__(self): 

47 return len(self.tok2ind) 

48 

49 

50def garbage2pad(preds, vocab, is_test=False): 

51 """ 

52 all garbage tokens will be converted to <pad> token 

53 "garbage" tokens: tokens after <eos> token 

54 

55 params: 

56 pred: predicted eqns (B, seq_len/max_len) 

57 

58 return: 

59 pred: cleaned pred eqn 

60 """ 

61 

62 pad_idx = vocab.stoi["<pad>"] 

63 eos_idx = vocab.stoi["<eos>"] 

64 for b in range(preds.shape[0]): 

65 try: 

66 # cleaning pred 

67 eos_pos = (preds[b, :] == eos_idx).nonzero(as_tuple=False)[0] 

68 preds[b, :] = preds[b, : eos_pos + 1] # pad_idx 

69 except: 

70 pass 

71 

72 return preds 

73 

74 

75def generate_square_subsequent_mask(sz: int) -> torch.Tensor: 

76 mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 

77 mask = ( 

78 mask.float() 

79 .masked_fill(mask == 0, float("-inf")) 

80 .masked_fill(mask == 1, float(0.0)) 

81 ) 

82 return mask 

83 

84 

85def calculating_accuracy(pred, mml): 

86 """ 

87 calculate accuracy 

88 

89 params: 

90 pred and mml: (B, l) 

91 """ 

92 train_acc = torch.sum(pred == mml) 

93 return train_acc 

94 

95 

96def beam_search(data, k, alpha, min_length): 

97 """ 

98 predicting k best possible sequences 

99 using beam search 

100 

101 params: 

102 data: (1, seq_len, output_dim) 

103 k: beam search parameter 

104 alpha: degree of regularization in length_normalization 

105 min_length: param for length_normalization 

106 """ 

107 

108 # data: (maxlen, output_dim) 

109 sequences = [[list(), 0.0]] 

110 # walk over each step in sequence 

111 for row in data: 

112 all_candidates = list() 

113 for i in range(len(sequences)): 

114 seq, score = sequences[i] 

115 log_row = row.log() 

116 for j in range(len(row)): 

117 # candidate = [seq + [j], score - math.log(row[j])] 

118 candidate = [seq + [j], score - log_row[j]] 

119 all_candidates.append(candidate) 

120 

121 # order all candiadates by score 

122 ordered = sorted(all_candidates, key=lambda t: t[1]) 

123 sequences = ordered[:1] 

124 return sequences 

125 

126 

127def length_normalization(sequence_length, alpha, min_length): 

128 ln = (1 + sequence_length) ** alpha / (1 + min_length) ** alpha 

129 return ln