Coverage for skema/rest/metal_proxy.py: 44%
43 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import itertools as it
2import json
4from askem_extractions.data_model import AttributeCollection
5from fastapi import UploadFile, File, APIRouter, FastAPI
7from skema.metal.model_linker.skema_model_linker.link_amr import replace_xml_codepoints
8from skema.metal.model_linker.skema_model_linker.linkers import PetriNetLinker, RegNetLinker
9from skema.metal.model_linker.skema_model_linker.linkers.generalizer_amr_linker import GeneralizedAMRLinker
10from skema.rest.schema import AMRLinkingEvaluationResults
11from skema.rest.utils import compute_amr_linking_evaluation
13router = APIRouter()
16@router.post(
17 "/link_amr",
18)
19def link_amr(similarity_model: str = "sentence-transformers/all-MiniLM-L6-v2",
20 similarity_threshold: float = 0.5,
21 amr_file: UploadFile = File(...),
22 text_extractions_file: UploadFile = File(...)):
23 """ Links an AMR to a text extractions file
25 ### Python example
26 ```
28 files = {
29 "amr_file": ("amr.json", open("amr.json"), "application/json"),
30 "text_extractions_file": ("extractions.json", open("extractions.json"), "application/json")
31 }
33 response = requests.post(f"{ENDPOINT}/metal/link_amr", files=files)
34 if response.status_code == 200:
35 enriched_amr = response.json()
36 ```
37 """
39 # Load the AMR
40 amr = json.load(amr_file.file)
41 amr = replace_xml_codepoints(amr)
43 # Load the extractions, that come out of the TR Proxy endpoint
44 raw_extractions = json.load(text_extractions_file.file)
45 if 'outputs' in raw_extractions:
46 text_extractions = [AttributeCollection.from_json(o['data']) for o in raw_extractions['outputs']]
47 # Merge all the attribute collections
48 extractions = AttributeCollection(
49 attributes=list(
50 it.chain.from_iterable(o.attributes for o in text_extractions)
51 )
52 )
53 else:
54 extractions = AttributeCollection.from_json(raw_extractions)
55 # text_extractions = TextReadingAnnotationsOutput(**json.load(text_extractions_file.file))
57 # Get the AMR type from the header of the json
58 if 'schema_name' in amr:
59 amr_type = amr['schema_name'].lower()
60 elif 'header' in amr and 'schema_name' in amr['header']:
61 amr_type = amr['header']['schema_name'].lower()
62 else:
63 raise Exception("Schema name missing in AMR")
65 # Link the AMR
66 if amr_type == "petrinet":
67 Linker = PetriNetLinker
68 elif amr_type == "regnet":
69 Linker = RegNetLinker
70 elif amr_type == "generalized amr":
71 Linker = GeneralizedAMRLinker
72 else:
73 raise NotImplementedError(f"{amr_type} AMR currently not supported")
75 linker = Linker(model_name=similarity_model, sim_threshold=similarity_threshold)
77 return linker.link_model_to_text_extractions(amr, extractions)
80@router.get(
81 "/healthcheck",
82 response_model=int,
83 status_code=200,
84 responses={
85 200: {
86 "model": int,
87 "description": "All component services are healthy (200 status)",
88 },
89 500: {
90 "model": int,
91 "description": "Internal error occurred",
92 "example_value": 500
93 }
94 },
95)
96def healthcheck():
97 return 200
99@router.post("/eval", response_model=AMRLinkingEvaluationResults, status_code=200)
100def quantitative_eval(linked_amr_file: UploadFile, gt_linked_amr_file: UploadFile) -> AMRLinkingEvaluationResults:
101 """
102 # Gets performance metrics of a linked amr with variable extractions against a ground truth linked amr.
104 ## Example:
105 ```python
106 files = {
107 "linked_amr": ("linked_amr_file.json", open("linked_amr_file.json", 'rb')),
108 "gt_linked_amr_file": ("gt_linked_amr_file.json", open("gt_linked_amr_file.json", 'rb')),
109 }
111 response = requests.post(f"{endpoint}/metal/eval", files=files)
112 ```
114 """
116 linked_amr = json.load(linked_amr_file.file)
117 gt_linked_amr_file = json.load(gt_linked_amr_file.file)
119 return compute_amr_linking_evaluation(linked_amr, gt_linked_amr_file)
121app = FastAPI()
122app.include_router(router)