Coverage for skema/model_assembly/linking.py: 0%

220 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1from abc import ABC, abstractmethod 

2from dataclasses import dataclass 

3from functools import singledispatch 

4import re 

5 

6from networkx import DiGraph 

7 

8 

9@dataclass(repr=False, frozen=True) 

10class LinkNode(ABC): 

11 source: str 

12 content: str 

13 content_type: str 

14 

15 def __repr__(self): 

16 return self.__str__() 

17 

18 def __str__(self): 

19 return self.content 

20 

21 @staticmethod 

22 def from_dict(data: dict): 

23 args = (data["source"], data["content"], data["content_type"]) 

24 element_type = data["type"] 

25 if element_type == "identifier": 

26 return CodeVarNode(*args) 

27 elif element_type == "comment_span": 

28 return CommSpanNode(*args) 

29 elif element_type == "equation_span": 

30 return EqnSpanNode(*args) 

31 elif element_type == "text_var": 

32 query_string = ";".join(data["svo_query_terms"]) 

33 return TextVarNode(*args, query_string) 

34 elif element_type == "text_span": 

35 return TextSpanNode(*args) 

36 else: 

37 raise ValueError(f"Unrecognized link element type: {element_type}") 

38 

39 @abstractmethod 

40 def get_table_rows(self, link_graph: DiGraph) -> list: 

41 return NotImplemented 

42 

43 

44@dataclass(repr=False, frozen=True) 

45class CodeVarNode(LinkNode): 

46 def __repr__(self): 

47 return self.__str__() 

48 

49 def __str__(self): 

50 (namespace, scope, basename, index) = self.content.split("::") 

51 return "\n".join( 

52 [ 

53 f"NAMESPACE: {namespace}", 

54 f"SCOPE: {scope}", 

55 f"NAME: {basename}", 

56 f"INDEX: {index}", 

57 ] 

58 ) 

59 

60 def get_varname(self) -> str: 

61 (_, _, _, basename, _) = self.content.split("::") 

62 return basename 

63 

64 def get_table_rows(self, L: DiGraph) -> list: 

65 comm_span_nodes = [ 

66 n for n in L.predecessors(self) if isinstance(n, CommSpanNode) 

67 ] 

68 

69 rows = list() 

70 for comm_node in comm_span_nodes: 

71 w_vc = L.edges[comm_node, self]["weight"] 

72 for r in comm_node.get_table_rows(L): 

73 w_row = min(w_vc, r["ct_score"], r["te_score"]) 

74 r.update({"vc_score": w_vc, "link_score": w_row}) 

75 rows.append(r) 

76 

77 return rows 

78 

79 

80@dataclass(repr=False, frozen=True) 

81class TextVarNode(LinkNode): 

82 svo_query_str: str 

83 

84 def get_docname(self) -> str: 

85 path_pieces = self.source.split("/") 

86 doc_data = path_pieces[-1] 

87 (docname, _) = doc_data.split(".pdf_") 

88 return docname 

89 

90 def get_svo_terms(self): 

91 return self.svo_query_str.split(";") 

92 

93 def get_table_rows(self, L: DiGraph) -> list: 

94 # NOTE: nothing to do for now 

95 return [] 

96 

97 

98@dataclass(repr=False, frozen=True) 

99class CommSpanNode(LinkNode): 

100 def __repr__(self): 

101 return self.__str__() 

102 

103 def __str__(self): 

104 tokens = self.content.strip().split() 

105 if len(tokens) <= 4: 

106 return " ".join(tokens) 

107 

108 new_content = "" 

109 while len(tokens) > 4: 

110 new_content += "\n" + " ".join(tokens[:4]) 

111 tokens = tokens[4:] 

112 new_content += "\n" + " ".join(tokens) 

113 return new_content 

114 

115 def get_comment_location(self): 

116 (filename, sub_name, place) = self.source.split("; ") 

117 filename = filename[: filename.rfind(".f")] 

118 return f"{filename}::{sub_name}${place}" 

119 

120 def get_table_rows(self, L: DiGraph) -> list: 

121 txt_span_nodes = [ 

122 n for n in L.predecessors(self) if isinstance(n, TextSpanNode) 

123 ] 

124 

125 rows = list() 

126 for txt_node in txt_span_nodes: 

127 w_ct = L.edges[txt_node, self]["weight"] 

128 for r in txt_node.get_table_rows(L): 

129 r.update({"comm": str(self), "ct_score": w_ct}) 

130 rows.append(r) 

131 

132 return rows 

133 

134 

135@dataclass(repr=False, frozen=True) 

136class TextSpanNode(LinkNode): 

137 def __repr__(self): 

138 return self.__str__() 

139 

140 def __str__(self): 

141 tokens = self.content.strip().split() 

142 if len(tokens) <= 4: 

143 return " ".join(tokens) 

144 

145 new_content = "" 

146 while len(tokens) > 4: 

147 new_content += "\n" + " ".join(tokens[:4]) 

148 tokens = tokens[4:] 

149 new_content += "\n" + " ".join(tokens) 

150 return new_content 

151 

152 def __data_from_source(self) -> tuple: 

153 path_pieces = self.source.split("/") 

154 doc_data = path_pieces[-1] 

155 return tuple(doc_data.split(".pdf_")) 

156 

157 def get_docname(self) -> str: 

158 (docname, _) = self.__data_from_source() 

159 return docname 

160 

161 def get_sentence_id(self) -> str: 

162 (_, data) = self.__data_from_source() 

163 (sent_num, span_start, span_stop) = re.findall(r"[0-9]+", data) 

164 return 

165 

166 def get_table_rows(self, L: DiGraph) -> list: 

167 eqn_span_nodes = [ 

168 n for n in L.predecessors(self) if isinstance(n, EqnSpanNode) 

169 ] 

170 

171 rows = list() 

172 for eqn_node in eqn_span_nodes: 

173 w_te = L.edges[eqn_node, self]["weight"] 

174 for r in eqn_node.get_table_rows(L): 

175 r.update({"txt": str(self), "te_score": w_te}) 

176 rows.append(r) 

177 

178 return rows 

179 

180 

181@dataclass(repr=False, frozen=True) 

182class EqnSpanNode(LinkNode): 

183 def get_table_rows(self, L: DiGraph) -> list: 

184 return [{"eqn": str(self)}] 

185 

186 

187def build_link_graph(link_hypotheses: list) -> DiGraph: 

188 G = DiGraph() 

189 

190 def report_bad_link(n1, n2): 

191 raise ValueError(f"Inappropriate link type: ({type(n1)}, {type(n2)})") 

192 

193 @singledispatch 

194 def add_link_node(node): 

195 raise ValueError(f"Inappropriate node type: {type(node)}") 

196 

197 @add_link_node.register 

198 def _(node: CodeVarNode): 

199 G.add_node(node, color="darkviolet") 

200 

201 @add_link_node.register 

202 def _(node: CommSpanNode): 

203 G.add_node(node, color="lightskyblue") 

204 

205 @add_link_node.register 

206 def _(node: TextSpanNode): 

207 G.add_node(node, color="crimson") 

208 

209 @add_link_node.register 

210 def _(node: EqnSpanNode): 

211 G.add_node(node, color="orange") 

212 

213 @add_link_node.register 

214 def _(node: TextVarNode): 

215 G.add_node(node, color="deeppink") 

216 

217 @singledispatch 

218 def add_link(n1, n2, score): 

219 raise ValueError(f"Inappropriate node type: {type(n1)}") 

220 

221 @add_link.register 

222 def _(n1: CodeVarNode, n2, score): 

223 add_link_node(n1) 

224 add_link_node(n2) 

225 

226 if isinstance(n2, CommSpanNode): 

227 G.add_edge(n2, n1, weight=score) 

228 else: 

229 report_bad_link(n1, n2) 

230 

231 @add_link.register 

232 def _(n1: CommSpanNode, n2, score): 

233 add_link_node(n1) 

234 add_link_node(n2) 

235 

236 if isinstance(n2, CodeVarNode): 

237 G.add_edge(n1, n2, weight=score) 

238 elif isinstance(n2, TextSpanNode): 

239 G.add_edge(n2, n1, weight=score) 

240 else: 

241 report_bad_link(n1, n2) 

242 

243 @add_link.register 

244 def _(n1: TextSpanNode, n2, score): 

245 add_link_node(n1) 

246 add_link_node(n2) 

247 

248 if isinstance(n2, EqnSpanNode): 

249 G.add_edge(n2, n1, weight=link_score) 

250 elif isinstance(n2, CommSpanNode): 

251 G.add_edge(n1, n2, weight=link_score) 

252 elif isinstance(n2, TextVarNode): 

253 G.add_edge(n2, n1, weight=link_score) 

254 else: 

255 report_bad_link(n1, n2) 

256 

257 @add_link.register 

258 def _(n1: EqnSpanNode, n2, score): 

259 add_link_node(n1) 

260 add_link_node(n2) 

261 

262 if isinstance(n2, TextSpanNode): 

263 G.add_edge(n1, n2, weight=link_score) 

264 else: 

265 report_bad_link(n1, n2) 

266 

267 @add_link.register 

268 def _(n1: TextVarNode, n2, score): 

269 add_link_node(n1) 

270 add_link_node(n2) 

271 

272 if isinstance(n2, TextSpanNode): 

273 G.add_edge(n1, n2, weight=link_score) 

274 else: 

275 report_bad_link(n1, n2) 

276 

277 for link_dict in link_hypotheses: 

278 node1 = LinkNode.from_dict(link_dict["element_1"]) 

279 node2 = LinkNode.from_dict(link_dict["element_2"]) 

280 link_score = round(link_dict["score"], 3) 

281 add_link(node1, node2, link_score) 

282 

283 return G 

284 

285 

286def extract_link_tables(L: DiGraph) -> dict: 

287 var_nodes = [n for n in L.nodes if isinstance(n, CodeVarNode)] 

288 

289 tables = dict() 

290 for var_node in var_nodes: 

291 var_name = str(var_node) 

292 if var_name not in tables: 

293 table_rows = var_node.get_table_rows(L) 

294 table_rows.sort( 

295 key=lambda r: (r["vc_score"], r["ct_score"], r["te_score"]), 

296 reverse=True, 

297 ) 

298 tables[var_name] = table_rows 

299 

300 return tables 

301 

302 

303def print_table_data(table_data: dict) -> None: 

304 for var_name, table in table_data.items(): 

305 print(var_name) 

306 print("L-SCORE\tComment\tV-C\tText-span\tC-T\tEquation\tT-E") 

307 for row in table: 

308 row_data = [ 

309 str(row["link_score"]), 

310 row["comm"], 

311 str(row["vc_score"]), 

312 row["txt"], 

313 str(row["ct_score"]), 

314 row["eqn"], 

315 str(row["te_score"]), 

316 ] 

317 print("\t".join(row_data)) 

318 print("\n\n")