Coverage for skema/metal/model_linker/skema_model_linker/linkers/amr_linker.py: 26%
85 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import abc
2from abc import ABC
3from collections import defaultdict
4from typing import Iterable, Dict, List, Any, Tuple, Optional, Union
6import pandas as pd
7import torch
8from askem_extractions.data_model import Attribute, AnchoredEntity, AttributeCollection, AttributeType
9from sentence_transformers import SentenceTransformer, util
11from ..walkers import JsonNode, JsonDictWalker
14class Linker(ABC):
15 def __init__(self, model_name: str, sim_threshold: float = 0.7, device: Optional[str] = None):
16 self._model_name = model_name
17 self._model = SentenceTransformer(model_name)
18 self._threshold = sim_threshold
20 if device:
21 self._model.to(device)
23 self._model.eval()
25 @abc.abstractmethod
26 def _build_walker(self, amr_data: Dict[str, Any]) -> JsonDictWalker:
27 pass
29 @abc.abstractmethod
30 def _generate_linking_sources(self, elements: Iterable[JsonNode]) -> Dict[str, List[Any]]:
31 """" Will generate candidate texts to link to text extractions """
32 pass
34 def _align_texts(self, sources: List[str], targets: List[str], threshold: float) -> List[Tuple[str, str]]:
36 if len(sources) > 0 and len(targets) > 0:
37 with torch.no_grad():
38 s_embs = self._model.encode(sources)
39 t_embs = self._model.encode(targets)
41 similarities = util.pytorch_cos_sim(s_embs, t_embs)
43 indices = (similarities >= threshold).nonzero()
45 ret = list()
46 for ix in indices:
47 ret.append((sources[ix[0]], targets[ix[1]]))
49 return ret
50 else:
51 return []
53 def _generate_linking_targets(self, extractions: Iterable[Attribute]) -> Dict[str, List[AnchoredEntity]]:
54 """ Will generate candidate texts to link to model elements """
55 ret = defaultdict(list)
56 for ex in extractions:
57 for name in ex.payload.mentions:
58 if len(ex.payload.text_descriptions) > 0:
59 for desc in ex.payload.text_descriptions:
60 ret[f"{name.name.strip()}: {desc.description.strip()}"].append(ex)
61 ret[desc.description.strip()].append(ex)
62 else:
63 candidate_text = f"{name.name.strip()}"
64 ret[candidate_text].append(ex)
65 return ret
67 @abc.abstractmethod
68 def link_model_to_text_extractions(self, data: Union[Any, Dict[str, Any]], extractions: AttributeCollection) -> \
69 Dict[str, Any]:
70 pass
73class AMRLinker(Linker, ABC):
75 def link_model_to_manual_annotations(self, data: Dict[str, Any], candidates: List[Dict[str, Any]]) -> pd.DataFrame:
76 """
77 Similarly to linking a model to text extractions. This will link it to ground truth extractions
78 Used mostly for debugging
79 """
81 # Make a copy of the amr to avoid mutating the original model
82 data = {**data}
84 # Filter out the targets from the annotations
85 targets = defaultdict(list)
86 for candidate in candidates:
87 if candidate['type'] == "Highlight" and candidate['color'] == "#f9cd59": # This color is an anchored extraction
88 key = candidate["text"]
89 targets[key].append(candidate)
92 walker = self._build_walker(data)
94 to_link = list(walker.walk())
95 sources = self._generate_linking_sources(to_link)
97 pairs = self._align_texts(list(sources.keys()), list(targets.keys()), threshold=self._threshold)
99 linked_targets = list()
100 for s_key, t_key in pairs:
101 source = sources[s_key]
102 target = targets[t_key]
104 # Get the AMR ID of the source and add it to the target extractions
105 for t in target:
106 t['amr_element_id'] = source['id']
107 linked_targets.append(t)
109 # Serialize the attribute collection to json, after alignment
110 data["metadata"] = linked_targets
112 return data
114 def link_model_to_text_extractions(self, data: Dict[str, Any], extractions: AttributeCollection) -> Dict[str, Any]:
116 # Make a copy of the amr to avoid mutating the original model
117 data = {**data}
119 targets = self._generate_linking_targets(
120 e for e in extractions.attributes if e.type == AttributeType.anchored_entity)
122 walker = self._build_walker(data)
124 to_link = list(walker.walk())
125 sources = self._generate_linking_sources(to_link)
127 pairs = self._align_texts(list(sources.keys()), list(targets.keys()), threshold=self._threshold)
129 for s_key, t_key in pairs:
130 source = sources[s_key]
131 target = targets[t_key]
133 # Get the AMR ID of the source and add it to the target extractions
134 for t in target:
135 t.amr_element_id = source['id']
137 # Serialize the attribute collection to json, after alignment
138 attribute_dict = extractions.model_dump(exclude_unset=True)
139 data["metadata"] = attribute_dict
141 return data