Coverage for skema/isa/lib.py: 48%

436 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1# -*- coding: utf-8 -* 

2""" 

3All the functions required by performing incremental structure alignment (ISA) 

4Author: Liang Zhang (liangzh@arizona.edu) 

5Updated date: December 18, 2023 

6""" 

7import json 

8import warnings 

9from typing import List, Any, Union, Dict, Tuple 

10from numpy import ndarray 

11from pydot import Dot 

12from skema.rest.proxies import SKEMA_RS_ADDESS 

13 

14warnings.filterwarnings("ignore") 

15import requests 

16import pydot 

17import numpy as np 

18from graspologic.match import graph_match 

19from graphviz import Source 

20import graphviz 

21from copy import deepcopy 

22import Levenshtein 

23from typing import Tuple 

24import re 

25import xml.etree.ElementTree as ET 

26import html 

27from sentence_transformers import SentenceTransformer, util 

28import json 

29import ast 

30 

31 

32# Set up the random seed 

33np.random.seed(4) 

34rng = np.random.default_rng(4) 

35 

36# The encodings of basic operators when converting adjacency matrix 

37op_dict = {"+": 1, "-": 2, "*": 3, "/": 4, "=": 5, "√": 6} 

38 

39# Greek letters mapping 

40# List of Greek letters mapping to their lowercase, name, and Unicode representation 

41greek_letters: List[List[str]] = [ 

42 ["α", "alpha", "α"], 

43 ["β", "beta", "β"], 

44 ["γ", "gamma", "γ"], 

45 ["δ", "delta", "δ"], 

46 ["ε", "epsilon", "ε"], 

47 ["ζ", "zeta", "ζ"], 

48 ["η", "eta", "η"], 

49 ["θ", "theta", "θ"], 

50 ["ι", "iota", "ι"], 

51 ["κ", "kappa", "κ"], 

52 ["λ", "lambda", "λ"], 

53 ["μ", "mu", "μ"], 

54 ["ν", "nu", "ν"], 

55 ["ξ", "xi", "ξ"], 

56 ["ο", "omicron", "ο"], 

57 ["π", "pi", "π"], 

58 ["ρ", "rho", "ρ"], 

59 ["σ", "sigma", "σ"], 

60 ["τ", "tau", "τ"], 

61 ["υ", "upsilon", "υ"], 

62 ["φ", "phi", "φ"], 

63 ["χ", "chi", "χ"], 

64 ["ψ", "psi", "ψ"], 

65 ["ω", "omega", "ω"], 

66 ["Α", "Alpha", "Α"], 

67 ["Β", "Beta", "Β"], 

68 ["Γ", "Gamma", "Γ"], 

69 ["Δ", "Delta", "Δ"], 

70 ["Ε", "Epsilon", "Ε"], 

71 ["Ζ", "Zeta", "Ζ"], 

72 ["Η", "Eta", "Η"], 

73 ["Θ", "Theta", "Θ"], 

74 ["Ι", "Iota", "Ι"], 

75 ["Κ", "Kappa", "Κ"], 

76 ["Λ", "Lambda", "Λ"], 

77 ["Μ", "Mu", "Μ"], 

78 ["Ν", "Nu", "Ν"], 

79 ["Ξ", "Xi", "Ξ"], 

80 ["Ο", "Omicron", "Ο"], 

81 ["Π", "Pi", "Π"], 

82 ["Ρ", "Rho", "Ρ"], 

83 ["Σ", "Sigma", "Σ"], 

84 ["Τ", "Tau", "Τ"], 

85 ["Υ", "Upsilon", "Υ"], 

86 ["Φ", "Phi", "Φ"], 

87 ["Χ", "Chi", "Χ"], 

88 ["Ψ", "Psi", "Ψ"], 

89 ["Ω", "Omega", "Ω"], 

90] 

91 

92mathml_operators = [ 

93 "sin", 

94 "cos", 

95 "tan", 

96 "sec", 

97 "csc", 

98 "cot", 

99 "log", 

100 "ln", 

101 "exp", 

102 "sqrt", 

103 "sum", 

104 "prod", 

105 "lim", 

106] 

107 

108 

109def levenshtein_similarity(var1: str, var2: str) -> float: 

110 """ 

111 Compute the Levenshtein similarity between two variable names. 

112 The Levenshtein similarity is the ratio of the Levenshtein distance to the maximum length. 

113 Args: 

114 var1: The first variable name. 

115 var2: The second variable name. 

116 Returns: 

117 The Levenshtein similarity between the two variable names. 

118 """ 

119 distance = Levenshtein.distance(var1, var2) 

120 max_length = max(len(var1), len(var2)) 

121 similarity = 1 - (distance / max_length) 

122 return similarity 

123 

124 

125def jaccard_similarity(var1: str, var2: str) -> float: 

126 """ 

127 Compute the Jaccard similarity between two variable names. 

128 The Jaccard similarity is the size of the intersection divided by the size of the union of the variable names. 

129 Args: 

130 var1: The first variable name. 

131 var2: The second variable name. 

132 Returns: 

133 The Jaccard similarity between the two variable names. 

134 """ 

135 set1 = set(var1) 

136 set2 = set(var2) 

137 intersection = len(set1.intersection(set2)) 

138 union = len(set1.union(set2)) 

139 similarity = intersection / union 

140 return similarity 

141 

142 

143def cosine_similarity(var1: str, var2: str) -> float: 

144 """ 

145 Compute the cosine similarity between two variable names. 

146 The cosine similarity is the dot product of the character frequency vectors divided by the product of their norms. 

147 Args: 

148 var1: The first variable name. 

149 var2: The second variable name. 

150 Returns: 

151 The cosine similarity between the two variable names. 

152 """ 

153 char_freq1 = {char: var1.count(char) for char in var1} 

154 char_freq2 = {char: var2.count(char) for char in var2} 

155 

156 dot_product = sum( 

157 char_freq1.get(char, 0) * char_freq2.get(char, 0) for char in set(var1 + var2) 

158 ) 

159 norm1 = sum(freq**2 for freq in char_freq1.values()) ** 0.5 

160 norm2 = sum(freq**2 for freq in char_freq2.values()) ** 0.5 

161 

162 similarity = dot_product / (norm1 * norm2) 

163 return similarity 

164 

165 

166def generate_graph(file: str = "", render: bool = False) -> pydot.Dot: 

167 """ 

168 Call the REST API of math-exp-graph to convert the MathML input to its GraphViz representation 

169 Ensure running the REST API before calling this function 

170 Input: file directory or the MathML string 

171 Output: the GraphViz representation (pydot.Dot) 

172 """ 

173 if "<math>" in file and "</math>" in file: 

174 content = file 

175 else: 

176 with open(file) as f: 

177 content = f.read() 

178 

179 #SKEMA_RS_ADDESS = "http://localhost:8080" 

180 digraph = requests.put( 

181 f"{SKEMA_RS_ADDESS}/mathml/math-exp-graph", data=content.encode("utf-8") 

182 ) 

183 

184 if render: 

185 src = Source(digraph.text) 

186 src.render("doctest-output/mathml_exp_tree", view=True) 

187 graph = pydot.graph_from_dot_data(str(digraph.text))[0] 

188 return graph 

189 

190 

191def generate_code_graphs(graph_string: str = "") -> Dict[str, pydot.Dot]: 

192 """ 

193 Call the REST API of code-exp-graphs to convert the code input to its GraphViz representation 

194 Ensure running the REST API before calling this function 

195 Input: file directory 

196 Output: a dictionary of the GraphViz representation (pydot.Dot) 

197 """ 

198 

199 # Safely evaluate the string as a literal Python expression 

200 code_exp_graphs_dict = ast.literal_eval(graph_string) 

201 

202 # Convert the string representations to pydot.Dot objects 

203 for key, value in code_exp_graphs_dict.items(): 

204 code_exp_graphs_dict[key] = pydot.graph_from_dot_data(value)[0] 

205 

206 return code_exp_graphs_dict 

207 

208def generate_amatrix(graph: pydot.Dot) -> Tuple[ndarray, List[str]]: 

209 """ 

210 Convert the GraphViz representation to its corresponding adjacency matrix 

211 Input: the GraphViz representation 

212 Output: the adjacency matrix and the list of the names of variables and terms appeared in the expression 

213 """ 

214 node_labels = [] 

215 for node in graph.get_nodes(): 

216 node_labels.append(node.obj_dict["attributes"]["label"].replace('"', "")) 

217 

218 amatrix = np.zeros((len(node_labels), len(node_labels))) 

219 

220 for edge in graph.get_edges(): 

221 x, y = edge.obj_dict["points"] 

222 label = edge.obj_dict["attributes"]["label"].replace('"', "") 

223 amatrix[int(x)][int(y)] = op_dict[label] if label in op_dict else 7 

224 

225 return amatrix, node_labels 

226 

227 

228def heuristic_compare_variable_names(var1: str, var2: str) -> bool: 

229 """ 

230 Compare two variable names in a formula, accounting for Unicode representations. 

231 Convert the variable names to English letter representations before comparison. 

232 

233 Args: 

234 var1 (str): The first variable name. 

235 var2 (str): The second variable name. 

236 

237 Returns: 

238 bool: True if the variable names are the same, False otherwise. 

239 """ 

240 # Mapping of Greek letters to English letter representations 

241 greek_letters = { 

242 "α": "alpha", 

243 "β": "beta", 

244 "γ": "gamma", 

245 "δ": "delta", 

246 "ε": "epsilon", 

247 "ζ": "zeta", 

248 "η": "eta", 

249 "θ": "theta", 

250 "ι": "iota", 

251 "κ": "kappa", 

252 "λ": "lambda", 

253 "μ": "mu", 

254 "ν": "nu", 

255 "ξ": "xi", 

256 "ο": "omicron", 

257 "π": "pi", 

258 "ρ": "rho", 

259 "σ": "sigma", 

260 "τ": "tau", 

261 "υ": "upsilon", 

262 "φ": "phi", 

263 "χ": "chi", 

264 "ψ": "psi", 

265 "ω": "omega", 

266 } 

267 

268 # Convert Unicode representations to English letter representations 

269 var1 = re.sub(r"&#x(\w+);?", lambda m: chr(int(m.group(1), 16)), var1).lower() 

270 var2 = re.sub(r"&#x(\w+);?", lambda m: chr(int(m.group(1), 16)), var2).lower() 

271 

272 # Convert Greek letter representations to English letter representations 

273 for greek_letter, english_letter in greek_letters.items(): 

274 var1 = var1.replace(greek_letter, english_letter) 

275 var2 = var2.replace(greek_letter, english_letter) 

276 

277 # Remove trailing quotation marks, if present 

278 var1 = var1.strip("'\"") 

279 var2 = var2.strip("'\"") 

280 

281 # Compare the variable names 

282 return var1.lower() == var2.lower() 

283 

284 

285def extract_var_information( 

286 data: Dict[ 

287 str, 

288 Union[ 

289 List[ 

290 Dict[ 

291 str, 

292 Union[Dict[str, Union[str, int]], List[Dict[str, Union[str, int]]]], 

293 ] 

294 ], 

295 None, 

296 ], 

297 ] 

298) -> List[Dict[str, str]]: 

299 """ 

300 Extracts variable information from the given JSON data of the SKEMA mention extraction. 

301 

302 Parameters: 

303 - data (Dict[str, Union[List[Dict[str, Union[Dict[str, Union[str, int]], List[Dict[str, Union[str, int]]]]]], None]]): The input JSON data. 

304 

305 Returns: 

306 - List[Dict[str, str]]: A list of dictionaries containing extracted information (name, definition, and document_id). 

307 """ 

308 outputs = data.get("outputs", []) 

309 extracted_data = [] 

310 

311 for output in outputs: 

312 attributes = output.get("data", {}).get("attributes", []) 

313 

314 for attribute in attributes: 

315 payload = attribute.get("payload", {}) 

316 mentions = payload.get("mentions", []) 

317 text_descriptions = payload.get("text_descriptions", []) 

318 

319 for mention in mentions: 

320 name = mention.get("name", "") 

321 extraction_source = mention.get("extraction_source", {}) 

322 document_reference = extraction_source.get("document_reference", {}) 

323 document_id = document_reference.get("id", "") 

324 

325 for text_description in text_descriptions: 

326 description = text_description.get("description", "") 

327 extraction_source = text_description.get("extraction_source", {}) 

328 

329 extracted_data.append( 

330 { 

331 "name": name, 

332 "definition": description, 

333 "document_id": document_id, 

334 } 

335 ) 

336 

337 return extracted_data 

338 

339 

340def organize_into_json( 

341 extracted_data: List[Dict[str, str]] 

342) -> Dict[str, List[Dict[str, str]]]: 

343 """ 

344 Organizes the extracted information into a new JSON format. 

345 

346 Parameters: 

347 - extracted_data (List[Dict[str, str]]): A list of dictionaries containing extracted information (name, definition, and document_id). 

348 

349 Returns: 

350 - Dict[str, List[Dict[str, str]]]: A dictionary containing organized information in a new JSON format. 

351 """ 

352 organized_data = {"variables": []} 

353 

354 for item in extracted_data: 

355 organized_data["variables"].append( 

356 { 

357 "name": item["name"], 

358 "definition": item["definition"], 

359 "document_id": item["document_id"], 

360 } 

361 ) 

362 

363 return organized_data 

364 

365 

366def extract_var_defs_from_metions(input_file: str) -> Dict[str, List[Dict[str, str]]]: 

367 """ 

368 Processes the input JSON file, extracts information, and writes the organized data to the output JSON file. 

369 

370 Parameters: 

371 - input_file (str): The path to the input JSON file. 

372 

373 Returns: 

374 - organized_data (Dict[str, List[Dict[str, str]]]): The dictionary of the variable names and their definitions. 

375 """ 

376 try: 

377 # Read the original JSON data 

378 with open(input_file, "r", encoding="utf-8") as file: 

379 original_data = json.load(file) 

380 except FileNotFoundError: 

381 print(f"Error: File '{input_file}' not found.") 

382 return {} 

383 except json.JSONDecodeError as e: 

384 print(f"Error: JSON decoding failed. Details: {e}") 

385 return {} 

386 

387 # Extract information 

388 extracted_data = extract_var_information(original_data) 

389 

390 # Organize into a new JSON format 

391 organized_data = organize_into_json(extracted_data) 

392 

393 return organized_data 

394 

395 

396def find_definition( 

397 variable_name: str, extracted_data: Dict[str, List[Dict[str, str]]] 

398) -> str: 

399 """ 

400 Finds the definition for a variable name in the extracted data. 

401 

402 Args: 

403 variable_name (str): Variable name to find. 

404 extracted_data (List[Dict[str, Union[str, int]]]): List of dictionaries containing extracted information. 

405 

406 Returns: 

407 str: Definition for the variable name, or an empty string if not found. 

408 """ 

409 for attribute in extracted_data["variables"]: 

410 if heuristic_compare_variable_names(variable_name, attribute["name"]): 

411 return attribute["definition"] 

412 

413 return "" 

414 

415 

416def calculate_similarity( 

417 definition1: str, definition2: str, field: str = "biomedical" 

418) -> float: 

419 """ 

420 Calculates semantic similarity between two variable definitions using BERT embeddings. 

421 

422 Args: 

423 definition1 (str): First variable definition. 

424 definition2 (str): Second variable definition. 

425 field (str): Language model to load. 

426 

427 Returns: 

428 float: Semantic similarity score between 0 and 1. 

429 """ 

430 pre_trained_model = "msmarco-distilbert-base-v2" 

431 model = SentenceTransformer(pre_trained_model) 

432 

433 # Convert definitions to BERT embeddings 

434 embedding1 = model.encode(definition1, convert_to_tensor=True) 

435 embedding2 = model.encode(definition2, convert_to_tensor=True) 

436 

437 # Calculate cosine similarity between embeddings 

438 cosine_similarity = util.pytorch_cos_sim(embedding1, embedding2)[0][0].item() 

439 

440 return cosine_similarity 

441 

442 

443def match_variable_definitions( 

444 list1: List[str], 

445 list2: List[str], 

446 json_path1: str, 

447 json_path2: str, 

448 threshold: float, 

449) -> Tuple[List[int], List[int]]: 

450 """ 

451 Match variable definitions for given variable names in two lists. 

452 

453 Args: 

454 list1 (List[str]): List of variable names from the first equation. 

455 list2 (List[str]): List of variable names from the second equation. 

456 json_path1 (str): Path to the JSON file containing variable definitions for the first article. 

457 json_path2 (str): Path to the JSON file containing variable definitions for the second article. 

458 threshold (float): Similarity threshold for considering a match. 

459 

460 Returns: 

461 Tuple[List[int], List[int]]: Lists of indices for matched variable names in list1 and list2. 

462 """ 

463 extracted_data1 = extract_var_defs_from_metions(json_path1) 

464 extracted_data2 = extract_var_defs_from_metions(json_path2) 

465 

466 var_idx_list1 = [] 

467 var_idx_list2 = [] 

468 

469 for idx1, var1 in enumerate(list1): 

470 max_similarity = 0.0 

471 matching_idx = -1 

472 for idx2, var2 in enumerate(list2): 

473 def1 = find_definition(var1, extracted_data1) 

474 def2 = find_definition(var2, extracted_data2) 

475 

476 if def1 and def2: 

477 similarity = calculate_similarity(def1, def2) 

478 if similarity > max_similarity and similarity >= threshold: 

479 max_similarity = similarity 

480 matching_idx = idx2 

481 

482 if matching_idx != -1: 

483 if idx1 not in var_idx_list1: 

484 var_idx_list1.append(idx1) 

485 var_idx_list2.append(matching_idx) 

486 

487 return var_idx_list1, var_idx_list2 

488 

489 

490def get_seeds( 

491 node_labels1: List[str], 

492 node_labels2: List[str], 

493 method: str = "heuristic", 

494 threshold: float = 0.8, 

495 mention_json1: str = "", 

496 mention_json2: str = "", 

497) -> Tuple[List[int], List[int]]: 

498 """ 

499 Calculate the seeds in the two equations. 

500 

501 Args: 

502 node_labels1: The name lists of the variables and terms in equation 1. 

503 node_labels2: The name lists of the variables and terms in equation 2. 

504 method: The method to get seeds. 

505 - "heuristic": Based on variable name identification. 

506 - "levenshtein": Based on Levenshtein similarity of variable names. 

507 - "jaccard": Based on Jaccard similarity of variable names. 

508 - "cosine": Based on cosine similarity of variable names. 

509 threshold: The threshold to use for Levenshtein, Jaccard, and cosine methods. 

510 mention_json1: The JSON file path of the mention extraction of paper 1. 

511 mention_json2: The JSON file path of the mention extraction of paper 2. 

512 

513 Returns: 

514 A tuple of two lists: 

515 - seed1: The seed indices from equation 1. 

516 - seed2: The seed indices from equation 2. 

517 """ 

518 seed1 = [] 

519 seed2 = [] 

520 if method == "var_defs": 

521 seed1, seed2 = match_variable_definitions( 

522 node_labels1, 

523 node_labels2, 

524 json_path1=mention_json1, 

525 json_path2=mention_json2, 

526 threshold=0.9, 

527 ) 

528 else: 

529 for i in range(0, len(node_labels1)): 

530 for j in range(0, len(node_labels2)): 

531 if method == "heuristic": 

532 if heuristic_compare_variable_names( 

533 node_labels1[i], node_labels2[j] 

534 ): 

535 if i not in seed1: 

536 seed1.append(i) 

537 seed2.append(j) 

538 elif method == "levenshtein": 

539 if ( 

540 levenshtein_similarity(node_labels1[i], node_labels2[j]) 

541 > threshold 

542 ): 

543 if i not in seed1: 

544 seed1.append(i) 

545 seed2.append(j) 

546 elif method == "jaccard": 

547 if jaccard_similarity(node_labels1[i], node_labels2[j]) > threshold: 

548 if i not in seed1: 

549 seed1.append(i) 

550 seed2.append(j) 

551 elif method == "cosine": 

552 if cosine_similarity(node_labels1[i], node_labels2[j]) > threshold: 

553 if i not in seed1: 

554 seed1.append(i) 

555 seed2.append(j) 

556 

557 return seed1, seed2 

558 

559 

560def has_edge(dot: pydot.Dot, src: str, dst: str) -> bool: 

561 """ 

562 Check if an edge exists between two nodes in a PyDot graph object. 

563 

564 Args: 

565 dot (pydot.Dot): PyDot graph object. 

566 src (str): Source node ID. 

567 dst (str): Destination node ID. 

568 

569 Returns: 

570 bool: True if an edge exists between src and dst, False otherwise. 

571 """ 

572 edges = dot.get_edges() 

573 for edge in edges: 

574 if edge.get_source() == src and edge.get_destination() == dst: 

575 return True 

576 return False 

577 

578 

579def get_union_graph( 

580 graph1: pydot.Dot, 

581 graph2: pydot.Dot, 

582 aligned_idx1: List[int], 

583 aligned_idx2: List[int], 

584) -> pydot.Dot: 

585 """ 

586 return the union graph for visualizing the alignment results 

587 input: The dot representation of Graph1, the dot representation of Graph2, the aligned node indices in Graph1, the aligned node indices in Graph2 

588 output: dot graph 

589 """ 

590 g2idx2g1idx = {str(x): str(-1) for x in range(len(graph2.get_nodes()))} 

591 union_graph = deepcopy(graph1) 

592 """ 

593 set the aligned variables or terms as a blue circle;  

594 if their names are the same, show one name;  

595 if not, show two names' connection using '<<|>>' 

596 """ 

597 for i in range(len(aligned_idx1)): 

598 if ( 

599 union_graph.get_nodes()[aligned_idx1[i]] 

600 .obj_dict["attributes"]["label"] 

601 .replace('"', "") 

602 .lower() 

603 != graph2.get_nodes()[aligned_idx2[i]] 

604 .obj_dict["attributes"]["label"] 

605 .replace('"', "") 

606 .lower() 

607 ): 

608 union_graph.get_nodes()[aligned_idx1[i]].obj_dict["attributes"]["label"] = ( 

609 union_graph.get_nodes()[aligned_idx1[i]] 

610 .obj_dict["attributes"]["label"] 

611 .replace('"', "") 

612 + " <<|>> " 

613 + graph2.get_nodes()[aligned_idx2[i]] 

614 .obj_dict["attributes"]["label"] 

615 .replace('"', "") 

616 ) 

617 

618 union_graph.get_nodes()[aligned_idx1[i]].obj_dict["attributes"][ 

619 "color" 

620 ] = "blue" 

621 g2idx2g1idx[str(aligned_idx2[i])] = str(aligned_idx1[i]) 

622 

623 # represent the nodes only in graph 1 as a red circle 

624 for i in range(len(union_graph.get_nodes())): 

625 if i not in aligned_idx1: 

626 union_graph.get_nodes()[i].obj_dict["attributes"]["color"] = "red" 

627 

628 # represent the nodes only in graph 2 as a green circle 

629 for i in range(len(graph2.get_nodes())): 

630 if i not in aligned_idx2: 

631 graph2.get_nodes()[i].obj_dict["attributes"]["color"] = "green" 

632 graph2.get_nodes()[i].obj_dict["name"] = str(len(union_graph.get_nodes())) 

633 union_graph.add_node(graph2.get_nodes()[i]) 

634 g2idx2g1idx[str(i)] = str(len(union_graph.get_nodes()) - 1) 

635 

636 # add the edges of graph 2 to graph 1 

637 for edge in union_graph.get_edges(): 

638 edge.obj_dict["attributes"]["color"] = "red" 

639 

640 for edge in graph2.get_edges(): 

641 x, y = edge.obj_dict["points"] 

642 if has_edge(union_graph, g2idx2g1idx[x], g2idx2g1idx[y]): 

643 if ( 

644 union_graph.get_edge(g2idx2g1idx[x], g2idx2g1idx[y])[0] 

645 .obj_dict["attributes"]["label"] 

646 .lower() 

647 == edge.obj_dict["attributes"]["label"].lower() 

648 ): 

649 union_graph.get_edge(g2idx2g1idx[x], g2idx2g1idx[y])[0].obj_dict[ 

650 "attributes" 

651 ]["color"] = "blue" 

652 else: 

653 e = pydot.Edge( 

654 g2idx2g1idx[x], 

655 g2idx2g1idx[y], 

656 label=edge.obj_dict["attributes"]["label"], 

657 color="green", 

658 ) 

659 union_graph.add_edge(e) 

660 else: 

661 e = pydot.Edge( 

662 g2idx2g1idx[x], 

663 g2idx2g1idx[y], 

664 label=edge.obj_dict["attributes"]["label"], 

665 color="green", 

666 ) 

667 union_graph.add_edge(e) 

668 

669 return union_graph 

670 

671 

672def check_square_array(arr: np.ndarray) -> List[int]: 

673 """ 

674 Given a square numpy array, returns a list of size equal to the length of the array, 

675 where each element of the list is either 0 or 1, depending on whether the corresponding 

676 row and column of the input array are all 0s or not. 

677 

678 Parameters: 

679 arr (np.ndarray): a square numpy array 

680 

681 Returns: 

682 List[int]: a list of 0s and 1s 

683 """ 

684 

685 n = arr.shape[0] # get the size of the array 

686 result = [] 

687 for i in range(n): 

688 # Check if the ith row and ith column are all 0s 

689 if np.all(arr[i, :] == 0) and np.all(arr[:, i] == 0): 

690 result.append(0) # if so, append 0 to the result list 

691 else: 

692 result.append(1) # otherwise, append 1 to the result list 

693 return result 

694 

695 

696def align_mathml_eqs( 

697 mml1: str = "", 

698 mml2: str = "", 

699 mention_json1: str = "", 

700 mention_json2: str = "", 

701 mode: int = 2, 

702) -> Tuple[ 

703 Any, Any, List[str], List[str], Union[int, Any], Union[int, Any], Dot, List[int] 

704]: 

705 """ 

706 align two equation graphs using the seeded graph matching (SGD) algorithm [1]. 

707 

708 [1] Fishkind, D. E., Adali, S., Patsolic, H. G., Meng, L., Singh, D., Lyzinski, V., & Priebe, C. E. (2019). 

709 Seeded graph matching. Pattern recognition, 87, 203-215. 

710 

711 Input: mml1 & mml2: the file path or contents of the two equation MathMLs; mention_json1: the mention file of paper 1; mention_json1: the mention file of paper 2; 

712 mode 0: without considering any priors; mode 1: having a heuristic prior 

713 with the similarity of node labels; mode 2: using the variable definitions 

714 Output: 

715 matching_ratio: the matching ratio between the equations 1 and the equation 2 

716 num_diff_edges: the number of different edges between the equations 1 and the equation 2 

717 node_labels1: the name list of the variables and terms in the equation 1 

718 node_labels2: the name list of the variables and terms in the equation 2 

719 aligned_indices1: the aligned indices in the name list of the equation 1 

720 aligned_indices2: the aligned indices in the name list of the equation 2 

721 union_graph: the visualization of the alignment result 

722 perfectly_matched_indices1: strictly matched node indices in Graph 1 

723 """ 

724 graph1 = generate_graph(mml1) 

725 graph2 = generate_graph(mml2) 

726 

727 amatrix1, node_labels1 = generate_amatrix(graph1) 

728 amatrix2, node_labels2 = generate_amatrix(graph2) 

729 

730 # If there are no mention files provided, it returns to mode 1 

731 if (mention_json1 == "" or mention_json2 == "") and mode == 2: 

732 mode = 1 

733 

734 if mode == 0: 

735 seed1 = [] 

736 seed2 = [] 

737 elif mode == 1: 

738 seed1, seed2 = get_seeds(node_labels1, node_labels2) 

739 else: 

740 seed1, seed2 = get_seeds( 

741 node_labels1, 

742 node_labels2, 

743 method="var_defs", 

744 threshold=0.9, 

745 mention_json1=mention_json1, 

746 mention_json2=mention_json2, 

747 ) 

748 

749 partial_match = np.column_stack((seed1, seed2)) 

750 

751 matched_indices1, matched_indices2, _, _ = graph_match( 

752 amatrix1, 

753 amatrix2, 

754 partial_match=partial_match, 

755 padding="adopted", 

756 rng=rng, 

757 max_iter=50, 

758 ) 

759 

760 big_graph_idx = 0 if len(node_labels1) >= len(node_labels2) else 1 

761 if big_graph_idx == 0: 

762 big_graph = amatrix1 

763 big_graph_matched_indices = matched_indices1 

764 small_graph = amatrix2 

765 small_graph_matched_indices = matched_indices2 

766 else: 

767 big_graph = amatrix2 

768 big_graph_matched_indices = matched_indices2 

769 small_graph = amatrix1 

770 small_graph_matched_indices = matched_indices1 

771 

772 small_graph_aligned = small_graph[small_graph_matched_indices][ 

773 :, small_graph_matched_indices 

774 ] 

775 small_graph_aligned_full = np.zeros(big_graph.shape) 

776 small_graph_aligned_full[ 

777 np.ix_(big_graph_matched_indices, big_graph_matched_indices) 

778 ] = small_graph_aligned 

779 

780 num_edges = ((big_graph + small_graph_aligned_full) > 0).sum() 

781 diff_edges = abs(big_graph - small_graph_aligned_full) 

782 diff_edges[diff_edges > 0] = 1 

783 perfectly_matched_indices1 = check_square_array( 

784 diff_edges 

785 ) # strictly aligned node indices of Graph 1 

786 num_diff_edges = np.sum(diff_edges) 

787 matching_ratio = round(1 - (num_diff_edges / num_edges), 2) 

788 

789 long_len = ( 

790 len(node_labels1) 

791 if len(node_labels1) >= len(node_labels2) 

792 else len(node_labels2) 

793 ) 

794 aligned_indices1 = np.zeros((long_len)) - 1 

795 aligned_indices2 = np.zeros((long_len)) - 1 

796 for i in range(long_len): 

797 if i < len(node_labels1): 

798 if i in matched_indices1: 

799 aligned_indices1[i] = matched_indices2[ 

800 np.where(matched_indices1 == i)[0][0] 

801 ] 

802 aligned_indices2[ 

803 matched_indices2[np.where(matched_indices1 == i)[0][0]] 

804 ] = i 

805 

806 # The visualization of the alignment result. 

807 union_graph = get_union_graph( 

808 graph1, 

809 graph2, 

810 [int(i) for i in matched_indices1.tolist()], 

811 [int(i) for i in matched_indices2.tolist()], 

812 ) 

813 

814 return ( 

815 matching_ratio, 

816 num_diff_edges, 

817 node_labels1, 

818 node_labels2, 

819 aligned_indices1, 

820 aligned_indices2, 

821 union_graph, 

822 perfectly_matched_indices1, 

823 ) 

824 

825 

826def align_eqn_code( 

827 eqn_mml: str = "", 

828 code: str = "", 

829 mode: int = 1, 

830) -> Dict[ 

831 Any, 

832 Tuple[Any, Any, List[str], List[str], Union[int, Any], Union[int, Any], str, List[int]], 

833]: 

834 """ 

835 align the mathml equation graph and the code graphs using the seeded graph matching (SGD) algorithm [1]. 

836 

837 [1] Fishkind, D. E., Adali, S., Patsolic, H. G., Meng, L., Singh, D., Lyzinski, V., & Priebe, C. E. (2019). 

838 Seeded graph matching. Pattern recognition, 87, 203-215. 

839 

840 Input: eqn_mml: the equation mathml 

841 code: the file path of the code input; 

842 mode 0: without considering any priors; mode 1: having a heuristic prior 

843 with the similarity of node labels; 

844 Output: 

845 The dictionary of the following data structure 

846 matching_ratio: the matching ratio between the equations 1 and the equation 2 

847 num_diff_edges: the number of different edges between the equations 1 and the equation 2 

848 node_labels1: the name list of the variables and terms in the equation 1 

849 node_labels2: the name list of the variables and terms in the equation 2 

850 aligned_indices1: the aligned indices in the name list of the equation 1 

851 aligned_indices2: the aligned indices in the name list of the equation 2 

852 union_graph: the visualization of the alignment result 

853 perfectly_matched_indices1: strictly matched node indices in Graph 1 

854 """ 

855 eqn_graph = generate_graph(eqn_mml) 

856 code_graphs = generate_code_graphs(code) 

857 

858 amatrix1, node_labels1 = generate_amatrix(eqn_graph) 

859 matching_results = {} 

860 

861 for exp_idx, exp_graph in code_graphs.items(): 

862 amatrix2, node_labels2 = generate_amatrix(exp_graph) 

863 

864 if mode == 0: 

865 seed1 = [] 

866 seed2 = [] 

867 else: 

868 seed1, seed2 = get_seeds(node_labels1, node_labels2) 

869 

870 partial_match = np.column_stack((seed1, seed2)) 

871 

872 matched_indices1, matched_indices2, _, _ = graph_match( 

873 amatrix1, 

874 amatrix2, 

875 partial_match=partial_match, 

876 padding="adopted", 

877 rng=rng, 

878 max_iter=50, 

879 ) 

880 

881 big_graph_idx = 0 if len(node_labels1) >= len(node_labels2) else 1 

882 if big_graph_idx == 0: 

883 big_graph = amatrix1 

884 big_graph_matched_indices = matched_indices1 

885 small_graph = amatrix2 

886 small_graph_matched_indices = matched_indices2 

887 else: 

888 big_graph = amatrix2 

889 big_graph_matched_indices = matched_indices2 

890 small_graph = amatrix1 

891 small_graph_matched_indices = matched_indices1 

892 

893 small_graph_aligned = small_graph[small_graph_matched_indices][ 

894 :, small_graph_matched_indices 

895 ] 

896 small_graph_aligned_full = np.zeros(big_graph.shape) 

897 small_graph_aligned_full[ 

898 np.ix_(big_graph_matched_indices, big_graph_matched_indices) 

899 ] = small_graph_aligned 

900 

901 num_edges = ((big_graph + small_graph_aligned_full) > 0).sum() 

902 diff_edges = abs(big_graph - small_graph_aligned_full) 

903 diff_edges[diff_edges > 0] = 1 

904 perfectly_matched_indices1 = check_square_array( 

905 diff_edges 

906 ) # strictly aligned node indices of Graph 1 

907 num_diff_edges = np.sum(diff_edges) 

908 matching_ratio = round(1 - (num_diff_edges / num_edges), 2) 

909 

910 long_len = ( 

911 len(node_labels1) 

912 if len(node_labels1) >= len(node_labels2) 

913 else len(node_labels2) 

914 ) 

915 aligned_indices1 = np.zeros((long_len)) - 1 

916 aligned_indices2 = np.zeros((long_len)) - 1 

917 for i in range(long_len): 

918 if i < len(node_labels1): 

919 if i in matched_indices1: 

920 aligned_indices1[i] = matched_indices2[ 

921 np.where(matched_indices1 == i)[0][0] 

922 ] 

923 aligned_indices2[ 

924 matched_indices2[np.where(matched_indices1 == i)[0][0]] 

925 ] = i 

926 

927 # The visualization of the alignment result. 

928 union_graph = get_union_graph( 

929 eqn_graph, 

930 exp_graph, 

931 [int(i) for i in matched_indices1.tolist()], 

932 [int(i) for i in matched_indices2.tolist()], 

933 ) 

934 

935 matching_results[exp_idx] = ( 

936 matching_ratio, 

937 num_diff_edges, 

938 node_labels1, 

939 node_labels2, 

940 aligned_indices1, 

941 aligned_indices2, 

942 union_graph.to_string(), 

943 perfectly_matched_indices1, 

944 ) 

945 

946 return matching_results 

947 

948 

949def extract_variables_with_subsup(mathml_str: str) -> List[str]: 

950 # Function to extract variable names from MathML 

951 root = ET.fromstring(mathml_str) 

952 variables = [] 

953 

954 def process_math_element(element) -> str: 

955 if element.tag == "mi": # If it's a simple variable 

956 variable_name = element.text 

957 return variable_name 

958 elif element.tag in ["msup", "msub", "msubsup"]: 

959 # Handling superscripts, subscripts, and their combinations 

960 base_name = process_math_element(element[0]) 

961 if element.tag == "msup": 

962 modifier = "^" + process_math_element(element[1]) 

963 elif element.tag == "msub": 

964 modifier = "_" + process_math_element(element[1]) 

965 else: # msubsup 

966 modifier = ( 

967 "_" 

968 + process_math_element(element[1]) 

969 + "^" 

970 + process_math_element(element[2]) 

971 ) 

972 variable_name = base_name + modifier 

973 return variable_name 

974 elif element.tag == "mrow": 

975 # Handling row elements by concatenating children's results 

976 variable_name = "" 

977 for child in element: 

978 variable_name += process_math_element(child) 

979 return variable_name 

980 elif element.tag in ["mfrac", "msqrt", "mroot"]: 

981 # Handling fractions, square roots, and root expressions 

982 base_name = process_math_element(element[0]) 

983 if element.tag == "mfrac": 

984 modifier = "/" + process_math_element(element[1]) 

985 elif element.tag == "msqrt": 

986 modifier = "√(" + base_name + ")" 

987 else: # mroot 

988 modifier = "^" + process_math_element(element[1]) 

989 variable_name = base_name + modifier 

990 return variable_name 

991 elif element.tag in ["mover", "munder", "munderover"]: 

992 # Handling overlines, underlines, and combinations 

993 base_name = process_math_element(element[0]) 

994 if element.tag == "mover": 

995 modifier = "^" + process_math_element(element[1]) 

996 elif element.tag == "munder": 

997 modifier = "_" + process_math_element(element[1]) 

998 else: # munderover 

999 modifier = ( 

1000 "_" 

1001 + process_math_element(element[1]) 

1002 + "^" 

1003 + process_math_element(element[2]) 

1004 ) 

1005 variable_name = base_name + modifier 

1006 return variable_name 

1007 elif element.tag in ["mo", "mn"]: 

1008 # Handling operators and numbers 

1009 variable_name = element.text 

1010 return variable_name 

1011 elif element.tag == "mtext": 

1012 # Handling mtext 

1013 variable_name = element.text 

1014 return variable_name 

1015 else: 

1016 # Handling any other tag 

1017 try: 

1018 variable_name = element.text 

1019 return variable_name 

1020 except: 

1021 return "" 

1022 

1023 for elem in root.iter(): 

1024 if elem.tag in ["mi", "msup", "msub", "msubsup"]: 

1025 variables.append(process_math_element(elem)) 

1026 result_list = list(set(variables)) 

1027 result_list = [item for item in result_list if item not in mathml_operators] 

1028 return result_list # Returning unique variable names 

1029 

1030 

1031def format_subscripts_and_superscripts(latex_str: str) -> str: 

1032 # Function to format subscripts and superscripts in a LaTeX string 

1033 # Returns a list of unique variable names 

1034 def replace_sub(match): 

1035 return f"{match.group(1)}_{{{match.group(2)}}}" 

1036 

1037 def replace_sup(match): 

1038 superscript = match.group(2) 

1039 return f"{match.group(1)}^{{{superscript}}}" 

1040 

1041 pattern_sub = r"(\S+)_(\S+)" 

1042 pattern_sup = r"(\S+)\^(\S+)" 

1043 

1044 formatted_str = re.sub(pattern_sup, replace_sup, latex_str) 

1045 formatted_str = re.sub(pattern_sub, replace_sub, formatted_str) 

1046 

1047 return formatted_str 

1048 

1049 

1050def replace_greek_with_unicode(input_str): 

1051 # Function to replace Greek letters and their names with Unicode 

1052 # Returns the replaced string if replacements were made, otherwise an empty string 

1053 replaced_str = input_str 

1054 for gl in greek_letters: 

1055 replaced_str = replaced_str.replace(gl[0], gl[2]) 

1056 replaced_str = replaced_str.replace(gl[1], gl[2]) 

1057 return replaced_str if replaced_str != input_str else "" 

1058 

1059 

1060def replace_unicode_with_symbol(input_str): 

1061 # Function to replace Unicode representations with corresponding symbols 

1062 # Returns the replaced string if replacements were made, otherwise an empty string 

1063 pattern = r"&#x[A-Fa-f0-9]+;" 

1064 matches = re.findall(pattern, input_str) 

1065 

1066 replaced_str = input_str 

1067 for match in matches: 

1068 unicode_char = html.unescape(match) 

1069 replaced_str = replaced_str.replace(match, unicode_char) 

1070 

1071 return replaced_str if replaced_str != input_str else "" 

1072 

1073 

1074def transform_variable(variable: str) -> List[Union[str, List[str]]]: 

1075 # Function to transform a variable into a list containing different representations 

1076 # Returns a list containing various representations of the variable 

1077 if variable.startswith("&#x"): 

1078 for gl in greek_letters: 

1079 if variable in gl: 

1080 return gl 

1081 return [html.unescape(variable), variable] 

1082 elif variable.isalpha(): 

1083 if len(variable) == 1: 

1084 for gl in greek_letters: 

1085 if variable in gl: 

1086 return gl 

1087 return [variable, "&#x{:04X};".format(ord(variable))] 

1088 else: 

1089 return [variable] 

1090 else: 

1091 if len(variable) == 1: 

1092 return [variable, "&#x{:04X};".format(ord(variable))] 

1093 else: 

1094 variable_list = [variable, format_subscripts_and_superscripts(variable)] 

1095 if replace_greek_with_unicode(variable) != "": 

1096 variable_list.append(replace_greek_with_unicode(variable)) 

1097 variable_list.append( 

1098 replace_greek_with_unicode( 

1099 format_subscripts_and_superscripts(variable) 

1100 ) 

1101 ) 

1102 if replace_unicode_with_symbol(variable) != "": 

1103 variable_list.append(replace_unicode_with_symbol(variable)) 

1104 variable_list.append(format_subscripts_and_superscripts(variable)) 

1105 

1106 return variable_list 

1107 

1108 

1109def create_variable_dictionary( 

1110 variables: List[str], 

1111) -> Dict[str, List[Union[str, List[str]]]]: 

1112 # Function to create a dictionary mapping variables to their representations 

1113 # Returns a dictionary with variables as keys and their representations as values 

1114 variable_dict = {} 

1115 for variable in variables: 

1116 variable_dict[variable] = transform_variable(variable) 

1117 return variable_dict 

1118 

1119 

1120def generate_variable_dict(mathml_string): 

1121 # Function to generate a variable dictionary from MathML 

1122 try: 

1123 variables = extract_variables_with_subsup(mathml_string) 

1124 variable_dict = create_variable_dictionary(variables) 

1125 return variable_dict 

1126 except: 

1127 return {} 

1128 

1129def convert_to_dict(obj: Any) -> Union[List[Any], Dict[Any, Any], np.ndarray, Any]: 

1130 """ 

1131 Recursively converts an object to a dictionary, handling numpy arrays and nested structures. 

1132 

1133 Args: 

1134 obj (Any): The input object to be converted. 

1135 

1136 Returns: 

1137 Union[List[Any], Dict[Any, Any], np.ndarray, Any]: The converted object in dictionary form. 

1138 

1139 Note: 

1140 This function recursively processes the input object and converts it to a dictionary. 

1141 - If the input is a numpy array, it is converted to a Python list. 

1142 - If the input is a list or tuple, each element is processed recursively. 

1143 - If the input is a dictionary, each value is processed recursively. 

1144 - Other types are returned as is. 

1145 """ 

1146 if isinstance(obj, np.ndarray): 

1147 return obj.tolist() 

1148 elif isinstance(obj, (list, tuple)): 

1149 return [convert_to_dict(item) for item in obj] 

1150 elif isinstance(obj, dict): 

1151 return {key: convert_to_dict(value) for key, value in obj.items()} 

1152 else: 

1153 return obj