Coverage for skema/metal/model_linker/skema_model_linker/link_amr.py: 21%
53 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import json, html
2from pathlib import Path
3from typing import Optional
5import fire.fire_test
6from askem_extractions.data_model import AttributeCollection
8from .linkers import PetriNetLinker, RegNetLinker
9import itertools as it
11def replace_xml_codepoints(json):
12 """ Looks for xml special characters and substitutes them with their unicode character """
13 def clean(text):
14 return html.unescape(text) if text.startswith("&#") else text
16 if isinstance(json, list):
17 return [replace_xml_codepoints(elem) for elem in json]
18 elif isinstance(json, dict):
19 return {clean(k):replace_xml_codepoints(v) for k, v in json.items()}
20 elif isinstance(json, str):
21 return clean(json)
22 else:
23 return json
25def link_amr(
26 amr_path: str, # Path of the AMR model
27 attribute_collection: str, # Path to the attribute collection
28 amr_type: str, # AMR model type. I.e. "petrinet" or "regnet"
29 eval_mode: bool = False, # True when the extractions are manual annotations
30 output_path: Optional[str] = None, # Output file path
31 clean_xml_codepoints: Optional[bool] = False, # Replaces html codepoints with the unicode character
32 similarity_model: str = "sentence-transformers/all-MiniLM-L6-v2", # Transformer model to compute similarities
33 similarity_threshold: float = 0.7, # Cosine similarity threshold for linking
34 device: Optional[str] = None # PyTorch device to run the model on
35):
36 """ Links and AMR model to an attribute collections from ASKEM text reading pipelines """
38 if amr_type == "petrinet":
39 Linker = PetriNetLinker
40 elif amr_type == "regnet":
41 Linker = RegNetLinker
42 else:
43 raise NotImplementedError(f"{amr_type} AMR currently not supported")
45 with open(amr_path) as f:
46 amr = json.load(f)
47 if clean_xml_codepoints:
48 amr = replace_xml_codepoints(amr)
52 linker = Linker(model_name=similarity_model, device=device, sim_threshold=similarity_threshold)
54 if not eval_mode:
55 # Handle extractions from the SKEMA service or directly from the library
56 try:
57 extractions = AttributeCollection.from_json(attribute_collection)
58 except KeyError:
59 with open(attribute_collection) as f:
60 service_output = json.load(f)
61 collections = list()
62 for collection in service_output['outputs']:
63 collection = AttributeCollection.from_json(collection['data'])
64 collections.append(collection)
66 extractions = AttributeCollection(
67 attributes=list(it.chain.from_iterable(c.attributes for c in collections)))
68 linked_model = linker.link_model_to_text_extractions(amr, extractions)
69 else:
70 with open(attribute_collection) as f:
71 annotations = json.load(f)
72 annotations = replace_xml_codepoints(annotations)
73 linked_model = linker.link_model_to_manual_annotations(amr, annotations)
75 if not output_path:
76 input_amr_name = str(Path(amr_path).name)
77 output_path = f'linked_{input_amr_name}'
79 with open(output_path, 'w') as f:
80 json.dump(linked_model, f, default=str, indent=2, ensure_ascii=False)
83def main():
84 """ Module's entry point"""
85 fire.Fire(link_amr)
88if __name__ == "__main__":
89 main()