Coverage for skema/program_analysis/CAST2FN/ann_cast/id_collapse_pass.py: 93%
158 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1from re import A
2import typing
3from collections import defaultdict
4from functools import singledispatchmethod
6from skema.program_analysis.CAST2FN.ann_cast.ann_cast_helpers import (
7 call_container_name,
8)
9from skema.program_analysis.CAST2FN.ann_cast.annotated_cast import *
10from skema.program_analysis.CAST2FN.model.cast import (
11 ScalarType,
12 StructureType,
13 ValueConstructor,
14)
17class IdCollapsePass:
18 def __init__(self, pipeline_state: PipelineState):
19 self.pipeline_state = pipeline_state
20 # cache Call nodes so after visiting we can determine which Call's have associated
21 # FunctionDefs
22 # this dict maps call container name to the AnnCastCall node
23 self.cached_call_nodes: typing.Dict[str, AnnCastCall] = {}
24 # during the pass, we collpase Name ids to a range starting from zero
25 self.old_id_to_collapsed_id = {}
26 # this tracks what collapsed ids we have used so far
27 self.collapsed_id_counter = 0
28 # dict mapping collapsed function id to number of invocations
29 # used to populate `invocation_index` of AnnCastCall nodes
30 self.func_invocation_counter = defaultdict(int)
31 for node in self.pipeline_state.nodes:
32 at_module_scope = False
33 self.visit(node, at_module_scope)
34 self.nodes = self.pipeline_state.nodes
35 self.determine_function_defs_for_calls()
36 self.store_highest_id()
38 def store_highest_id(self):
39 self.pipeline_state.collapsed_id_counter = self.collapsed_id_counter
41 def collapse_id(self, id: int) -> int:
42 """
43 Returns the collapsed id for id if it already exists,
44 otherwise creates a collapsed id for it
45 """
46 if id not in self.old_id_to_collapsed_id:
47 self.old_id_to_collapsed_id[id] = self.collapsed_id_counter
48 self.collapsed_id_counter += 1
50 return self.old_id_to_collapsed_id[id]
52 def next_function_invocation(self, coll_func_id: int) -> int:
53 """
54 Returns the next invocation index for function with collapsed id `coll_func_id`
55 """
56 index = self.func_invocation_counter[coll_func_id]
57 self.func_invocation_counter[coll_func_id] += 1
59 return index
61 def determine_function_defs_for_calls(self):
62 for call_name, call in self.cached_call_nodes.items():
63 if isinstance(call.func, AnnCastAttribute):
64 func_id = call.func.attr.id
65 else:
66 func_id = call.func.id
67 call.has_func_def = self.pipeline_state.func_def_exists(func_id)
69 # DEBUG printing
70 if self.pipeline_state.PRINT_DEBUGGING_INFO:
71 print(f"{call_name} has FunctionDef: {call.has_func_def}")
73 def visit(self, node: AnnCastNode, at_module_scope):
74 # print current node being visited.
75 # this can be useful for debugging
76 # class_name = node.__class__.__name__
77 # print(f"\nProcessing node type {class_name}")
78 try:
79 return self._visit(node, at_module_scope)
80 except Exception as e:
81 print(
82 f"id_collapse_pass.py: Error for {type(node)} which has source ref information {node.source_refs}"
83 )
84 raise e
86 def visit_node_list(
87 self, node_list: typing.List[AnnCastNode], at_module_scope
88 ):
89 return [self.visit(node, at_module_scope) for node in node_list]
91 @singledispatchmethod
92 def _visit(self, node: AnnCastNode, at_module_scope):
93 """
94 Visit each AnnCastNode, collapsing AnnCastName ids along the way
95 """
96 print(node.source_refs[0])
97 raise Exception(f"Unimplemented AST node of type: {type(node)}")
99 @_visit.register
100 def visit_assignment(self, node: AnnCastAssignment, at_module_scope):
101 self.visit(node.right, at_module_scope)
102 # The AnnCastTuple is added to handle scenarios where an assignment
103 # is made by assigning to a tuple of values, as opposed to one singular value
104 assert (
105 isinstance(node.left, AnnCastVar)
106 or (isinstance(node.left, AnnCastLiteralValue) and (node.left.value_type == StructureType.TUPLE))
107 or isinstance(node.left, AnnCastAttribute) or isinstance(node.left, AnnCastCall)
108 ), f"id_collapse: visit_assigment: node.left is {type(node.left)}"
109 self.visit(node.left, at_module_scope)
111 @_visit.register
112 def visit_attribute(self, node: AnnCastAttribute, at_module_scope):
113 value = self.visit(node.value, at_module_scope)
114 attr = self.visit(node.attr, at_module_scope)
116 @_visit.register
117 def visit_call(self, node: AnnCastCall, at_module_scope):
118 if isinstance(node.func, AnnCastLiteralValue):
119 return
121 assert isinstance(node.func, AnnCastName) or isinstance(
122 node.func, AnnCastAttribute
123 ), f"node.func is type f{type(node.func)}"
124 if isinstance(node.func, AnnCastName):
125 node.func.id = self.collapse_id(node.func.id)
126 node.invocation_index = self.next_function_invocation(node.func.id)
127 else:
128 if isinstance(node.func.value, AnnCastCall):
129 self.visit(node.func.value, at_module_scope)
130 elif isinstance(node.func.value, AnnCastAttribute):
131 self.visit(node.func.value, at_module_scope)
132 #elif isinstance(node.func.value, AnnCastSubscript):
133 # self.visit(node.func.value, at_module_scope)
134 elif isinstance(node.func.value, AnnCastOperator):
135 self.visit(node.func.value, at_module_scope)
136 elif isinstance(node.func.value, AnnCastAssignment):
137 self.visit(node.func.value, at_module_scope)
138 else:
139 if not isinstance(node.func.value, AnnCastLiteralValue):
140 node.func.value.id = self.collapse_id(node.func.value.id)
141 node.func.attr.id = self.collapse_id(node.func.attr.id)
142 node.invocation_index = self.next_function_invocation(
143 node.func.attr.id
144 )
146 # cache Call node to later determine if this Call has a FunctionDef
147 call_name = call_container_name(node)
148 self.cached_call_nodes[call_name] = node
150 self.visit_node_list(node.arguments, at_module_scope)
152 @_visit.register
153 def visit_record_def(self, node: AnnCastRecordDef, at_module_scope):
154 at_module_scope = False
156 # Each base should be an AnnCastName node
157 self.visit_node_list(node.bases, at_module_scope)
159 # Each func is an AnnCastFuncDef node
160 self.visit_node_list(node.funcs, at_module_scope)
162 # Each field (attribute) is an AnnCastVar node
163 self.visit_node_list(node.fields, at_module_scope)
165 @_visit.register
166 def visit_function_def(self, node: AnnCastFunctionDef, at_module_scope):
167 # collapse the function id
168 node.name.id = self.collapse_id(node.name.id)
169 self.pipeline_state.func_id_to_def[node.name.id] = node
171 at_module_scope = False
172 self.visit_node_list(node.func_args, at_module_scope)
173 self.visit_node_list(node.body, at_module_scope)
175 @_visit.register
176 def visit_goto(self, node: AnnCastGoto, at_module_scope):
177 if node.expr != None:
178 self.visit(node.expr, at_module_scope)
179 # self.visit(node.label, at_module_scope)
181 @_visit.register
182 def visit_label(self, node: AnnCastLabel, at_module_scope):
183 # self.visit(node.label, at_module_scope)
184 pass
186 @_visit.register
187 def visit_literal_value(self, node: AnnCastLiteralValue, at_module_scope):
188 if node.value_type == "List[Any]":
189 # operator - string
190 # size - Var node or a LiteralValue node (for number)
191 # initial_value - LiteralValue node
192 val = node.value
193 self.visit(val.size, at_module_scope)
195 # List literal doesn't need to add any other changes
196 # to the anncast at this pass
198 elif node.value_type == StructureType.TUPLE: # or node.value_type == StructureType.LIST:
199 self.visit_node_list(node.value, at_module_scope)
200 elif node.value_type == ScalarType.INTEGER:
201 pass
202 elif node.value_type == ScalarType.ABSTRACTFLOAT:
203 pass
204 pass
206 @_visit.register
207 def visit_loop(self, node: AnnCastLoop, at_module_scope):
208 self.visit_node_list(node.pre, at_module_scope)
209 self.visit(node.expr, at_module_scope)
210 self.visit_node_list(node.body, at_module_scope)
211 self.visit_node_list(node.post, at_module_scope)
213 @_visit.register
214 def visit_model_break(self, node: AnnCastModelBreak, at_module_scope):
215 pass
217 @_visit.register
218 def visit_model_continue(
219 self, node: AnnCastModelContinue, at_module_scope
220 ):
221 pass
223 @_visit.register
224 def visit_model_if(self, node: AnnCastModelIf, at_module_scope):
225 self.visit(node.expr, at_module_scope)
226 self.visit_node_list(node.body, at_module_scope)
227 self.visit_node_list(node.orelse, at_module_scope)
229 @_visit.register
230 def visit_return(self, node: AnnCastModelReturn, at_module_scope):
231 self.visit(node.value, at_module_scope)
233 @_visit.register
234 def visit_model_import(self, node: AnnCastModelImport, at_module_scope):
235 pass
237 @_visit.register
238 def visit_module(self, node: AnnCastModule, at_module_scope):
239 # we cache the module node in the AnnCast object
240 self.pipeline_state.module_node = node
241 at_module_scope = True
242 self.visit_node_list(node.body, at_module_scope)
244 @_visit.register
245 def visit_name(self, node: AnnCastName, at_module_scope):
246 node.id = self.collapse_id(node.id)
248 # we consider name nodes at the module scope to be globals
249 # and store them in the `used_vars` attribute of the module_node
250 if at_module_scope:
251 self.pipeline_state.module_node.used_vars[node.id] = node.name
253 @_visit.register
254 def visit_operator(self, node: AnnCastOperator, at_module_scope):
255 # visit operands
256 self.visit_node_list(node.operands, at_module_scope)
258 @_visit.register
259 def visit_var(self, node: AnnCastVar, at_module_scope):
260 self.visit(node.val, at_module_scope)
261 if node.default_value != None:
262 self.visit(node.default_value, at_module_scope)
264 @_visit.register
265 def visit_tuple(self, node: AnnCastTuple, at_module_scope):
266 # Tuple of vars: Visit them all to collapse IDs, nothing else to be done I think
267 self.visit_node_list(node.values, at_module_scope)