Coverage for skema/rest/metal_proxy.py: 44%

43 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import itertools as it 

2import json 

3 

4from askem_extractions.data_model import AttributeCollection 

5from fastapi import UploadFile, File, APIRouter, FastAPI 

6 

7from skema.metal.model_linker.skema_model_linker.link_amr import replace_xml_codepoints 

8from skema.metal.model_linker.skema_model_linker.linkers import PetriNetLinker, RegNetLinker 

9from skema.metal.model_linker.skema_model_linker.linkers.generalizer_amr_linker import GeneralizedAMRLinker 

10from skema.rest.schema import AMRLinkingEvaluationResults 

11from skema.rest.utils import compute_amr_linking_evaluation 

12 

13router = APIRouter() 

14 

15 

16@router.post( 

17 "/link_amr", 

18) 

19def link_amr(similarity_model: str = "sentence-transformers/all-MiniLM-L6-v2", 

20 similarity_threshold: float = 0.5, 

21 amr_file: UploadFile = File(...), 

22 text_extractions_file: UploadFile = File(...)): 

23 """ Links an AMR to a text extractions file 

24 

25 ### Python example 

26 ``` 

27 

28 files = { 

29 "amr_file": ("amr.json", open("amr.json"), "application/json"), 

30 "text_extractions_file": ("extractions.json", open("extractions.json"), "application/json") 

31 } 

32 

33 response = requests.post(f"{ENDPOINT}/metal/link_amr", files=files) 

34 if response.status_code == 200: 

35 enriched_amr = response.json() 

36 ``` 

37 """ 

38 

39 # Load the AMR 

40 amr = json.load(amr_file.file) 

41 amr = replace_xml_codepoints(amr) 

42 

43 # Load the extractions, that come out of the TR Proxy endpoint 

44 raw_extractions = json.load(text_extractions_file.file) 

45 if 'outputs' in raw_extractions: 

46 text_extractions = [AttributeCollection.from_json(o['data']) for o in raw_extractions['outputs']] 

47 # Merge all the attribute collections 

48 extractions = AttributeCollection( 

49 attributes=list( 

50 it.chain.from_iterable(o.attributes for o in text_extractions) 

51 ) 

52 ) 

53 else: 

54 extractions = AttributeCollection.from_json(raw_extractions) 

55 # text_extractions = TextReadingAnnotationsOutput(**json.load(text_extractions_file.file)) 

56 

57 # Get the AMR type from the header of the json 

58 if 'schema_name' in amr: 

59 amr_type = amr['schema_name'].lower() 

60 elif 'header' in amr and 'schema_name' in amr['header']: 

61 amr_type = amr['header']['schema_name'].lower() 

62 else: 

63 raise Exception("Schema name missing in AMR") 

64 

65 # Link the AMR 

66 if amr_type == "petrinet": 

67 Linker = PetriNetLinker 

68 elif amr_type == "regnet": 

69 Linker = RegNetLinker 

70 elif amr_type == "generalized amr": 

71 Linker = GeneralizedAMRLinker 

72 else: 

73 raise NotImplementedError(f"{amr_type} AMR currently not supported") 

74 

75 linker = Linker(model_name=similarity_model, sim_threshold=similarity_threshold) 

76 

77 return linker.link_model_to_text_extractions(amr, extractions) 

78 

79 

80@router.get( 

81 "/healthcheck", 

82 response_model=int, 

83 status_code=200, 

84 responses={ 

85 200: { 

86 "model": int, 

87 "description": "All component services are healthy (200 status)", 

88 }, 

89 500: { 

90 "model": int, 

91 "description": "Internal error occurred", 

92 "example_value": 500 

93 } 

94 }, 

95) 

96def healthcheck(): 

97 return 200 

98 

99@router.post("/eval", response_model=AMRLinkingEvaluationResults, status_code=200) 

100def quantitative_eval(linked_amr_file: UploadFile, gt_linked_amr_file: UploadFile) -> AMRLinkingEvaluationResults: 

101 """ 

102 # Gets performance metrics of a linked amr with variable extractions against a ground truth linked amr. 

103 

104 ## Example: 

105 ```python 

106 files = { 

107 "linked_amr": ("linked_amr_file.json", open("linked_amr_file.json", 'rb')), 

108 "gt_linked_amr_file": ("gt_linked_amr_file.json", open("gt_linked_amr_file.json", 'rb')), 

109 } 

110 

111 response = requests.post(f"{endpoint}/metal/eval", files=files) 

112 ``` 

113 

114 """ 

115 

116 linked_amr = json.load(linked_amr_file.file) 

117 gt_linked_amr_file = json.load(gt_linked_amr_file.file) 

118 

119 return compute_amr_linking_evaluation(linked_amr, gt_linked_amr_file) 

120 

121app = FastAPI() 

122app.include_router(router)