Coverage for skema/img2mml/utils.py: 24%
62 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import torch
2from collections import Counter
5class CreateVocab(object):
6 """
7 building vocab for the dataset
8 stoi: string to index dictionary
9 itos: index to string dictionary
10 """
12 def __init__(self, train, special_tokens=None, min_freq=1):
13 self.counter = Counter()
14 for line in train["EQUATION"]:
15 self.counter.update(line.split())
17 self.min_freq = min_freq
18 self.tok2ind = dict()
19 self.ind2tok = dict()
21 # appending special_tokens
22 self.s_count = 0
23 if special_tokens is not None:
24 for i in special_tokens:
25 self.tok2ind[i] = self.s_count
26 self.ind2tok[self.s_count] = i
27 self.s_count += 1
29 # appending rest of the vocab
30 self.tok_array = [
31 tok for (tok, freq) in self.counter.items() if freq >= min_freq
32 ]
33 self.stoi, self.itos = self.vocab()
35 def vocab(self):
36 _count = self.s_count
37 for t in self.tok_array:
38 self.tok2ind[t] = _count
39 self.ind2tok[_count] = t
40 _count += 1
41 return self.tok2ind, self.ind2tok
43 def __getitem__(self):
44 return self.stoi, self.itos
46 def __len__(self):
47 return len(self.tok2ind)
50def garbage2pad(preds, vocab, is_test=False):
51 """
52 all garbage tokens will be converted to <pad> token
53 "garbage" tokens: tokens after <eos> token
55 params:
56 pred: predicted eqns (B, seq_len/max_len)
58 return:
59 pred: cleaned pred eqn
60 """
62 pad_idx = vocab.stoi["<pad>"]
63 eos_idx = vocab.stoi["<eos>"]
64 for b in range(preds.shape[0]):
65 try:
66 # cleaning pred
67 eos_pos = (preds[b, :] == eos_idx).nonzero(as_tuple=False)[0]
68 preds[b, :] = preds[b, : eos_pos + 1] # pad_idx
69 except:
70 pass
72 return preds
75def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
76 mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
77 mask = (
78 mask.float()
79 .masked_fill(mask == 0, float("-inf"))
80 .masked_fill(mask == 1, float(0.0))
81 )
82 return mask
85def calculating_accuracy(pred, mml):
86 """
87 calculate accuracy
89 params:
90 pred and mml: (B, l)
91 """
92 train_acc = torch.sum(pred == mml)
93 return train_acc
96def beam_search(data, k, alpha, min_length):
97 """
98 predicting k best possible sequences
99 using beam search
101 params:
102 data: (1, seq_len, output_dim)
103 k: beam search parameter
104 alpha: degree of regularization in length_normalization
105 min_length: param for length_normalization
106 """
108 # data: (maxlen, output_dim)
109 sequences = [[list(), 0.0]]
110 # walk over each step in sequence
111 for row in data:
112 all_candidates = list()
113 for i in range(len(sequences)):
114 seq, score = sequences[i]
115 log_row = row.log()
116 for j in range(len(row)):
117 # candidate = [seq + [j], score - math.log(row[j])]
118 candidate = [seq + [j], score - log_row[j]]
119 all_candidates.append(candidate)
121 # order all candiadates by score
122 ordered = sorted(all_candidates, key=lambda t: t[1])
123 sequences = ordered[:1]
124 return sequences
127def length_normalization(sequence_length, alpha, min_length):
128 ln = (1 + sequence_length) ** alpha / (1 + min_length) ** alpha
129 return ln