Coverage for skema/isa/lib.py: 48%
436 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1# -*- coding: utf-8 -*
2"""
3All the functions required by performing incremental structure alignment (ISA)
4Author: Liang Zhang (liangzh@arizona.edu)
5Updated date: December 18, 2023
6"""
7import json
8import warnings
9from typing import List, Any, Union, Dict, Tuple
10from numpy import ndarray
11from pydot import Dot
12from skema.rest.proxies import SKEMA_RS_ADDESS
14warnings.filterwarnings("ignore")
15import requests
16import pydot
17import numpy as np
18from graspologic.match import graph_match
19from graphviz import Source
20import graphviz
21from copy import deepcopy
22import Levenshtein
23from typing import Tuple
24import re
25import xml.etree.ElementTree as ET
26import html
27from sentence_transformers import SentenceTransformer, util
28import json
29import ast
32# Set up the random seed
33np.random.seed(4)
34rng = np.random.default_rng(4)
36# The encodings of basic operators when converting adjacency matrix
37op_dict = {"+": 1, "-": 2, "*": 3, "/": 4, "=": 5, "√": 6}
39# Greek letters mapping
40# List of Greek letters mapping to their lowercase, name, and Unicode representation
41greek_letters: List[List[str]] = [
42 ["α", "alpha", "α"],
43 ["β", "beta", "β"],
44 ["γ", "gamma", "γ"],
45 ["δ", "delta", "δ"],
46 ["ε", "epsilon", "ε"],
47 ["ζ", "zeta", "ζ"],
48 ["η", "eta", "η"],
49 ["θ", "theta", "θ"],
50 ["ι", "iota", "ι"],
51 ["κ", "kappa", "κ"],
52 ["λ", "lambda", "λ"],
53 ["μ", "mu", "μ"],
54 ["ν", "nu", "ν"],
55 ["ξ", "xi", "ξ"],
56 ["ο", "omicron", "ο"],
57 ["π", "pi", "π"],
58 ["ρ", "rho", "ρ"],
59 ["σ", "sigma", "σ"],
60 ["τ", "tau", "τ"],
61 ["υ", "upsilon", "υ"],
62 ["φ", "phi", "φ"],
63 ["χ", "chi", "χ"],
64 ["ψ", "psi", "ψ"],
65 ["ω", "omega", "ω"],
66 ["Α", "Alpha", "Α"],
67 ["Β", "Beta", "Β"],
68 ["Γ", "Gamma", "Γ"],
69 ["Δ", "Delta", "Δ"],
70 ["Ε", "Epsilon", "Ε"],
71 ["Ζ", "Zeta", "Ζ"],
72 ["Η", "Eta", "Η"],
73 ["Θ", "Theta", "Θ"],
74 ["Ι", "Iota", "Ι"],
75 ["Κ", "Kappa", "Κ"],
76 ["Λ", "Lambda", "Λ"],
77 ["Μ", "Mu", "Μ"],
78 ["Ν", "Nu", "Ν"],
79 ["Ξ", "Xi", "Ξ"],
80 ["Ο", "Omicron", "Ο"],
81 ["Π", "Pi", "Π"],
82 ["Ρ", "Rho", "Ρ"],
83 ["Σ", "Sigma", "Σ"],
84 ["Τ", "Tau", "Τ"],
85 ["Υ", "Upsilon", "Υ"],
86 ["Φ", "Phi", "Φ"],
87 ["Χ", "Chi", "Χ"],
88 ["Ψ", "Psi", "Ψ"],
89 ["Ω", "Omega", "Ω"],
90]
92mathml_operators = [
93 "sin",
94 "cos",
95 "tan",
96 "sec",
97 "csc",
98 "cot",
99 "log",
100 "ln",
101 "exp",
102 "sqrt",
103 "sum",
104 "prod",
105 "lim",
106]
109def levenshtein_similarity(var1: str, var2: str) -> float:
110 """
111 Compute the Levenshtein similarity between two variable names.
112 The Levenshtein similarity is the ratio of the Levenshtein distance to the maximum length.
113 Args:
114 var1: The first variable name.
115 var2: The second variable name.
116 Returns:
117 The Levenshtein similarity between the two variable names.
118 """
119 distance = Levenshtein.distance(var1, var2)
120 max_length = max(len(var1), len(var2))
121 similarity = 1 - (distance / max_length)
122 return similarity
125def jaccard_similarity(var1: str, var2: str) -> float:
126 """
127 Compute the Jaccard similarity between two variable names.
128 The Jaccard similarity is the size of the intersection divided by the size of the union of the variable names.
129 Args:
130 var1: The first variable name.
131 var2: The second variable name.
132 Returns:
133 The Jaccard similarity between the two variable names.
134 """
135 set1 = set(var1)
136 set2 = set(var2)
137 intersection = len(set1.intersection(set2))
138 union = len(set1.union(set2))
139 similarity = intersection / union
140 return similarity
143def cosine_similarity(var1: str, var2: str) -> float:
144 """
145 Compute the cosine similarity between two variable names.
146 The cosine similarity is the dot product of the character frequency vectors divided by the product of their norms.
147 Args:
148 var1: The first variable name.
149 var2: The second variable name.
150 Returns:
151 The cosine similarity between the two variable names.
152 """
153 char_freq1 = {char: var1.count(char) for char in var1}
154 char_freq2 = {char: var2.count(char) for char in var2}
156 dot_product = sum(
157 char_freq1.get(char, 0) * char_freq2.get(char, 0) for char in set(var1 + var2)
158 )
159 norm1 = sum(freq**2 for freq in char_freq1.values()) ** 0.5
160 norm2 = sum(freq**2 for freq in char_freq2.values()) ** 0.5
162 similarity = dot_product / (norm1 * norm2)
163 return similarity
166def generate_graph(file: str = "", render: bool = False) -> pydot.Dot:
167 """
168 Call the REST API of math-exp-graph to convert the MathML input to its GraphViz representation
169 Ensure running the REST API before calling this function
170 Input: file directory or the MathML string
171 Output: the GraphViz representation (pydot.Dot)
172 """
173 if "<math>" in file and "</math>" in file:
174 content = file
175 else:
176 with open(file) as f:
177 content = f.read()
179 #SKEMA_RS_ADDESS = "http://localhost:8080"
180 digraph = requests.put(
181 f"{SKEMA_RS_ADDESS}/mathml/math-exp-graph", data=content.encode("utf-8")
182 )
184 if render:
185 src = Source(digraph.text)
186 src.render("doctest-output/mathml_exp_tree", view=True)
187 graph = pydot.graph_from_dot_data(str(digraph.text))[0]
188 return graph
191def generate_code_graphs(graph_string: str = "") -> Dict[str, pydot.Dot]:
192 """
193 Call the REST API of code-exp-graphs to convert the code input to its GraphViz representation
194 Ensure running the REST API before calling this function
195 Input: file directory
196 Output: a dictionary of the GraphViz representation (pydot.Dot)
197 """
199 # Safely evaluate the string as a literal Python expression
200 code_exp_graphs_dict = ast.literal_eval(graph_string)
202 # Convert the string representations to pydot.Dot objects
203 for key, value in code_exp_graphs_dict.items():
204 code_exp_graphs_dict[key] = pydot.graph_from_dot_data(value)[0]
206 return code_exp_graphs_dict
208def generate_amatrix(graph: pydot.Dot) -> Tuple[ndarray, List[str]]:
209 """
210 Convert the GraphViz representation to its corresponding adjacency matrix
211 Input: the GraphViz representation
212 Output: the adjacency matrix and the list of the names of variables and terms appeared in the expression
213 """
214 node_labels = []
215 for node in graph.get_nodes():
216 node_labels.append(node.obj_dict["attributes"]["label"].replace('"', ""))
218 amatrix = np.zeros((len(node_labels), len(node_labels)))
220 for edge in graph.get_edges():
221 x, y = edge.obj_dict["points"]
222 label = edge.obj_dict["attributes"]["label"].replace('"', "")
223 amatrix[int(x)][int(y)] = op_dict[label] if label in op_dict else 7
225 return amatrix, node_labels
228def heuristic_compare_variable_names(var1: str, var2: str) -> bool:
229 """
230 Compare two variable names in a formula, accounting for Unicode representations.
231 Convert the variable names to English letter representations before comparison.
233 Args:
234 var1 (str): The first variable name.
235 var2 (str): The second variable name.
237 Returns:
238 bool: True if the variable names are the same, False otherwise.
239 """
240 # Mapping of Greek letters to English letter representations
241 greek_letters = {
242 "α": "alpha",
243 "β": "beta",
244 "γ": "gamma",
245 "δ": "delta",
246 "ε": "epsilon",
247 "ζ": "zeta",
248 "η": "eta",
249 "θ": "theta",
250 "ι": "iota",
251 "κ": "kappa",
252 "λ": "lambda",
253 "μ": "mu",
254 "ν": "nu",
255 "ξ": "xi",
256 "ο": "omicron",
257 "π": "pi",
258 "ρ": "rho",
259 "σ": "sigma",
260 "τ": "tau",
261 "υ": "upsilon",
262 "φ": "phi",
263 "χ": "chi",
264 "ψ": "psi",
265 "ω": "omega",
266 }
268 # Convert Unicode representations to English letter representations
269 var1 = re.sub(r"&#x(\w+);?", lambda m: chr(int(m.group(1), 16)), var1).lower()
270 var2 = re.sub(r"&#x(\w+);?", lambda m: chr(int(m.group(1), 16)), var2).lower()
272 # Convert Greek letter representations to English letter representations
273 for greek_letter, english_letter in greek_letters.items():
274 var1 = var1.replace(greek_letter, english_letter)
275 var2 = var2.replace(greek_letter, english_letter)
277 # Remove trailing quotation marks, if present
278 var1 = var1.strip("'\"")
279 var2 = var2.strip("'\"")
281 # Compare the variable names
282 return var1.lower() == var2.lower()
285def extract_var_information(
286 data: Dict[
287 str,
288 Union[
289 List[
290 Dict[
291 str,
292 Union[Dict[str, Union[str, int]], List[Dict[str, Union[str, int]]]],
293 ]
294 ],
295 None,
296 ],
297 ]
298) -> List[Dict[str, str]]:
299 """
300 Extracts variable information from the given JSON data of the SKEMA mention extraction.
302 Parameters:
303 - data (Dict[str, Union[List[Dict[str, Union[Dict[str, Union[str, int]], List[Dict[str, Union[str, int]]]]]], None]]): The input JSON data.
305 Returns:
306 - List[Dict[str, str]]: A list of dictionaries containing extracted information (name, definition, and document_id).
307 """
308 outputs = data.get("outputs", [])
309 extracted_data = []
311 for output in outputs:
312 attributes = output.get("data", {}).get("attributes", [])
314 for attribute in attributes:
315 payload = attribute.get("payload", {})
316 mentions = payload.get("mentions", [])
317 text_descriptions = payload.get("text_descriptions", [])
319 for mention in mentions:
320 name = mention.get("name", "")
321 extraction_source = mention.get("extraction_source", {})
322 document_reference = extraction_source.get("document_reference", {})
323 document_id = document_reference.get("id", "")
325 for text_description in text_descriptions:
326 description = text_description.get("description", "")
327 extraction_source = text_description.get("extraction_source", {})
329 extracted_data.append(
330 {
331 "name": name,
332 "definition": description,
333 "document_id": document_id,
334 }
335 )
337 return extracted_data
340def organize_into_json(
341 extracted_data: List[Dict[str, str]]
342) -> Dict[str, List[Dict[str, str]]]:
343 """
344 Organizes the extracted information into a new JSON format.
346 Parameters:
347 - extracted_data (List[Dict[str, str]]): A list of dictionaries containing extracted information (name, definition, and document_id).
349 Returns:
350 - Dict[str, List[Dict[str, str]]]: A dictionary containing organized information in a new JSON format.
351 """
352 organized_data = {"variables": []}
354 for item in extracted_data:
355 organized_data["variables"].append(
356 {
357 "name": item["name"],
358 "definition": item["definition"],
359 "document_id": item["document_id"],
360 }
361 )
363 return organized_data
366def extract_var_defs_from_metions(input_file: str) -> Dict[str, List[Dict[str, str]]]:
367 """
368 Processes the input JSON file, extracts information, and writes the organized data to the output JSON file.
370 Parameters:
371 - input_file (str): The path to the input JSON file.
373 Returns:
374 - organized_data (Dict[str, List[Dict[str, str]]]): The dictionary of the variable names and their definitions.
375 """
376 try:
377 # Read the original JSON data
378 with open(input_file, "r", encoding="utf-8") as file:
379 original_data = json.load(file)
380 except FileNotFoundError:
381 print(f"Error: File '{input_file}' not found.")
382 return {}
383 except json.JSONDecodeError as e:
384 print(f"Error: JSON decoding failed. Details: {e}")
385 return {}
387 # Extract information
388 extracted_data = extract_var_information(original_data)
390 # Organize into a new JSON format
391 organized_data = organize_into_json(extracted_data)
393 return organized_data
396def find_definition(
397 variable_name: str, extracted_data: Dict[str, List[Dict[str, str]]]
398) -> str:
399 """
400 Finds the definition for a variable name in the extracted data.
402 Args:
403 variable_name (str): Variable name to find.
404 extracted_data (List[Dict[str, Union[str, int]]]): List of dictionaries containing extracted information.
406 Returns:
407 str: Definition for the variable name, or an empty string if not found.
408 """
409 for attribute in extracted_data["variables"]:
410 if heuristic_compare_variable_names(variable_name, attribute["name"]):
411 return attribute["definition"]
413 return ""
416def calculate_similarity(
417 definition1: str, definition2: str, field: str = "biomedical"
418) -> float:
419 """
420 Calculates semantic similarity between two variable definitions using BERT embeddings.
422 Args:
423 definition1 (str): First variable definition.
424 definition2 (str): Second variable definition.
425 field (str): Language model to load.
427 Returns:
428 float: Semantic similarity score between 0 and 1.
429 """
430 pre_trained_model = "msmarco-distilbert-base-v2"
431 model = SentenceTransformer(pre_trained_model)
433 # Convert definitions to BERT embeddings
434 embedding1 = model.encode(definition1, convert_to_tensor=True)
435 embedding2 = model.encode(definition2, convert_to_tensor=True)
437 # Calculate cosine similarity between embeddings
438 cosine_similarity = util.pytorch_cos_sim(embedding1, embedding2)[0][0].item()
440 return cosine_similarity
443def match_variable_definitions(
444 list1: List[str],
445 list2: List[str],
446 json_path1: str,
447 json_path2: str,
448 threshold: float,
449) -> Tuple[List[int], List[int]]:
450 """
451 Match variable definitions for given variable names in two lists.
453 Args:
454 list1 (List[str]): List of variable names from the first equation.
455 list2 (List[str]): List of variable names from the second equation.
456 json_path1 (str): Path to the JSON file containing variable definitions for the first article.
457 json_path2 (str): Path to the JSON file containing variable definitions for the second article.
458 threshold (float): Similarity threshold for considering a match.
460 Returns:
461 Tuple[List[int], List[int]]: Lists of indices for matched variable names in list1 and list2.
462 """
463 extracted_data1 = extract_var_defs_from_metions(json_path1)
464 extracted_data2 = extract_var_defs_from_metions(json_path2)
466 var_idx_list1 = []
467 var_idx_list2 = []
469 for idx1, var1 in enumerate(list1):
470 max_similarity = 0.0
471 matching_idx = -1
472 for idx2, var2 in enumerate(list2):
473 def1 = find_definition(var1, extracted_data1)
474 def2 = find_definition(var2, extracted_data2)
476 if def1 and def2:
477 similarity = calculate_similarity(def1, def2)
478 if similarity > max_similarity and similarity >= threshold:
479 max_similarity = similarity
480 matching_idx = idx2
482 if matching_idx != -1:
483 if idx1 not in var_idx_list1:
484 var_idx_list1.append(idx1)
485 var_idx_list2.append(matching_idx)
487 return var_idx_list1, var_idx_list2
490def get_seeds(
491 node_labels1: List[str],
492 node_labels2: List[str],
493 method: str = "heuristic",
494 threshold: float = 0.8,
495 mention_json1: str = "",
496 mention_json2: str = "",
497) -> Tuple[List[int], List[int]]:
498 """
499 Calculate the seeds in the two equations.
501 Args:
502 node_labels1: The name lists of the variables and terms in equation 1.
503 node_labels2: The name lists of the variables and terms in equation 2.
504 method: The method to get seeds.
505 - "heuristic": Based on variable name identification.
506 - "levenshtein": Based on Levenshtein similarity of variable names.
507 - "jaccard": Based on Jaccard similarity of variable names.
508 - "cosine": Based on cosine similarity of variable names.
509 threshold: The threshold to use for Levenshtein, Jaccard, and cosine methods.
510 mention_json1: The JSON file path of the mention extraction of paper 1.
511 mention_json2: The JSON file path of the mention extraction of paper 2.
513 Returns:
514 A tuple of two lists:
515 - seed1: The seed indices from equation 1.
516 - seed2: The seed indices from equation 2.
517 """
518 seed1 = []
519 seed2 = []
520 if method == "var_defs":
521 seed1, seed2 = match_variable_definitions(
522 node_labels1,
523 node_labels2,
524 json_path1=mention_json1,
525 json_path2=mention_json2,
526 threshold=0.9,
527 )
528 else:
529 for i in range(0, len(node_labels1)):
530 for j in range(0, len(node_labels2)):
531 if method == "heuristic":
532 if heuristic_compare_variable_names(
533 node_labels1[i], node_labels2[j]
534 ):
535 if i not in seed1:
536 seed1.append(i)
537 seed2.append(j)
538 elif method == "levenshtein":
539 if (
540 levenshtein_similarity(node_labels1[i], node_labels2[j])
541 > threshold
542 ):
543 if i not in seed1:
544 seed1.append(i)
545 seed2.append(j)
546 elif method == "jaccard":
547 if jaccard_similarity(node_labels1[i], node_labels2[j]) > threshold:
548 if i not in seed1:
549 seed1.append(i)
550 seed2.append(j)
551 elif method == "cosine":
552 if cosine_similarity(node_labels1[i], node_labels2[j]) > threshold:
553 if i not in seed1:
554 seed1.append(i)
555 seed2.append(j)
557 return seed1, seed2
560def has_edge(dot: pydot.Dot, src: str, dst: str) -> bool:
561 """
562 Check if an edge exists between two nodes in a PyDot graph object.
564 Args:
565 dot (pydot.Dot): PyDot graph object.
566 src (str): Source node ID.
567 dst (str): Destination node ID.
569 Returns:
570 bool: True if an edge exists between src and dst, False otherwise.
571 """
572 edges = dot.get_edges()
573 for edge in edges:
574 if edge.get_source() == src and edge.get_destination() == dst:
575 return True
576 return False
579def get_union_graph(
580 graph1: pydot.Dot,
581 graph2: pydot.Dot,
582 aligned_idx1: List[int],
583 aligned_idx2: List[int],
584) -> pydot.Dot:
585 """
586 return the union graph for visualizing the alignment results
587 input: The dot representation of Graph1, the dot representation of Graph2, the aligned node indices in Graph1, the aligned node indices in Graph2
588 output: dot graph
589 """
590 g2idx2g1idx = {str(x): str(-1) for x in range(len(graph2.get_nodes()))}
591 union_graph = deepcopy(graph1)
592 """
593 set the aligned variables or terms as a blue circle;
594 if their names are the same, show one name;
595 if not, show two names' connection using '<<|>>'
596 """
597 for i in range(len(aligned_idx1)):
598 if (
599 union_graph.get_nodes()[aligned_idx1[i]]
600 .obj_dict["attributes"]["label"]
601 .replace('"', "")
602 .lower()
603 != graph2.get_nodes()[aligned_idx2[i]]
604 .obj_dict["attributes"]["label"]
605 .replace('"', "")
606 .lower()
607 ):
608 union_graph.get_nodes()[aligned_idx1[i]].obj_dict["attributes"]["label"] = (
609 union_graph.get_nodes()[aligned_idx1[i]]
610 .obj_dict["attributes"]["label"]
611 .replace('"', "")
612 + " <<|>> "
613 + graph2.get_nodes()[aligned_idx2[i]]
614 .obj_dict["attributes"]["label"]
615 .replace('"', "")
616 )
618 union_graph.get_nodes()[aligned_idx1[i]].obj_dict["attributes"][
619 "color"
620 ] = "blue"
621 g2idx2g1idx[str(aligned_idx2[i])] = str(aligned_idx1[i])
623 # represent the nodes only in graph 1 as a red circle
624 for i in range(len(union_graph.get_nodes())):
625 if i not in aligned_idx1:
626 union_graph.get_nodes()[i].obj_dict["attributes"]["color"] = "red"
628 # represent the nodes only in graph 2 as a green circle
629 for i in range(len(graph2.get_nodes())):
630 if i not in aligned_idx2:
631 graph2.get_nodes()[i].obj_dict["attributes"]["color"] = "green"
632 graph2.get_nodes()[i].obj_dict["name"] = str(len(union_graph.get_nodes()))
633 union_graph.add_node(graph2.get_nodes()[i])
634 g2idx2g1idx[str(i)] = str(len(union_graph.get_nodes()) - 1)
636 # add the edges of graph 2 to graph 1
637 for edge in union_graph.get_edges():
638 edge.obj_dict["attributes"]["color"] = "red"
640 for edge in graph2.get_edges():
641 x, y = edge.obj_dict["points"]
642 if has_edge(union_graph, g2idx2g1idx[x], g2idx2g1idx[y]):
643 if (
644 union_graph.get_edge(g2idx2g1idx[x], g2idx2g1idx[y])[0]
645 .obj_dict["attributes"]["label"]
646 .lower()
647 == edge.obj_dict["attributes"]["label"].lower()
648 ):
649 union_graph.get_edge(g2idx2g1idx[x], g2idx2g1idx[y])[0].obj_dict[
650 "attributes"
651 ]["color"] = "blue"
652 else:
653 e = pydot.Edge(
654 g2idx2g1idx[x],
655 g2idx2g1idx[y],
656 label=edge.obj_dict["attributes"]["label"],
657 color="green",
658 )
659 union_graph.add_edge(e)
660 else:
661 e = pydot.Edge(
662 g2idx2g1idx[x],
663 g2idx2g1idx[y],
664 label=edge.obj_dict["attributes"]["label"],
665 color="green",
666 )
667 union_graph.add_edge(e)
669 return union_graph
672def check_square_array(arr: np.ndarray) -> List[int]:
673 """
674 Given a square numpy array, returns a list of size equal to the length of the array,
675 where each element of the list is either 0 or 1, depending on whether the corresponding
676 row and column of the input array are all 0s or not.
678 Parameters:
679 arr (np.ndarray): a square numpy array
681 Returns:
682 List[int]: a list of 0s and 1s
683 """
685 n = arr.shape[0] # get the size of the array
686 result = []
687 for i in range(n):
688 # Check if the ith row and ith column are all 0s
689 if np.all(arr[i, :] == 0) and np.all(arr[:, i] == 0):
690 result.append(0) # if so, append 0 to the result list
691 else:
692 result.append(1) # otherwise, append 1 to the result list
693 return result
696def align_mathml_eqs(
697 mml1: str = "",
698 mml2: str = "",
699 mention_json1: str = "",
700 mention_json2: str = "",
701 mode: int = 2,
702) -> Tuple[
703 Any, Any, List[str], List[str], Union[int, Any], Union[int, Any], Dot, List[int]
704]:
705 """
706 align two equation graphs using the seeded graph matching (SGD) algorithm [1].
708 [1] Fishkind, D. E., Adali, S., Patsolic, H. G., Meng, L., Singh, D., Lyzinski, V., & Priebe, C. E. (2019).
709 Seeded graph matching. Pattern recognition, 87, 203-215.
711 Input: mml1 & mml2: the file path or contents of the two equation MathMLs; mention_json1: the mention file of paper 1; mention_json1: the mention file of paper 2;
712 mode 0: without considering any priors; mode 1: having a heuristic prior
713 with the similarity of node labels; mode 2: using the variable definitions
714 Output:
715 matching_ratio: the matching ratio between the equations 1 and the equation 2
716 num_diff_edges: the number of different edges between the equations 1 and the equation 2
717 node_labels1: the name list of the variables and terms in the equation 1
718 node_labels2: the name list of the variables and terms in the equation 2
719 aligned_indices1: the aligned indices in the name list of the equation 1
720 aligned_indices2: the aligned indices in the name list of the equation 2
721 union_graph: the visualization of the alignment result
722 perfectly_matched_indices1: strictly matched node indices in Graph 1
723 """
724 graph1 = generate_graph(mml1)
725 graph2 = generate_graph(mml2)
727 amatrix1, node_labels1 = generate_amatrix(graph1)
728 amatrix2, node_labels2 = generate_amatrix(graph2)
730 # If there are no mention files provided, it returns to mode 1
731 if (mention_json1 == "" or mention_json2 == "") and mode == 2:
732 mode = 1
734 if mode == 0:
735 seed1 = []
736 seed2 = []
737 elif mode == 1:
738 seed1, seed2 = get_seeds(node_labels1, node_labels2)
739 else:
740 seed1, seed2 = get_seeds(
741 node_labels1,
742 node_labels2,
743 method="var_defs",
744 threshold=0.9,
745 mention_json1=mention_json1,
746 mention_json2=mention_json2,
747 )
749 partial_match = np.column_stack((seed1, seed2))
751 matched_indices1, matched_indices2, _, _ = graph_match(
752 amatrix1,
753 amatrix2,
754 partial_match=partial_match,
755 padding="adopted",
756 rng=rng,
757 max_iter=50,
758 )
760 big_graph_idx = 0 if len(node_labels1) >= len(node_labels2) else 1
761 if big_graph_idx == 0:
762 big_graph = amatrix1
763 big_graph_matched_indices = matched_indices1
764 small_graph = amatrix2
765 small_graph_matched_indices = matched_indices2
766 else:
767 big_graph = amatrix2
768 big_graph_matched_indices = matched_indices2
769 small_graph = amatrix1
770 small_graph_matched_indices = matched_indices1
772 small_graph_aligned = small_graph[small_graph_matched_indices][
773 :, small_graph_matched_indices
774 ]
775 small_graph_aligned_full = np.zeros(big_graph.shape)
776 small_graph_aligned_full[
777 np.ix_(big_graph_matched_indices, big_graph_matched_indices)
778 ] = small_graph_aligned
780 num_edges = ((big_graph + small_graph_aligned_full) > 0).sum()
781 diff_edges = abs(big_graph - small_graph_aligned_full)
782 diff_edges[diff_edges > 0] = 1
783 perfectly_matched_indices1 = check_square_array(
784 diff_edges
785 ) # strictly aligned node indices of Graph 1
786 num_diff_edges = np.sum(diff_edges)
787 matching_ratio = round(1 - (num_diff_edges / num_edges), 2)
789 long_len = (
790 len(node_labels1)
791 if len(node_labels1) >= len(node_labels2)
792 else len(node_labels2)
793 )
794 aligned_indices1 = np.zeros((long_len)) - 1
795 aligned_indices2 = np.zeros((long_len)) - 1
796 for i in range(long_len):
797 if i < len(node_labels1):
798 if i in matched_indices1:
799 aligned_indices1[i] = matched_indices2[
800 np.where(matched_indices1 == i)[0][0]
801 ]
802 aligned_indices2[
803 matched_indices2[np.where(matched_indices1 == i)[0][0]]
804 ] = i
806 # The visualization of the alignment result.
807 union_graph = get_union_graph(
808 graph1,
809 graph2,
810 [int(i) for i in matched_indices1.tolist()],
811 [int(i) for i in matched_indices2.tolist()],
812 )
814 return (
815 matching_ratio,
816 num_diff_edges,
817 node_labels1,
818 node_labels2,
819 aligned_indices1,
820 aligned_indices2,
821 union_graph,
822 perfectly_matched_indices1,
823 )
826def align_eqn_code(
827 eqn_mml: str = "",
828 code: str = "",
829 mode: int = 1,
830) -> Dict[
831 Any,
832 Tuple[Any, Any, List[str], List[str], Union[int, Any], Union[int, Any], str, List[int]],
833]:
834 """
835 align the mathml equation graph and the code graphs using the seeded graph matching (SGD) algorithm [1].
837 [1] Fishkind, D. E., Adali, S., Patsolic, H. G., Meng, L., Singh, D., Lyzinski, V., & Priebe, C. E. (2019).
838 Seeded graph matching. Pattern recognition, 87, 203-215.
840 Input: eqn_mml: the equation mathml
841 code: the file path of the code input;
842 mode 0: without considering any priors; mode 1: having a heuristic prior
843 with the similarity of node labels;
844 Output:
845 The dictionary of the following data structure
846 matching_ratio: the matching ratio between the equations 1 and the equation 2
847 num_diff_edges: the number of different edges between the equations 1 and the equation 2
848 node_labels1: the name list of the variables and terms in the equation 1
849 node_labels2: the name list of the variables and terms in the equation 2
850 aligned_indices1: the aligned indices in the name list of the equation 1
851 aligned_indices2: the aligned indices in the name list of the equation 2
852 union_graph: the visualization of the alignment result
853 perfectly_matched_indices1: strictly matched node indices in Graph 1
854 """
855 eqn_graph = generate_graph(eqn_mml)
856 code_graphs = generate_code_graphs(code)
858 amatrix1, node_labels1 = generate_amatrix(eqn_graph)
859 matching_results = {}
861 for exp_idx, exp_graph in code_graphs.items():
862 amatrix2, node_labels2 = generate_amatrix(exp_graph)
864 if mode == 0:
865 seed1 = []
866 seed2 = []
867 else:
868 seed1, seed2 = get_seeds(node_labels1, node_labels2)
870 partial_match = np.column_stack((seed1, seed2))
872 matched_indices1, matched_indices2, _, _ = graph_match(
873 amatrix1,
874 amatrix2,
875 partial_match=partial_match,
876 padding="adopted",
877 rng=rng,
878 max_iter=50,
879 )
881 big_graph_idx = 0 if len(node_labels1) >= len(node_labels2) else 1
882 if big_graph_idx == 0:
883 big_graph = amatrix1
884 big_graph_matched_indices = matched_indices1
885 small_graph = amatrix2
886 small_graph_matched_indices = matched_indices2
887 else:
888 big_graph = amatrix2
889 big_graph_matched_indices = matched_indices2
890 small_graph = amatrix1
891 small_graph_matched_indices = matched_indices1
893 small_graph_aligned = small_graph[small_graph_matched_indices][
894 :, small_graph_matched_indices
895 ]
896 small_graph_aligned_full = np.zeros(big_graph.shape)
897 small_graph_aligned_full[
898 np.ix_(big_graph_matched_indices, big_graph_matched_indices)
899 ] = small_graph_aligned
901 num_edges = ((big_graph + small_graph_aligned_full) > 0).sum()
902 diff_edges = abs(big_graph - small_graph_aligned_full)
903 diff_edges[diff_edges > 0] = 1
904 perfectly_matched_indices1 = check_square_array(
905 diff_edges
906 ) # strictly aligned node indices of Graph 1
907 num_diff_edges = np.sum(diff_edges)
908 matching_ratio = round(1 - (num_diff_edges / num_edges), 2)
910 long_len = (
911 len(node_labels1)
912 if len(node_labels1) >= len(node_labels2)
913 else len(node_labels2)
914 )
915 aligned_indices1 = np.zeros((long_len)) - 1
916 aligned_indices2 = np.zeros((long_len)) - 1
917 for i in range(long_len):
918 if i < len(node_labels1):
919 if i in matched_indices1:
920 aligned_indices1[i] = matched_indices2[
921 np.where(matched_indices1 == i)[0][0]
922 ]
923 aligned_indices2[
924 matched_indices2[np.where(matched_indices1 == i)[0][0]]
925 ] = i
927 # The visualization of the alignment result.
928 union_graph = get_union_graph(
929 eqn_graph,
930 exp_graph,
931 [int(i) for i in matched_indices1.tolist()],
932 [int(i) for i in matched_indices2.tolist()],
933 )
935 matching_results[exp_idx] = (
936 matching_ratio,
937 num_diff_edges,
938 node_labels1,
939 node_labels2,
940 aligned_indices1,
941 aligned_indices2,
942 union_graph.to_string(),
943 perfectly_matched_indices1,
944 )
946 return matching_results
949def extract_variables_with_subsup(mathml_str: str) -> List[str]:
950 # Function to extract variable names from MathML
951 root = ET.fromstring(mathml_str)
952 variables = []
954 def process_math_element(element) -> str:
955 if element.tag == "mi": # If it's a simple variable
956 variable_name = element.text
957 return variable_name
958 elif element.tag in ["msup", "msub", "msubsup"]:
959 # Handling superscripts, subscripts, and their combinations
960 base_name = process_math_element(element[0])
961 if element.tag == "msup":
962 modifier = "^" + process_math_element(element[1])
963 elif element.tag == "msub":
964 modifier = "_" + process_math_element(element[1])
965 else: # msubsup
966 modifier = (
967 "_"
968 + process_math_element(element[1])
969 + "^"
970 + process_math_element(element[2])
971 )
972 variable_name = base_name + modifier
973 return variable_name
974 elif element.tag == "mrow":
975 # Handling row elements by concatenating children's results
976 variable_name = ""
977 for child in element:
978 variable_name += process_math_element(child)
979 return variable_name
980 elif element.tag in ["mfrac", "msqrt", "mroot"]:
981 # Handling fractions, square roots, and root expressions
982 base_name = process_math_element(element[0])
983 if element.tag == "mfrac":
984 modifier = "/" + process_math_element(element[1])
985 elif element.tag == "msqrt":
986 modifier = "√(" + base_name + ")"
987 else: # mroot
988 modifier = "^" + process_math_element(element[1])
989 variable_name = base_name + modifier
990 return variable_name
991 elif element.tag in ["mover", "munder", "munderover"]:
992 # Handling overlines, underlines, and combinations
993 base_name = process_math_element(element[0])
994 if element.tag == "mover":
995 modifier = "^" + process_math_element(element[1])
996 elif element.tag == "munder":
997 modifier = "_" + process_math_element(element[1])
998 else: # munderover
999 modifier = (
1000 "_"
1001 + process_math_element(element[1])
1002 + "^"
1003 + process_math_element(element[2])
1004 )
1005 variable_name = base_name + modifier
1006 return variable_name
1007 elif element.tag in ["mo", "mn"]:
1008 # Handling operators and numbers
1009 variable_name = element.text
1010 return variable_name
1011 elif element.tag == "mtext":
1012 # Handling mtext
1013 variable_name = element.text
1014 return variable_name
1015 else:
1016 # Handling any other tag
1017 try:
1018 variable_name = element.text
1019 return variable_name
1020 except:
1021 return ""
1023 for elem in root.iter():
1024 if elem.tag in ["mi", "msup", "msub", "msubsup"]:
1025 variables.append(process_math_element(elem))
1026 result_list = list(set(variables))
1027 result_list = [item for item in result_list if item not in mathml_operators]
1028 return result_list # Returning unique variable names
1031def format_subscripts_and_superscripts(latex_str: str) -> str:
1032 # Function to format subscripts and superscripts in a LaTeX string
1033 # Returns a list of unique variable names
1034 def replace_sub(match):
1035 return f"{match.group(1)}_{{{match.group(2)}}}"
1037 def replace_sup(match):
1038 superscript = match.group(2)
1039 return f"{match.group(1)}^{{{superscript}}}"
1041 pattern_sub = r"(\S+)_(\S+)"
1042 pattern_sup = r"(\S+)\^(\S+)"
1044 formatted_str = re.sub(pattern_sup, replace_sup, latex_str)
1045 formatted_str = re.sub(pattern_sub, replace_sub, formatted_str)
1047 return formatted_str
1050def replace_greek_with_unicode(input_str):
1051 # Function to replace Greek letters and their names with Unicode
1052 # Returns the replaced string if replacements were made, otherwise an empty string
1053 replaced_str = input_str
1054 for gl in greek_letters:
1055 replaced_str = replaced_str.replace(gl[0], gl[2])
1056 replaced_str = replaced_str.replace(gl[1], gl[2])
1057 return replaced_str if replaced_str != input_str else ""
1060def replace_unicode_with_symbol(input_str):
1061 # Function to replace Unicode representations with corresponding symbols
1062 # Returns the replaced string if replacements were made, otherwise an empty string
1063 pattern = r"&#x[A-Fa-f0-9]+;"
1064 matches = re.findall(pattern, input_str)
1066 replaced_str = input_str
1067 for match in matches:
1068 unicode_char = html.unescape(match)
1069 replaced_str = replaced_str.replace(match, unicode_char)
1071 return replaced_str if replaced_str != input_str else ""
1074def transform_variable(variable: str) -> List[Union[str, List[str]]]:
1075 # Function to transform a variable into a list containing different representations
1076 # Returns a list containing various representations of the variable
1077 if variable.startswith("&#x"):
1078 for gl in greek_letters:
1079 if variable in gl:
1080 return gl
1081 return [html.unescape(variable), variable]
1082 elif variable.isalpha():
1083 if len(variable) == 1:
1084 for gl in greek_letters:
1085 if variable in gl:
1086 return gl
1087 return [variable, "&#x{:04X};".format(ord(variable))]
1088 else:
1089 return [variable]
1090 else:
1091 if len(variable) == 1:
1092 return [variable, "&#x{:04X};".format(ord(variable))]
1093 else:
1094 variable_list = [variable, format_subscripts_and_superscripts(variable)]
1095 if replace_greek_with_unicode(variable) != "":
1096 variable_list.append(replace_greek_with_unicode(variable))
1097 variable_list.append(
1098 replace_greek_with_unicode(
1099 format_subscripts_and_superscripts(variable)
1100 )
1101 )
1102 if replace_unicode_with_symbol(variable) != "":
1103 variable_list.append(replace_unicode_with_symbol(variable))
1104 variable_list.append(format_subscripts_and_superscripts(variable))
1106 return variable_list
1109def create_variable_dictionary(
1110 variables: List[str],
1111) -> Dict[str, List[Union[str, List[str]]]]:
1112 # Function to create a dictionary mapping variables to their representations
1113 # Returns a dictionary with variables as keys and their representations as values
1114 variable_dict = {}
1115 for variable in variables:
1116 variable_dict[variable] = transform_variable(variable)
1117 return variable_dict
1120def generate_variable_dict(mathml_string):
1121 # Function to generate a variable dictionary from MathML
1122 try:
1123 variables = extract_variables_with_subsup(mathml_string)
1124 variable_dict = create_variable_dictionary(variables)
1125 return variable_dict
1126 except:
1127 return {}
1129def convert_to_dict(obj: Any) -> Union[List[Any], Dict[Any, Any], np.ndarray, Any]:
1130 """
1131 Recursively converts an object to a dictionary, handling numpy arrays and nested structures.
1133 Args:
1134 obj (Any): The input object to be converted.
1136 Returns:
1137 Union[List[Any], Dict[Any, Any], np.ndarray, Any]: The converted object in dictionary form.
1139 Note:
1140 This function recursively processes the input object and converts it to a dictionary.
1141 - If the input is a numpy array, it is converted to a Python list.
1142 - If the input is a list or tuple, each element is processed recursively.
1143 - If the input is a dictionary, each value is processed recursively.
1144 - Other types are returned as is.
1145 """
1146 if isinstance(obj, np.ndarray):
1147 return obj.tolist()
1148 elif isinstance(obj, (list, tuple)):
1149 return [convert_to_dict(item) for item in obj]
1150 elif isinstance(obj, dict):
1151 return {key: convert_to_dict(value) for key, value in obj.items()}
1152 else:
1153 return obj