Coverage for skema/program_analysis/CAST/matlab/matlab_to_cast.py: 84%
224 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import json
2import os.path
3from pathlib import Path
4from typing import Any, Dict, List, Union
5from tree_sitter import Language, Parser, Node, Tree
6from skema.program_analysis.CAST2FN.cast import CAST
7from skema.program_analysis.CAST2FN.model.cast import (
8 Assignment,
9 AstNode,
10 Attribute,
11 Call,
12 FunctionDef,
13 CASTLiteralValue,
14 Loop,
15 ModelBreak,
16 ModelContinue,
17 ModelIf,
18 ModelImport,
19 ModelReturn,
20 Module,
21 Name,
22 Operator,
23 RecordDef,
24 ScalarType,
25 SourceCodeDataType,
26 SourceRef,
27 StructureType,
28 ValueConstructor,
29 Var,
30 VarType,
31)
32from skema.program_analysis.CAST.matlab.variable_context import (
33 VariableContext
34)
35from skema.program_analysis.CAST.matlab.node_helper import (
36 get_children_by_types,
37 get_control_children,
38 get_first_child_by_type,
39 get_keyword_children,
40 NodeHelper
41)
42from skema.program_analysis.CAST.matlab.tokens import KEYWORDS
43from skema.program_analysis.tree_sitter_parsers.build_parsers import(
44 INSTALLED_LANGUAGES_FILEPATH
45)
47MATLAB_VERSION='matlab_version_here'
48MODULE_NAME='module_name_here'
49INTERPRETER='matlab_to_cast'
51class MatlabToCast(object):
53 node_visits = dict()
55 def __init__(self, source_path = "", source = ""):
57 # if a source file path is provided, read source from file
58 if not source_path == "":
59 path = Path(source_path)
60 self.filename = path.name
61 self.source = path.read_text().strip()
62 # otherwise copy the input source and flag the filename unused
63 else:
64 self.source = source
65 self.filename = "None"
67 # create MATLAB parser
68 parser = Parser()
69 parser.set_language(
70 Language(INSTALLED_LANGUAGES_FILEPATH, "matlab")
71 )
73 # create a syntax tree using the source file
74 self.tree = parser.parse(bytes(self.source, "utf8"))
76 # create helper classes
77 self.variable_context = VariableContext()
78 self.node_helper = NodeHelper(self.source, self.filename)
80 # create CAST object
81 module = self.run(self.tree.root_node)
82 module.name = MODULE_NAME
83 self.out_cast = CAST([module], "matlab")
85 def log_visit(self, node: Node):
86 """ record a visit of a node type """
87 key = node.type if node else "None"
88 value = self.node_visits[key] + 1 if key in self.node_visits else 1
89 self.node_visits[key] = value
91 def run(self, root) -> Module:
92 """ process an entire Tree-sitter syntax instance """
93 return self.visit(root)
95 def visit(self, node):
96 """Switch execution based on node type"""
97 self.log_visit(node)
98 if node.type == "assignment":
99 return self.visit_assignment(node)
100 elif node.type == "boolean":
101 return self.visit_boolean(node)
102 elif node.type == "command":
103 return self.visit_command(node)
104 elif node.type == "function_call":
105 return self.visit_function_call(node)
106 elif node.type == "function_definition":
107 return self.visit_function_def(node)
108 elif node.type in [
109 "identifier"
110 ]:return self.visit_identifier(node)
111 elif node.type == "if_statement":
112 return self.visit_if_statement(node)
113 elif node.type == "iterator":
114 return self.visit_iterator(node)
115 elif node.type == "for_statement":
116 return self.visit_for_statement(node)
117 elif node.type in [
118 "cell",
119 "matrix"
120 ]: return self.visit_matrix(node)
121 elif node.type == "source_file":
122 return self.visit_module(node)
123 elif node.type in [
124 "command_name",
125 "command_argument",
126 "name"
127 ]: return self.visit_name(node)
128 elif node.type == "number":
129 return self.visit_number(node)
130 elif node.type in [
131 "binary_operator",
132 "boolean_operator",
133 "comparison_operator",
134 "unary_operator",
135 "spread_operator",
136 "postfix_operator",
137 "not_operator"
138 ]: return self.visit_operator(node)
139 elif node.type == "string":
140 return self.visit_string(node)
141 elif node.type == "switch_statement":
142 return self.visit_switch_statement(node)
143 else:
144 return self._visit_passthrough(node)
146 def visit_assignment(self, node):
147 """ Translate Tree-sitter assignment node """
148 children = get_keyword_children(node)
149 return Assignment(
150 left=self.visit(children[0]),
151 right=self.visit(children[1]),
152 source_refs=[self.node_helper.get_source_ref(node)],
153 )
155 def visit_boolean(self, node):
156 """ Translate Tree-sitter boolean node """
157 for child in node.children:
158 # set the first letter to upper case for python
159 value = child.type
160 value = value[0].upper() + value[1:].lower()
161 # store as string, use Python Boolean capitalization.
163 value_type = ScalarType.BOOLEAN
164 return CASTLiteralValue(
165 value_type=value_type,
166 value = value,
167 source_code_data_type=["matlab", MATLAB_VERSION, ScalarType.BOOLEAN],
168 source_refs=[self.node_helper.get_source_ref(node)],
169 )
171 def visit_command(self, node):
172 """ Translate the Tree-sitter command node """
173 children = get_keyword_children(node)
174 argument = [self.visit(children[1])] if len(children) > 1 else []
175 return Call(
176 func = self.visit(children[0]),
177 source_language = "matlab",
178 source_language_version = MATLAB_VERSION,
179 arguments = argument,
180 source_refs=[self.node_helper.get_source_ref(node)]
181 )
183 def visit_function_call(self, node):
184 """ Translate Tree-sitter function call node """
185 func = self.visit(get_keyword_children(node)[0])
186 arguments = [self.visit(child) for child in
187 get_keyword_children(get_first_child_by_type(node, "arguments"))]
188 return Call(
189 func = func,
190 source_language = "matlab",
191 source_language_version = MATLAB_VERSION,
192 arguments = arguments,
193 source_refs=[self.node_helper.get_source_ref(node)]
194 )
196 def visit_function_def(self, node):
197 """ return a CAST transation of a MATLAB function definition """
198 return FunctionDef(
199 name = self.visit(get_first_child_by_type(node, "function_output")),
200 body = self.get_block(node),
201 func_args = [self.visit(child) for child in
202 get_keyword_children(
203 get_first_child_by_type(node, "function_arguments"))],
204 source_refs=[self.node_helper.get_source_ref(node)]
205 )
207 def visit_identifier(self, node):
208 """ return an identifier (variable) node """
209 identifier = self.node_helper.get_identifier(node)
210 return Var(
211 val = self.visit_name(node),
212 type = self.variable_context.get_type(identifier) if
213 self.variable_context.is_variable(identifier) else "Unknown",
214 default_value = CASTLiteralValue(
215 value_type=ScalarType.CHARACTER,
216 value=self.node_helper.get_identifier(node),
217 source_code_data_type=["matlab", MATLAB_VERSION, ScalarType.CHARACTER],
218 source_refs=[self.node_helper.get_source_ref(node)]
219 ),
220 source_refs = [self.node_helper.get_source_ref(node)],
221 )
223 def visit_if_statement(self, node):
224 """ return a node describing if, elseif, else conditional logic"""
226 def get_conditional(conditional_node):
227 """ return a ModelIf struct for the conditional logic node. """
229 # Conditional will be after the "if" or "elseif" child
230 for i, child in enumerate(conditional_node.children):
231 if child.type in ["if", "elseif"]:
232 expr = conditional_node.children[i+1]
234 return ModelIf(
235 expr = self.visit(expr),
236 body = self.get_block(conditional_node),
237 orelse = [],
238 source_refs=[self.node_helper.get_source_ref(conditional_node)]
239 )
241 # start with the if statement
242 first = get_conditional(node)
243 current = first
245 # add 0-n elseif clauses
246 for child in get_children_by_types(node, ["elseif_clause"]):
247 current.orelse = [get_conditional(child)]
248 current = current.orelse[0]
250 # add 0-1 else clause
251 else_clause = get_first_child_by_type(node, "else_clause")
252 if else_clause:
253 current.orelse = self.get_block(else_clause)
255 return first
257 # CAST has no Iterator node, so we return a partially
258 # completed Loop object
259 # MATLAB iterators are either matrices or ranges.
260 def visit_iterator(self, node) -> Loop:
262 itr_var = self.visit(get_first_child_by_type(node, "identifier"))
263 source_ref = self.node_helper.get_source_ref(node)
265 # process matrix iterator
266 matrix_node = get_first_child_by_type(node, "matrix")
267 if matrix_node is not None:
268 row_node = get_first_child_by_type(matrix_node, "row")
269 if row_node is not None:
270 mat = [self.visit(child) for child in
271 get_keyword_children(row_node)]
272 mat_idx = 0
273 mat_len = len(mat)
276 return Loop(
277 pre = [
278 Assignment(
279 left = "_mat",
280 right = mat,
281 source_refs = [source_ref]
282 ),
283 Assignment(
284 left = "_mat_len",
285 right = mat_len,
286 source_refs = [source_ref]
287 ),
288 Assignment(
289 left = "_mat_idx",
290 right = mat_idx,
291 source_refs = [source_ref]
292 ),
293 Assignment(
294 left = itr_var,
295 right = mat[mat_idx],
296 source_refs = [source_ref]
297 )
298 ],
299 expr = self.get_operator(
300 op = "<",
301 operands = ["_mat_idx", "_mat_len"],
302 source_refs = [source_ref]
303 ),
304 body = [
305 Assignment(
306 left = "_mat_idx",
307 right = self.get_operator(
308 op = "+",
309 operands = ["_mat_idx", 1],
310 source_refs = [source_ref]
311 ),
312 source_refs = [source_ref]
313 ),
314 Assignment(
315 left = itr_var,
316 right = "_mat[_mat_idx]",
317 source_refs = [source_ref]
318 )
319 ],
320 post = []
321 )
325 # process range iterator
326 range_node = get_first_child_by_type(node, "range")
327 if range_node is not None:
328 numbers = [self.visit(child) for child in
329 get_children_by_types(range_node, ["number"])]
330 start = numbers[0]
331 step = 1
332 stop = 0
334 # two values mean the step is implicitely defined as 1
335 if len(numbers) == 2:
336 stop = numbers[1]
338 # three values mean the step is explictely defined
339 elif len(numbers) == 3:
340 step = numbers[1]
341 stop = numbers[2]
343 # create the itrerator based on the range limits and step
344 range_name_node = self.variable_context.get_gromet_function_node("range")
345 iter_name_node = self.variable_context.get_gromet_function_node("iter")
346 next_name_node = self.variable_context.get_gromet_function_node("next")
347 generated_iter_name_node = self.variable_context.generate_iterator()
348 stop_condition_name_node = self.variable_context.generate_stop_condition()
350 return Loop(
351 pre = [
352 Assignment(
353 left = itr_var,
354 right = start,
355 source_refs = [source_ref]
356 )
357 ],
358 expr = self.get_operator(
359 op = "<=",
360 operands = [itr_var, stop],
361 source_refs = [source_ref]
362 ),
363 body = [
364 Assignment(
365 left = itr_var,
366 right = self.get_operator(
367 op = "+",
368 operands = [itr_var, step],
369 source_refs = [source_ref]
370 ),
371 source_refs = [source_ref]
372 )
373 ],
374 post = []
375 )
377 def visit_for_statement(self, node) -> Loop:
378 """ Translate Tree-sitter for loop node into CAST Loop node """
380 loop = self.visit(get_first_child_by_type(node, "iterator"))
381 loop.source_refs=[self.node_helper.get_source_ref(node)]
382 loop.body = self.get_block(node) + loop.body
384 return loop
387 def visit_matrix(self, node):
388 """ Translate the Tree-sitter cell node into a List """
390 def get_values(element, ret):
391 for child in get_keyword_children(element):
392 if child.type == "row":
393 ret.append(get_values(child, []))
394 else:
395 ret.append(self.visit(child))
396 return ret
398 values = get_values(node, [])
399 value = []
400 if len(values) > 0:
401 value = values[0]
403 value_type=StructureType.LIST
404 return CASTLiteralValue(
405 value_type=value_type,
406 value = value,
407 source_code_data_type=["matlab", MATLAB_VERSION, StructureType.LIST],
408 source_refs=[self.node_helper.get_source_ref(node)],
409 )
411 def visit_module(self, node: Node) -> Module:
412 """Visitor for program and module statement. Returns a Module object"""
413 self.variable_context.push_context()
415 program_body = []
416 for child in node.children:
417 child_cast = self.visit(child)
418 if isinstance(child_cast, List):
419 program_body.extend(child_cast)
420 elif isinstance(child_cast, AstNode):
421 program_body.append(child_cast)
423 self.variable_context.pop_context()
425 return Module(
426 name=None, #TODO: Fill out name field
427 body=program_body,
428 source_refs = [self.node_helper.get_source_ref(node)]
429 )
431 def visit_name(self, node):
432 """ return or create the node for this variable name """
433 identifier = self.node_helper.get_identifier(node)
434 # if the identifier exists, return its node
435 if self.variable_context.is_variable(identifier):
436 return self.variable_context.get_node(identifier)
437 # create a new node
438 return self.variable_context.add_variable(
439 identifier, "Unknown", [self.node_helper.get_source_ref(node)]
440 )
443 def visit_number(self, node) -> CASTLiteralValue:
444 """Visitor for numbers """
445 number = self.node_helper.get_identifier(node)
446 # Check if this is a real value, or an Integer
447 literal_value = self.node_helper.get_identifier(node)
448 if "e" in literal_value.lower() or "." in literal_value:
449 value_type = "AbstractFloat"
450 return CASTLiteralValue(
451 value_type=value_type,
452 value=float(literal_value),
453 source_code_data_type=["matlab", MATLAB_VERSION, value_type],
454 source_refs=[self.node_helper.get_source_ref(node)]
455 )
456 value_type = "Integer"
457 return CASTLiteralValue(
458 value_type=value_type,
459 value=int(literal_value),
460 source_code_data_type=["matlab", MATLAB_VERSION, value_type],
461 source_refs=[self.node_helper.get_source_ref(node)]
462 )
464 def visit_operator(self, node):
465 """return an operator based on the Tree-sitter node """
466 # The operator will be the first control character
467 op = self.node_helper.get_identifier(
468 get_control_children(node)[0]
469 )
470 # the operands will be the keyword children
471 operands=[self.visit(child) for child in get_keyword_children(node)]
472 return self.get_operator(
473 op = op,
474 operands = operands,
475 source_refs=[self.node_helper.get_source_ref(node)],
476 )
478 def visit_string(self, node):
479 value_type = "Character"
480 return CASTLiteralValue(
481 value_type=value_type,
482 value=self.node_helper.get_identifier(node),
483 source_code_data_type=["matlab", MATLAB_VERSION, ScalarType.CHARACTER],
484 source_refs=[self.node_helper.get_source_ref(node)]
485 )
487 def visit_switch_statement(self, node):
488 """ return a conditional statement based on a MATLAB switch statement """
489 # node types used for case comparison
490 case_node_types = [
491 "boolean",
492 "identifier",
493 "matrix",
494 "number",
495 "string",
496 "unary_operator"
497 ]
499 def get_case_expression(case_node, switch_var):
500 """ return an operator representing the case test """
501 source_refs=[self.node_helper.get_source_ref(case_node)]
502 cell_node = get_first_child_by_type(case_node, "cell")
503 # multiple case arguments
504 if (cell_node):
505 value_type=StructureType.LIST
506 operand = CASTLiteralValue(
507 value_type=value_type,
508 value = self.visit(cell_node),
509 source_code_data_type=["matlab", MATLAB_VERSION, StructureType.LIST],
510 source_refs=[self.node_helper.get_source_ref(cell_node)]
511 )
512 return self.get_operator(
513 op = "in",
514 operands = [switch_var, operand],
515 source_refs = source_refs
516 )
517 # single case argument
518 operand = [self.visit(node) for node in
519 get_children_by_types(case_node, case_node_types)][0]
520 return self.get_operator(
521 op = "==",
522 operands = [switch_var, operand],
523 source_refs = source_refs
524 )
526 def get_model_if(case_node, switch_var):
527 """ return conditional logic representing the case """
528 return ModelIf(
529 expr = get_case_expression(case_node, switch_var),
530 body = self.get_block(case_node),
531 orelse = [],
532 source_refs=[self.node_helper.get_source_ref(case_node)]
533 )
535 # switch variable is usually an identifier
536 switch_var = get_first_child_by_type(node, "identifier")
537 if switch_var is not None:
538 switch_var = self.visit(switch_var)
540 # however it can be a function call
541 else:
542 switch_var = self.visit(get_first_child_by_type(node, "function_call"))
544 # n case clauses as 'if then' nodes
545 case_nodes = get_children_by_types(node, ["case_clause"])
546 model_ifs = [get_model_if(node, switch_var) for node in case_nodes]
547 for i, model_if in enumerate(model_ifs[1:]):
548 model_ifs[i].orelse = [model_if]
550 # otherwise clause as 'else' node after last 'if then' node
551 otherwise_clause = get_first_child_by_type(node, "otherwise_clause")
552 if otherwise_clause:
553 last = model_ifs[len(model_ifs)-1]
554 last.orelse = self.get_block(otherwise_clause)
556 return model_ifs[0]
558 def get_block(self, node):
559 """return all the children of the block as a list of AstNodes"""
560 block = get_first_child_by_type(node, "block")
561 if block:
562 return [self.visit(child) for child in
563 get_keyword_children(block)]
565 def get_operator(self, op, operands, source_refs):
566 """ return an operator representing the arguments """
567 return Operator(
568 source_language = "matlab",
569 interpreter = INTERPRETER,
570 version = MATLAB_VERSION,
571 op = op,
572 operands = operands,
573 source_refs = source_refs
574 )
576 def get_gromet_function_node(self, func_name: str):
577 if self.variable_context.is_variable(func_name):
578 return self.variable_context.get_node(func_name)
580 # skip control nodes and other junk
581 def _visit_passthrough(self, node):
582 if len(node.children) == 0:
583 return []
585 for child in node.children:
586 child_cast = self.visit(child)
587 if child_cast:
588 return child_cast