Coverage for skema/rest/utils.py: 57%

221 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import itertools as it 

2import httpx 

3from collections import defaultdict 

4from typing import Any, Dict 

5 

6from typing import List 

7from askem_extractions.data_model import AttributeCollection, AttributeType, AnchoredEntity 

8from bs4 import BeautifulSoup, Comment 

9 

10from skema.img2mml.api import get_mathml_from_latex 

11from skema.rest import config 

12from skema.rest.schema import TextReadingEvaluationResults, AMRLinkingEvaluationResults 

13 

14 

15# see https://stackoverflow.com/a/74401249 

16async def get_client(): 

17 # create a new client for each request 

18 async with httpx.AsyncClient(timeout=config.SKEMA_RS_DEFAULT_TIMEOUT, follow_redirects=True) as client: 

19 # yield the client to the endpoint function 

20 yield client 

21 # close the client when the request is done 

22 

23 

24def fn_preprocessor(function_network: Dict[str, Any]): 

25 fn_data = function_network.copy() 

26 

27 logs = [] 

28 

29 ''' 

30 We will currently preprocess based on 2 different common bugs 

31 1) wire tgt's being -1 -> which we will delete these wires 

32 2) metadata being inline for bf entries instead of an index into the metadata_collection -> which we will replace with an index of 2 

33 3) missing function_type field on a bf entry -> will replace with function_type: "IMPORTED" 

34 4) If there is not a body field to a function -> replace "FUNCTION" with "ABSTRACT and set "name":"unknown" 

35 5) If there are -1 entries in the metadata for line spans and col spans -> replaced with 1 

36 6) NOT DONE YET: In the future we will preprocess about function calls being arguments, in order to simplify extracting the dataflow  

37 ''' 

38 

39 # first we check the top bf level of wires and inline metadata:  

40 keys_to_check = ['bf', 'wff', 'wfopi', 'wfopo', 'wopio'] 

41 metadata_keys_to_check = ['line_begin', 'line_end', 'col_begin', 'col_end'] 

42 for key in metadata_keys_to_check: 

43 try: 

44 for (i, entry) in enumerate(fn_data['modules'][0]['metadata_collection']): 

45 try: 

46 for (j, datum) in enumerate(entry): 

47 try: 

48 if datum[key] == -1: 

49 datum[key] = 1 

50 logs.append( 

51 f"The {j + 1}'th metadata in the {i + 1} metadata index has -1 for the {key} entry") 

52 except: 

53 continue 

54 except: 

55 continue 

56 except: 

57 continue 

58 

59 for key in keys_to_check: 

60 if key == 'bf': 

61 try: 

62 for (i, entry) in enumerate(fn_data['modules'][0]['fn'][key]): 

63 try: 

64 metadata_obj = entry['metadata'] 

65 if not isinstance(metadata_obj, int): 

66 entry['metadata'] = 2 

67 logs.append(f"Inline metadata on {i + 1}'th entry in top level bf") 

68 except: 

69 continue 

70 try: 

71 temp = entry['function_type'] 

72 except: 

73 entry['function_type'] = "IMPORTED" 

74 logs.append(f"Missing function_type on {i + 1}'th entry in top level bf") 

75 try: 

76 if entry['function_type'] == "FUNCTION": 

77 temp = entry['body'] 

78 except: 

79 entry['function_type'] = "ABSTRACT" 

80 entry['name'] = "Unknown" 

81 logs.append(f"Missing Function body on {i + 1}'th entry in top level bf") 

82 except: 

83 continue 

84 else: 

85 try: 

86 for (i, entry) in enumerate(reversed(fn_data['modules'][0]['fn'][key])): 

87 try: 

88 if entry['tgt'] == -1: 

89 try: 

90 fn_data['modules'][0]['fn'][key].remove(entry) 

91 logs.append(f"The {i + 1}'th {key} wire in the top level bf is targeting -1") 

92 except: 

93 entry['tgt'] = 1 

94 except: 

95 continue 

96 except: 

97 continue 

98 

99 # now we iterate through the fn_array and do the same thing 

100 for (j, fn_ent) in enumerate(fn_data['modules'][0]['fn_array']): 

101 for key in keys_to_check: 

102 if key == 'bf': 

103 try: 

104 for (i, entry) in enumerate(fn_ent[key]): 

105 try: 

106 metadata_obj = entry['metadata'] 

107 if not isinstance(metadata_obj, int): 

108 entry['metadata'] = 2 

109 logs.append(f"Inline metadata on {i + 1}'th bf in the {j + 1}'th fn_array") 

110 except: 

111 continue 

112 try: 

113 temp = entry['function_type'] 

114 except: 

115 entry['function_type'] = "IMPORTED" 

116 logs.append(f"Missing function_type on {i + 1}'th bf in the {j + 1}'th fn_array") 

117 try: 

118 if entry['function_type'] == "FUNCTION": 

119 temp = entry['body'] 

120 except: 

121 entry['function_type'] = "ABSTRACT" 

122 entry['name'] = "Unknown" 

123 logs.append(f"Missing Function body on {i + 1}'th bf in the {j + 1}'th fn_array") 

124 except: 

125 continue 

126 else: 

127 try: 

128 for (i, entry) in enumerate(reversed(fn_ent[key])): 

129 if entry['tgt'] == -1: 

130 try: 

131 fn_ent[key][i].remove(entry) 

132 logs.append(f"The {i + 1}'th {key} wire in the {j + 1}'th fn_array is targeting -1") 

133 except: 

134 entry['tgt'] = 1 

135 except: 

136 continue 

137 

138 return fn_data, logs 

139 

140 

141def clean_mml(mml: str) -> str: 

142 """Cleans/sterilizes pMML for AMR generation service""" 

143 # FIXME: revisit if JSON deserialization on MORAE side changes 

144 to_remove = ["alttext", "display", "xmlns", "mathvariant", "class"] 

145 soup = BeautifulSoup(mml, "html.parser") 

146 # remove comments 

147 for comment in soup(text=lambda text: isinstance(text, Comment)): 

148 comment.extract() 

149 

150 # prune attributes 

151 for attr in to_remove: 

152 for tag in soup.find_all(attrs={attr: True}): 

153 del tag[attr] 

154 return str(soup).replace("\n", "") 

155 

156 

157def parse_equations(eqns: List[str]) -> List[str]: 

158 """Parses the equations based on if they are mathml or latex""" 

159 parsed_eqns: List[str] = [] 

160 for eqn in eqns: 

161 if "</math>" in eqn: 

162 parsed_eqns.append(clean_mml(eqn)) 

163 else: 

164 parsed_eqns.append(clean_mml(get_mathml_from_latex(eqn))) 

165 return parsed_eqns 

166 

167def extraction_matches_annotation(extraction: AnchoredEntity, annotation: Dict[str, Any], json_contents: Dict) -> bool: 

168 """ Determines whether the extraction matches the annotation""" 

169 

170 # First iteration of the matching algorithm 

171 

172 # Get the annotation's text 

173 gt_text = annotation["text"] 

174 

175 # Get the extractions text 

176 src = extraction.extraction_source 

177 m_text = json_contents[src.block]['content'][src.char_start:src.char_end] 

178 

179 return gt_text in m_text or m_text in gt_text 

180 

181 

182def compute_text_reading_evaluation(gt_data: list, attributes: AttributeCollection, 

183 json_contents: Dict) -> TextReadingEvaluationResults: 

184 """ Compute the coverage of text reading extractions """ 

185 

186 # Get the extractions from the attribute collection 

187 extractions = [a.payload for a in attributes.attributes if a.type == AttributeType.anchored_entity] 

188 

189 # Get the extraction annotations from the ground truth data 

190 annotations_by_page = defaultdict(list) 

191 for a in gt_data: 

192 if a["type"] == "Highlight" and a["color"] in {"#f9cd59", "#ffd100", "#0000ff"}: 

193 page = a["page"] 

194 annotations_by_page[page].append(a) 

195 

196 def annotation_key(a: Dict): 

197 return a['page'], tuple(a['start_xy']), a['text'] 

198 

199 # Count the matches 

200 tp, tn, fp, fn = 0, 0, 0, 0 

201 matched_annotations = set() 

202 for e in extractions: 

203 matched = False 

204 for m in e.mentions: 

205 if not matched: 

206 if m.extraction_source is not None: 

207 te = m.extraction_source 

208 if te.page is not None: 

209 e_page = te.page 

210 page_annotations = annotations_by_page[e_page] 

211 

212 for a in page_annotations: 

213 key = annotation_key(a) 

214 if key not in matched_annotations: 

215 if extraction_matches_annotation(m, a, json_contents): 

216 matched_annotations.add(key) 

217 matched = True 

218 tp += 1 

219 break 

220 if not matched: 

221 fp += 1 

222 

223 recall = tp / len(gt_data) 

224 precision = tp / (tp + fp + 0.00000000001) 

225 return TextReadingEvaluationResults( 

226 num_manual_annotations=len(gt_data), 

227 yield_=len(extractions), 

228 correct_extractions=tp, 

229 recall=recall, 

230 precision=precision, 

231 f1=(2 * precision * recall) / (precision + recall + .0000000001) 

232 ) 

233 

234 

235greek_alphabet = { 

236 'Α': 'alpha', 

237 'α': 'alpha', 

238 'Β': 'beta', 

239 'β': 'beta', 

240 'Γ': 'gamma', 

241 'γ': 'gamma', 

242 'Δ': 'delta', 

243 'δ': 'delta', 

244 'Ε': 'epsilon', 

245 'ε': 'epsilon', 

246 'Ζ': 'zeta', 

247 'ζ': 'zeta', 

248 'Η': 'eta', 

249 'η': 'eta', 

250 'Θ': 'theta', 

251 'θ': 'theta', 

252 'Ι': 'iota', 

253 'ι': 'iota', 

254 'Κ': 'kappa', 

255 'κ': 'kappa', 

256 'Λ': 'lambda', 

257 'λ': 'lambda', 

258 'Μ': 'mu', 

259 'μ': 'mu', 

260 'Ν': 'nu', 

261 'ν': 'nu', 

262 'Ξ': 'xi', 

263 'ξ': 'xi', 

264 'Ο': 'omicron', 

265 'ο': 'omicron', 

266 'Π': 'pi', 

267 'π': 'pi', 

268 'Ρ': 'rho', 

269 'ρ': 'rho', 

270 'Σ': 'sigma', 

271 'σ': 'sigma', 

272 'ς': 'sigma', 

273 'Τ': 'tau', 

274 'τ': 'tau', 

275 'Υ': 'upsilon', 

276 'υ': 'upsilon', 

277 'Φ': 'phi', 

278 'φ': 'phi', 

279 'Χ': 'chi', 

280 'χ': 'chi', 

281 'Ψ': 'psi', 

282 'ψ': 'psi', 

283 'Ω': 'omega', 

284 'ω': 'omega' 

285} 

286 

287 

288def compute_amr_linking_evaluation(linked_amr, gt_linked_amr) -> AMRLinkingEvaluationResults: 

289 # Find the amr elements with metadata in the GT 

290 gt_amr_ids = {m['amr_element_id'] for m in gt_linked_amr['metadata'] if m['amr_element_id'] is not None} 

291 

292 # Fetch the relevant elements from both amrs 

293 def get_elem_by_id(data, ids): 

294 ret = list() 

295 if isinstance(data, list): 

296 ret.extend(it.chain.from_iterable(get_elem_by_id(a, ids) for a in data)) 

297 elif isinstance(data, dict): 

298 if "id" in data and data["id"] in ids: 

299 ret.append(data) 

300 else: 

301 ret.extend(it.chain.from_iterable(get_elem_by_id(v, ids) for k, v in data.items() if k != "metadata")) 

302 return ret 

303 

304 gt_elems = get_elem_by_id(gt_linked_amr, gt_amr_ids) 

305 runtime_elems = get_elem_by_id(linked_amr, gt_amr_ids) 

306 

307 # Generate metadata dictionaries 

308 gt_metadata = defaultdict(list) 

309 for m in gt_linked_amr['metadata']: 

310 gt_metadata[m['amr_element_id']].append(m) 

311 

312 runtime_metadata = defaultdict(list) 

313 for m in linked_amr['metadata']['attributes']: 

314 runtime_metadata[m['amr_element_id']].append(m) 

315 

316 # Compute the numbers 

317 tp, tn, fp, fn = 0, 0, 0, 0 

318 

319 for amr_id in gt_amr_ids: 

320 gt = gt_metadata[amr_id] 

321 rt = runtime_metadata[amr_id] 

322 

323 # Get the text from the ground truth 

324 gt_texts = {e['text'] for e in gt} 

325 expanded_gt_texts = set() 

326 for t in gt_texts: 

327 for k, v in greek_alphabet.items(): 

328 if k in t: 

329 expanded_gt_texts.add(t.replace(k, v)) 

330 gt_texts |= expanded_gt_texts 

331 

332 # Get the text from the automated extractions 

333 rt_texts = set() 

334 for e in rt: 

335 e = e['payload'] 

336 for m in e['mentions']: 

337 name = m['name'] 

338 for d in e['text_descriptions']: 

339 desc = d['description'] 

340 rt_texts.add((name, desc)) 

341 for v in e['value_descriptions']: 

342 val = v['value']['amount'] 

343 rt_texts.add((name, val)) 

344 

345 # Compute hits and misses 

346 if len(gt_texts) > 0: 

347 hit = False 

348 for gtt in gt_texts: 

349 if not hit: 

350 for (a, b) in rt_texts: 

351 # Both the name and the desc have to be present in the 

352 # annotation in order to be a "hit" 

353 if a in gtt and b in gtt: 

354 tp += 1 

355 hit = True 

356 break 

357 # If we made it to this point and neither of the extractions matched 

358 # then, this is a false negative 

359 fn += 1 

360 elif len(rt_texts) > 0: 

361 fp += 1 

362 else: 

363 tn += 1 

364 

365 precision = tp / ((tp + fp) + 0.000000001) 

366 recall = tp / ((tp + fn) + 0.000000001) 

367 

368 f1 = (2 * precision * recall) / ((precision + recall) + 0.000000001) 

369 

370 return AMRLinkingEvaluationResults( 

371 num_gt_elems_with_metadata=len(gt_amr_ids), 

372 precision=precision, 

373 recall=recall, 

374 f1=f1 

375 )