Coverage for skema/program_analysis/CAST/python/ts2cast.py: 88%
597 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import json
2import os.path
3from pathlib import Path
4from typing import Any, Dict, List, Union
6from tree_sitter import Language, Parser, Node
8from skema.program_analysis.CAST2FN.cast import CAST
9from skema.program_analysis.CAST2FN.model.cast import (
10 Module,
11 SourceRef,
12 Assignment,
13 CASTLiteralValue,
14 Var,
15 VarType,
16 Name,
17 Operator,
18 AstNode,
19 SourceCodeDataType,
20 ModelImport,
21 FunctionDef,
22 Loop,
23 Call,
24 ModelReturn,
25 ModelIf,
26 RecordDef,
27 Attribute,
28 ScalarType,
29 StructureType
30)
32from skema.program_analysis.CAST.python.node_helper import (
33 NodeHelper,
34 get_first_child_by_type,
35 get_children_by_types,
36 get_first_child_index,
37 get_last_child_index,
38 get_control_children,
39 get_non_control_children,
40 FOR_LOOP_LEFT_TYPES,
41 FOR_LOOP_RIGHT_TYPES,
42 WHILE_COND_TYPES,
43 COMPREHENSION_OPERATORS
44)
45from skema.program_analysis.CAST.python.util import (
46 generate_dummy_source_refs,
47 get_op
48)
49from skema.program_analysis.CAST.fortran.variable_context import VariableContext
51from skema.program_analysis.tree_sitter_parsers.build_parsers import INSTALLED_LANGUAGES_FILEPATH
54PYTHON_VERSION = "3.10"
56class TS2CAST(object):
57 def __init__(self, source_file_path: str, from_file = True):
58 # from_file flag is used for testing purposes, when we don't have actual files
59 if from_file:
60 self.path = Path(source_file_path)
61 self.source_file_name = self.path.name
63 # Python doesn't have a preprocessing step like fortran
64 self.source = self.path.read_text()
65 else:
66 self.path = "None"
67 self.source_file_name = "Temp"
68 self.source = source_file_path
70 # Run tree-sitter preprocessor output to generate parse tree
71 parser = Parser()
72 parser.set_language(
73 Language(
74 INSTALLED_LANGUAGES_FILEPATH,
75 "python"
76 )
77 )
79 # Generated FNs by comprehensions/lambdas
80 self.generated_fns = []
82 # Additional variables used in generation
83 self.var_count = 0
85 # A dictionary used to keep track of aliases that imports use
86 # (like import x as y, or from x import y as z)
87 # Used to resolve aliasing in imports
88 self.aliases = {}
90 # Tree walking structures
91 self.variable_context = VariableContext()
92 self.node_helper = NodeHelper(self.source, self.source_file_name)
94 self.tree = parser.parse(bytes(self.source, "utf8"))
96 self.out_cast = self.generate_cast()
98 def generate_cast(self) -> List[CAST]:
99 '''Interface for generating CAST.'''
100 module = self.run(self.tree.root_node)
101 module.name = self.source_file_name
102 return CAST([generate_dummy_source_refs(module)], "Python")
104 def run(self, root) -> List[Module]:
105 # In python there's generally only one module at the top level
106 # I believe then we just need to visit the root, which is a module
107 # Which can then contain multiple things (Handled at module visitor)
108 return self.visit(root)
110 # TODO: node helper for ignoring comments
112 def check_alias(self, name):
113 """Given a python string that represents a name,
114 this function checks to see if that name is an alias
115 for a different name, and returns it if it is indeed an alias.
116 Otherwise, the original name is returned.
117 """
118 if name in self.aliases:
119 return self.aliases[name]
120 else:
121 return name
123 def visit(self, node: Node):
124 # print(f"===Visiting node[{node.type}]===")
125 if node.type == "module":
126 return self.visit_module(node)
127 elif node.type == "parenthesized_expression":
128 # Node for "( op )", extract op
129 # The actual op is in the middle of the list of nodes
130 return self.visit(node.children[1])
131 elif node.type == "expression_statement":
132 return self.visit_expression(node)
133 elif node.type == "function_definition":
134 return self.visit_function_def(node)
135 elif node.type == "return_statement":
136 return self.visit_return(node)
137 elif node.type == "call":
138 return self.visit_call(node)
139 elif node.type == "if_statement":
140 return self.visit_if_statement(node)
141 elif node.type == "comparison_operator":
142 return self.visit_comparison_op(node)
143 elif node.type == "assignment":
144 return self.visit_assignment(node)
145 elif node.type == "attribute":
146 return self.visit_attribute(node)
147 elif node.type == "identifier":
148 return self.visit_identifier(node)
149 elif node.type == "unary_operator":
150 return self.visit_unary_op(node)
151 elif node.type == "binary_operator":
152 return self.visit_binary_op(node)
153 elif node.type in ["integer", "list"]:
154 return self.visit_literal(node)
155 elif node.type in ["list_pattern", "pattern_list", "tuple_pattern"]:
156 return self.visit_pattern(node)
157 elif node.type == "list_comprehension":
158 return self.visit_list_comprehension(node)
159 elif node.type == "dictionary_comprehension":
160 return self.visit_dict_comprehension(node)
161 elif node.type == "lambda":
162 return self.visit_lambda(node)
163 elif node.type == "subscript":
164 return self.visit_subscript(node)
165 elif node.type == "slice":
166 return self.visit_slice(node)
167 elif node.type == "pair":
168 return self.visit_pair(node)
169 elif node.type == "while_statement":
170 return self.visit_while(node)
171 elif node.type == "for_statement":
172 return self.visit_for(node)
173 elif node.type == "import_statement":
174 return self.visit_import(node)
175 elif node.type == "import_from_statement":
176 return self.visit_import_from(node)
177 elif node.type == "class_definition":
178 return self.visit_class_definition(node)
179 elif node.type == "yield":
180 return self.visit_yield(node)
181 elif node.type == "assert_statement":
182 return self.visit_assert(node)
183 else:
184 return self._visit_passthrough(node)
186 def visit_module(self, node: Node) -> Module:
187 # A module is comprised of one or several statements/expressions
188 # At the global level
189 self.variable_context.push_context()
191 body = []
192 for child in node.children:
193 child_cast = self.visit(child)
194 if isinstance(child_cast, List):
195 body.extend(child_cast)
196 elif isinstance(child_cast, AstNode):
197 body.append(child_cast)
199 self.variable_context.pop_context()
201 return Module(
202 name=None,
203 body=self.generated_fns + body,
204 source_refs = [self.node_helper.get_source_ref(node)]
205 )
207 def visit_expression(self, node: Node):
208 # NOTE: Is there an instance where an 'expression statement' node
209 # Has more than one child?
211 expr_body = []
212 for child in node.children:
213 child_cast = self.visit(child)
214 if isinstance(child_cast, List):
215 expr_body.extend(child_cast)
216 elif isinstance(child_cast, AstNode):
217 expr_body.append(child_cast)
219 return expr_body
221 def visit_function_def(self, node: Node) -> FunctionDef:
222 ref = self.node_helper.get_source_ref(node)
224 name_node = get_first_child_by_type(node, "identifier")
225 name = self.visit(name_node)
227 # Create new variable context
228 self.variable_context.push_context()
230 parameters = get_children_by_types(node, ["parameters"])[0]
231 parameters = get_non_control_children(parameters)
233 # The body of the function is stored in a 'block' type node
234 body = get_children_by_types(node, ["block"])[0].children
237 func_params = []
238 for node in parameters:
239 cast = self.visit(node)
240 if isinstance(cast, List):
241 func_params.extend(cast)
242 elif isinstance(cast, AstNode):
243 func_params.append(cast)
245 func_body = []
246 for node in body:
247 cast = self.visit(node)
248 if isinstance(cast, List):
249 func_body.extend(cast)
250 elif isinstance(cast, AstNode):
251 func_body.append(cast)
252 # TODO: Do we need to handle return statements in any special way?
254 self.variable_context.pop_context()
256 return FunctionDef(
257 name=name.val,
258 func_args=func_params,
259 body=func_body,
260 source_refs=[ref]
261 )
263 def visit_return(self, node: Node) -> ModelReturn:
264 ref = self.node_helper.get_source_ref(node)
265 ret_val = node.children[1]
266 ret_cast = self.visit(ret_val)
268 return ModelReturn(value=get_operand_node(ret_cast), source_refs=[ref])
270 def visit_call(self, node: Node) -> Call:
271 ref = self.node_helper.get_source_ref(node)
273 func_cast = self.visit(node.children[0])
275 func_name = get_func_name_node(func_cast)
277 arg_list = get_first_child_by_type(node, "argument_list")
278 args = get_non_control_children(arg_list)
280 func_args = []
281 for arg in args:
282 cast = get_name_node(self.visit(arg))
283 if isinstance(cast, List):
284 func_args.extend(cast)
285 elif isinstance(cast, AstNode):
286 func_args.append(cast)
288 if get_name_node(func_cast).name == "range":
289 start_step_value = CASTLiteralValue(
290 ScalarType.INTEGER,
291 value="1",
292 source_code_data_type=["Python", PYTHON_VERSION, str(type(1))],
293 source_refs=[ref]
294 )
295 # Add a step value
296 if len(func_args) == 2:
297 func_args.append(start_step_value)
298 # Add a start and step value
299 elif len(func_args) == 1:
300 func_args.insert(0, start_step_value)
301 func_args.append(start_step_value)
303 # Function calls only want the 'Name' part of the 'Var' that the visit returns
304 return Call(
305 func=func_name,
306 arguments=func_args,
307 source_refs=[ref]
308 )
310 def visit_comparison_op(self, node: Node):
311 ref = self.node_helper.get_source_ref(node)
312 op = get_op(self.node_helper.get_operator(node.children[1]))
313 left, _, right = node.children
315 left_cast = get_name_node(self.visit(left))
316 right_cast = get_name_node(self.visit(right))
318 return Operator(
319 op=op,
320 operands=[left_cast, right_cast],
321 source_refs=[ref]
322 )
324 def visit_if_statement(self, node: Node) -> ModelIf:
325 if_condition = self.visit(get_first_child_by_type(node, "comparison_operator"))
327 # Get the body of the if true part
328 if_true = get_children_by_types(node, ["block"])[0].children
330 # Because in tree-sitter the else if, and else aren't nested, but they're
331 # in a flat level order, we need to do some arranging of the pieces
332 # in order to get the correct CAST nested structure that we use
333 # Visit all the alternatives, generate CAST for each one
334 # and then join them all together
335 alternatives = get_children_by_types(node, ["elif_clause","else_clause"])
337 if_true_cast = []
338 for node in if_true:
339 cast = self.visit(node)
340 if isinstance(cast, List):
341 if_true_cast.extend(cast)
342 elif isinstance(cast, AstNode):
343 if_true_cast.append(cast)
345 # If we have ts nodes in alternatives, then we're guaranteed
346 # at least an else at the end of the if-statement construct
347 # We generate the cast for the final else statement, and then
348 # reverse the rest of the if-elses that we have, so we can
349 # create the CAST correctly
350 final_else_cast = []
351 if len(alternatives) > 0:
352 final_else = alternatives.pop()
353 alternatives.reverse()
354 final_else_body = get_children_by_types(final_else, ["block"])[0].children
355 for node in final_else_body:
356 cast = self.visit(node)
357 if isinstance(cast, List):
358 final_else_cast.extend(cast)
359 elif isinstance(cast, AstNode):
360 final_else_cast.append(cast)
362 # We go through any additional if-else nodes that we may have,
363 # generating their ModelIf CAST and appending the tail of the
364 # overall if-else construct, starting with the else at the very end
365 # We do this tail appending so that when we finish generating CAST the
366 # resulting ModelIf CAST is in the correct order
367 alternatives_cast = None
368 for ts_node in alternatives:
369 assert ts_node.type == "elif_clause"
370 temp_cast = self.visit_if_statement(ts_node)
371 if alternatives_cast == None:
372 temp_cast.orelse = final_else_cast
373 else:
374 temp_cast.orelse = [alternatives_cast]
375 alternatives_cast = temp_cast
377 if alternatives_cast == None:
378 if_false_cast = final_else_cast
379 else:
380 if_false_cast = [alternatives_cast]
382 return ModelIf(
383 expr=if_condition,
384 body=if_true_cast,
385 orelse=if_false_cast,
386 source_refs=[self.node_helper.get_source_ref(node)]
387 )
389 def visit_assignment(self, node: Node) -> Assignment:
390 left, _, right = node.children
391 ref = self.node_helper.get_source_ref(node)
393 # For the RHS of an assignment we want the Name CAST node
394 # and not the entire Var CAST node if we're doing an
395 # assignment like x = y
396 right_cast = get_name_node(self.visit(right))
398 return Assignment(
399 left=self.visit(left),
400 right=right_cast,
401 source_refs=[ref]
402 )
404 def visit_unary_op(self, node: Node) -> Operator:
405 """
406 Unary Ops
407 OP operand
408 where operand is some kind of expression
409 """
410 ref = self.node_helper.get_source_ref(node)
411 op = get_op(self.node_helper.get_operator(node.children[0]))
412 operand = node.children[1]
414 if op == 'ast.Sub':
415 op = 'ast.USub'
417 # For the operand we need the Name CAST node and
418 # not the whole Var CAST node
419 # in instances like -x
420 operand_cast = get_name_node(self.visit(operand))
422 if isinstance(operand_cast, Var):
423 operand_cast = operand_cast.val
425 return Operator(
426 op=op,
427 operands=[operand_cast],
428 source_refs=[ref]
429 )
431 def visit_binary_op(self, node: Node) -> Operator:
432 """
433 Binary Ops
434 left OP right
435 where left and right can either be operators or literals
436 """
437 ref = self.node_helper.get_source_ref(node)
438 op = get_op(self.node_helper.get_operator(node.children[1]))
439 left, _, right = node.children
441 left_cast = get_operand_node(self.visit(left))
442 right_cast = get_operand_node(self.visit(right))
444 return Operator(
445 op=op,
446 operands=[left_cast, right_cast],
447 source_refs=[ref]
448 )
450 def visit_pattern(self, node: Node):
451 pattern_cast = []
452 for elem in node.children:
453 cast = self.visit(elem)
454 if isinstance(cast, List):
455 pattern_cast.extend(cast)
456 elif isinstance(cast, AstNode):
457 pattern_cast.append(cast)
459 return CASTLiteralValue(value_type=StructureType.TUPLE, value=pattern_cast)
461 def visit_identifier(self, node: Node) -> Var:
462 identifier = self.node_helper.get_identifier(node)
464 if self.variable_context.is_variable(identifier):
465 var_type = self.variable_context.get_type(identifier)
466 else:
467 var_type = "unknown"
469 # TODO: Python default values
470 default_value = None
472 value = self.visit_name(node)
474 return Var(
475 val=value,
476 type=var_type,
477 default_value=default_value,
478 source_refs=[self.node_helper.get_source_ref(node)]
479 )
481 def visit_literal(self, node: Node) -> Any:
482 literal_type = node.type
483 literal_value = self.node_helper.get_identifier(node)
484 literal_source_ref = self.node_helper.get_source_ref(node)
486 if literal_type == "integer":
487 return CASTLiteralValue(
488 value_type=ScalarType.INTEGER,
489 value=literal_value,
490 source_code_data_type=["Python", PYTHON_VERSION, str(type(1))],
491 source_refs=[literal_source_ref]
492 )
493 elif literal_type == "float":
494 return CASTLiteralValue(
495 value_type=ScalarType.ABSTRACTFLOAT,
496 value=literal_value,
497 source_code_data_type=["Python", PYTHON_VERSION, str(type(1.0))],
498 source_refs=[literal_source_ref]
499 )
500 elif literal_type == "true" or literal_type == "false":
501 return CASTLiteralValue(
502 value_type=ScalarType.BOOLEAN,
503 value="True" if literal_type == "true" else "False",
504 source_code_data_type=["Python", PYTHON_VERSION, str(type(True))],
505 source_refs=[literal_source_ref]
506 )
507 elif literal_type == "list":
508 list_items = []
509 for elem in node.children:
510 cast = self.visit(elem)
511 if isinstance(cast, List):
512 list_items.extend(cast)
513 elif isinstance(cast, AstNode):
514 list_items.append(cast)
516 return CASTLiteralValue(
517 value_type=StructureType.LIST,
518 value = list_items,
519 source_code_data_type=["Python", PYTHON_VERSION, str(type([0]))],
520 source_refs=[literal_source_ref]
521 )
522 elif literal_type == "tuple":
523 tuple_items = []
524 for elem in node.children:
525 cast = self.visit(cast)
526 if isinstance(cast, List):
527 tuple_items.extend(cast)
528 elif isinstance(cast, AstNode):
529 tuple_items.append(cast)
531 return CASTLiteralValue(
532 value_type=StructureType.LIST,
533 value = tuple_items,
534 source_code_data_type=["Python", PYTHON_VERSION, str(type((0)))],
535 source_refs=[literal_source_ref]
536 )
538 def handle_dotted_name(self, import_stmt) -> ModelImport:
539 ref = self.node_helper.get_source_ref(import_stmt)
540 name = self.node_helper.get_identifier(import_stmt)
541 self.visit(import_stmt)
543 return name
545 def handle_aliased_import(self, import_stmt) -> ModelImport:
546 ref = self.node_helper.get_source_ref(import_stmt)
547 dot_name = get_children_by_types(import_stmt,["dotted_name"])[0]
548 name = self.handle_dotted_name(dot_name)
549 alias = get_children_by_types(import_stmt, ["identifier"])[0]
550 self.visit(alias)
552 return (name, self.node_helper.get_identifier(alias))
554 def visit_import(self, node: Node):
555 ref = self.node_helper.get_source_ref(node)
556 to_ret = []
558 names_list = get_children_by_types(node, ["dotted_name", "aliased_import"])
559 for name in names_list:
560 if name.type == "dotted_name":
561 resolved_name = self.handle_dotted_name(name)
562 to_ret.append(ModelImport(name=resolved_name, alias=None, symbol=None, all=False, source_refs=ref))
563 elif name.type == "aliased_import":
564 resolved_name = self.handle_aliased_import(name)
565 self.aliases[resolved_name[1]] = resolved_name[0]
566 to_ret.append(ModelImport(name=resolved_name[0], alias=resolved_name[1], symbol=None, all=False, source_refs=ref))
568 return to_ret
570 def visit_import_from(self, node: Node):
571 ref = self.node_helper.get_source_ref(node)
572 to_ret = []
574 names_list = get_children_by_types(node, ["dotted_name", "aliased_import"])
575 wild_card = get_children_by_types(node, ["wildcard_import"])
576 module_name = self.node_helper.get_identifier(names_list[0])
578 # if "wildcard_import" exists then it'll be in the list
579 if len(wild_card) == 1:
580 to_ret.append(ModelImport(name=module_name, alias=None, symbol=None, all=True, source_refs=ref))
581 else:
582 for name in names_list[1:]:
583 if name.type == "dotted_name":
584 resolved_name = self.handle_dotted_name(name)
585 to_ret.append(ModelImport(name=module_name, alias=None, symbol=resolved_name, all=False, source_refs=ref))
586 elif name.type == "aliased_import":
587 resolved_name = self.handle_aliased_import(name)
588 self.aliases[resolved_name[1]] = resolved_name[0]
589 to_ret.append(ModelImport(name=module_name, alias=resolved_name[1], symbol=resolved_name[0], all=False, source_refs=ref))
591 return to_ret
593 def visit_attribute(self, node: Node):
594 ref = self.node_helper.get_source_ref(node)
595 obj,_,attr = node.children
596 obj_cast = self.visit(obj)
597 attr_cast = self.visit(attr)
599 return Attribute(value= get_name_node(obj_cast), attr=get_name_node(attr_cast), source_refs=ref)
601 def visit_subscript(self, node: Node):
602 ref = self.node_helper.get_source_ref(node)
603 values = get_non_control_children(node)
604 name_cast = self.visit(values[0])
605 subscript_list = values[1:]
606 subscript_casts = []
607 for subscript in subscript_list:
608 cast = self.visit(subscript)
609 if isinstance(cast, list):
610 for elem in cast:
611 subscript_casts.append(get_func_name_node(cast))
612 else:
613 subscript_casts.append(get_func_name_node(cast))
615 get_func = self.get_gromet_function_node("_get")
617 get_call = Call(
618 func = get_func,
619 arguments = [get_func_name_node(name_cast)] + subscript_casts,
620 source_refs=ref
621 )
623 return get_call
625 def visit_slice(self, node: Node):
626 ref = self.node_helper.get_source_ref(node)
627 indices = get_non_control_children(node)
628 index_cast = []
630 for index in indices:
631 cast = self.visit(index)
632 if isinstance(cast ,list):
633 index_cast.extend(cast)
634 else:
635 index_cast.append(cast)
637 start = index_cast[0]
638 end = index_cast[1]
639 if len(index_cast) == 3:
640 step = index_cast[2]
641 else:
642 step = CASTLiteralValue(
643 value_type=ScalarType.INTEGER,
644 value="1",
645 source_code_data_type=["Python", "3.8", "Float"],
646 source_refs=ref,
647 )
649 return CASTLiteralValue(value_type=StructureType.LIST,
650 value=[start,end,step],
651 source_code_data_type=["Python", "3.8", "List"],
652 source_refs=ref
653 )
656 def handle_for_clause(self, node: Node):
657 # Given the "for x in seq" clause of a list comprehension
658 # we translate it to a CAST for loop, leaving the actual
659 # computation of the body node for the main comprehension handler
660 assert node.type == "for_in_clause"
661 ref = self.node_helper.get_source_ref(node)
663 # NOTE: Assumes the left part with the variable is always the 2nd
664 # element in the children and the right part with the actual
665 # function call is on the 4th (last) element of the children
666 left = self.visit(node.children[1])
667 right = self.visit(node.children[-1])
669 iterator_name = self.variable_context.generate_iterator()
670 stop_cond_name = self.variable_context.generate_stop_condition()
671 iter_func = self.get_gromet_function_node("iter")
672 next_func = self.get_gromet_function_node("next")
674 iter_call = Assignment(
675 left = Var(iterator_name, "Iterator"),
676 right = Call(
677 iter_func,
678 arguments=[right]
679 )
680 )
682 next_call = Call(
683 next_func,
684 arguments=[Var(iterator_name, "Iterator")]
685 )
687 next_assign = Assignment(
688 left=CASTLiteralValue(
689 "Tuple",
690 [
691 left,
692 Var(iterator_name, "Iterator"),
693 Var(stop_cond_name, "Boolean"),
694 ],
695 source_code_data_type = ["Python",PYTHON_VERSION,"Tuple"],
696 source_refs=ref
697 ),
698 right=next_call
699 )
701 loop_pre = []
702 loop_pre.append(iter_call)
703 loop_pre.append(next_assign)
705 loop_expr = Operator(
706 source_language="Python",
707 interpreter="Python",
708 version=PYTHON_VERSION,
709 op="ast.Eq",
710 operands=[
711 stop_cond_name,
712 CASTLiteralValue(
713 ScalarType.BOOLEAN,
714 False,
715 ["Python", PYTHON_VERSION, "boolean"],
716 source_refs=ref,
717 )
718 ],
719 source_refs=ref
720 )
722 loop_body = [None, next_assign]
724 return Loop(pre=loop_pre, expr=loop_expr, body=loop_body, post=[])
726 def handle_if_clause(self, node: Node):
727 assert node.type == "if_clause"
728 ref = self.node_helper.get_source_ref(node)
729 conditional = get_children_by_types(node, WHILE_COND_TYPES)[0]
730 cond_cast = self.visit(conditional)
732 return ModelIf(expr=cond_cast,body=[],orelse=[],source_refs=ref)
734 def construct_loop_construct(self, node: Node):
735 return []
737 def visit_list_comprehension(self, node: Node) -> Call:
738 ref = self.node_helper.get_source_ref(node)
740 temp_list_name = self.variable_context.add_variable(
741 "list__temp_", "Unknown", [ref]
742 )
744 temp_asg_cast = Assignment(
745 left=Var(val=temp_list_name),
746 right=CASTLiteralValue(value=[], value_type=StructureType.LIST),
747 source_refs = ref
748 )
750 append_call = self.get_gromet_function_node("append")
751 computation = get_children_by_types(node, COMPREHENSION_OPERATORS)[0]
752 computation_cast = self.visit(computation)
754 # IDEA: When we see a for_clause we start a new loop construct, and collect if_clauses
755 # as we see them
756 clauses = get_children_by_types(node, ["for_in_clause", "if_clause"])
757 loop_start = []
758 prev_loop = []
760 if_start = []
761 prev_if = []
763 for clause in clauses:
764 if clause.type == "for_in_clause":
765 new_loop = self.handle_for_clause(clause)
766 if loop_start == []:
767 loop_start = new_loop
768 prev_loop = loop_start
769 else:
770 if prev_if == []:
771 prev_loop.body[0] = new_loop
772 prev_loop = new_loop
773 else:
774 prev_loop.body[0] = prev_if
775 prev_if.body = [new_loop]
776 prev_loop = new_loop
777 if_start = []
778 prev_if = []
779 elif clause.type == "if_clause":
780 new_if = self.handle_if_clause(clause)
781 if if_start == []:
782 if_start = new_if
783 prev_if = if_start
784 else:
785 prev_if.body = [new_if]
786 prev_if = new_if
788 if prev_if == []:
789 prev_loop.body[0] = Call(func=Attribute(temp_list_name, append_call), arguments=[computation_cast], source_refs=ref)
790 else:
791 prev_loop.body[0] = prev_if
792 prev_if.body = [Call(func=Attribute(temp_list_name, append_call), arguments=[computation_cast], source_refs=ref)]
794 return_cast = ModelReturn(temp_list_name)
796 func_name = self.variable_context.generate_func("%comprehension_list")
797 func_def_cast = FunctionDef(name=func_name, func_args=[], body=[temp_asg_cast,loop_start,return_cast], source_refs=ref)
799 self.generated_fns.append(func_def_cast)
801 return Call(func=func_name, arguments=[], source_refs=ref)
803 def visit_pair(self, node: Node):
804 key = self.visit(node.children[0])
805 value = self.visit(node.children[2])
807 return key,value
809 def visit_dict_comprehension(self, node: Node) -> Call:
810 ref = self.node_helper.get_source_ref(node)
812 temp_dict_name = self.variable_context.add_variable(
813 "dict__temp_", "Unknown", [ref]
814 )
816 temp_asg_cast = Assignment(
817 left=Var(val=temp_dict_name),
818 right=CASTLiteralValue(value={}, value_type=StructureType.MAP),
819 source_refs = ref
820 )
822 set_call = self.get_gromet_function_node("_set")
823 computation = get_children_by_types(node, COMPREHENSION_OPERATORS)[0]
824 computation_cast = self.visit(computation)
826 # IDEA: When we see a for_clause we start a new loop construct, and collect if_clauses
827 # as we see them
828 clauses = get_children_by_types(node, ["for_in_clause", "if_clause"])
829 loop_start = []
830 prev_loop = []
832 if_start = []
833 prev_if = []
835 for clause in clauses:
836 if clause.type == "for_in_clause":
837 new_loop = self.handle_for_clause(clause)
838 if loop_start == []:
839 loop_start = new_loop
840 prev_loop = loop_start
841 else:
842 if prev_if == []:
843 prev_loop.body[0] = new_loop
844 prev_loop = new_loop
845 else:
846 prev_loop.body[0] = prev_if
847 prev_if.body = [new_loop]
848 prev_loop = new_loop
849 if_start = []
850 prev_if = []
851 elif clause.type == "if_clause":
852 new_if = self.handle_if_clause(clause)
853 if if_start == []:
854 if_start = new_if
855 prev_if = if_start
856 else:
857 prev_if.body = [new_if]
858 prev_if = new_if
860 if prev_if == []:
861 prev_loop.body[0] = Assignment(left=Var(val=temp_dict_name), right=Call(func=set_call, arguments=[temp_dict_name, computation_cast[0].val, computation_cast[1]], source_refs=ref), source_refs=ref)
862 else:
863 prev_loop.body[0] = prev_if
864 prev_loop = Assignment(left=Var(val=temp_dict_name), right=Call(func=set_call, arguments=[temp_dict_name, computation_cast[0].val, computation_cast[1]], source_refs=ref), source_refs=ref)
866 return_cast = ModelReturn(temp_dict_name)
868 func_name = self.variable_context.generate_func("%comprehension_dict")
869 func_def_cast = FunctionDef(name=func_name, func_args=[], body=[temp_asg_cast,loop_start,return_cast], source_refs=ref)
871 self.generated_fns.append(func_def_cast)
873 return Call(func=func_name, arguments=[], source_refs=ref)
876 def visit_lambda(self, node: Node) -> Call:
877 # TODO: we have to determine how to grab the variables that are being
878 # used in the lambda that aren't part of the lambda's arguments
879 ref=self.node_helper.get_source_ref(node)
880 params = get_children_by_types(node, ["lambda_parameters"])[0]
881 body = get_children_by_types(node, COMPREHENSION_OPERATORS)[0]
883 parameters = []
884 for param in params.children:
885 cast = self.visit(param)
886 if isinstance(cast, list):
887 parameters.extend(cast)
888 else:
889 parameters.append(cast)
891 body_cast = self.visit(body)
892 func_body = body_cast
894 func_name = self.variable_context.generate_func("%lambda")
895 func_def_cast = FunctionDef(name=func_name, func_args=parameters, body=[ModelReturn(value=func_body)], source_refs=ref)
897 self.generated_fns.append(func_def_cast)
899 # Collect all the Name node instances to use as arguments for the lambda call
900 args = [par.val if isinstance(par, Var) else par for par in parameters]
902 return Call(func=func_name, arguments=args, source_refs=ref)
904 def visit_while(self, node: Node) -> Loop:
905 ref = self.node_helper.get_source_ref(node)
907 # Push a variable context since a loop
908 # can create variables that only it can see
909 self.variable_context.push_context()
911 loop_cond_node = get_children_by_types(node, WHILE_COND_TYPES)[0]
912 loop_body_node = get_children_by_types(node, ["block"])[0].children
914 loop_cond = self.visit(loop_cond_node)
916 loop_body = []
917 for node in loop_body_node:
918 cast = self.visit(node)
919 if isinstance(cast, List):
920 loop_body.extend(cast)
921 elif isinstance(cast, AstNode):
922 loop_body.append(cast)
924 self.variable_context.pop_context()
926 return Loop(
927 pre=[],
928 expr=loop_cond,
929 body=loop_body,
930 post=[],
931 source_refs = ref
932 )
934 def visit_for(self, node: Node) -> Loop:
935 ref = self.node_helper.get_source_ref(node)
937 # Pre: left, right
938 loop_cond_left = get_children_by_types(node, FOR_LOOP_LEFT_TYPES)[0]
939 loop_cond_right = get_children_by_types(node, FOR_LOOP_RIGHT_TYPES)[-1]
941 # Construct pre and expr value using left and right as needed
942 # need calls to "_Iterator"
944 self.variable_context.push_context()
945 iterator_name = self.variable_context.generate_iterator()
946 stop_cond_name = self.variable_context.generate_stop_condition()
947 iter_func = self.get_gromet_function_node("iter")
948 next_func = self.get_gromet_function_node("next")
950 loop_cond_left_cast = self.visit(loop_cond_left)
951 loop_cond_right_cast = self.visit(loop_cond_right)
953 loop_pre = []
954 loop_pre.append(
955 Assignment(
956 left = Var(iterator_name, "Iterator"),
957 right = Call(
958 iter_func,
959 arguments=[loop_cond_right_cast]
960 )
961 )
962 )
964 loop_pre.append(
965 Assignment(
966 left=CASTLiteralValue(
967 "Tuple",
968 [
969 loop_cond_left_cast,
970 Var(iterator_name, "Iterator"),
971 Var(stop_cond_name, "Boolean"),
972 ],
973 source_code_data_type = ["Python",PYTHON_VERSION,"Tuple"],
974 source_refs=ref
975 ),
976 right=Call(
977 next_func,
978 arguments=[Var(iterator_name, "Iterator")],
979 ),
980 )
981 )
983 loop_expr = Operator(
984 source_language="Python",
985 interpreter="Python",
986 version=PYTHON_VERSION,
987 op="ast.Eq",
988 operands=[
989 stop_cond_name,
990 CASTLiteralValue(
991 ScalarType.BOOLEAN,
992 False,
993 ["Python", PYTHON_VERSION, "boolean"],
994 source_refs=ref,
995 )
996 ],
997 source_refs=ref
998 )
1000 loop_body_node = get_children_by_types(node, ["block"])[0].children
1001 loop_body = []
1002 for node in loop_body_node:
1003 cast = self.visit(node)
1004 if isinstance(cast, List):
1005 loop_body.extend(cast)
1006 elif isinstance(cast, AstNode):
1007 loop_body.append(cast)
1009 # Insert an additional call to 'next' at the end of the loop body,
1010 # to facilitate looping in GroMEt
1011 loop_body.append(
1012 Assignment(
1013 left=CASTLiteralValue(
1014 "Tuple",
1015 [
1016 loop_cond_left_cast,
1017 Var(iterator_name, "Iterator"),
1018 Var(stop_cond_name, "Boolean"),
1019 ],
1020 ),
1021 right=Call(
1022 next_func,
1023 arguments=[Var(iterator_name, "Iterator")],
1024 ),
1025 )
1026 )
1028 self.variable_context.pop_context()
1029 return Loop(
1030 pre=loop_pre,
1031 expr=loop_expr,
1032 body=loop_body,
1033 post=[],
1034 source_refs = ref
1035 )
1037 def retrieve_init_func(self, functions: List[FunctionDef]):
1038 # Given a list of CAST function defs, we
1039 # attempt to retrieve the CAST function that corresponds to
1040 # "__init__"
1041 for func in functions:
1042 if func.name.name == "__init__":
1043 return func
1044 return None
1046 def retrieve_class_attrs(self, init_func: FunctionDef):
1047 attrs = []
1048 for stmt in init_func.body:
1049 if isinstance(stmt, Assignment):
1050 if isinstance(stmt.left, Attribute):
1051 if stmt.left.value.name == "self":
1052 attrs.append(stmt.left.attr)
1054 return attrs
1056 def visit_class_definition(self, node):
1057 class_name_node = get_first_child_by_type(node, "identifier")
1058 class_cast = self.visit(class_name_node)
1060 function_defs = get_children_by_types(get_children_by_types(node, "block")[0], "function_definition")
1061 func_defs_cast = []
1062 for func in function_defs:
1063 func_cast = self.visit(func)
1064 if isinstance(func_cast, List):
1065 func_defs_cast.extend(func_cast)
1066 else:
1067 func_defs_cast.append(func_cast)
1069 init_func = self.retrieve_init_func(func_defs_cast)
1070 attributes = self.retrieve_class_attrs(init_func)
1072 return RecordDef(name=get_name_node(class_cast).name, bases=[], funcs=func_defs_cast, fields=attributes)
1075 def visit_name(self, node):
1076 # First, we will check if this name is already defined, and if it is return the name node generated previously
1077 # NOTE: the call to check_alias is a crucial change, to resolve aliasing
1078 # need to make sure nothing breaks
1079 identifier = self.check_alias(self.node_helper.get_identifier(node))
1080 if self.variable_context.is_variable(identifier):
1081 return self.variable_context.get_node(identifier)
1083 return self.variable_context.add_variable(
1084 identifier, "Unknown", [self.node_helper.get_source_ref(node)]
1085 )
1087 def _visit_passthrough(self, node):
1088 if len(node.children) == 0:
1089 return None
1091 for child in node.children:
1092 child_cast = self.visit(child)
1093 if child_cast:
1094 return child_cast
1096 def get_gromet_function_node(self, func_name: str) -> Name:
1097 # Idealy, we would be able to create a dummy node and just call the name visitor.
1098 # However, tree-sitter does not allow you to create or modify nodes, so we have to recreate the logic here.
1099 if self.variable_context.is_variable(func_name):
1100 return self.variable_context.get_node(func_name)
1102 return self.variable_context.add_variable(func_name, "function", None)
1104 def visit_yield(self, node):
1105 source_code_data_type = ["Python", "3.8", "List"]
1106 ref = self.node_helper.get_source_ref(node)
1107 return [
1108 CASTLiteralValue(
1109 StructureType.LIST,
1110 "YieldNotImplemented",
1111 source_code_data_type,
1112 ref
1113 )
1114 ]
1116 def visit_assert(self, node):
1117 source_code_data_type = ["Python", "3.8", "List"]
1118 ref = self.node_helper.get_source_ref(node)
1119 return [
1120 CASTLiteralValue(
1121 StructureType.LIST,
1122 "AssertNotImplemented",
1123 source_code_data_type,
1124 ref
1125 )
1126 ]
1129def get_name_node(node):
1130 # Given a CAST node, if it's type Var, then we extract the name node out of it
1131 # If it's anything else, then the node just gets returned normally
1132 cur_node = node
1133 if isinstance(node, list):
1134 cur_node = node[0]
1135 if isinstance(cur_node, Attribute):
1136 return get_name_node(cur_node.attr)
1137 if isinstance(cur_node, Var):
1138 return cur_node.val
1139 else:
1140 return cur_node
1142def get_func_name_node(node):
1143 # Given a CAST node, we attempt to extract the appropriate name element
1144 # from it.
1145 cur_node = node
1146 if isinstance(cur_node, Var):
1147 return cur_node.val
1148 else:
1149 return cur_node
1151def get_operand_node(node):
1152 # Given a CAST/list node, we extract the appropriate operand for the operator from it
1153 cur_node = node
1154 if isinstance(node, list):
1155 cur_node = node[0]
1156 if isinstance(cur_node, Var):
1157 return cur_node.val
1158 else:
1159 return cur_node