Coverage for skema/metal/model_linker/skema_model_linker/link_amr.py: 21%

53 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import json, html 

2from pathlib import Path 

3from typing import Optional 

4 

5import fire.fire_test 

6from askem_extractions.data_model import AttributeCollection 

7 

8from .linkers import PetriNetLinker, RegNetLinker 

9import itertools as it 

10 

11def replace_xml_codepoints(json): 

12 """ Looks for xml special characters and substitutes them with their unicode character """ 

13 def clean(text): 

14 return html.unescape(text) if text.startswith("&#") else text 

15 

16 if isinstance(json, list): 

17 return [replace_xml_codepoints(elem) for elem in json] 

18 elif isinstance(json, dict): 

19 return {clean(k):replace_xml_codepoints(v) for k, v in json.items()} 

20 elif isinstance(json, str): 

21 return clean(json) 

22 else: 

23 return json 

24 

25def link_amr( 

26 amr_path: str, # Path of the AMR model 

27 attribute_collection: str, # Path to the attribute collection 

28 amr_type: str, # AMR model type. I.e. "petrinet" or "regnet" 

29 eval_mode: bool = False, # True when the extractions are manual annotations 

30 output_path: Optional[str] = None, # Output file path 

31 clean_xml_codepoints: Optional[bool] = False, # Replaces html codepoints with the unicode character 

32 similarity_model: str = "sentence-transformers/all-MiniLM-L6-v2", # Transformer model to compute similarities 

33 similarity_threshold: float = 0.7, # Cosine similarity threshold for linking 

34 device: Optional[str] = None # PyTorch device to run the model on 

35): 

36 """ Links and AMR model to an attribute collections from ASKEM text reading pipelines """ 

37 

38 if amr_type == "petrinet": 

39 Linker = PetriNetLinker 

40 elif amr_type == "regnet": 

41 Linker = RegNetLinker 

42 else: 

43 raise NotImplementedError(f"{amr_type} AMR currently not supported") 

44 

45 with open(amr_path) as f: 

46 amr = json.load(f) 

47 if clean_xml_codepoints: 

48 amr = replace_xml_codepoints(amr) 

49 

50 

51 

52 linker = Linker(model_name=similarity_model, device=device, sim_threshold=similarity_threshold) 

53 

54 if not eval_mode: 

55 # Handle extractions from the SKEMA service or directly from the library 

56 try: 

57 extractions = AttributeCollection.from_json(attribute_collection) 

58 except KeyError: 

59 with open(attribute_collection) as f: 

60 service_output = json.load(f) 

61 collections = list() 

62 for collection in service_output['outputs']: 

63 collection = AttributeCollection.from_json(collection['data']) 

64 collections.append(collection) 

65 

66 extractions = AttributeCollection( 

67 attributes=list(it.chain.from_iterable(c.attributes for c in collections))) 

68 linked_model = linker.link_model_to_text_extractions(amr, extractions) 

69 else: 

70 with open(attribute_collection) as f: 

71 annotations = json.load(f) 

72 annotations = replace_xml_codepoints(annotations) 

73 linked_model = linker.link_model_to_manual_annotations(amr, annotations) 

74 

75 if not output_path: 

76 input_amr_name = str(Path(amr_path).name) 

77 output_path = f'linked_{input_amr_name}' 

78 

79 with open(output_path, 'w') as f: 

80 json.dump(linked_model, f, default=str, indent=2, ensure_ascii=False) 

81 

82 

83def main(): 

84 """ Module's entry point""" 

85 fire.Fire(link_amr) 

86 

87 

88if __name__ == "__main__": 

89 main()