Coverage for skema/program_analysis/CAST2FN/ann_cast/lambda_expression_pass.py: 71%
272 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import typing
2from functools import singledispatchmethod
4from skema.program_analysis.CAST2FN.ann_cast.ann_cast_helpers import (
5 ELSEBODY,
6 IFBODY,
7 GrfnAssignment,
8 ann_cast_name_to_fullid,
9 cast_op_to_str,
10 lambda_var_from_fullid,
11)
12from skema.program_analysis.CAST2FN.ann_cast.annotated_cast import *
13from skema.program_analysis.CAST2FN.model.cast import (
14 ScalarType,
15 StructureType,
16 ValueConstructor,
17)
20def lambda_for_grfn_assignment(
21 grfn_assignment: GrfnAssignment, lambda_body: str
22) -> str:
23 var_names = map(lambda_var_from_fullid, grfn_assignment.inputs.keys())
25 param_str = ", ".join(var_names)
26 lambda_expr = f"lambda {param_str}: {lambda_body}"
28 return lambda_expr
31def lambda_for_condition(condition_in: typing.Dict, lambda_body: str) -> str:
32 var_names = map(lambda_var_from_fullid, condition_in.values())
34 param_str = ", ".join(var_names)
35 lambda_expr = f"lambda {param_str}: {lambda_body}"
37 return lambda_expr
40def lambda_for_decision(
41 condition_fullid: str, decision_in: typing.Dict
42) -> str:
43 """
44 Lambdas for decision nodes chooses betweeen IFBODY and ELSEBODY variables from
45 interface_in based on condition_in
47 The lambda has for the form:
48 lambda COND, x_if, y_if, x_else, y_else: (x_if, y_if) if COND else (x_else, y_else)
49 """
50 if len(decision_in) == 0:
51 return f"lambda: None"
52 cond_name = lambda_var_from_fullid(condition_fullid)
54 lambda_body = ""
56 if_names = []
57 else_names = []
58 for dec in decision_in.values():
59 if_fullid = dec[IFBODY]
60 if_names.append(lambda_var_from_fullid(if_fullid) + "_if")
61 else_fullid = dec[ELSEBODY]
62 else_names.append(lambda_var_from_fullid(else_fullid) + "_else")
64 if_names_str = ", ".join(if_names)
65 else_names_str = ", ".join(else_names)
67 lambda_body = f"({if_names_str}) if {cond_name} else ({else_names_str})"
69 lambda_expr = (
70 f"lambda {cond_name}, {if_names_str}, {else_names_str}: {lambda_body}"
71 )
73 return lambda_expr
76def lambda_for_interface(interface_in: typing.Dict) -> str:
77 """
78 Lambdas for plain interface nodes are simply multi-parameter identity functions
79 """
80 if len(interface_in) == 0:
81 return "lambda: None"
83 var_names = map(lambda_var_from_fullid, interface_in.values())
84 param_str = ", ".join(var_names)
86 lambda_expr = f"lambda {param_str}: ({param_str})"
88 return lambda_expr
91def lambda_for_loop_top_interface(
92 top_interface_initial: typing.Dict, top_interface_updated: typing.Dict
93) -> str:
94 """
95 Lambda for loop top interface chooses between initial and updated version
96 of variables
98 LoopTopInterfaces are special LambdaNode's which store state on whether we have executed the
99 body of the loop at least once.
100 The returned lambda str has the form
101 lambda use_initial, x_init, y_init, x_update, y_update: (x_init, y_init) if use_initial else (x_update, y_update)
102 The `use_initial` value comes from the internal state of the LoopTopInterface during execution.
103 """
105 init_name = lambda fullid: lambda_var_from_fullid(fullid) + "_init"
106 init_names = map(init_name, top_interface_initial.values())
107 updt_name = lambda fullid: lambda_var_from_fullid(fullid) + "_update"
108 updt_names = map(updt_name, top_interface_updated.values())
110 # NOTE: the lengths of top_interface_initial and top_interface_updated may not be the same
111 # in some loops, you always use the initial value of a variable because it is never modified
112 # to model this, for those variables which have no updated version,
113 # we add the "init" variable to the "update" variable group of the lambda expression
114 non_updated_keys = set(top_interface_initial.keys()).difference(
115 top_interface_updated.keys()
116 )
117 non_updated_vars = {k: top_interface_initial[k] for k in non_updated_keys}
119 # use "init" var names for non updates variables
120 non_updt_names = map(init_name, non_updated_vars.values())
121 # extend returned updated names to include non updated variables
122 updt_names = list(updt_names)
123 return_updt_names = updt_names + list(non_updt_names)
125 # now, the lengths of init group and update group should match
126 assert len(return_updt_names) == len(top_interface_initial)
128 use_initial_str = "use_initial"
129 init_names_str = ", ".join(init_names)
130 updt_names_str = ", ".join(updt_names)
131 return_updt_names_str = ", ".join(return_updt_names)
133 lambda_body = f"({init_names_str}) if {use_initial_str} else ({return_updt_names_str})"
135 lambda_expr = f"lambda {use_initial_str}, {init_names_str}, {updt_names_str}: {lambda_body}"
137 return lambda_expr
140def lambda_for_loop_condition(condition_in, lambda_body):
141 var_names = map(lambda_var_from_fullid, condition_in.values())
143 param_str = ", ".join(var_names)
144 lambda_expr = f"lambda {param_str}: {lambda_body}"
146 return lambda_expr
149class LambdaExpressionPass:
150 def __init__(self, pipeline_state: PipelineState):
151 self.pipeline_state = pipeline_state
152 self.nodes = self.pipeline_state.nodes
153 # Any other state variables that are needed during
154 # the pass
155 for node in self.pipeline_state.nodes:
156 self.visit(node)
158 def visit(self, node: AnnCastNode) -> str:
159 """
160 External visit that calls the internal visit
161 Useful for debugging/development. For example,
162 printing the nodes that are visited
163 """
164 # print current node being visited.
165 # this can be useful for debugging
166 # class_name = node.__class__.__name__
167 # print(f"\nProcessing node type {class_name}")
169 # call internal visit
170 return self._visit(node)
172 def visit_node_list(
173 self, node_list: typing.List[AnnCastNode]
174 ) -> typing.List[str]:
175 return [self.visit(node) for node in node_list]
177 @singledispatchmethod
178 def _visit(self, node: AnnCastNode) -> str:
179 """
180 Internal visit
181 """
182 raise NameError(f"Unrecognized node type: {type(node)}")
184 @_visit.register
185 def visit_assignment(self, node: AnnCastAssignment) -> str:
186 right = self.visit(node.right)
187 # build the lambda expression for the assignment
188 # and store in GrfnAssignment
189 lambda_expr = lambda_for_grfn_assignment(node.grfn_assignment, right)
190 node.grfn_assignment.lambda_expr = lambda_expr
191 node.expr_str = lambda_expr
193 return node.expr_str
195 @_visit.register
196 def visit_attribute(self, node: AnnCastAttribute) -> str:
197 return node.expr_str
199 def visit_call_grfn_2_2(self, node: AnnCastCall):
200 # example for argument lambda expression
201 # Call: func(x + 3, y * 2)
202 # GrfnAssignment with index 0 corresponds to the assignment arg_0 = x + 3
203 # the lambda for this assigment looks like
204 # lambda x : x + 3
205 # for the lambda body, we need to visit the Call nodes arguments
206 for i, grfn_assignment in node.arg_assignments.items():
207 lambda_body = self.visit(node.arguments[i])
208 grfn_assignment.lambda_expr = lambda_for_grfn_assignment(
209 grfn_assignment, lambda_body
210 )
212 # top interface lambda
213 node.top_interface_lambda = lambda_for_interface(node.top_interface_in)
215 # build lamba expressions for function def copy body
216 body_expr = self.visit_function_def_copy(node.func_def_copy)
218 # bot interface lambda
219 node.bot_interface_lambda = lambda_for_interface(node.bot_interface_in)
221 # DEBUG printing
222 if self.pipeline_state.PRINT_DEBUGGING_INFO:
223 print(f"Call GrFN 2.2 {node.func.name}")
224 print(f"\t Args Expressions:")
225 for arg in node.arg_assignments.values():
226 print(f"\t\t{arg.lambda_expr}")
227 print(f"\t Top Interface:")
228 print(f"\t\t{node.top_interface_lambda}")
229 print(f"FunctionDefCopy {node.func_def_copy.name.name}")
230 print(f"\t Body Expressions:")
231 for e in body_expr:
232 print(f"\t\t{e}")
233 print(f"\t Bot Interface:")
234 print(f"\t\t{node.bot_interface_lambda}")
236 def visit_call_without_func_copy(self, node: AnnCastCall):
237 # example for argument lambda expression
238 # Call: func(x + 3, y * 2)
239 # GrfnAssignment with index 0 corresponds to the assignment arg_0 = x + 3
240 # the lambda for this assigment looks like
241 # lambda x : x + 3
242 # for the lambda body, we need to visit the Call nodes arguments
243 for i, grfn_assignment in node.arg_assignments.items():
244 lambda_body = self.visit(node.arguments[i])
245 grfn_assignment.lambda_expr = lambda_for_grfn_assignment(
246 grfn_assignment, lambda_body
247 )
249 # top interface lambda
250 node.top_interface_lambda = lambda_for_interface(node.top_interface_in)
252 # bot interface lambda
253 node.bot_interface_lambda = lambda_for_interface(node.bot_interface_in)
255 # DEBUG printing
256 if self.pipeline_state.PRINT_DEBUGGING_INFO:
257 print(f"Call No FuncDef{node.func.name}")
258 print(f"\t Args Expressions:")
259 for arg in node.arg_assignments.values():
260 print(f"\t\t{arg.lambda_expr}")
261 print(f"\t Top Interface:")
262 print(f"\t\t{node.top_interface_lambda}")
263 print(f"\t Bot Interface:")
264 print(f"\t\t{node.bot_interface_lambda}")
266 @_visit.register
267 def visit_call(self, node: AnnCastCall) -> str:
268 if node.is_grfn_2_2:
269 self.visit_call_grfn_2_2(node)
270 # in the case of GrFN 2.3 style Call or
271 # if this Call does not have FunctionDef
272 # the Call node lambda expression has the same form
273 else:
274 self.visit_call_without_func_copy(node)
275 if node.has_ret_val:
276 assert len(node.out_ret_val) == 1
277 ret_val_fullid = list(node.out_ret_val.values())[0]
278 node.expr_str = lambda_var_from_fullid(ret_val_fullid)
280 return node.expr_str
282 @_visit.register
283 def visit_record_def(self, node: AnnCastRecordDef) -> str:
284 return node.expr_str
286 def visit_function_def_copy(self, node: AnnCastFunctionDef) -> typing.List:
287 body_expr = self.visit_node_list(node.body)
288 return body_expr
290 @_visit.register
291 def visit_function_def(self, node: AnnCastFunctionDef) -> str:
292 node.top_interface_lambda = lambda_for_interface(node.top_interface_in)
293 # NOTE: we do not visit node.func_args because those parameters are
294 # included in the outputs of the top interface lambda
295 body_expr = self.visit_node_list(node.body)
296 node.bot_interface_lambda = lambda_for_interface(node.bot_interface_in)
298 # DEBUG printing
299 if self.pipeline_state.PRINT_DEBUGGING_INFO:
300 print(f"FunctionDef {node.name.name}")
301 print(f"\t Top Interface:")
302 print(f"\t\t{node.top_interface_lambda}")
303 print(f"\t Body Expressions:")
304 for e in body_expr:
305 print(f"\t\t{e}")
306 print(f"\t Bot Interface:")
307 print(f"\t\t{node.bot_interface_lambda}")
309 return node.expr_str
311 @_visit.register
312 def visit_goto(self, node: AnnCastGoto):
313 # self.visit(node.expr)
314 # self.visit(node.label, at_module_scope)
315 return ""
317 @_visit.register
318 def visit_label(self, node: AnnCastLabel):
319 # self.visit(node.label, at_module_scope)
320 return ""
322 @_visit.register
323 def visit_literal_value(self, node: AnnCastLiteralValue) -> str:
324 if node.value_type == "List[Any]":
325 # val has
326 # operator - string
327 # size - Var node or a LiteralValue node (for number)
328 # initial_value - dictionary holding a literal_value (or perhaps a Var)
329 val = node.value
331 # visit size's anncast name node
332 size_str = self.visit(val.size)
333 init_val = self.visit(val.initial_value)
334 op = val.operator
336 to_ret = f"[{init_val}] {op} {size_str}"
337 # print(to_ret) # NOTE: remove when not needed
338 node.expr_str = to_ret
339 return node.expr_str
340 elif node.value_type == StructureType.TUPLE: # or node.value_type == StructureType.LIST:
341 return ""
342 elif node.value_type == ScalarType.INTEGER:
343 node.expr_str = str(node.value)
344 return node.expr_str
345 elif node.value_type == ScalarType.ABSTRACTFLOAT:
346 node.expr_str = str(node.value)
347 return node.expr_str
348 elif node.value_type == ScalarType.BOOLEAN:
349 node.expr_str = str(node.value)
350 return node.expr_str
351 return node.expr_str
353 @_visit.register
354 def visit_loop(self, node: AnnCastLoop) -> str:
355 # top interface lambda
356 node.top_interface_lambda = lambda_for_loop_top_interface(
357 node.top_interface_initial, node.top_interface_updated
358 )
359 # init lambda
360 if len(node.pre) > 0:
361 loop_pre = self.visit_node_list(node.pre)
363 # condition lambda
364 loop_expr = self.visit(node.expr)
365 node.condition_lambda = lambda_for_loop_condition(
366 node.condition_in, loop_expr
367 )
369 body_expr = self.visit_node_list(node.body)
371 node.bot_interface_lambda = lambda_for_interface(node.bot_interface_in)
373 # DEBUG printing
374 if self.pipeline_state.PRINT_DEBUGGING_INFO:
375 print(f"Loop ")
376 print(f"\t Loop Top Interface:")
377 print(f"\t\t{node.top_interface_lambda}")
378 print(f"\t Loop Expression:")
379 print(f"\t\t{node.condition_lambda}")
380 print(f"\t Body Expressions:")
381 for e in body_expr:
382 print(f"\t\t{e}")
383 print(f"\t Loop Bot Interface:")
384 print(f"\t\t{node.bot_interface_lambda}")
386 return node.expr_str
388 @_visit.register
389 def visit_model_break(self, node: AnnCastModelBreak) -> str:
390 return node.expr_str
392 @_visit.register
393 def visit_model_continue(self, node: AnnCastModelContinue) -> str:
394 return node.expr_str
396 @_visit.register
397 def visit_model_import(self, node: AnnCastModelImport) -> str:
398 pass
400 @_visit.register
401 def visit_model_if(self, node: AnnCastModelIf) -> str:
402 # top interface lambda
403 node.top_interface_lambda = lambda_for_interface(node.top_interface_in)
405 # make condition lambda
406 expr_str = self.visit(node.expr)
407 node.condition_lambda = lambda_for_condition(
408 node.condition_in, expr_str
409 )
411 body_expr = self.visit_node_list(node.body)
412 or_else_expr = self.visit_node_list(node.orelse)
414 # make decision lambda
415 cond_fullid = list(node.condition_out.values())[0]
416 node.decision_lambda = lambda_for_decision(
417 cond_fullid, node.decision_in
418 )
420 # bot interface lambda
421 node.bot_interface_lambda = lambda_for_interface(node.bot_interface_in)
423 # DEBUG printing
424 if self.pipeline_state.PRINT_DEBUGGING_INFO:
425 print(f"If ")
426 print(f"\t If Top Interface:")
427 print(f"\t\t{node.top_interface_lambda}")
428 print(f"\t If Expression:")
429 print(f"\t\t{node.condition_lambda}")
430 print(f"\t Body Expressions:")
431 for e in body_expr:
432 print(f"\t\t{e}")
433 print(f"\t OrElse Expressions:")
434 for e in or_else_expr:
435 print(f"\t\t{e}")
436 print(f"\t If Decision Lambda:")
437 print(f"\t\t{node.decision_lambda}")
438 print(f"\t If Bot Interface:")
439 print(f"\t\t{node.bot_interface_lambda}")
441 return node.expr_str
443 @_visit.register
444 def visit_model_return(self, node: AnnCastModelReturn) -> str:
445 val = self.visit(node.value)
446 # build the lambda expression for the ret_val assignment
447 # and store in GrfnAssignment
448 lambda_expr = lambda_for_grfn_assignment(node.grfn_assignment, val)
449 node.grfn_assignment.lambda_expr = lambda_expr
450 node.expr_str = lambda_expr
452 return node.expr_str
454 @_visit.register
455 def visit_module(self, node: AnnCastModule) -> str:
456 body_expr = self.visit_node_list(node.body)
458 # DEBUG printing
459 if self.pipeline_state.PRINT_DEBUGGING_INFO:
460 print(f"Module")
461 print(f"\t Body Expressions:")
462 for e in body_expr:
463 print(f"\t\t{e}")
465 return node.expr_str
467 @_visit.register
468 def visit_name(self, node: AnnCastName) -> str:
469 fullid = ann_cast_name_to_fullid(node)
470 node.expr_str = lambda_var_from_fullid(fullid)
471 return node.expr_str
473 @_visit.register
474 def visit_operator(self, node: AnnCastOperator) -> str:
475 # TODO
476 # op = cast_op_to_str(node.op)
477 # right = self.visit(node.right)
478 # left = self.visit(node.left)
479 # node.expr_str = f"({left} {op} {right})"
481 # visit operands
482 # self.visit_node_list(node.operands, add_to)
483 return ""
486 @_visit.register
487 def visit_set(self, node: AnnCastSet) -> str:
488 return node.expr_str
490 @_visit.register
491 def visit_tuple(self, node: AnnCastTuple) -> str:
492 pieces = self.visit_node_list(node.values)
493 node.expr_str = f"({', '.join(pieces)})"
495 return node.expr_str
497 @_visit.register
498 def visit_var(self, node: AnnCastVar) -> str:
499 node.expr_str = self.visit(node.val)
500 return node.expr_str