Coverage for skema/program_analysis/CAST2FN/ann_cast/grfn_assignment_pass.py: 77%
161 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import typing
2from functools import singledispatchmethod
4from skema.model_assembly.metadata import LambdaType
5from skema.program_analysis.CAST2FN.ann_cast.ann_cast_helpers import (
6 GrfnAssignment,
7 ann_cast_name_to_fullid,
8 create_grfn_assign_node,
9 create_grfn_literal_node,
10 create_grfn_pack_node,
11 create_grfn_unpack_node,
12 create_lambda_node_metadata,
13 is_literal_assignment,
14)
15from skema.program_analysis.CAST2FN.ann_cast.annotated_cast import *
16from skema.program_analysis.CAST2FN.model.cast import (
17 ScalarType,
18 StructureType,
19 ValueConstructor,
20)
23class GrfnAssignmentPass:
24 def __init__(self, pipeline_state: PipelineState):
25 self.pipeline_state = pipeline_state
26 self.nodes = self.pipeline_state.nodes
27 # Any other state variables that are needed during
28 # the pass
29 for node in self.pipeline_state.nodes:
30 add_to = {}
31 self.visit(node, add_to)
33 def visit(self, node: AnnCastNode, add_to: typing.Dict):
34 """
35 `add_to` is either the input or outputs to an GrFN Assignment/Literal node
36 When visiting variable nodes, we add the variable to this `add_to` dict.
37 When visiting call nodes, we add the function return value to this `add_to` dict.
38 """
39 # print current node being visited.
40 # this can be useful for debugging
41 # class_name = node.__class__.__name__
42 # print(f"\nProcessing node type {class_name}")
44 # call internal visit
45 return self._visit(node, add_to)
47 def visit_node_list(
48 self, node_list: typing.List[AnnCastNode], add_to: typing.Dict
49 ):
50 return [self.visit(node, add_to) for node in node_list]
52 @singledispatchmethod
53 def _visit(self, node: AnnCastNode, add_to: typing.Dict):
54 """
55 `add_to` is either the input or outputs to an GrFN Assignment/Literal node
56 When visiting variable nodes, we add the variable to this `add_to` dict.
57 When visiting call nodes, we add the function return value to this `add_to` dict.
58 """
59 raise NameError(f"Unrecognized node type: {type(node)}")
61 @_visit.register
62 def visit_assignment(self, node: AnnCastAssignment, add_to: typing.Dict):
63 # create the LambdaNode
64 metadata = create_lambda_node_metadata(node.source_refs)
65 if is_literal_assignment(node.right):
66 node.grfn_assignment = GrfnAssignment(
67 create_grfn_literal_node(metadata), LambdaType.LITERAL
68 )
69 elif isinstance(node.left, AnnCastTuple):
70 node.grfn_assignment = GrfnAssignment(
71 create_grfn_unpack_node(metadata), LambdaType.UNPACK
72 )
73 else:
74 node.grfn_assignment = GrfnAssignment(
75 create_grfn_assign_node(metadata), LambdaType.ASSIGN
76 )
78 self.visit(node.right, node.grfn_assignment.inputs)
79 # The AnnCastTuple is added to handle scenarios where an assignment
80 # is made by assigning to a tuple of values, as opposed to one singular value
81 assert (
82 isinstance(node.left, AnnCastVar)
83 or (isinstance(node.left, AnnCastLiteralValue) and (node.left.value_type == StructureType.TUPLE))
84 or isinstance(node.left, AnnCastAttribute)
85 ), f"container_scope: visit_assigment: node.left is not AnnCastVar or AnnCastTuple it is {type(node.left)}"
86 self.visit(node.left, node.grfn_assignment.outputs)
88 # DEBUG printing
89 if self.pipeline_state.PRINT_DEBUGGING_INFO:
90 print(
91 f"GrFN {node.grfn_assignment.assignment_type} after visiting children:"
92 )
93 print(
94 f" grfn_assignment.inputs: {node.grfn_assignment.inputs}"
95 )
96 print(
97 f" grfn_assignment.outputs: {node.grfn_assignment.outputs}"
98 )
100 @_visit.register
101 def visit_attribute(self, node: AnnCastAttribute, add_to: typing.Dict):
102 pass
104 @_visit.register
105 def visit_call(self, node: AnnCastCall, add_to: typing.Dict):
106 if node.is_grfn_2_2:
107 self.visit_call_grfn_2_2(node, add_to)
108 return
110 # add ret_val to add_to dict
111 for id, fullid in node.out_ret_val.items():
112 grfn_var = self.pipeline_state.get_grfn_var(fullid)
113 add_to[fullid] = grfn_var.uid
115 # populate `arg_assignments` attribute of node
116 for i, n in enumerate(node.arguments):
117 # grab GrFN variable for argument
118 if i in node.arg_index_to_fullid.keys(): # NOTE: M7 Placeholder
119 arg_fullid = node.arg_index_to_fullid[i]
120 arg_grfn_var = self.pipeline_state.get_grfn_var(arg_fullid)
122 # create GrfnAssignment based on assignment type
123 metadata = create_lambda_node_metadata(node.source_refs)
124 if is_literal_assignment(n):
125 arg_assignment = GrfnAssignment(
126 create_grfn_literal_node(metadata), LambdaType.LITERAL
127 )
128 else:
129 arg_assignment = GrfnAssignment(
130 create_grfn_assign_node(metadata), LambdaType.ASSIGN
131 )
133 # store argument as output to GrfnAssignment
134 arg_assignment.outputs[arg_fullid] = arg_grfn_var.uid
135 # populate GrfnAssignment inputs for arguments
136 self.visit(n, arg_assignment.inputs)
137 # store GrfnAssignment for this argument
138 node.arg_assignments[i] = arg_assignment
140 # DEBUG printing
141 if self.pipeline_state.PRINT_DEBUGGING_INFO:
142 print(f"Call after processing arguments:")
143 for pos, grfn_assignment in node.arg_assignments.items():
144 print(f" {pos} : {str(grfn_assignment)}")
146 def visit_call_grfn_2_2(self, node: AnnCastCall, add_to: typing.Dict):
147 assert isinstance(node.func, AnnCastName)
148 # add ret_val to add_to dict
149 for id, fullid in node.out_ret_val.items():
150 grfn_var = self.pipeline_state.get_grfn_var(fullid)
151 add_to[fullid] = grfn_var.uid
153 # populate `arg_assignments` attribute of node
154 for i, n in enumerate(node.arguments):
155 # grab GrFN variable for argument
156 arg_fullid = node.arg_index_to_fullid[i]
157 arg_grfn_var = self.pipeline_state.get_grfn_var(arg_fullid)
159 # create GrfnAssignment based on assignment type
160 metadata = create_lambda_node_metadata(node.source_refs)
161 if is_literal_assignment(n):
162 arg_assignment = GrfnAssignment(
163 create_grfn_literal_node(metadata), LambdaType.LITERAL
164 )
165 else:
166 arg_assignment = GrfnAssignment(
167 create_grfn_assign_node(metadata), LambdaType.ASSIGN
168 )
170 # store argument as output to GrfnAssignment
171 arg_assignment.outputs[arg_fullid] = arg_grfn_var.uid
172 # populate GrfnAssignment inputs for arguments
173 self.visit(n, arg_assignment.inputs)
174 # store GrfnAssignment for this argument
175 node.arg_assignments[i] = arg_assignment
177 self.visit_function_def(node.func_def_copy, {})
179 # DEBUG printing
180 if self.pipeline_state.PRINT_DEBUGGING_INFO:
181 print(f"Call after processing arguments:")
182 for pos, grfn_assignment in node.arg_assignments.items():
183 print(f" {pos} : {str(grfn_assignment)}")
185 @_visit.register
186 def visit_record_def(self, node: AnnCastRecordDef, add_to: typing.Dict):
187 pass
189 @_visit.register
190 def visit_function_def(
191 self, node: AnnCastFunctionDef, add_to: typing.Dict
192 ):
193 # linking function arguments to formal parameters through the top
194 # interface is handled during VariableVersionPass, so we don't need to visit
195 # func_args here
196 self.visit_node_list(node.body, add_to)
198 @_visit.register
199 def visit_goto(self, node: AnnCastGoto, add_to):
200 if node.expr != None:
201 self.visit(node.expr, add_to)
202 # self.visit(node.label, add_to)
204 @_visit.register
205 def visit_label(self, node: AnnCastLabel, add_to):
206 # self.visit(node.label, add_to)
207 pass
209 @_visit.register
210 def visit_literal_value(
211 self, node: AnnCastLiteralValue, add_to: typing.Dict
212 ):
213 if node.value_type == "List[Any]":
214 # val has
215 # operator - string
216 # size - Var node or a LiteralValue node (for number)
217 # initial_value - LiteralValue node
218 val = node.value
220 # visit size's anncast name node
221 self.visit(val.size, add_to)
223 # List literal doesn't need to add any other changes
224 # to the anncast at this pass
226 elif node.value_type == StructureType.TUPLE: # or node.value_type == StructureType.LIST:
227 self.visit_node_list(node.value, add_to)
228 elif node.value_type == ScalarType.INTEGER:
229 pass
230 elif node.value_type == ScalarType.ABSTRACTFLOAT:
231 pass
232 pass
234 @_visit.register
235 def visit_loop(self, node: AnnCastLoop, add_to: typing.Dict):
236 self.visit_node_list(node.pre, add_to)
237 self.visit(node.expr, add_to)
238 self.visit_node_list(node.body, add_to)
239 self.visit_node_list(node.post, add_to)
241 @_visit.register
242 def visit_model_break(self, node: AnnCastModelBreak, add_to: typing.Dict):
243 pass
245 @_visit.register
246 def visit_model_continue(
247 self, node: AnnCastModelContinue, add_to: typing.Dict
248 ):
249 pass
251 @_visit.register
252 def visit_model_import(
253 self, node: AnnCastModelImport, add_to: typing.Dict
254 ):
255 pass
257 @_visit.register
258 def visit_model_if(self, node: AnnCastModelIf, add_to: typing.Dict):
259 self.visit(node.expr, add_to)
260 self.visit_node_list(node.body, add_to)
261 self.visit_node_list(node.orelse, add_to)
263 @_visit.register
264 def visit_model_return(
265 self, node: AnnCastModelReturn, add_to: typing.Dict
266 ):
267 # create the assignment LambdaNode for this return statement
268 metadata = create_lambda_node_metadata(node.source_refs)
269 if is_literal_assignment(node.value):
270 node.grfn_assignment = GrfnAssignment(
271 create_grfn_literal_node(metadata), LambdaType.LITERAL
272 )
273 elif isinstance(node.value, AnnCastTuple):
274 node.grfn_assignment = GrfnAssignment(
275 create_grfn_pack_node(metadata), LambdaType.PACK
276 )
277 else:
278 node.grfn_assignment = GrfnAssignment(
279 create_grfn_assign_node(metadata), LambdaType.ASSIGN
280 )
282 self.visit(node.value, node.grfn_assignment.inputs)
284 for id, fullid in node.owning_func_def.in_ret_val.items():
285 grfn_var = self.pipeline_state.get_grfn_var(fullid)
286 node.grfn_assignment.outputs[fullid] = grfn_var.uid
288 # DEBUG printing
289 if self.pipeline_state.PRINT_DEBUGGING_INFO:
290 print(
291 f"GrFN RETURN with type {node.grfn_assignment.assignment_type} after visiting children:"
292 )
293 print(
294 f" grfn_assignment.inputs: {node.grfn_assignment.inputs}"
295 )
296 print(
297 f" grfn_assignment.outputs: {node.grfn_assignment.outputs}"
298 )
300 @_visit.register
301 def visit_module(self, node: AnnCastModule, add_to: typing.Dict):
302 add_to = {}
303 self.visit_node_list(node.body, add_to)
305 @_visit.register
306 def visit_name(self, node: AnnCastName, add_to: typing.Dict):
307 fullid = ann_cast_name_to_fullid(node)
309 # store fullid/grfn id in add_to
310 add_to[fullid] = node.grfn_id
312 @_visit.register
313 def visit_operator(self, node: AnnCastOperator, add_to: typing.Dict):
314 # visit operands
315 self.visit_node_list(node.operands, add_to)
317 @_visit.register
318 def visit_set(self, node: AnnCastSet, add_to: typing.Dict):
319 pass
321 @_visit.register
322 def visit_tuple(self, node: AnnCastTuple, add_to: typing.Dict):
323 self.visit_node_list(node.values, add_to)
325 @_visit.register
326 def visit_var(self, node: AnnCastVar, add_to: typing.Dict):
327 self.visit(node.val, add_to)