Coverage for skema/metal/model_linker/skema_model_linker/linkers/amr_linker.py: 26%

85 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import abc 

2from abc import ABC 

3from collections import defaultdict 

4from typing import Iterable, Dict, List, Any, Tuple, Optional, Union 

5 

6import pandas as pd 

7import torch 

8from askem_extractions.data_model import Attribute, AnchoredEntity, AttributeCollection, AttributeType 

9from sentence_transformers import SentenceTransformer, util 

10 

11from ..walkers import JsonNode, JsonDictWalker 

12 

13 

14class Linker(ABC): 

15 def __init__(self, model_name: str, sim_threshold: float = 0.7, device: Optional[str] = None): 

16 self._model_name = model_name 

17 self._model = SentenceTransformer(model_name) 

18 self._threshold = sim_threshold 

19 

20 if device: 

21 self._model.to(device) 

22 

23 self._model.eval() 

24 

25 @abc.abstractmethod 

26 def _build_walker(self, amr_data: Dict[str, Any]) -> JsonDictWalker: 

27 pass 

28 

29 @abc.abstractmethod 

30 def _generate_linking_sources(self, elements: Iterable[JsonNode]) -> Dict[str, List[Any]]: 

31 """" Will generate candidate texts to link to text extractions """ 

32 pass 

33 

34 def _align_texts(self, sources: List[str], targets: List[str], threshold: float) -> List[Tuple[str, str]]: 

35 

36 if len(sources) > 0 and len(targets) > 0: 

37 with torch.no_grad(): 

38 s_embs = self._model.encode(sources) 

39 t_embs = self._model.encode(targets) 

40 

41 similarities = util.pytorch_cos_sim(s_embs, t_embs) 

42 

43 indices = (similarities >= threshold).nonzero() 

44 

45 ret = list() 

46 for ix in indices: 

47 ret.append((sources[ix[0]], targets[ix[1]])) 

48 

49 return ret 

50 else: 

51 return [] 

52 

53 def _generate_linking_targets(self, extractions: Iterable[Attribute]) -> Dict[str, List[AnchoredEntity]]: 

54 """ Will generate candidate texts to link to model elements """ 

55 ret = defaultdict(list) 

56 for ex in extractions: 

57 for name in ex.payload.mentions: 

58 if len(ex.payload.text_descriptions) > 0: 

59 for desc in ex.payload.text_descriptions: 

60 ret[f"{name.name.strip()}: {desc.description.strip()}"].append(ex) 

61 ret[desc.description.strip()].append(ex) 

62 else: 

63 candidate_text = f"{name.name.strip()}" 

64 ret[candidate_text].append(ex) 

65 return ret 

66 

67 @abc.abstractmethod 

68 def link_model_to_text_extractions(self, data: Union[Any, Dict[str, Any]], extractions: AttributeCollection) -> \ 

69 Dict[str, Any]: 

70 pass 

71 

72 

73class AMRLinker(Linker, ABC): 

74 

75 def link_model_to_manual_annotations(self, data: Dict[str, Any], candidates: List[Dict[str, Any]]) -> pd.DataFrame: 

76 """ 

77 Similarly to linking a model to text extractions. This will link it to ground truth extractions 

78 Used mostly for debugging 

79 """ 

80 

81 # Make a copy of the amr to avoid mutating the original model 

82 data = {**data} 

83 

84 # Filter out the targets from the annotations 

85 targets = defaultdict(list) 

86 for candidate in candidates: 

87 if candidate['type'] == "Highlight" and candidate['color'] == "#f9cd59": # This color is an anchored extraction 

88 key = candidate["text"] 

89 targets[key].append(candidate) 

90 

91 

92 walker = self._build_walker(data) 

93 

94 to_link = list(walker.walk()) 

95 sources = self._generate_linking_sources(to_link) 

96 

97 pairs = self._align_texts(list(sources.keys()), list(targets.keys()), threshold=self._threshold) 

98 

99 linked_targets = list() 

100 for s_key, t_key in pairs: 

101 source = sources[s_key] 

102 target = targets[t_key] 

103 

104 # Get the AMR ID of the source and add it to the target extractions 

105 for t in target: 

106 t['amr_element_id'] = source['id'] 

107 linked_targets.append(t) 

108 

109 # Serialize the attribute collection to json, after alignment 

110 data["metadata"] = linked_targets 

111 

112 return data 

113 

114 def link_model_to_text_extractions(self, data: Dict[str, Any], extractions: AttributeCollection) -> Dict[str, Any]: 

115 

116 # Make a copy of the amr to avoid mutating the original model 

117 data = {**data} 

118 

119 targets = self._generate_linking_targets( 

120 e for e in extractions.attributes if e.type == AttributeType.anchored_entity) 

121 

122 walker = self._build_walker(data) 

123 

124 to_link = list(walker.walk()) 

125 sources = self._generate_linking_sources(to_link) 

126 

127 pairs = self._align_texts(list(sources.keys()), list(targets.keys()), threshold=self._threshold) 

128 

129 for s_key, t_key in pairs: 

130 source = sources[s_key] 

131 target = targets[t_key] 

132 

133 # Get the AMR ID of the source and add it to the target extractions 

134 for t in target: 

135 t.amr_element_id = source['id'] 

136 

137 # Serialize the attribute collection to json, after alignment 

138 attribute_dict = extractions.model_dump(exclude_unset=True) 

139 data["metadata"] = attribute_dict 

140 

141 return data