Coverage for skema/program_analysis/CAST/matlab/matlab_to_cast.py: 84%

224 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import json 

2import os.path 

3from pathlib import Path 

4from typing import Any, Dict, List, Union 

5from tree_sitter import Language, Parser, Node, Tree 

6from skema.program_analysis.CAST2FN.cast import CAST 

7from skema.program_analysis.CAST2FN.model.cast import ( 

8 Assignment, 

9 AstNode, 

10 Attribute, 

11 Call, 

12 FunctionDef, 

13 CASTLiteralValue, 

14 Loop, 

15 ModelBreak, 

16 ModelContinue, 

17 ModelIf, 

18 ModelImport, 

19 ModelReturn, 

20 Module, 

21 Name, 

22 Operator, 

23 RecordDef, 

24 ScalarType, 

25 SourceCodeDataType, 

26 SourceRef, 

27 StructureType, 

28 ValueConstructor, 

29 Var, 

30 VarType, 

31) 

32from skema.program_analysis.CAST.matlab.variable_context import ( 

33 VariableContext 

34) 

35from skema.program_analysis.CAST.matlab.node_helper import ( 

36 get_children_by_types, 

37 get_control_children, 

38 get_first_child_by_type, 

39 get_keyword_children, 

40 NodeHelper 

41) 

42from skema.program_analysis.CAST.matlab.tokens import KEYWORDS 

43from skema.program_analysis.tree_sitter_parsers.build_parsers import( 

44 INSTALLED_LANGUAGES_FILEPATH 

45) 

46 

47MATLAB_VERSION='matlab_version_here' 

48MODULE_NAME='module_name_here' 

49INTERPRETER='matlab_to_cast' 

50 

51class MatlabToCast(object): 

52 

53 node_visits = dict() 

54 

55 def __init__(self, source_path = "", source = ""): 

56 

57 # if a source file path is provided, read source from file 

58 if not source_path == "": 

59 path = Path(source_path) 

60 self.filename = path.name 

61 self.source = path.read_text().strip() 

62 # otherwise copy the input source and flag the filename unused 

63 else: 

64 self.source = source 

65 self.filename = "None" 

66 

67 # create MATLAB parser 

68 parser = Parser() 

69 parser.set_language( 

70 Language(INSTALLED_LANGUAGES_FILEPATH, "matlab") 

71 ) 

72 

73 # create a syntax tree using the source file 

74 self.tree = parser.parse(bytes(self.source, "utf8")) 

75 

76 # create helper classes 

77 self.variable_context = VariableContext() 

78 self.node_helper = NodeHelper(self.source, self.filename) 

79 

80 # create CAST object  

81 module = self.run(self.tree.root_node) 

82 module.name = MODULE_NAME 

83 self.out_cast = CAST([module], "matlab") 

84 

85 def log_visit(self, node: Node): 

86 """ record a visit of a node type """ 

87 key = node.type if node else "None" 

88 value = self.node_visits[key] + 1 if key in self.node_visits else 1 

89 self.node_visits[key] = value 

90 

91 def run(self, root) -> Module: 

92 """ process an entire Tree-sitter syntax instance """ 

93 return self.visit(root) 

94 

95 def visit(self, node): 

96 """Switch execution based on node type""" 

97 self.log_visit(node) 

98 if node.type == "assignment": 

99 return self.visit_assignment(node) 

100 elif node.type == "boolean": 

101 return self.visit_boolean(node) 

102 elif node.type == "command": 

103 return self.visit_command(node) 

104 elif node.type == "function_call": 

105 return self.visit_function_call(node) 

106 elif node.type == "function_definition": 

107 return self.visit_function_def(node) 

108 elif node.type in [ 

109 "identifier" 

110 ]:return self.visit_identifier(node) 

111 elif node.type == "if_statement": 

112 return self.visit_if_statement(node) 

113 elif node.type == "iterator": 

114 return self.visit_iterator(node) 

115 elif node.type == "for_statement": 

116 return self.visit_for_statement(node) 

117 elif node.type in [ 

118 "cell", 

119 "matrix" 

120 ]: return self.visit_matrix(node) 

121 elif node.type == "source_file": 

122 return self.visit_module(node) 

123 elif node.type in [ 

124 "command_name", 

125 "command_argument", 

126 "name" 

127 ]: return self.visit_name(node) 

128 elif node.type == "number": 

129 return self.visit_number(node) 

130 elif node.type in [ 

131 "binary_operator", 

132 "boolean_operator", 

133 "comparison_operator", 

134 "unary_operator", 

135 "spread_operator", 

136 "postfix_operator", 

137 "not_operator" 

138 ]: return self.visit_operator(node) 

139 elif node.type == "string": 

140 return self.visit_string(node) 

141 elif node.type == "switch_statement": 

142 return self.visit_switch_statement(node) 

143 else: 

144 return self._visit_passthrough(node) 

145 

146 def visit_assignment(self, node): 

147 """ Translate Tree-sitter assignment node """ 

148 children = get_keyword_children(node) 

149 return Assignment( 

150 left=self.visit(children[0]), 

151 right=self.visit(children[1]), 

152 source_refs=[self.node_helper.get_source_ref(node)], 

153 ) 

154 

155 def visit_boolean(self, node): 

156 """ Translate Tree-sitter boolean node """ 

157 for child in node.children: 

158 # set the first letter to upper case for python 

159 value = child.type 

160 value = value[0].upper() + value[1:].lower() 

161 # store as string, use Python Boolean capitalization. 

162 

163 value_type = ScalarType.BOOLEAN 

164 return CASTLiteralValue( 

165 value_type=value_type, 

166 value = value, 

167 source_code_data_type=["matlab", MATLAB_VERSION, ScalarType.BOOLEAN], 

168 source_refs=[self.node_helper.get_source_ref(node)], 

169 ) 

170 

171 def visit_command(self, node): 

172 """ Translate the Tree-sitter command node """ 

173 children = get_keyword_children(node) 

174 argument = [self.visit(children[1])] if len(children) > 1 else [] 

175 return Call( 

176 func = self.visit(children[0]), 

177 source_language = "matlab", 

178 source_language_version = MATLAB_VERSION, 

179 arguments = argument, 

180 source_refs=[self.node_helper.get_source_ref(node)] 

181 ) 

182 

183 def visit_function_call(self, node): 

184 """ Translate Tree-sitter function call node """ 

185 func = self.visit(get_keyword_children(node)[0]) 

186 arguments = [self.visit(child) for child in 

187 get_keyword_children(get_first_child_by_type(node, "arguments"))] 

188 return Call( 

189 func = func, 

190 source_language = "matlab", 

191 source_language_version = MATLAB_VERSION, 

192 arguments = arguments, 

193 source_refs=[self.node_helper.get_source_ref(node)] 

194 ) 

195 

196 def visit_function_def(self, node): 

197 """ return a CAST transation of a MATLAB function definition """ 

198 return FunctionDef( 

199 name = self.visit(get_first_child_by_type(node, "function_output")), 

200 body = self.get_block(node), 

201 func_args = [self.visit(child) for child in 

202 get_keyword_children( 

203 get_first_child_by_type(node, "function_arguments"))], 

204 source_refs=[self.node_helper.get_source_ref(node)] 

205 ) 

206 

207 def visit_identifier(self, node): 

208 """ return an identifier (variable) node """ 

209 identifier = self.node_helper.get_identifier(node) 

210 return Var( 

211 val = self.visit_name(node), 

212 type = self.variable_context.get_type(identifier) if 

213 self.variable_context.is_variable(identifier) else "Unknown", 

214 default_value = CASTLiteralValue( 

215 value_type=ScalarType.CHARACTER, 

216 value=self.node_helper.get_identifier(node), 

217 source_code_data_type=["matlab", MATLAB_VERSION, ScalarType.CHARACTER], 

218 source_refs=[self.node_helper.get_source_ref(node)] 

219 ), 

220 source_refs = [self.node_helper.get_source_ref(node)], 

221 ) 

222 

223 def visit_if_statement(self, node): 

224 """ return a node describing if, elseif, else conditional logic""" 

225 

226 def get_conditional(conditional_node): 

227 """ return a ModelIf struct for the conditional logic node. """ 

228 

229 # Conditional will be after the "if" or "elseif" child 

230 for i, child in enumerate(conditional_node.children): 

231 if child.type in ["if", "elseif"]: 

232 expr = conditional_node.children[i+1] 

233 

234 return ModelIf( 

235 expr = self.visit(expr), 

236 body = self.get_block(conditional_node), 

237 orelse = [], 

238 source_refs=[self.node_helper.get_source_ref(conditional_node)] 

239 ) 

240 

241 # start with the if statement 

242 first = get_conditional(node) 

243 current = first 

244 

245 # add 0-n elseif clauses 

246 for child in get_children_by_types(node, ["elseif_clause"]): 

247 current.orelse = [get_conditional(child)] 

248 current = current.orelse[0] 

249 

250 # add 0-1 else clause  

251 else_clause = get_first_child_by_type(node, "else_clause") 

252 if else_clause: 

253 current.orelse = self.get_block(else_clause) 

254 

255 return first 

256 

257 # CAST has no Iterator node, so we return a partially  

258 # completed Loop object  

259 # MATLAB iterators are either matrices or ranges. 

260 def visit_iterator(self, node) -> Loop: 

261 

262 itr_var = self.visit(get_first_child_by_type(node, "identifier")) 

263 source_ref = self.node_helper.get_source_ref(node) 

264 

265 # process matrix iterator 

266 matrix_node = get_first_child_by_type(node, "matrix") 

267 if matrix_node is not None: 

268 row_node = get_first_child_by_type(matrix_node, "row") 

269 if row_node is not None: 

270 mat = [self.visit(child) for child in 

271 get_keyword_children(row_node)] 

272 mat_idx = 0 

273 mat_len = len(mat) 

274 

275 

276 return Loop( 

277 pre = [ 

278 Assignment( 

279 left = "_mat", 

280 right = mat, 

281 source_refs = [source_ref] 

282 ), 

283 Assignment( 

284 left = "_mat_len", 

285 right = mat_len, 

286 source_refs = [source_ref] 

287 ), 

288 Assignment( 

289 left = "_mat_idx", 

290 right = mat_idx, 

291 source_refs = [source_ref] 

292 ), 

293 Assignment( 

294 left = itr_var, 

295 right = mat[mat_idx], 

296 source_refs = [source_ref] 

297 ) 

298 ], 

299 expr = self.get_operator( 

300 op = "<", 

301 operands = ["_mat_idx", "_mat_len"], 

302 source_refs = [source_ref] 

303 ), 

304 body = [ 

305 Assignment( 

306 left = "_mat_idx", 

307 right = self.get_operator( 

308 op = "+", 

309 operands = ["_mat_idx", 1], 

310 source_refs = [source_ref] 

311 ), 

312 source_refs = [source_ref] 

313 ), 

314 Assignment( 

315 left = itr_var, 

316 right = "_mat[_mat_idx]", 

317 source_refs = [source_ref] 

318 ) 

319 ], 

320 post = [] 

321 ) 

322 

323 

324 

325 # process range iterator 

326 range_node = get_first_child_by_type(node, "range") 

327 if range_node is not None: 

328 numbers = [self.visit(child) for child in 

329 get_children_by_types(range_node, ["number"])] 

330 start = numbers[0] 

331 step = 1 

332 stop = 0 

333 

334 # two values mean the step is implicitely defined as 1 

335 if len(numbers) == 2: 

336 stop = numbers[1] 

337 

338 # three values mean the step is explictely defined 

339 elif len(numbers) == 3: 

340 step = numbers[1] 

341 stop = numbers[2] 

342 

343 # create the itrerator based on the range limits and step 

344 range_name_node = self.variable_context.get_gromet_function_node("range") 

345 iter_name_node = self.variable_context.get_gromet_function_node("iter") 

346 next_name_node = self.variable_context.get_gromet_function_node("next") 

347 generated_iter_name_node = self.variable_context.generate_iterator() 

348 stop_condition_name_node = self.variable_context.generate_stop_condition() 

349 

350 return Loop( 

351 pre = [ 

352 Assignment( 

353 left = itr_var, 

354 right = start, 

355 source_refs = [source_ref] 

356 ) 

357 ], 

358 expr = self.get_operator( 

359 op = "<=", 

360 operands = [itr_var, stop], 

361 source_refs = [source_ref] 

362 ), 

363 body = [ 

364 Assignment( 

365 left = itr_var, 

366 right = self.get_operator( 

367 op = "+", 

368 operands = [itr_var, step], 

369 source_refs = [source_ref] 

370 ), 

371 source_refs = [source_ref] 

372 ) 

373 ], 

374 post = [] 

375 ) 

376 

377 def visit_for_statement(self, node) -> Loop: 

378 """ Translate Tree-sitter for loop node into CAST Loop node """ 

379 

380 loop = self.visit(get_first_child_by_type(node, "iterator")) 

381 loop.source_refs=[self.node_helper.get_source_ref(node)] 

382 loop.body = self.get_block(node) + loop.body 

383 

384 return loop 

385 

386 

387 def visit_matrix(self, node): 

388 """ Translate the Tree-sitter cell node into a List """ 

389 

390 def get_values(element, ret): 

391 for child in get_keyword_children(element): 

392 if child.type == "row": 

393 ret.append(get_values(child, [])) 

394 else: 

395 ret.append(self.visit(child)) 

396 return ret 

397 

398 values = get_values(node, []) 

399 value = [] 

400 if len(values) > 0: 

401 value = values[0] 

402 

403 value_type=StructureType.LIST 

404 return CASTLiteralValue( 

405 value_type=value_type, 

406 value = value, 

407 source_code_data_type=["matlab", MATLAB_VERSION, StructureType.LIST], 

408 source_refs=[self.node_helper.get_source_ref(node)], 

409 ) 

410 

411 def visit_module(self, node: Node) -> Module: 

412 """Visitor for program and module statement. Returns a Module object""" 

413 self.variable_context.push_context() 

414 

415 program_body = [] 

416 for child in node.children: 

417 child_cast = self.visit(child) 

418 if isinstance(child_cast, List): 

419 program_body.extend(child_cast) 

420 elif isinstance(child_cast, AstNode): 

421 program_body.append(child_cast) 

422 

423 self.variable_context.pop_context() 

424 

425 return Module( 

426 name=None, #TODO: Fill out name field 

427 body=program_body, 

428 source_refs = [self.node_helper.get_source_ref(node)] 

429 ) 

430 

431 def visit_name(self, node): 

432 """ return or create the node for this variable name """ 

433 identifier = self.node_helper.get_identifier(node) 

434 # if the identifier exists, return its node 

435 if self.variable_context.is_variable(identifier): 

436 return self.variable_context.get_node(identifier) 

437 # create a new node 

438 return self.variable_context.add_variable( 

439 identifier, "Unknown", [self.node_helper.get_source_ref(node)] 

440 ) 

441 

442 

443 def visit_number(self, node) -> CASTLiteralValue: 

444 """Visitor for numbers """ 

445 number = self.node_helper.get_identifier(node) 

446 # Check if this is a real value, or an Integer 

447 literal_value = self.node_helper.get_identifier(node) 

448 if "e" in literal_value.lower() or "." in literal_value: 

449 value_type = "AbstractFloat" 

450 return CASTLiteralValue( 

451 value_type=value_type, 

452 value=float(literal_value), 

453 source_code_data_type=["matlab", MATLAB_VERSION, value_type], 

454 source_refs=[self.node_helper.get_source_ref(node)] 

455 ) 

456 value_type = "Integer" 

457 return CASTLiteralValue( 

458 value_type=value_type, 

459 value=int(literal_value), 

460 source_code_data_type=["matlab", MATLAB_VERSION, value_type], 

461 source_refs=[self.node_helper.get_source_ref(node)] 

462 ) 

463 

464 def visit_operator(self, node): 

465 """return an operator based on the Tree-sitter node """ 

466 # The operator will be the first control character 

467 op = self.node_helper.get_identifier( 

468 get_control_children(node)[0] 

469 ) 

470 # the operands will be the keyword children 

471 operands=[self.visit(child) for child in get_keyword_children(node)] 

472 return self.get_operator( 

473 op = op, 

474 operands = operands, 

475 source_refs=[self.node_helper.get_source_ref(node)], 

476 ) 

477 

478 def visit_string(self, node): 

479 value_type = "Character" 

480 return CASTLiteralValue( 

481 value_type=value_type, 

482 value=self.node_helper.get_identifier(node), 

483 source_code_data_type=["matlab", MATLAB_VERSION, ScalarType.CHARACTER], 

484 source_refs=[self.node_helper.get_source_ref(node)] 

485 ) 

486 

487 def visit_switch_statement(self, node): 

488 """ return a conditional statement based on a MATLAB switch statement """ 

489 # node types used for case comparison 

490 case_node_types = [ 

491 "boolean", 

492 "identifier", 

493 "matrix", 

494 "number", 

495 "string", 

496 "unary_operator" 

497 ] 

498 

499 def get_case_expression(case_node, switch_var): 

500 """ return an operator representing the case test """ 

501 source_refs=[self.node_helper.get_source_ref(case_node)] 

502 cell_node = get_first_child_by_type(case_node, "cell") 

503 # multiple case arguments 

504 if (cell_node): 

505 value_type=StructureType.LIST 

506 operand = CASTLiteralValue( 

507 value_type=value_type, 

508 value = self.visit(cell_node), 

509 source_code_data_type=["matlab", MATLAB_VERSION, StructureType.LIST], 

510 source_refs=[self.node_helper.get_source_ref(cell_node)] 

511 ) 

512 return self.get_operator( 

513 op = "in", 

514 operands = [switch_var, operand], 

515 source_refs = source_refs 

516 ) 

517 # single case argument 

518 operand = [self.visit(node) for node in 

519 get_children_by_types(case_node, case_node_types)][0] 

520 return self.get_operator( 

521 op = "==", 

522 operands = [switch_var, operand], 

523 source_refs = source_refs 

524 ) 

525 

526 def get_model_if(case_node, switch_var): 

527 """ return conditional logic representing the case """ 

528 return ModelIf( 

529 expr = get_case_expression(case_node, switch_var), 

530 body = self.get_block(case_node), 

531 orelse = [], 

532 source_refs=[self.node_helper.get_source_ref(case_node)] 

533 ) 

534 

535 # switch variable is usually an identifier 

536 switch_var = get_first_child_by_type(node, "identifier") 

537 if switch_var is not None: 

538 switch_var = self.visit(switch_var) 

539 

540 # however it can be a function call 

541 else: 

542 switch_var = self.visit(get_first_child_by_type(node, "function_call")) 

543 

544 # n case clauses as 'if then' nodes 

545 case_nodes = get_children_by_types(node, ["case_clause"]) 

546 model_ifs = [get_model_if(node, switch_var) for node in case_nodes] 

547 for i, model_if in enumerate(model_ifs[1:]): 

548 model_ifs[i].orelse = [model_if] 

549 

550 # otherwise clause as 'else' node after last 'if then' node 

551 otherwise_clause = get_first_child_by_type(node, "otherwise_clause") 

552 if otherwise_clause: 

553 last = model_ifs[len(model_ifs)-1] 

554 last.orelse = self.get_block(otherwise_clause) 

555 

556 return model_ifs[0] 

557 

558 def get_block(self, node): 

559 """return all the children of the block as a list of AstNodes""" 

560 block = get_first_child_by_type(node, "block") 

561 if block: 

562 return [self.visit(child) for child in 

563 get_keyword_children(block)] 

564 

565 def get_operator(self, op, operands, source_refs): 

566 """ return an operator representing the arguments """ 

567 return Operator( 

568 source_language = "matlab", 

569 interpreter = INTERPRETER, 

570 version = MATLAB_VERSION, 

571 op = op, 

572 operands = operands, 

573 source_refs = source_refs 

574 ) 

575 

576 def get_gromet_function_node(self, func_name: str): 

577 if self.variable_context.is_variable(func_name): 

578 return self.variable_context.get_node(func_name) 

579 

580 # skip control nodes and other junk 

581 def _visit_passthrough(self, node): 

582 if len(node.children) == 0: 

583 return [] 

584 

585 for child in node.children: 

586 child_cast = self.visit(child) 

587 if child_cast: 

588 return child_cast