Coverage for skema/model_assembly/linking.py: 0%
220 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1from abc import ABC, abstractmethod
2from dataclasses import dataclass
3from functools import singledispatch
4import re
6from networkx import DiGraph
9@dataclass(repr=False, frozen=True)
10class LinkNode(ABC):
11 source: str
12 content: str
13 content_type: str
15 def __repr__(self):
16 return self.__str__()
18 def __str__(self):
19 return self.content
21 @staticmethod
22 def from_dict(data: dict):
23 args = (data["source"], data["content"], data["content_type"])
24 element_type = data["type"]
25 if element_type == "identifier":
26 return CodeVarNode(*args)
27 elif element_type == "comment_span":
28 return CommSpanNode(*args)
29 elif element_type == "equation_span":
30 return EqnSpanNode(*args)
31 elif element_type == "text_var":
32 query_string = ";".join(data["svo_query_terms"])
33 return TextVarNode(*args, query_string)
34 elif element_type == "text_span":
35 return TextSpanNode(*args)
36 else:
37 raise ValueError(f"Unrecognized link element type: {element_type}")
39 @abstractmethod
40 def get_table_rows(self, link_graph: DiGraph) -> list:
41 return NotImplemented
44@dataclass(repr=False, frozen=True)
45class CodeVarNode(LinkNode):
46 def __repr__(self):
47 return self.__str__()
49 def __str__(self):
50 (namespace, scope, basename, index) = self.content.split("::")
51 return "\n".join(
52 [
53 f"NAMESPACE: {namespace}",
54 f"SCOPE: {scope}",
55 f"NAME: {basename}",
56 f"INDEX: {index}",
57 ]
58 )
60 def get_varname(self) -> str:
61 (_, _, _, basename, _) = self.content.split("::")
62 return basename
64 def get_table_rows(self, L: DiGraph) -> list:
65 comm_span_nodes = [
66 n for n in L.predecessors(self) if isinstance(n, CommSpanNode)
67 ]
69 rows = list()
70 for comm_node in comm_span_nodes:
71 w_vc = L.edges[comm_node, self]["weight"]
72 for r in comm_node.get_table_rows(L):
73 w_row = min(w_vc, r["ct_score"], r["te_score"])
74 r.update({"vc_score": w_vc, "link_score": w_row})
75 rows.append(r)
77 return rows
80@dataclass(repr=False, frozen=True)
81class TextVarNode(LinkNode):
82 svo_query_str: str
84 def get_docname(self) -> str:
85 path_pieces = self.source.split("/")
86 doc_data = path_pieces[-1]
87 (docname, _) = doc_data.split(".pdf_")
88 return docname
90 def get_svo_terms(self):
91 return self.svo_query_str.split(";")
93 def get_table_rows(self, L: DiGraph) -> list:
94 # NOTE: nothing to do for now
95 return []
98@dataclass(repr=False, frozen=True)
99class CommSpanNode(LinkNode):
100 def __repr__(self):
101 return self.__str__()
103 def __str__(self):
104 tokens = self.content.strip().split()
105 if len(tokens) <= 4:
106 return " ".join(tokens)
108 new_content = ""
109 while len(tokens) > 4:
110 new_content += "\n" + " ".join(tokens[:4])
111 tokens = tokens[4:]
112 new_content += "\n" + " ".join(tokens)
113 return new_content
115 def get_comment_location(self):
116 (filename, sub_name, place) = self.source.split("; ")
117 filename = filename[: filename.rfind(".f")]
118 return f"{filename}::{sub_name}${place}"
120 def get_table_rows(self, L: DiGraph) -> list:
121 txt_span_nodes = [
122 n for n in L.predecessors(self) if isinstance(n, TextSpanNode)
123 ]
125 rows = list()
126 for txt_node in txt_span_nodes:
127 w_ct = L.edges[txt_node, self]["weight"]
128 for r in txt_node.get_table_rows(L):
129 r.update({"comm": str(self), "ct_score": w_ct})
130 rows.append(r)
132 return rows
135@dataclass(repr=False, frozen=True)
136class TextSpanNode(LinkNode):
137 def __repr__(self):
138 return self.__str__()
140 def __str__(self):
141 tokens = self.content.strip().split()
142 if len(tokens) <= 4:
143 return " ".join(tokens)
145 new_content = ""
146 while len(tokens) > 4:
147 new_content += "\n" + " ".join(tokens[:4])
148 tokens = tokens[4:]
149 new_content += "\n" + " ".join(tokens)
150 return new_content
152 def __data_from_source(self) -> tuple:
153 path_pieces = self.source.split("/")
154 doc_data = path_pieces[-1]
155 return tuple(doc_data.split(".pdf_"))
157 def get_docname(self) -> str:
158 (docname, _) = self.__data_from_source()
159 return docname
161 def get_sentence_id(self) -> str:
162 (_, data) = self.__data_from_source()
163 (sent_num, span_start, span_stop) = re.findall(r"[0-9]+", data)
164 return
166 def get_table_rows(self, L: DiGraph) -> list:
167 eqn_span_nodes = [
168 n for n in L.predecessors(self) if isinstance(n, EqnSpanNode)
169 ]
171 rows = list()
172 for eqn_node in eqn_span_nodes:
173 w_te = L.edges[eqn_node, self]["weight"]
174 for r in eqn_node.get_table_rows(L):
175 r.update({"txt": str(self), "te_score": w_te})
176 rows.append(r)
178 return rows
181@dataclass(repr=False, frozen=True)
182class EqnSpanNode(LinkNode):
183 def get_table_rows(self, L: DiGraph) -> list:
184 return [{"eqn": str(self)}]
187def build_link_graph(link_hypotheses: list) -> DiGraph:
188 G = DiGraph()
190 def report_bad_link(n1, n2):
191 raise ValueError(f"Inappropriate link type: ({type(n1)}, {type(n2)})")
193 @singledispatch
194 def add_link_node(node):
195 raise ValueError(f"Inappropriate node type: {type(node)}")
197 @add_link_node.register
198 def _(node: CodeVarNode):
199 G.add_node(node, color="darkviolet")
201 @add_link_node.register
202 def _(node: CommSpanNode):
203 G.add_node(node, color="lightskyblue")
205 @add_link_node.register
206 def _(node: TextSpanNode):
207 G.add_node(node, color="crimson")
209 @add_link_node.register
210 def _(node: EqnSpanNode):
211 G.add_node(node, color="orange")
213 @add_link_node.register
214 def _(node: TextVarNode):
215 G.add_node(node, color="deeppink")
217 @singledispatch
218 def add_link(n1, n2, score):
219 raise ValueError(f"Inappropriate node type: {type(n1)}")
221 @add_link.register
222 def _(n1: CodeVarNode, n2, score):
223 add_link_node(n1)
224 add_link_node(n2)
226 if isinstance(n2, CommSpanNode):
227 G.add_edge(n2, n1, weight=score)
228 else:
229 report_bad_link(n1, n2)
231 @add_link.register
232 def _(n1: CommSpanNode, n2, score):
233 add_link_node(n1)
234 add_link_node(n2)
236 if isinstance(n2, CodeVarNode):
237 G.add_edge(n1, n2, weight=score)
238 elif isinstance(n2, TextSpanNode):
239 G.add_edge(n2, n1, weight=score)
240 else:
241 report_bad_link(n1, n2)
243 @add_link.register
244 def _(n1: TextSpanNode, n2, score):
245 add_link_node(n1)
246 add_link_node(n2)
248 if isinstance(n2, EqnSpanNode):
249 G.add_edge(n2, n1, weight=link_score)
250 elif isinstance(n2, CommSpanNode):
251 G.add_edge(n1, n2, weight=link_score)
252 elif isinstance(n2, TextVarNode):
253 G.add_edge(n2, n1, weight=link_score)
254 else:
255 report_bad_link(n1, n2)
257 @add_link.register
258 def _(n1: EqnSpanNode, n2, score):
259 add_link_node(n1)
260 add_link_node(n2)
262 if isinstance(n2, TextSpanNode):
263 G.add_edge(n1, n2, weight=link_score)
264 else:
265 report_bad_link(n1, n2)
267 @add_link.register
268 def _(n1: TextVarNode, n2, score):
269 add_link_node(n1)
270 add_link_node(n2)
272 if isinstance(n2, TextSpanNode):
273 G.add_edge(n1, n2, weight=link_score)
274 else:
275 report_bad_link(n1, n2)
277 for link_dict in link_hypotheses:
278 node1 = LinkNode.from_dict(link_dict["element_1"])
279 node2 = LinkNode.from_dict(link_dict["element_2"])
280 link_score = round(link_dict["score"], 3)
281 add_link(node1, node2, link_score)
283 return G
286def extract_link_tables(L: DiGraph) -> dict:
287 var_nodes = [n for n in L.nodes if isinstance(n, CodeVarNode)]
289 tables = dict()
290 for var_node in var_nodes:
291 var_name = str(var_node)
292 if var_name not in tables:
293 table_rows = var_node.get_table_rows(L)
294 table_rows.sort(
295 key=lambda r: (r["vc_score"], r["ct_score"], r["te_score"]),
296 reverse=True,
297 )
298 tables[var_name] = table_rows
300 return tables
303def print_table_data(table_data: dict) -> None:
304 for var_name, table in table_data.items():
305 print(var_name)
306 print("L-SCORE\tComment\tV-C\tText-span\tC-T\tEquation\tT-E")
307 for row in table:
308 row_data = [
309 str(row["link_score"]),
310 row["comm"],
311 str(row["vc_score"]),
312 row["txt"],
313 str(row["ct_score"]),
314 row["eqn"],
315 str(row["te_score"]),
316 ]
317 print("\t".join(row_data))
318 print("\n\n")