Coverage for skema/program_analysis/CAST2FN/cast.py: 29%
127 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import ast
2import json
3import difflib
4import typing
5import networkx as nx
7from skema.program_analysis.CAST2FN.model.cast import (
8 AstNode,
9 Assignment,
10 Attribute,
11 Call,
12 CASTLiteralValue,
13 FunctionDef,
14 Goto,
15 Label,
16 Loop,
17 ModelBreak,
18 ModelContinue,
19 ModelIf,
20 ModelImport,
21 ModelReturn,
22 Module,
23 Name,
24 Operator,
25 RecordDef,
26 ScalarType,
27 StructureType,
28 SourceRef,
29 VarType,
30 Var,
31 ValueConstructor,
32)
33# from skema.program_analysis.CAST2FN.visitors import (
34# CASTToAIRVisitor,
35#)
36from skema.model_assembly.air import AutoMATES_IR
37from skema.model_assembly.networks import GroundedFunctionNetwork
38from skema.model_assembly.structures import (
39 GenericContainer,
40 GenericStmt,
41 GenericIdentifier,
42 GenericDefinition,
43 TypeDefinition,
44 VariableDefinition,
45)
47CAST_NODES_TYPES_LIST = [
48 AstNode,
49 Assignment,
50 Attribute,
51 Call,
52 CASTLiteralValue,
53 FunctionDef,
54 Goto,
55 Label,
56 Loop,
57 ModelBreak,
58 ModelContinue,
59 ModelIf,
60 ModelImport,
61 ModelReturn,
62 Module,
63 Name,
64 Operator,
65 RecordDef,
66 ScalarType,
67 StructureType,
68 SourceRef,
69 VarType,
70 Var,
71 ValueConstructor,
72]
75def compare_name_nodes(name1: Name, name2: Name) -> bool:
76 """
77 Checks if two Name nodes are equal, by only looking at
78 their `name` fields. The `__eq__` method on `Name` nodes generated
79 by Swagger also checks the `id` attribute, which we do not expect
80 to be consistent across CAST generations, since it is a UUID
81 """
82 if not (isinstance(name1, Name) and isinstance(name2, Name)):
83 return False
84 return name1.name == name2.name
87class CASTJsonException(Exception):
88 """
89 Class used to represent exceptions encountered when encoding/decoding CAST json
90 """
92 pass
95class CAST(object):
96 """
97 Represents the Common Abstract Syntax Tree (CAST) that will be used to generically represent
98 any languages AST.
99 """
101 nodes: typing.List[AstNode]
102 cast_source_language: str
104 def __init__(self, nodes: typing.List[AstNode], cast_source_language: str):
105 self.nodes = nodes
106 self.cast_source_language = cast_source_language
108 def __eq__(self, other):
109 """
110 For equality, the two CAST objects must have the same node data.
111 When checking each node, we allow a custom node comparison function.
112 Currently, the only case where we do a custom comparison is for Name nodes.
113 For all other nodes we use their Swagger generated `__eq__` method
114 """
115 if len(self.nodes) != len(other.nodes):
116 return False
118 for i, node in enumerate(self.nodes):
119 other_node = other.nodes[i]
120 if isinstance(node, Name):
121 comparator = compare_name_nodes
122 else:
123 comparator = lambda n1, n2: n1 == n2
125 if not comparator(node, other_node):
126 # printing diff to help locating difference
127 print(f"CAST __eq__ failed:")
128 self_lines = str(node).splitlines()
129 other_lines = str(other_node).splitlines()
130 for i, diff in enumerate(
131 difflib.ndiff(self_lines, other_lines)
132 ):
133 if diff[0] == " ":
134 continue
135 print(f"Line {i}: {diff}")
136 return False
138 return True
140 def to_AGraph(self):
141 G = nx.DiGraph()
142 for node in self.nodes:
143 print("node", node)
144 print("type", type(node))
145 for ast_node in ast.walk(node.body):
146 for child_node in ast_node.children:
147 G.add_edge(ast_node, child_node)
149 import skema.utils.misc.test_pygraphviz
151 test_pygraphviz(
152 "The to_AGraph method requires the pygraphviz package to be installed!"
153 )
155 A = nx.nx_agraph.to_agraph(G)
156 A.graph_attr.update(
157 {"dpi": 227, "fontsize": 20, "fontname": "Menlo", "rankdir": "TB"}
158 )
159 A.node_attr.update({"fontname": "Menlo"})
160 return A
162 def to_air_dict(self):
163 c2a_visitor = CASTToAIRVisitor(self.nodes, self.cast_source_language)
164 air = c2a_visitor.to_air()
166 main_container = [
167 c["name"]
168 for c in air["containers"]
169 if c["name"].endswith("::main")
170 ]
172 called_containers = [
173 s["function"]["name"]
174 for c in air["containers"]
175 for s in c["body"]
176 if s["function"]["type"] == "container"
177 ]
178 root_containers = [
179 c["name"]
180 for c in air["containers"]
181 if c["name"] not in called_containers
182 ]
184 container_id_to_start_from = None
185 if len(main_container) > 0:
186 container_id_to_start_from = main_container[0]
187 elif len(root_containers) > 0:
188 container_id_to_start_from = root_containers[0]
189 else:
190 # TODO
191 raise Exception(
192 "Error: Unable to find root container to build GrFN."
193 )
195 air["entrypoint"] = container_id_to_start_from
197 return air
199 def to_AIR(self):
200 air = self.to_air_dict()
202 C, V, T, D = dict(), dict(), dict(), dict()
204 # Create variable definitions
205 for var_data in air["variables"]:
206 new_var = VariableDefinition.from_data(var_data)
207 V[new_var.identifier] = new_var
209 # Create type definitions
210 for type_data in air["types"]:
211 new_type = TypeDefinition.from_dict(type_data)
212 T[new_type.identifier] = new_type
214 # Create container definitions
215 for con_data in air["containers"]:
216 new_container = GenericContainer.from_dict(con_data)
217 for in_var in new_container.arguments:
218 if in_var not in V:
219 V[in_var] = VariableDefinition.from_identifier(in_var)
220 C[new_container.identifier] = new_container
222 return AutoMATES_IR(
223 GenericIdentifier.from_str(air["entrypoint"]), C, V, T, [], [], []
224 )
226 def to_GrFN(self):
227 air = self.to_AIR()
228 grfn = GroundedFunctionNetwork.from_AIR(air)
229 return grfn
231 def write_cast_object(self, cast_value):
232 if isinstance(cast_value, list):
233 return [self.write_cast_object(val) for val in cast_value]
234 elif not isinstance(cast_value, AstNode) and not isinstance(
235 cast_value, SourceRef
236 ):
237 return cast_value
239 return dict(
240 {
241 attr: self.write_cast_object(getattr(cast_value, attr))
242 for attr in cast_value.attribute_map.keys()
243 },
244 **{"node_type": type(cast_value).__name__},
245 )
247 def to_json_object(self):
248 """
249 Returns a json object of the CAST
250 """
251 return {"nodes": [self.write_cast_object(n) for n in self.nodes]}
253 def to_json_str(self):
254 """
255 Returns a json string of the CAST
256 """
257 return json.dumps(
258 self.to_json_object(),
259 sort_keys=True,
260 indent=4,
261 )
263 @classmethod
264 def parse_cast_json(cls, data):
265 if isinstance(data, list):
266 # If we see a list parse each one of its elements
267 return [cls.parse_cast_json(item) for item in data]
268 elif data is None:
269 return None
270 elif isinstance(data, (float, int, str, bool)):
271 # If we see a primitave type, simply return its value
272 return data
273 elif all(
274 k in data for k in ("row_start", "row_end", "col_start", "col_end")
275 ):
276 return SourceRef(
277 row_start=data["row_start"],
278 row_end=data["row_end"],
279 col_start=data["col_start"],
280 col_end=data["col_end"],
281 source_file_name=data["source_file_name"],
282 )
284 if "node_type" in data:
285 # Create the object specified by "node_type" object with the values
286 # from its children nodes
287 for node_type in CAST_NODES_TYPES_LIST:
289 if node_type.__name__ == data["node_type"]:
290 node_results = {
291 k: cls.parse_cast_json(v)
292 for k, v in data.items()
293 if k != "node_type"
294 }
295 return node_type(**node_results)
297 raise CASTJsonException(
298 f"Unable to decode json CAST field with field names: {set(data.keys())}"
299 )
301 @classmethod
302 def from_json_data(cls, json_data, cast_source_language="unknown"):
303 """
304 Parses json CAST data object and returns the created CAST object
306 Args:
307 data: JSON object with a "nodes" field containing a
308 list of the top level nodes
310 Returns:
311 CAST: The parsed CAST object.
312 """
313 nodes = cls.parse_cast_json(json_data["nodes"])
314 return cls(nodes, cast_source_language)
316 @classmethod
317 def from_json_file(cls, json_filepath):
318 """
319 Loads json CAST data from a file and returns the created CAST object
321 Args:
322 json_filepath: string of a full filepath to a JSON file
323 representing a CAST with a `nodes` field
325 Returns:
326 CAST: The parsed CAST object.
327 """
328 return cls.from_json_data(json.load(open(json_filepath, "r")))
330 @classmethod
331 def from_json_str(cls, json_str):
332 """
333 Parses json CAST string and returns the created CAST object
335 Args:
336 json_str: JSON string representing a CAST with a "nodes" field
337 containing a list of the top level nodes
339 Raises:
340 CASTJsonException: If we encounter an unknown CAST node
342 Returns:
343 CAST: The parsed CAST object.
344 """
345 return cls.from_json_data(json.loads(json_str))
347 @classmethod
348 def from_python_ast(cls):
349 pass