Coverage for skema/rest/utils.py: 57%
221 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import itertools as it
2import httpx
3from collections import defaultdict
4from typing import Any, Dict
6from typing import List
7from askem_extractions.data_model import AttributeCollection, AttributeType, AnchoredEntity
8from bs4 import BeautifulSoup, Comment
10from skema.img2mml.api import get_mathml_from_latex
11from skema.rest import config
12from skema.rest.schema import TextReadingEvaluationResults, AMRLinkingEvaluationResults
15# see https://stackoverflow.com/a/74401249
16async def get_client():
17 # create a new client for each request
18 async with httpx.AsyncClient(timeout=config.SKEMA_RS_DEFAULT_TIMEOUT, follow_redirects=True) as client:
19 # yield the client to the endpoint function
20 yield client
21 # close the client when the request is done
24def fn_preprocessor(function_network: Dict[str, Any]):
25 fn_data = function_network.copy()
27 logs = []
29 '''
30 We will currently preprocess based on 2 different common bugs
31 1) wire tgt's being -1 -> which we will delete these wires
32 2) metadata being inline for bf entries instead of an index into the metadata_collection -> which we will replace with an index of 2
33 3) missing function_type field on a bf entry -> will replace with function_type: "IMPORTED"
34 4) If there is not a body field to a function -> replace "FUNCTION" with "ABSTRACT and set "name":"unknown"
35 5) If there are -1 entries in the metadata for line spans and col spans -> replaced with 1
36 6) NOT DONE YET: In the future we will preprocess about function calls being arguments, in order to simplify extracting the dataflow
37 '''
39 # first we check the top bf level of wires and inline metadata:
40 keys_to_check = ['bf', 'wff', 'wfopi', 'wfopo', 'wopio']
41 metadata_keys_to_check = ['line_begin', 'line_end', 'col_begin', 'col_end']
42 for key in metadata_keys_to_check:
43 try:
44 for (i, entry) in enumerate(fn_data['modules'][0]['metadata_collection']):
45 try:
46 for (j, datum) in enumerate(entry):
47 try:
48 if datum[key] == -1:
49 datum[key] = 1
50 logs.append(
51 f"The {j + 1}'th metadata in the {i + 1} metadata index has -1 for the {key} entry")
52 except:
53 continue
54 except:
55 continue
56 except:
57 continue
59 for key in keys_to_check:
60 if key == 'bf':
61 try:
62 for (i, entry) in enumerate(fn_data['modules'][0]['fn'][key]):
63 try:
64 metadata_obj = entry['metadata']
65 if not isinstance(metadata_obj, int):
66 entry['metadata'] = 2
67 logs.append(f"Inline metadata on {i + 1}'th entry in top level bf")
68 except:
69 continue
70 try:
71 temp = entry['function_type']
72 except:
73 entry['function_type'] = "IMPORTED"
74 logs.append(f"Missing function_type on {i + 1}'th entry in top level bf")
75 try:
76 if entry['function_type'] == "FUNCTION":
77 temp = entry['body']
78 except:
79 entry['function_type'] = "ABSTRACT"
80 entry['name'] = "Unknown"
81 logs.append(f"Missing Function body on {i + 1}'th entry in top level bf")
82 except:
83 continue
84 else:
85 try:
86 for (i, entry) in enumerate(reversed(fn_data['modules'][0]['fn'][key])):
87 try:
88 if entry['tgt'] == -1:
89 try:
90 fn_data['modules'][0]['fn'][key].remove(entry)
91 logs.append(f"The {i + 1}'th {key} wire in the top level bf is targeting -1")
92 except:
93 entry['tgt'] = 1
94 except:
95 continue
96 except:
97 continue
99 # now we iterate through the fn_array and do the same thing
100 for (j, fn_ent) in enumerate(fn_data['modules'][0]['fn_array']):
101 for key in keys_to_check:
102 if key == 'bf':
103 try:
104 for (i, entry) in enumerate(fn_ent[key]):
105 try:
106 metadata_obj = entry['metadata']
107 if not isinstance(metadata_obj, int):
108 entry['metadata'] = 2
109 logs.append(f"Inline metadata on {i + 1}'th bf in the {j + 1}'th fn_array")
110 except:
111 continue
112 try:
113 temp = entry['function_type']
114 except:
115 entry['function_type'] = "IMPORTED"
116 logs.append(f"Missing function_type on {i + 1}'th bf in the {j + 1}'th fn_array")
117 try:
118 if entry['function_type'] == "FUNCTION":
119 temp = entry['body']
120 except:
121 entry['function_type'] = "ABSTRACT"
122 entry['name'] = "Unknown"
123 logs.append(f"Missing Function body on {i + 1}'th bf in the {j + 1}'th fn_array")
124 except:
125 continue
126 else:
127 try:
128 for (i, entry) in enumerate(reversed(fn_ent[key])):
129 if entry['tgt'] == -1:
130 try:
131 fn_ent[key][i].remove(entry)
132 logs.append(f"The {i + 1}'th {key} wire in the {j + 1}'th fn_array is targeting -1")
133 except:
134 entry['tgt'] = 1
135 except:
136 continue
138 return fn_data, logs
141def clean_mml(mml: str) -> str:
142 """Cleans/sterilizes pMML for AMR generation service"""
143 # FIXME: revisit if JSON deserialization on MORAE side changes
144 to_remove = ["alttext", "display", "xmlns", "mathvariant", "class"]
145 soup = BeautifulSoup(mml, "html.parser")
146 # remove comments
147 for comment in soup(text=lambda text: isinstance(text, Comment)):
148 comment.extract()
150 # prune attributes
151 for attr in to_remove:
152 for tag in soup.find_all(attrs={attr: True}):
153 del tag[attr]
154 return str(soup).replace("\n", "")
157def parse_equations(eqns: List[str]) -> List[str]:
158 """Parses the equations based on if they are mathml or latex"""
159 parsed_eqns: List[str] = []
160 for eqn in eqns:
161 if "</math>" in eqn:
162 parsed_eqns.append(clean_mml(eqn))
163 else:
164 parsed_eqns.append(clean_mml(get_mathml_from_latex(eqn)))
165 return parsed_eqns
167def extraction_matches_annotation(extraction: AnchoredEntity, annotation: Dict[str, Any], json_contents: Dict) -> bool:
168 """ Determines whether the extraction matches the annotation"""
170 # First iteration of the matching algorithm
172 # Get the annotation's text
173 gt_text = annotation["text"]
175 # Get the extractions text
176 src = extraction.extraction_source
177 m_text = json_contents[src.block]['content'][src.char_start:src.char_end]
179 return gt_text in m_text or m_text in gt_text
182def compute_text_reading_evaluation(gt_data: list, attributes: AttributeCollection,
183 json_contents: Dict) -> TextReadingEvaluationResults:
184 """ Compute the coverage of text reading extractions """
186 # Get the extractions from the attribute collection
187 extractions = [a.payload for a in attributes.attributes if a.type == AttributeType.anchored_entity]
189 # Get the extraction annotations from the ground truth data
190 annotations_by_page = defaultdict(list)
191 for a in gt_data:
192 if a["type"] == "Highlight" and a["color"] in {"#f9cd59", "#ffd100", "#0000ff"}:
193 page = a["page"]
194 annotations_by_page[page].append(a)
196 def annotation_key(a: Dict):
197 return a['page'], tuple(a['start_xy']), a['text']
199 # Count the matches
200 tp, tn, fp, fn = 0, 0, 0, 0
201 matched_annotations = set()
202 for e in extractions:
203 matched = False
204 for m in e.mentions:
205 if not matched:
206 if m.extraction_source is not None:
207 te = m.extraction_source
208 if te.page is not None:
209 e_page = te.page
210 page_annotations = annotations_by_page[e_page]
212 for a in page_annotations:
213 key = annotation_key(a)
214 if key not in matched_annotations:
215 if extraction_matches_annotation(m, a, json_contents):
216 matched_annotations.add(key)
217 matched = True
218 tp += 1
219 break
220 if not matched:
221 fp += 1
223 recall = tp / len(gt_data)
224 precision = tp / (tp + fp + 0.00000000001)
225 return TextReadingEvaluationResults(
226 num_manual_annotations=len(gt_data),
227 yield_=len(extractions),
228 correct_extractions=tp,
229 recall=recall,
230 precision=precision,
231 f1=(2 * precision * recall) / (precision + recall + .0000000001)
232 )
235greek_alphabet = {
236 'Α': 'alpha',
237 'α': 'alpha',
238 'Β': 'beta',
239 'β': 'beta',
240 'Γ': 'gamma',
241 'γ': 'gamma',
242 'Δ': 'delta',
243 'δ': 'delta',
244 'Ε': 'epsilon',
245 'ε': 'epsilon',
246 'Ζ': 'zeta',
247 'ζ': 'zeta',
248 'Η': 'eta',
249 'η': 'eta',
250 'Θ': 'theta',
251 'θ': 'theta',
252 'Ι': 'iota',
253 'ι': 'iota',
254 'Κ': 'kappa',
255 'κ': 'kappa',
256 'Λ': 'lambda',
257 'λ': 'lambda',
258 'Μ': 'mu',
259 'μ': 'mu',
260 'Ν': 'nu',
261 'ν': 'nu',
262 'Ξ': 'xi',
263 'ξ': 'xi',
264 'Ο': 'omicron',
265 'ο': 'omicron',
266 'Π': 'pi',
267 'π': 'pi',
268 'Ρ': 'rho',
269 'ρ': 'rho',
270 'Σ': 'sigma',
271 'σ': 'sigma',
272 'ς': 'sigma',
273 'Τ': 'tau',
274 'τ': 'tau',
275 'Υ': 'upsilon',
276 'υ': 'upsilon',
277 'Φ': 'phi',
278 'φ': 'phi',
279 'Χ': 'chi',
280 'χ': 'chi',
281 'Ψ': 'psi',
282 'ψ': 'psi',
283 'Ω': 'omega',
284 'ω': 'omega'
285}
288def compute_amr_linking_evaluation(linked_amr, gt_linked_amr) -> AMRLinkingEvaluationResults:
289 # Find the amr elements with metadata in the GT
290 gt_amr_ids = {m['amr_element_id'] for m in gt_linked_amr['metadata'] if m['amr_element_id'] is not None}
292 # Fetch the relevant elements from both amrs
293 def get_elem_by_id(data, ids):
294 ret = list()
295 if isinstance(data, list):
296 ret.extend(it.chain.from_iterable(get_elem_by_id(a, ids) for a in data))
297 elif isinstance(data, dict):
298 if "id" in data and data["id"] in ids:
299 ret.append(data)
300 else:
301 ret.extend(it.chain.from_iterable(get_elem_by_id(v, ids) for k, v in data.items() if k != "metadata"))
302 return ret
304 gt_elems = get_elem_by_id(gt_linked_amr, gt_amr_ids)
305 runtime_elems = get_elem_by_id(linked_amr, gt_amr_ids)
307 # Generate metadata dictionaries
308 gt_metadata = defaultdict(list)
309 for m in gt_linked_amr['metadata']:
310 gt_metadata[m['amr_element_id']].append(m)
312 runtime_metadata = defaultdict(list)
313 for m in linked_amr['metadata']['attributes']:
314 runtime_metadata[m['amr_element_id']].append(m)
316 # Compute the numbers
317 tp, tn, fp, fn = 0, 0, 0, 0
319 for amr_id in gt_amr_ids:
320 gt = gt_metadata[amr_id]
321 rt = runtime_metadata[amr_id]
323 # Get the text from the ground truth
324 gt_texts = {e['text'] for e in gt}
325 expanded_gt_texts = set()
326 for t in gt_texts:
327 for k, v in greek_alphabet.items():
328 if k in t:
329 expanded_gt_texts.add(t.replace(k, v))
330 gt_texts |= expanded_gt_texts
332 # Get the text from the automated extractions
333 rt_texts = set()
334 for e in rt:
335 e = e['payload']
336 for m in e['mentions']:
337 name = m['name']
338 for d in e['text_descriptions']:
339 desc = d['description']
340 rt_texts.add((name, desc))
341 for v in e['value_descriptions']:
342 val = v['value']['amount']
343 rt_texts.add((name, val))
345 # Compute hits and misses
346 if len(gt_texts) > 0:
347 hit = False
348 for gtt in gt_texts:
349 if not hit:
350 for (a, b) in rt_texts:
351 # Both the name and the desc have to be present in the
352 # annotation in order to be a "hit"
353 if a in gtt and b in gtt:
354 tp += 1
355 hit = True
356 break
357 # If we made it to this point and neither of the extractions matched
358 # then, this is a false negative
359 fn += 1
360 elif len(rt_texts) > 0:
361 fp += 1
362 else:
363 tn += 1
365 precision = tp / ((tp + fp) + 0.000000001)
366 recall = tp / ((tp + fn) + 0.000000001)
368 f1 = (2 * precision * recall) / ((precision + recall) + 0.000000001)
370 return AMRLinkingEvaluationResults(
371 num_gt_elems_with_metadata=len(gt_amr_ids),
372 precision=precision,
373 recall=recall,
374 f1=f1
375 )