Coverage for skema/program_analysis/CAST/python/ts2cast.py: 88%

597 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import json 

2import os.path 

3from pathlib import Path 

4from typing import Any, Dict, List, Union 

5 

6from tree_sitter import Language, Parser, Node 

7 

8from skema.program_analysis.CAST2FN.cast import CAST 

9from skema.program_analysis.CAST2FN.model.cast import ( 

10 Module, 

11 SourceRef, 

12 Assignment, 

13 CASTLiteralValue, 

14 Var, 

15 VarType, 

16 Name, 

17 Operator, 

18 AstNode, 

19 SourceCodeDataType, 

20 ModelImport, 

21 FunctionDef, 

22 Loop, 

23 Call, 

24 ModelReturn, 

25 ModelIf, 

26 RecordDef, 

27 Attribute, 

28 ScalarType, 

29 StructureType 

30) 

31 

32from skema.program_analysis.CAST.python.node_helper import ( 

33 NodeHelper, 

34 get_first_child_by_type, 

35 get_children_by_types, 

36 get_first_child_index, 

37 get_last_child_index, 

38 get_control_children, 

39 get_non_control_children, 

40 FOR_LOOP_LEFT_TYPES, 

41 FOR_LOOP_RIGHT_TYPES, 

42 WHILE_COND_TYPES, 

43 COMPREHENSION_OPERATORS 

44) 

45from skema.program_analysis.CAST.python.util import ( 

46 generate_dummy_source_refs, 

47 get_op 

48) 

49from skema.program_analysis.CAST.fortran.variable_context import VariableContext 

50 

51from skema.program_analysis.tree_sitter_parsers.build_parsers import INSTALLED_LANGUAGES_FILEPATH 

52 

53 

54PYTHON_VERSION = "3.10" 

55 

56class TS2CAST(object): 

57 def __init__(self, source_file_path: str, from_file = True): 

58 # from_file flag is used for testing purposes, when we don't have actual files 

59 if from_file: 

60 self.path = Path(source_file_path) 

61 self.source_file_name = self.path.name 

62 

63 # Python doesn't have a preprocessing step like fortran 

64 self.source = self.path.read_text() 

65 else: 

66 self.path = "None" 

67 self.source_file_name = "Temp" 

68 self.source = source_file_path 

69 

70 # Run tree-sitter preprocessor output to generate parse tree 

71 parser = Parser() 

72 parser.set_language( 

73 Language( 

74 INSTALLED_LANGUAGES_FILEPATH, 

75 "python" 

76 ) 

77 ) 

78 

79 # Generated FNs by comprehensions/lambdas 

80 self.generated_fns = [] 

81 

82 # Additional variables used in generation 

83 self.var_count = 0 

84 

85 # A dictionary used to keep track of aliases that imports use 

86 # (like import x as y, or from x import y as z) 

87 # Used to resolve aliasing in imports 

88 self.aliases = {} 

89 

90 # Tree walking structures 

91 self.variable_context = VariableContext() 

92 self.node_helper = NodeHelper(self.source, self.source_file_name) 

93 

94 self.tree = parser.parse(bytes(self.source, "utf8")) 

95 

96 self.out_cast = self.generate_cast() 

97 

98 def generate_cast(self) -> List[CAST]: 

99 '''Interface for generating CAST.''' 

100 module = self.run(self.tree.root_node) 

101 module.name = self.source_file_name 

102 return CAST([generate_dummy_source_refs(module)], "Python") 

103 

104 def run(self, root) -> List[Module]: 

105 # In python there's generally only one module at the top level 

106 # I believe then we just need to visit the root, which is a module 

107 # Which can then contain multiple things (Handled at module visitor) 

108 return self.visit(root) 

109 

110 # TODO: node helper for ignoring comments 

111 

112 def check_alias(self, name): 

113 """Given a python string that represents a name, 

114 this function checks to see if that name is an alias 

115 for a different name, and returns it if it is indeed an alias. 

116 Otherwise, the original name is returned. 

117 """ 

118 if name in self.aliases: 

119 return self.aliases[name] 

120 else: 

121 return name 

122 

123 def visit(self, node: Node): 

124 # print(f"===Visiting node[{node.type}]===") 

125 if node.type == "module": 

126 return self.visit_module(node) 

127 elif node.type == "parenthesized_expression": 

128 # Node for "( op )", extract op 

129 # The actual op is in the middle of the list of nodes 

130 return self.visit(node.children[1]) 

131 elif node.type == "expression_statement": 

132 return self.visit_expression(node) 

133 elif node.type == "function_definition": 

134 return self.visit_function_def(node) 

135 elif node.type == "return_statement": 

136 return self.visit_return(node) 

137 elif node.type == "call": 

138 return self.visit_call(node) 

139 elif node.type == "if_statement": 

140 return self.visit_if_statement(node) 

141 elif node.type == "comparison_operator": 

142 return self.visit_comparison_op(node) 

143 elif node.type == "assignment": 

144 return self.visit_assignment(node) 

145 elif node.type == "attribute": 

146 return self.visit_attribute(node) 

147 elif node.type == "identifier": 

148 return self.visit_identifier(node) 

149 elif node.type == "unary_operator": 

150 return self.visit_unary_op(node) 

151 elif node.type == "binary_operator": 

152 return self.visit_binary_op(node) 

153 elif node.type in ["integer", "list"]: 

154 return self.visit_literal(node) 

155 elif node.type in ["list_pattern", "pattern_list", "tuple_pattern"]: 

156 return self.visit_pattern(node) 

157 elif node.type == "list_comprehension": 

158 return self.visit_list_comprehension(node) 

159 elif node.type == "dictionary_comprehension": 

160 return self.visit_dict_comprehension(node) 

161 elif node.type == "lambda": 

162 return self.visit_lambda(node) 

163 elif node.type == "subscript": 

164 return self.visit_subscript(node) 

165 elif node.type == "slice": 

166 return self.visit_slice(node) 

167 elif node.type == "pair": 

168 return self.visit_pair(node) 

169 elif node.type == "while_statement": 

170 return self.visit_while(node) 

171 elif node.type == "for_statement": 

172 return self.visit_for(node) 

173 elif node.type == "import_statement": 

174 return self.visit_import(node) 

175 elif node.type == "import_from_statement": 

176 return self.visit_import_from(node) 

177 elif node.type == "class_definition": 

178 return self.visit_class_definition(node) 

179 elif node.type == "yield": 

180 return self.visit_yield(node) 

181 elif node.type == "assert_statement": 

182 return self.visit_assert(node) 

183 else: 

184 return self._visit_passthrough(node) 

185 

186 def visit_module(self, node: Node) -> Module: 

187 # A module is comprised of one or several statements/expressions 

188 # At the global level 

189 self.variable_context.push_context() 

190 

191 body = [] 

192 for child in node.children: 

193 child_cast = self.visit(child) 

194 if isinstance(child_cast, List): 

195 body.extend(child_cast) 

196 elif isinstance(child_cast, AstNode): 

197 body.append(child_cast) 

198 

199 self.variable_context.pop_context() 

200 

201 return Module( 

202 name=None, 

203 body=self.generated_fns + body, 

204 source_refs = [self.node_helper.get_source_ref(node)] 

205 ) 

206 

207 def visit_expression(self, node: Node): 

208 # NOTE: Is there an instance where an 'expression statement' node 

209 # Has more than one child? 

210 

211 expr_body = [] 

212 for child in node.children: 

213 child_cast = self.visit(child) 

214 if isinstance(child_cast, List): 

215 expr_body.extend(child_cast) 

216 elif isinstance(child_cast, AstNode): 

217 expr_body.append(child_cast) 

218 

219 return expr_body 

220 

221 def visit_function_def(self, node: Node) -> FunctionDef: 

222 ref = self.node_helper.get_source_ref(node) 

223 

224 name_node = get_first_child_by_type(node, "identifier") 

225 name = self.visit(name_node) 

226 

227 # Create new variable context 

228 self.variable_context.push_context() 

229 

230 parameters = get_children_by_types(node, ["parameters"])[0] 

231 parameters = get_non_control_children(parameters) 

232 

233 # The body of the function is stored in a 'block' type node 

234 body = get_children_by_types(node, ["block"])[0].children 

235 

236 

237 func_params = [] 

238 for node in parameters: 

239 cast = self.visit(node) 

240 if isinstance(cast, List): 

241 func_params.extend(cast) 

242 elif isinstance(cast, AstNode): 

243 func_params.append(cast) 

244 

245 func_body = [] 

246 for node in body: 

247 cast = self.visit(node) 

248 if isinstance(cast, List): 

249 func_body.extend(cast) 

250 elif isinstance(cast, AstNode): 

251 func_body.append(cast) 

252 # TODO: Do we need to handle return statements in any special way? 

253 

254 self.variable_context.pop_context() 

255 

256 return FunctionDef( 

257 name=name.val, 

258 func_args=func_params, 

259 body=func_body, 

260 source_refs=[ref] 

261 ) 

262 

263 def visit_return(self, node: Node) -> ModelReturn: 

264 ref = self.node_helper.get_source_ref(node) 

265 ret_val = node.children[1] 

266 ret_cast = self.visit(ret_val) 

267 

268 return ModelReturn(value=get_operand_node(ret_cast), source_refs=[ref]) 

269 

270 def visit_call(self, node: Node) -> Call: 

271 ref = self.node_helper.get_source_ref(node) 

272 

273 func_cast = self.visit(node.children[0]) 

274 

275 func_name = get_func_name_node(func_cast) 

276 

277 arg_list = get_first_child_by_type(node, "argument_list") 

278 args = get_non_control_children(arg_list) 

279 

280 func_args = [] 

281 for arg in args: 

282 cast = get_name_node(self.visit(arg)) 

283 if isinstance(cast, List): 

284 func_args.extend(cast) 

285 elif isinstance(cast, AstNode): 

286 func_args.append(cast) 

287 

288 if get_name_node(func_cast).name == "range": 

289 start_step_value = CASTLiteralValue( 

290 ScalarType.INTEGER, 

291 value="1", 

292 source_code_data_type=["Python", PYTHON_VERSION, str(type(1))], 

293 source_refs=[ref] 

294 ) 

295 # Add a step value 

296 if len(func_args) == 2: 

297 func_args.append(start_step_value) 

298 # Add a start and step value 

299 elif len(func_args) == 1: 

300 func_args.insert(0, start_step_value) 

301 func_args.append(start_step_value) 

302 

303 # Function calls only want the 'Name' part of the 'Var' that the visit returns 

304 return Call( 

305 func=func_name, 

306 arguments=func_args, 

307 source_refs=[ref] 

308 ) 

309 

310 def visit_comparison_op(self, node: Node): 

311 ref = self.node_helper.get_source_ref(node) 

312 op = get_op(self.node_helper.get_operator(node.children[1])) 

313 left, _, right = node.children 

314 

315 left_cast = get_name_node(self.visit(left)) 

316 right_cast = get_name_node(self.visit(right)) 

317 

318 return Operator( 

319 op=op, 

320 operands=[left_cast, right_cast], 

321 source_refs=[ref] 

322 ) 

323 

324 def visit_if_statement(self, node: Node) -> ModelIf: 

325 if_condition = self.visit(get_first_child_by_type(node, "comparison_operator")) 

326 

327 # Get the body of the if true part 

328 if_true = get_children_by_types(node, ["block"])[0].children 

329 

330 # Because in tree-sitter the else if, and else aren't nested, but they're  

331 # in a flat level order, we need to do some arranging of the pieces 

332 # in order to get the correct CAST nested structure that we use 

333 # Visit all the alternatives, generate CAST for each one 

334 # and then join them all together 

335 alternatives = get_children_by_types(node, ["elif_clause","else_clause"]) 

336 

337 if_true_cast = [] 

338 for node in if_true: 

339 cast = self.visit(node) 

340 if isinstance(cast, List): 

341 if_true_cast.extend(cast) 

342 elif isinstance(cast, AstNode): 

343 if_true_cast.append(cast) 

344 

345 # If we have ts nodes in alternatives, then we're guaranteed 

346 # at least an else at the end of the if-statement construct 

347 # We generate the cast for the final else statement, and then 

348 # reverse the rest of the if-elses that we have, so we can  

349 # create the CAST correctly 

350 final_else_cast = [] 

351 if len(alternatives) > 0: 

352 final_else = alternatives.pop() 

353 alternatives.reverse() 

354 final_else_body = get_children_by_types(final_else, ["block"])[0].children 

355 for node in final_else_body: 

356 cast = self.visit(node) 

357 if isinstance(cast, List): 

358 final_else_cast.extend(cast) 

359 elif isinstance(cast, AstNode): 

360 final_else_cast.append(cast) 

361 

362 # We go through any additional if-else nodes that we may have, 

363 # generating their ModelIf CAST and appending the tail of the  

364 # overall if-else construct, starting with the else at the very end 

365 # We do this tail appending so that when we finish generating CAST the 

366 # resulting ModelIf CAST is in the correct order 

367 alternatives_cast = None 

368 for ts_node in alternatives: 

369 assert ts_node.type == "elif_clause" 

370 temp_cast = self.visit_if_statement(ts_node) 

371 if alternatives_cast == None: 

372 temp_cast.orelse = final_else_cast 

373 else: 

374 temp_cast.orelse = [alternatives_cast] 

375 alternatives_cast = temp_cast 

376 

377 if alternatives_cast == None: 

378 if_false_cast = final_else_cast 

379 else: 

380 if_false_cast = [alternatives_cast] 

381 

382 return ModelIf( 

383 expr=if_condition, 

384 body=if_true_cast, 

385 orelse=if_false_cast, 

386 source_refs=[self.node_helper.get_source_ref(node)] 

387 ) 

388 

389 def visit_assignment(self, node: Node) -> Assignment: 

390 left, _, right = node.children 

391 ref = self.node_helper.get_source_ref(node) 

392 

393 # For the RHS of an assignment we want the Name CAST node 

394 # and not the entire Var CAST node if we're doing an 

395 # assignment like x = y 

396 right_cast = get_name_node(self.visit(right)) 

397 

398 return Assignment( 

399 left=self.visit(left), 

400 right=right_cast, 

401 source_refs=[ref] 

402 ) 

403 

404 def visit_unary_op(self, node: Node) -> Operator: 

405 """ 

406 Unary Ops 

407 OP operand 

408 where operand is some kind of expression 

409 """ 

410 ref = self.node_helper.get_source_ref(node) 

411 op = get_op(self.node_helper.get_operator(node.children[0])) 

412 operand = node.children[1] 

413 

414 if op == 'ast.Sub': 

415 op = 'ast.USub' 

416 

417 # For the operand we need the Name CAST node and 

418 # not the whole Var CAST node 

419 # in instances like -x 

420 operand_cast = get_name_node(self.visit(operand)) 

421 

422 if isinstance(operand_cast, Var): 

423 operand_cast = operand_cast.val 

424 

425 return Operator( 

426 op=op, 

427 operands=[operand_cast], 

428 source_refs=[ref] 

429 ) 

430 

431 def visit_binary_op(self, node: Node) -> Operator: 

432 """ 

433 Binary Ops 

434 left OP right 

435 where left and right can either be operators or literals 

436 """ 

437 ref = self.node_helper.get_source_ref(node) 

438 op = get_op(self.node_helper.get_operator(node.children[1])) 

439 left, _, right = node.children 

440 

441 left_cast = get_operand_node(self.visit(left)) 

442 right_cast = get_operand_node(self.visit(right)) 

443 

444 return Operator( 

445 op=op, 

446 operands=[left_cast, right_cast], 

447 source_refs=[ref] 

448 ) 

449 

450 def visit_pattern(self, node: Node): 

451 pattern_cast = [] 

452 for elem in node.children: 

453 cast = self.visit(elem) 

454 if isinstance(cast, List): 

455 pattern_cast.extend(cast) 

456 elif isinstance(cast, AstNode): 

457 pattern_cast.append(cast) 

458 

459 return CASTLiteralValue(value_type=StructureType.TUPLE, value=pattern_cast) 

460 

461 def visit_identifier(self, node: Node) -> Var: 

462 identifier = self.node_helper.get_identifier(node) 

463 

464 if self.variable_context.is_variable(identifier): 

465 var_type = self.variable_context.get_type(identifier) 

466 else: 

467 var_type = "unknown" 

468 

469 # TODO: Python default values 

470 default_value = None 

471 

472 value = self.visit_name(node) 

473 

474 return Var( 

475 val=value, 

476 type=var_type, 

477 default_value=default_value, 

478 source_refs=[self.node_helper.get_source_ref(node)] 

479 ) 

480 

481 def visit_literal(self, node: Node) -> Any: 

482 literal_type = node.type 

483 literal_value = self.node_helper.get_identifier(node) 

484 literal_source_ref = self.node_helper.get_source_ref(node) 

485 

486 if literal_type == "integer": 

487 return CASTLiteralValue( 

488 value_type=ScalarType.INTEGER, 

489 value=literal_value, 

490 source_code_data_type=["Python", PYTHON_VERSION, str(type(1))], 

491 source_refs=[literal_source_ref] 

492 ) 

493 elif literal_type == "float": 

494 return CASTLiteralValue( 

495 value_type=ScalarType.ABSTRACTFLOAT, 

496 value=literal_value, 

497 source_code_data_type=["Python", PYTHON_VERSION, str(type(1.0))], 

498 source_refs=[literal_source_ref] 

499 ) 

500 elif literal_type == "true" or literal_type == "false": 

501 return CASTLiteralValue( 

502 value_type=ScalarType.BOOLEAN, 

503 value="True" if literal_type == "true" else "False", 

504 source_code_data_type=["Python", PYTHON_VERSION, str(type(True))], 

505 source_refs=[literal_source_ref] 

506 ) 

507 elif literal_type == "list": 

508 list_items = [] 

509 for elem in node.children: 

510 cast = self.visit(elem) 

511 if isinstance(cast, List): 

512 list_items.extend(cast) 

513 elif isinstance(cast, AstNode): 

514 list_items.append(cast) 

515 

516 return CASTLiteralValue( 

517 value_type=StructureType.LIST, 

518 value = list_items, 

519 source_code_data_type=["Python", PYTHON_VERSION, str(type([0]))], 

520 source_refs=[literal_source_ref] 

521 ) 

522 elif literal_type == "tuple": 

523 tuple_items = [] 

524 for elem in node.children: 

525 cast = self.visit(cast) 

526 if isinstance(cast, List): 

527 tuple_items.extend(cast) 

528 elif isinstance(cast, AstNode): 

529 tuple_items.append(cast) 

530 

531 return CASTLiteralValue( 

532 value_type=StructureType.LIST, 

533 value = tuple_items, 

534 source_code_data_type=["Python", PYTHON_VERSION, str(type((0)))], 

535 source_refs=[literal_source_ref] 

536 ) 

537 

538 def handle_dotted_name(self, import_stmt) -> ModelImport: 

539 ref = self.node_helper.get_source_ref(import_stmt) 

540 name = self.node_helper.get_identifier(import_stmt) 

541 self.visit(import_stmt) 

542 

543 return name 

544 

545 def handle_aliased_import(self, import_stmt) -> ModelImport: 

546 ref = self.node_helper.get_source_ref(import_stmt) 

547 dot_name = get_children_by_types(import_stmt,["dotted_name"])[0] 

548 name = self.handle_dotted_name(dot_name) 

549 alias = get_children_by_types(import_stmt, ["identifier"])[0] 

550 self.visit(alias) 

551 

552 return (name, self.node_helper.get_identifier(alias)) 

553 

554 def visit_import(self, node: Node): 

555 ref = self.node_helper.get_source_ref(node) 

556 to_ret = [] 

557 

558 names_list = get_children_by_types(node, ["dotted_name", "aliased_import"]) 

559 for name in names_list: 

560 if name.type == "dotted_name": 

561 resolved_name = self.handle_dotted_name(name) 

562 to_ret.append(ModelImport(name=resolved_name, alias=None, symbol=None, all=False, source_refs=ref)) 

563 elif name.type == "aliased_import": 

564 resolved_name = self.handle_aliased_import(name) 

565 self.aliases[resolved_name[1]] = resolved_name[0] 

566 to_ret.append(ModelImport(name=resolved_name[0], alias=resolved_name[1], symbol=None, all=False, source_refs=ref)) 

567 

568 return to_ret 

569 

570 def visit_import_from(self, node: Node): 

571 ref = self.node_helper.get_source_ref(node) 

572 to_ret = [] 

573 

574 names_list = get_children_by_types(node, ["dotted_name", "aliased_import"]) 

575 wild_card = get_children_by_types(node, ["wildcard_import"]) 

576 module_name = self.node_helper.get_identifier(names_list[0]) 

577 

578 # if "wildcard_import" exists then it'll be in the list 

579 if len(wild_card) == 1: 

580 to_ret.append(ModelImport(name=module_name, alias=None, symbol=None, all=True, source_refs=ref)) 

581 else: 

582 for name in names_list[1:]: 

583 if name.type == "dotted_name": 

584 resolved_name = self.handle_dotted_name(name) 

585 to_ret.append(ModelImport(name=module_name, alias=None, symbol=resolved_name, all=False, source_refs=ref)) 

586 elif name.type == "aliased_import": 

587 resolved_name = self.handle_aliased_import(name) 

588 self.aliases[resolved_name[1]] = resolved_name[0] 

589 to_ret.append(ModelImport(name=module_name, alias=resolved_name[1], symbol=resolved_name[0], all=False, source_refs=ref)) 

590 

591 return to_ret 

592 

593 def visit_attribute(self, node: Node): 

594 ref = self.node_helper.get_source_ref(node) 

595 obj,_,attr = node.children 

596 obj_cast = self.visit(obj) 

597 attr_cast = self.visit(attr) 

598 

599 return Attribute(value= get_name_node(obj_cast), attr=get_name_node(attr_cast), source_refs=ref) 

600 

601 def visit_subscript(self, node: Node): 

602 ref = self.node_helper.get_source_ref(node) 

603 values = get_non_control_children(node) 

604 name_cast = self.visit(values[0]) 

605 subscript_list = values[1:] 

606 subscript_casts = [] 

607 for subscript in subscript_list: 

608 cast = self.visit(subscript) 

609 if isinstance(cast, list): 

610 for elem in cast: 

611 subscript_casts.append(get_func_name_node(cast)) 

612 else: 

613 subscript_casts.append(get_func_name_node(cast)) 

614 

615 get_func = self.get_gromet_function_node("_get") 

616 

617 get_call = Call( 

618 func = get_func, 

619 arguments = [get_func_name_node(name_cast)] + subscript_casts, 

620 source_refs=ref 

621 ) 

622 

623 return get_call 

624 

625 def visit_slice(self, node: Node): 

626 ref = self.node_helper.get_source_ref(node) 

627 indices = get_non_control_children(node) 

628 index_cast = [] 

629 

630 for index in indices: 

631 cast = self.visit(index) 

632 if isinstance(cast ,list): 

633 index_cast.extend(cast) 

634 else: 

635 index_cast.append(cast) 

636 

637 start = index_cast[0] 

638 end = index_cast[1] 

639 if len(index_cast) == 3: 

640 step = index_cast[2] 

641 else: 

642 step = CASTLiteralValue( 

643 value_type=ScalarType.INTEGER, 

644 value="1", 

645 source_code_data_type=["Python", "3.8", "Float"], 

646 source_refs=ref, 

647 ) 

648 

649 return CASTLiteralValue(value_type=StructureType.LIST, 

650 value=[start,end,step], 

651 source_code_data_type=["Python", "3.8", "List"], 

652 source_refs=ref 

653 ) 

654 

655 

656 def handle_for_clause(self, node: Node): 

657 # Given the "for x in seq" clause of a list comprehension 

658 # we translate it to a CAST for loop, leaving the actual 

659 # computation of the body node for the main comprehension handler 

660 assert node.type == "for_in_clause" 

661 ref = self.node_helper.get_source_ref(node) 

662 

663 # NOTE: Assumes the left part with the variable is always the 2nd 

664 # element in the children and the right part with the actual 

665 # function call is on the 4th (last) element of the children 

666 left = self.visit(node.children[1]) 

667 right = self.visit(node.children[-1]) 

668 

669 iterator_name = self.variable_context.generate_iterator() 

670 stop_cond_name = self.variable_context.generate_stop_condition() 

671 iter_func = self.get_gromet_function_node("iter") 

672 next_func = self.get_gromet_function_node("next") 

673 

674 iter_call = Assignment( 

675 left = Var(iterator_name, "Iterator"), 

676 right = Call( 

677 iter_func, 

678 arguments=[right] 

679 ) 

680 ) 

681 

682 next_call = Call( 

683 next_func, 

684 arguments=[Var(iterator_name, "Iterator")] 

685 ) 

686 

687 next_assign = Assignment( 

688 left=CASTLiteralValue( 

689 "Tuple", 

690 [ 

691 left, 

692 Var(iterator_name, "Iterator"), 

693 Var(stop_cond_name, "Boolean"), 

694 ], 

695 source_code_data_type = ["Python",PYTHON_VERSION,"Tuple"], 

696 source_refs=ref 

697 ), 

698 right=next_call 

699 ) 

700 

701 loop_pre = [] 

702 loop_pre.append(iter_call) 

703 loop_pre.append(next_assign) 

704 

705 loop_expr = Operator( 

706 source_language="Python", 

707 interpreter="Python", 

708 version=PYTHON_VERSION, 

709 op="ast.Eq", 

710 operands=[ 

711 stop_cond_name, 

712 CASTLiteralValue( 

713 ScalarType.BOOLEAN, 

714 False, 

715 ["Python", PYTHON_VERSION, "boolean"], 

716 source_refs=ref, 

717 ) 

718 ], 

719 source_refs=ref 

720 ) 

721 

722 loop_body = [None, next_assign] 

723 

724 return Loop(pre=loop_pre, expr=loop_expr, body=loop_body, post=[]) 

725 

726 def handle_if_clause(self, node: Node): 

727 assert node.type == "if_clause" 

728 ref = self.node_helper.get_source_ref(node) 

729 conditional = get_children_by_types(node, WHILE_COND_TYPES)[0] 

730 cond_cast = self.visit(conditional) 

731 

732 return ModelIf(expr=cond_cast,body=[],orelse=[],source_refs=ref) 

733 

734 def construct_loop_construct(self, node: Node): 

735 return [] 

736 

737 def visit_list_comprehension(self, node: Node) -> Call: 

738 ref = self.node_helper.get_source_ref(node) 

739 

740 temp_list_name = self.variable_context.add_variable( 

741 "list__temp_", "Unknown", [ref] 

742 ) 

743 

744 temp_asg_cast = Assignment( 

745 left=Var(val=temp_list_name), 

746 right=CASTLiteralValue(value=[], value_type=StructureType.LIST), 

747 source_refs = ref 

748 ) 

749 

750 append_call = self.get_gromet_function_node("append") 

751 computation = get_children_by_types(node, COMPREHENSION_OPERATORS)[0] 

752 computation_cast = self.visit(computation) 

753 

754 # IDEA: When we see a for_clause we start a new loop construct, and collect if_clauses  

755 # as we see them 

756 clauses = get_children_by_types(node, ["for_in_clause", "if_clause"]) 

757 loop_start = [] 

758 prev_loop = [] 

759 

760 if_start = [] 

761 prev_if = [] 

762 

763 for clause in clauses: 

764 if clause.type == "for_in_clause": 

765 new_loop = self.handle_for_clause(clause) 

766 if loop_start == []: 

767 loop_start = new_loop 

768 prev_loop = loop_start 

769 else: 

770 if prev_if == []: 

771 prev_loop.body[0] = new_loop 

772 prev_loop = new_loop 

773 else: 

774 prev_loop.body[0] = prev_if 

775 prev_if.body = [new_loop] 

776 prev_loop = new_loop 

777 if_start = [] 

778 prev_if = [] 

779 elif clause.type == "if_clause": 

780 new_if = self.handle_if_clause(clause) 

781 if if_start == []: 

782 if_start = new_if 

783 prev_if = if_start 

784 else: 

785 prev_if.body = [new_if] 

786 prev_if = new_if 

787 

788 if prev_if == []: 

789 prev_loop.body[0] = Call(func=Attribute(temp_list_name, append_call), arguments=[computation_cast], source_refs=ref) 

790 else: 

791 prev_loop.body[0] = prev_if 

792 prev_if.body = [Call(func=Attribute(temp_list_name, append_call), arguments=[computation_cast], source_refs=ref)] 

793 

794 return_cast = ModelReturn(temp_list_name) 

795 

796 func_name = self.variable_context.generate_func("%comprehension_list") 

797 func_def_cast = FunctionDef(name=func_name, func_args=[], body=[temp_asg_cast,loop_start,return_cast], source_refs=ref) 

798 

799 self.generated_fns.append(func_def_cast) 

800 

801 return Call(func=func_name, arguments=[], source_refs=ref) 

802 

803 def visit_pair(self, node: Node): 

804 key = self.visit(node.children[0]) 

805 value = self.visit(node.children[2]) 

806 

807 return key,value 

808 

809 def visit_dict_comprehension(self, node: Node) -> Call: 

810 ref = self.node_helper.get_source_ref(node) 

811 

812 temp_dict_name = self.variable_context.add_variable( 

813 "dict__temp_", "Unknown", [ref] 

814 ) 

815 

816 temp_asg_cast = Assignment( 

817 left=Var(val=temp_dict_name), 

818 right=CASTLiteralValue(value={}, value_type=StructureType.MAP), 

819 source_refs = ref 

820 ) 

821 

822 set_call = self.get_gromet_function_node("_set") 

823 computation = get_children_by_types(node, COMPREHENSION_OPERATORS)[0] 

824 computation_cast = self.visit(computation) 

825 

826 # IDEA: When we see a for_clause we start a new loop construct, and collect if_clauses  

827 # as we see them 

828 clauses = get_children_by_types(node, ["for_in_clause", "if_clause"]) 

829 loop_start = [] 

830 prev_loop = [] 

831 

832 if_start = [] 

833 prev_if = [] 

834 

835 for clause in clauses: 

836 if clause.type == "for_in_clause": 

837 new_loop = self.handle_for_clause(clause) 

838 if loop_start == []: 

839 loop_start = new_loop 

840 prev_loop = loop_start 

841 else: 

842 if prev_if == []: 

843 prev_loop.body[0] = new_loop 

844 prev_loop = new_loop 

845 else: 

846 prev_loop.body[0] = prev_if 

847 prev_if.body = [new_loop] 

848 prev_loop = new_loop 

849 if_start = [] 

850 prev_if = [] 

851 elif clause.type == "if_clause": 

852 new_if = self.handle_if_clause(clause) 

853 if if_start == []: 

854 if_start = new_if 

855 prev_if = if_start 

856 else: 

857 prev_if.body = [new_if] 

858 prev_if = new_if 

859 

860 if prev_if == []: 

861 prev_loop.body[0] = Assignment(left=Var(val=temp_dict_name), right=Call(func=set_call, arguments=[temp_dict_name, computation_cast[0].val, computation_cast[1]], source_refs=ref), source_refs=ref) 

862 else: 

863 prev_loop.body[0] = prev_if 

864 prev_loop = Assignment(left=Var(val=temp_dict_name), right=Call(func=set_call, arguments=[temp_dict_name, computation_cast[0].val, computation_cast[1]], source_refs=ref), source_refs=ref) 

865 

866 return_cast = ModelReturn(temp_dict_name) 

867 

868 func_name = self.variable_context.generate_func("%comprehension_dict") 

869 func_def_cast = FunctionDef(name=func_name, func_args=[], body=[temp_asg_cast,loop_start,return_cast], source_refs=ref) 

870 

871 self.generated_fns.append(func_def_cast) 

872 

873 return Call(func=func_name, arguments=[], source_refs=ref) 

874 

875 

876 def visit_lambda(self, node: Node) -> Call: 

877 # TODO: we have to determine how to grab the variables that are being 

878 # used in the lambda that aren't part of the lambda's arguments 

879 ref=self.node_helper.get_source_ref(node) 

880 params = get_children_by_types(node, ["lambda_parameters"])[0] 

881 body = get_children_by_types(node, COMPREHENSION_OPERATORS)[0] 

882 

883 parameters = [] 

884 for param in params.children: 

885 cast = self.visit(param) 

886 if isinstance(cast, list): 

887 parameters.extend(cast) 

888 else: 

889 parameters.append(cast) 

890 

891 body_cast = self.visit(body) 

892 func_body = body_cast 

893 

894 func_name = self.variable_context.generate_func("%lambda") 

895 func_def_cast = FunctionDef(name=func_name, func_args=parameters, body=[ModelReturn(value=func_body)], source_refs=ref) 

896 

897 self.generated_fns.append(func_def_cast) 

898 

899 # Collect all the Name node instances to use as arguments for the lambda call 

900 args = [par.val if isinstance(par, Var) else par for par in parameters] 

901 

902 return Call(func=func_name, arguments=args, source_refs=ref) 

903 

904 def visit_while(self, node: Node) -> Loop: 

905 ref = self.node_helper.get_source_ref(node) 

906 

907 # Push a variable context since a loop  

908 # can create variables that only it can see 

909 self.variable_context.push_context() 

910 

911 loop_cond_node = get_children_by_types(node, WHILE_COND_TYPES)[0] 

912 loop_body_node = get_children_by_types(node, ["block"])[0].children 

913 

914 loop_cond = self.visit(loop_cond_node) 

915 

916 loop_body = [] 

917 for node in loop_body_node: 

918 cast = self.visit(node) 

919 if isinstance(cast, List): 

920 loop_body.extend(cast) 

921 elif isinstance(cast, AstNode): 

922 loop_body.append(cast) 

923 

924 self.variable_context.pop_context() 

925 

926 return Loop( 

927 pre=[], 

928 expr=loop_cond, 

929 body=loop_body, 

930 post=[], 

931 source_refs = ref 

932 ) 

933 

934 def visit_for(self, node: Node) -> Loop: 

935 ref = self.node_helper.get_source_ref(node) 

936 

937 # Pre: left, right  

938 loop_cond_left = get_children_by_types(node, FOR_LOOP_LEFT_TYPES)[0] 

939 loop_cond_right = get_children_by_types(node, FOR_LOOP_RIGHT_TYPES)[-1] 

940 

941 # Construct pre and expr value using left and right as needed 

942 # need calls to "_Iterator" 

943 

944 self.variable_context.push_context() 

945 iterator_name = self.variable_context.generate_iterator() 

946 stop_cond_name = self.variable_context.generate_stop_condition() 

947 iter_func = self.get_gromet_function_node("iter") 

948 next_func = self.get_gromet_function_node("next") 

949 

950 loop_cond_left_cast = self.visit(loop_cond_left) 

951 loop_cond_right_cast = self.visit(loop_cond_right) 

952 

953 loop_pre = [] 

954 loop_pre.append( 

955 Assignment( 

956 left = Var(iterator_name, "Iterator"), 

957 right = Call( 

958 iter_func, 

959 arguments=[loop_cond_right_cast] 

960 ) 

961 ) 

962 ) 

963 

964 loop_pre.append( 

965 Assignment( 

966 left=CASTLiteralValue( 

967 "Tuple", 

968 [ 

969 loop_cond_left_cast, 

970 Var(iterator_name, "Iterator"), 

971 Var(stop_cond_name, "Boolean"), 

972 ], 

973 source_code_data_type = ["Python",PYTHON_VERSION,"Tuple"], 

974 source_refs=ref 

975 ), 

976 right=Call( 

977 next_func, 

978 arguments=[Var(iterator_name, "Iterator")], 

979 ), 

980 ) 

981 ) 

982 

983 loop_expr = Operator( 

984 source_language="Python", 

985 interpreter="Python", 

986 version=PYTHON_VERSION, 

987 op="ast.Eq", 

988 operands=[ 

989 stop_cond_name, 

990 CASTLiteralValue( 

991 ScalarType.BOOLEAN, 

992 False, 

993 ["Python", PYTHON_VERSION, "boolean"], 

994 source_refs=ref, 

995 ) 

996 ], 

997 source_refs=ref 

998 ) 

999 

1000 loop_body_node = get_children_by_types(node, ["block"])[0].children 

1001 loop_body = [] 

1002 for node in loop_body_node: 

1003 cast = self.visit(node) 

1004 if isinstance(cast, List): 

1005 loop_body.extend(cast) 

1006 elif isinstance(cast, AstNode): 

1007 loop_body.append(cast) 

1008 

1009 # Insert an additional call to 'next' at the end of the loop body, 

1010 # to facilitate looping in GroMEt  

1011 loop_body.append( 

1012 Assignment( 

1013 left=CASTLiteralValue( 

1014 "Tuple", 

1015 [ 

1016 loop_cond_left_cast, 

1017 Var(iterator_name, "Iterator"), 

1018 Var(stop_cond_name, "Boolean"), 

1019 ], 

1020 ), 

1021 right=Call( 

1022 next_func, 

1023 arguments=[Var(iterator_name, "Iterator")], 

1024 ), 

1025 ) 

1026 ) 

1027 

1028 self.variable_context.pop_context() 

1029 return Loop( 

1030 pre=loop_pre, 

1031 expr=loop_expr, 

1032 body=loop_body, 

1033 post=[], 

1034 source_refs = ref 

1035 ) 

1036 

1037 def retrieve_init_func(self, functions: List[FunctionDef]): 

1038 # Given a list of CAST function defs, we 

1039 # attempt to retrieve the CAST function that corresponds to 

1040 # "__init__" 

1041 for func in functions: 

1042 if func.name.name == "__init__": 

1043 return func 

1044 return None 

1045 

1046 def retrieve_class_attrs(self, init_func: FunctionDef): 

1047 attrs = [] 

1048 for stmt in init_func.body: 

1049 if isinstance(stmt, Assignment): 

1050 if isinstance(stmt.left, Attribute): 

1051 if stmt.left.value.name == "self": 

1052 attrs.append(stmt.left.attr) 

1053 

1054 return attrs 

1055 

1056 def visit_class_definition(self, node): 

1057 class_name_node = get_first_child_by_type(node, "identifier") 

1058 class_cast = self.visit(class_name_node) 

1059 

1060 function_defs = get_children_by_types(get_children_by_types(node, "block")[0], "function_definition") 

1061 func_defs_cast = [] 

1062 for func in function_defs: 

1063 func_cast = self.visit(func) 

1064 if isinstance(func_cast, List): 

1065 func_defs_cast.extend(func_cast) 

1066 else: 

1067 func_defs_cast.append(func_cast) 

1068 

1069 init_func = self.retrieve_init_func(func_defs_cast) 

1070 attributes = self.retrieve_class_attrs(init_func) 

1071 

1072 return RecordDef(name=get_name_node(class_cast).name, bases=[], funcs=func_defs_cast, fields=attributes) 

1073 

1074 

1075 def visit_name(self, node): 

1076 # First, we will check if this name is already defined, and if it is return the name node generated previously 

1077 # NOTE: the call to check_alias is a crucial change, to resolve aliasing 

1078 # need to make sure nothing breaks 

1079 identifier = self.check_alias(self.node_helper.get_identifier(node)) 

1080 if self.variable_context.is_variable(identifier): 

1081 return self.variable_context.get_node(identifier) 

1082 

1083 return self.variable_context.add_variable( 

1084 identifier, "Unknown", [self.node_helper.get_source_ref(node)] 

1085 ) 

1086 

1087 def _visit_passthrough(self, node): 

1088 if len(node.children) == 0: 

1089 return None 

1090 

1091 for child in node.children: 

1092 child_cast = self.visit(child) 

1093 if child_cast: 

1094 return child_cast 

1095 

1096 def get_gromet_function_node(self, func_name: str) -> Name: 

1097 # Idealy, we would be able to create a dummy node and just call the name visitor. 

1098 # However, tree-sitter does not allow you to create or modify nodes, so we have to recreate the logic here. 

1099 if self.variable_context.is_variable(func_name): 

1100 return self.variable_context.get_node(func_name) 

1101 

1102 return self.variable_context.add_variable(func_name, "function", None) 

1103 

1104 def visit_yield(self, node): 

1105 source_code_data_type = ["Python", "3.8", "List"] 

1106 ref = self.node_helper.get_source_ref(node) 

1107 return [ 

1108 CASTLiteralValue( 

1109 StructureType.LIST, 

1110 "YieldNotImplemented", 

1111 source_code_data_type, 

1112 ref 

1113 ) 

1114 ] 

1115 

1116 def visit_assert(self, node): 

1117 source_code_data_type = ["Python", "3.8", "List"] 

1118 ref = self.node_helper.get_source_ref(node) 

1119 return [ 

1120 CASTLiteralValue( 

1121 StructureType.LIST, 

1122 "AssertNotImplemented", 

1123 source_code_data_type, 

1124 ref 

1125 ) 

1126 ] 

1127 

1128 

1129def get_name_node(node): 

1130 # Given a CAST node, if it's type Var, then we extract the name node out of it 

1131 # If it's anything else, then the node just gets returned normally 

1132 cur_node = node 

1133 if isinstance(node, list): 

1134 cur_node = node[0] 

1135 if isinstance(cur_node, Attribute): 

1136 return get_name_node(cur_node.attr) 

1137 if isinstance(cur_node, Var): 

1138 return cur_node.val 

1139 else: 

1140 return cur_node 

1141 

1142def get_func_name_node(node): 

1143 # Given a CAST node, we attempt to extract the appropriate name element 

1144 # from it.  

1145 cur_node = node 

1146 if isinstance(cur_node, Var): 

1147 return cur_node.val 

1148 else: 

1149 return cur_node 

1150 

1151def get_operand_node(node): 

1152 # Given a CAST/list node, we extract the appropriate operand for the operator from it 

1153 cur_node = node 

1154 if isinstance(node, list): 

1155 cur_node = node[0] 

1156 if isinstance(cur_node, Var): 

1157 return cur_node.val 

1158 else: 

1159 return cur_node