Coverage for skema/program_analysis/CAST/fortran/ts2cast.py: 65%
457 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import json
2import os.path
3import time
4from pathlib import Path
5from typing import Any, Dict, List, Union
7from tree_sitter import Language, Parser, Node
9from skema.program_analysis.CAST2FN.cast import CAST
10from skema.program_analysis.CAST2FN.model.cast import (
11 Module,
12 SourceRef,
13 ModelBreak,
14 Assignment,
15 CASTLiteralValue,
16 Var,
17 VarType,
18 Name,
19 Operator,
20 AstNode,
21 SourceCodeDataType,
22 ModelImport,
23 FunctionDef,
24 Loop,
25 Call,
26 ModelReturn,
27 ModelIf,
28 RecordDef,
29 Attribute,
30 Label,
31 Goto,
32)
34from skema.program_analysis.CAST.fortran.variable_context import VariableContext
35from skema.program_analysis.CAST.fortran.node_helper import (
36 NodeHelper,
37 remove_comments,
38 get_children_by_types,
39 get_children_except_types,
40 get_first_child_by_type,
41 get_control_children,
42 get_non_control_children,
43 get_first_child_index,
44 get_last_child_index,
45)
46from skema.program_analysis.CAST.fortran.util import generate_dummy_source_refs
48from skema.program_analysis.CAST.fortran.preprocessor.preprocess import preprocess
49from skema.program_analysis.tree_sitter_parsers.build_parsers import (
50 INSTALLED_LANGUAGES_FILEPATH,
51)
53builtin_statements = set(
54 [
55 "read_statement",
56 "write_statement",
57 "rewind_statement",
58 "open_statement",
59 "print_statement",
60 ]
61)
64class TS2CAST(object):
65 def __init__(self, source_file_path: str):
66 # Prepare source with preprocessor
67 self.path = Path(source_file_path)
68 self.source_file_name = self.path.name
69 self.source = preprocess(self.path)
71 # Run tree-sitter on preprocessor output to generate parse tree
72 parser = Parser()
73 parser.set_language(Language(INSTALLED_LANGUAGES_FILEPATH, "fortran"))
74 self.tree = parser.parse(bytes(self.source, "utf8"))
75 self.root_node = remove_comments(self.tree.root_node)
77 # Walking data
78 self.variable_context = VariableContext()
79 self.node_helper = NodeHelper(self.source, self.source_file_name)
81 # Start visiting
82 self.out_cast = self.generate_cast()
83 #print(self.out_cast[0].to_json_str())
85 def generate_cast(self) -> List[CAST]:
86 """Interface for generating CAST."""
87 modules = self.run(self.root_node)
88 return [
89 CAST([generate_dummy_source_refs(module)], "Fortran") for module in modules
90 ]
92 def run(self, root) -> List[Module]:
93 """Top level visitor function. Will return between 1-3 Module objects."""
94 # A program can have between 1-3 modules
95 # 1. A module body
96 # 2. A program body
97 # 3. Everything else (defined functions)
98 modules = []
99 contexts = get_children_by_types(root, ["module", "program"])
100 for context in contexts:
101 modules.append(self.visit(context))
103 # Currently, we are supporting functions and subroutines defined outside of programs and modules
104 # Other than comments, it is unclear if anything else is allowed.
105 # TODO: Research the above
106 outer_body_nodes = get_children_by_types(root, ["function", "subroutine"])
107 if len(outer_body_nodes) > 0:
108 body = self.generate_cast_body(outer_body_nodes)
109 modules.append(
110 Module(
111 name=None,
112 body=body,
113 source_refs=[self.node_helper.get_source_ref(root)],
114 )
115 )
117 return modules
119 def visit(self, node: Node):
120 if node.type in ["program", "module"]:
121 return self.visit_module(node)
122 elif node.type == "internal_procedures":
123 return self.visit_internal_procedures(node)
124 elif node.type in ["subroutine", "function"]:
125 return self.visit_function_def(node)
126 elif node.type in ["subroutine_call", "call_expression"]:
127 return self.visit_function_call(node)
128 elif node.type == "use_statement":
129 return self.visit_use_statement(node)
130 elif node.type == "variable_declaration":
131 return self.visit_variable_declaration(node)
132 elif node.type == "assignment_statement":
133 return self.visit_assignment_statement(node)
134 elif node.type == "identifier":
135 return self.visit_identifier(node)
136 elif node.type == "name":
137 return self.visit_name(node)
138 elif node.type in [
139 "unary_expression",
140 "math_expression",
141 "relational_expression",
142 ]:
143 return self.visit_math_expression(node)
144 elif node.type in [
145 "number_literal",
146 "array_literal",
147 "string_literal",
148 "boolean_literal",
149 ]:
150 return self.visit_literal(node)
151 elif node.type == "keyword_statement":
152 return self.visit_keyword_statement(node)
153 elif node.type == "statement_label":
154 return self.visit_statement_label(node)
155 elif node.type in builtin_statements:
156 return self.visit_fortran_builtin_statement(node)
157 elif node.type == "extent_specifier":
158 return self.visit_extent_specifier(node)
159 elif node.type in ["do_loop_statement"]:
160 return self.visit_do_loop_statement(node)
161 elif node.type in ["if_statement", "else_if_clause", "else_clause"]:
162 return self.visit_if_statement(node)
163 elif node.type == "logical_expression":
164 return self.visit_logical_expression(node)
165 elif node.type == "derived_type_definition":
166 return self.visit_derived_type(node)
167 elif node.type == "derived_type_member_expression":
168 return self.visit_derived_type_member_expression(node)
169 else:
170 return self._visit_passthrough(node)
172 def visit_module(self, node: Node) -> Module:
173 """Visitor for program and module statement. Returns a Module object"""
174 self.variable_context.push_context()
176 program_body = self.generate_cast_body(node.children[1:-1])
178 self.variable_context.pop_context()
180 return Module(
181 name=None, # TODO: Fill out name field
182 body=program_body,
183 source_refs=[self.node_helper.get_source_ref(node)],
184 )
186 def visit_internal_procedures(self, node: Node) -> List[FunctionDef]:
187 """Visitor for internal procedures. Returns list of FunctionDef"""
188 internal_procedures = get_children_by_types(node, ["function", "subroutine"])
189 return [self.visit(procedure) for procedure in internal_procedures]
191 def visit_name(self, node):
192 # Node structure
193 # (name)
195 # First, we will check if this name is already defined, and if it is return the name node generated previously
196 identifier = self.node_helper.get_identifier(node)
197 if self.variable_context.is_variable(identifier):
198 return self.variable_context.get_node(identifier)
200 return self.variable_context.add_variable(
201 identifier, "Unknown", [self.node_helper.get_source_ref(node)]
202 )
204 def visit_function_def(self, node):
205 # TODO: Refactor function def code to use new helper functions
206 # Node structure
207 # (subroutine)
208 # (subroutine_statement)
209 # (subroutine)
210 # (name)
211 # (parameters) - Optional
212 # (body_node) ...
213 # (function)
214 # (function_statement)
215 # (function)
216 # (intrinsic_type) - Optional
217 # (name)
218 # (parameters) - Optional
219 # (function_result) - Optional
220 # (identifier)
221 # (body_node) ...
223 # Create a new variable context
224 self.variable_context.push_context()
226 # Top level statement node
228 statement_node = get_children_by_types(
229 node, ["subroutine_statement", "function_statement"]
230 )[0]
232 name_node = get_first_child_by_type(statement_node, "name")
233 name = self.visit(
234 name_node
235 ) # Visit the name node to add it to the variable context
237 # If this is a function, check for return type and return value
238 if node.type == "function":
239 intrinsic_type = None
240 return_value = None
241 signature_qualifiers = get_children_by_types(
242 statement_node, ["intrinsic_type", "function_result"]
243 )
244 for qualifier in signature_qualifiers:
245 if qualifier.type == "intrinsic_type":
246 intrinsic_type = self.node_helper.get_identifier(qualifier)
247 self.variable_context.add_variable(
248 self.node_helper.get_identifier(name_node), intrinsic_type, None
249 )
250 elif qualifier.type == "function_result":
251 return_value = self.visit(
252 get_first_child_by_type(qualifier, "identifier")
253 ).val
254 self.variable_context.add_return_value(return_value.name)
256 # NOTE: In the case of a function specifically, if there is no explicit return value, the return value will be the name of the function
257 # TODO: Should this be a node instead
258 if not return_value:
259 self.variable_context.add_return_value(
260 self.node_helper.get_identifier(name_node)
261 )
262 return_value = self.visit(name_node)
264 # If funciton has both an explicit intrinsic type, then we also need to update the type of the return value in the variable context
265 if intrinsic_type:
266 self.variable_context.update_type(return_value.name, intrinsic_type)
268 # Generating the function arguments by walking the parameters node
269 func_args = []
270 if parameters_node := get_first_child_by_type(statement_node, "parameters"):
271 for parameter in get_non_control_children(parameters_node):
272 # For both subroutine and functions, all arguments are assumes intent(inout) by default unless otherwise specified with intent(in)
273 # The variable declaration visitor will check for this and remove any arguments that are input only from the return values
274 self.variable_context.add_return_value(
275 self.node_helper.get_identifier(parameter)
276 )
277 func_args.append(self.visit(parameter))
279 # The first child of function will be the function statement, the rest will be body nodes
280 body = self.generate_cast_body(node.children[1:-1])
282 # After creating the body, we can go back and update the var nodes we created for the arguments
283 # We do this by looking for intent,in nodes
284 for i, arg in enumerate(func_args):
285 func_args[i].type = self.variable_context.get_type(arg.val.name)
287 # TODO:
288 # This logic can be made cleaner
289 # Fortran doesn't require a return statement, so we need to check if there is a top-level return statement
290 # If there is not, then we will create a dummy one
291 return_found = False
292 for child in body:
293 if isinstance(child, ModelReturn):
294 return_found = True
295 if not return_found:
296 body.append(self.visit_keyword_statement(node))
298 # Pop variable context off of stack before leaving this scope
299 self.variable_context.pop_context()
302 # If this is a class function, we need to associate the function def with the class
303 # We should also return None here so we don't duplicate the function def
304 if self.variable_context.is_class_function(name.name):
305 self.variable_context.copy_class_function(name.name,
306 FunctionDef(
307 name=name,
308 func_args=func_args,
309 body=body,
310 source_refs=[self.node_helper.get_source_ref(node)],
311 ))
312 return None
314 return FunctionDef(
315 name=name,
316 func_args=func_args,
317 body=body,
318 source_refs=[self.node_helper.get_source_ref(node)],
319 )
321 def visit_function_call(self, node):
322 # Pull relevent nodes
323 # A subroutine and function won't neccessarily have an arguments node.
324 # So we should be careful about trying to access it.
326 function_node = get_children_by_types(
327 node,
328 [
329 "unary_expression",
330 "subroutine",
331 "identifier",
332 "derived_type_member_expression",
333 ],
334 )[0]
335 if function_node.type == "derived_type_member_expression":
336 return self.visit_derived_type_member_expression(function_node)
338 arguments_node = get_first_child_by_type(node, "argument_list")
340 # If this is a unary expression (+foo()) the identifier will be nested.
341 # TODO: If this is a non '+' unary expression, how do we add it to the CAST?
342 if function_node.type == "unary_expression":
343 function_node = get_first_child_by_type(node, "identifier", recurse=True)
345 function_identifier = self.node_helper.get_identifier(function_node)
347 # Tree-Sitter incorrectly parses mutlidimensional array accesses as function calls
348 # We will need to check if this is truly a function call or a subscript
349 if self.variable_context.is_variable(function_identifier):
350 if self.variable_context.get_type(function_identifier) == "List":
351 return self._visit_get(
352 node
353 ) # This overrides the visitor and forces us to visit another
355 # TODO: What should get a name node? Instrincit functions? Imported functions?
356 # Judging from the Gromet generation pipeline, it appears that all functions need Name nodes.
357 if self.variable_context.is_variable(function_identifier):
358 func = self.variable_context.get_node(function_identifier)
359 else:
360 func = Name(function_identifier, -1) # TODO: REFACTOR
362 # Add arguments to arguments list
363 arguments = []
364 if arguments_node:
365 for argument in arguments_node.children:
366 child_cast = self.visit(argument)
367 if child_cast:
368 arguments.append(child_cast)
370 return Call(
371 func=func,
372 source_language="Fortran",
373 source_language_version="2008",
374 arguments=arguments,
375 source_refs=[self.node_helper.get_source_ref(node)],
376 )
378 """
379 (keyword_statement [6, 6] - [6, 61]
380 (statement_label_reference [6, 13] - [6, 16])
381 (statement_label_reference [6, 18] - [6, 21])
382 (statement_label_reference [6, 23] - [6, 26])
383 (statement_label_reference [6, 28] - [6, 31])
384 (math_expression [6, 34] - [6, 61]
385 left: (call_expression [6, 34] - [6, 57]
386 (identifier [6, 34] - [6, 37])
387 (argument_list [6, 37] - [6, 57]
388 (math_expression [6, 38] - [6, 53]
389 left: (math_expression [6, 38] - [6, 49]
390 left: (parenthesized_expression [6, 38] - [6, 45]
391 (math_expression [6, 39] - [6, 44]
392 left: (identifier [6, 39] - [6, 40])
393 right: (identifier [6, 43] - [6, 44])))
394 right: (identifier [6, 48] - [6, 49]))
395 right: (number_literal [6, 52] - [6, 53]))
396 (number_literal [6, 55] - [6, 56])))
397 right: (number_literal [6, 60] - [6, 61])))
398 """
400 def visit_keyword_statement(self, node):
401 # NOTE: RETURN is not the only Fortran keyword. GO TO and CONTINUE are also considered keywords
402 identifier = self.node_helper.get_identifier(node).lower()
403 if node.type == "keyword_statement":
404 if "go to" in identifier:
405 statement_labels = [
406 self.node_helper.get_identifier(child)
407 for child in get_children_by_types(
408 node, ["statement_label_reference"]
409 )
410 ]
411 # If there are multiple statement labels, then this is a COMPUTED GO TO
412 # Those are handled as a "_get" access into a List of statement labels with the index determined by the expression
413 if len(statement_labels) > 1:
414 expr = Call(
415 func=self.get_gromet_function_node("_get"),
416 arguments=[
417 CASTLiteralValue(value_type="List", value=[CASTLiteralValue(value=label, value_type="List") for label in statement_labels]),
418 self.visit(node.children[-1]),
419 ],
420 )
421 return Goto(label=None, expr=expr)
422 return Goto(
423 label=statement_labels[0],
424 expr=None,
425 )
426 if "continue" in identifier:
427 return self._visit_no_op(node)
428 if "exit" in identifier:
429 return ModelBreak(source_refs=[self.node_helper.get_source_ref(node)])
431 # In Fortran the return statement doesn't return a value (there is the obsolete "alternative return")
432 # We keep track of values that need to be returned in the variable context
433 return_values = self.variable_context.context_return_values[
434 -1
435 ] # TODO: Make function for this
437 if len(return_values) == 1:
438 value = self.variable_context.get_node(list(return_values)[0])
439 elif len(return_values) > 1:
440 value = CASTLiteralValue(
441 value_type="Tuple",
442 value=[self.variable_context.get_node(ret) for ret in return_values],
443 source_code_data_type=None,
444 source_refs=None,
445 )
446 else:
447 value = CASTLiteralValue(value=None, value_type=None, source_refs=None)
449 return ModelReturn(
450 value=value, source_refs=[self.node_helper.get_source_ref(node)]
451 )
453 def visit_statement_label(self, node):
454 """Visitor for fortran statement labels"""
455 return Label(label=self.node_helper.get_identifier(node))
457 def visit_fortran_builtin_statement(self, node):
458 """Visitor for Fortran keywords that are not classified as keyword_statement by tree-sitter"""
459 # All of the node types that fall into this category end with _statment.
460 # So the function name will be the node type with _statement removed (write, read, open, ...)
461 func = self.get_gromet_function_node(node.type.replace("_statement", ""))
463 arguments = []
465 return Call(
466 func=func,
467 arguments=arguments,
468 source_language="Fortran",
469 source_language_version=None,
470 source_refs=[self.node_helper.get_source_ref(node)],
471 )
473 def visit_print_statement(self, node):
474 func = self.get_gromet_function_node("print")
476 arguments = []
478 return Call(
479 func=func,
480 arguments=arguments,
481 source_language=None,
482 source_language_version=None,
483 )
485 def visit_use_statement(self, node):
486 # (use)
487 # (use)
488 # (module_name)
490 ## Pull relevent child nodes
491 module_name_node = get_first_child_by_type(node, "module_name")
492 module_name = self.node_helper.get_identifier(module_name_node)
493 included_items_node = get_first_child_by_type(node, "included_items")
495 import_all = included_items_node is None
496 import_alias = None # TODO: Look into local-name and use-name fields
498 # We need to check if this import is a full import of a module, i.e. use module
499 # Or a partial import i.e. use module,only: sub1, sub2
500 if import_all:
501 return ModelImport(
502 name=module_name,
503 alias=import_alias,
504 all=import_all,
505 symbol=None,
506 source_refs=[self.node_helper.get_source_ref(node)],
507 )
508 else:
509 imports = []
510 for symbol in get_non_control_children(included_items_node):
511 symbol_identifier = self.node_helper.get_identifier(symbol)
512 symbol_source_refs = [self.node_helper.get_source_ref(symbol)]
513 imports.append(
514 ModelImport(
515 name=module_name,
516 alias=import_alias,
517 all=import_all,
518 symbol=symbol_identifier,
519 source_refs=symbol_source_refs,
520 )
521 )
522 return imports
524 def visit_do_loop_statement(self, node) -> Loop:
525 """Visitor for Loops. Do to complexity, this visitor logic only handles the range-based do loop.
526 The do while loop will be passed off to a seperate visitor. Returns a Loop object.
527 """
528 """
529 Node structure
530 Do loop
531 (do_loop_statement)
532 (loop_control_expression)
533 (...) ...
534 (body) ...
536 Do while
537 (do_loop_statement)
538 (while_statement)
539 (parenthesized_expression)
540 (...) ...
541 (body) ...
542 """
544 loop_control_node = get_first_child_by_type(node, "loop_control_expression")
545 if not loop_control_node:
546 return self._visit_while(node)
548 # If there is a loop control expression, the first body node will be the node after the loop_control_expression
549 # It is valid Fortran to have a single itteration do loop as well.
550 # NOTE: This code is for the creation of the main body. The do loop will still add some additional nodes at the end of this body.
551 body_start_index = 1 + get_first_child_index(node, "loop_control_expression")
552 body = self.generate_cast_body(node.children[body_start_index:])
554 # For the init and expression fields, we first need to determine if we are in a regular "do" or a "do while" loop
555 # PRE:
556 # _next(_iter(range(start, stop, step)))
557 loop_control_node = get_first_child_by_type(node, "loop_control_expression")
558 loop_control_children = get_non_control_children(loop_control_node)
559 if len(loop_control_children) == 3:
560 itterator, start, stop = [
561 self.visit(child) for child in loop_control_children
562 ]
563 step = CASTLiteralValue("Integer", "1")
564 elif len(loop_control_children) == 4:
565 itterator, start, stop, step = [
566 self.visit(child) for child in loop_control_children
567 ]
568 else:
569 itterator = None
570 start = None
571 stop = None
572 step = None
574 range_name_node = self.get_gromet_function_node("range")
575 iter_name_node = self.get_gromet_function_node("iter")
576 next_name_node = self.get_gromet_function_node("next")
577 generated_iter_name_node = self.variable_context.generate_iterator()
578 stop_condition_name_node = self.variable_context.generate_stop_condition()
580 # generated_iter_0 = iter(range(start, stop, step))
581 pre = []
582 pre.append(
583 Assignment(
584 left=Var(generated_iter_name_node, "Iterator"),
585 right=Call(
586 iter_name_node,
587 arguments=[Call(range_name_node, arguments=[start, stop, step])],
588 ),
589 )
590 )
592 # (i, generated_iter_0, sc_0) = next(generated_iter_0)
593 pre.append(
594 Assignment(
595 left=CASTLiteralValue(
596 "Tuple",
597 [
598 itterator,
599 Var(generated_iter_name_node, "Iterator"),
600 Var(stop_condition_name_node, "Boolean"),
601 ],
602 ),
603 right=Call(
604 next_name_node,
605 arguments=[Var(generated_iter_name_node, "Iterator")],
606 ),
607 )
608 )
610 # EXPR
611 expr = []
612 expr = Operator(
613 op="!=", # TODO: Should this be == or !=
614 operands=[
615 stop_condition_name_node,
616 CASTLiteralValue("Boolean", True),
617 ],
618 )
620 # BODY
621 # At this point, the body nodes have already been visited
622 # We just need to append the iterator next call
623 body.append(
624 Assignment(
625 left=CASTLiteralValue(
626 "Tuple",
627 [
628 itterator,
629 Var(generated_iter_name_node, "Iterator"),
630 Var(stop_condition_name_node, "Boolean"),
631 ],
632 ),
633 right=Call(
634 next_name_node,
635 arguments=[Var(generated_iter_name_node, "Iterator")],
636 ),
637 )
638 )
640 # POST
641 post = []
642 post.append(
643 Assignment(
644 left=itterator,
645 right=Operator(op="+", operands=[itterator, step]),
646 )
647 )
649 return Loop(
650 pre=pre,
651 expr=expr,
652 body=body,
653 post=post,
654 source_refs=[self.node_helper.get_source_ref(node)],
655 )
657 def visit_if_statement(self, node):
658 # (if_statement)
659 # (if)
660 # (parenthesised_expression)
661 # (then)
662 # (body_nodes) ...
663 # (elseif_clauses) ..
664 # (else_clause)
665 # (end_if_statement)
667 # TODO: Can you have a parenthesized expression as a body node
668 body_nodes = get_children_except_types(
669 node,
670 [
671 "if",
672 "elseif",
673 "else",
674 "then",
675 "parenthesized_expression",
676 "elseif_clause",
677 "else_clause",
678 "end_if_statement",
679 ],
680 )
681 body = self.generate_cast_body(body_nodes)
683 expr_node = get_first_child_by_type(node, "parenthesized_expression")
684 expr = None
685 if expr_node:
686 expr = self.visit(expr_node)
688 elseif_nodes = get_children_by_types(node, ["elseif_clause"])
689 elseif_cast = [self.visit(elseif_clause) for elseif_clause in elseif_nodes]
690 for i in range(len(elseif_cast) - 1):
691 elseif_cast[i].orelse = [elseif_cast[i + 1]]
693 else_node = get_first_child_by_type(node, "else_clause")
694 else_cast = None
695 if else_node:
696 else_cast = self.visit(else_node)
698 orelse = []
699 if len(elseif_cast) > 0:
700 orelse = [elseif_cast[0]]
701 elif else_cast:
702 orelse = else_cast.body
704 return ModelIf(expr=expr, body=body, orelse=orelse)
706 def visit_logical_expression(self, node):
707 """Visitior for logical expression (i.e. true and false) which is used in compound conditional"""
708 # If this is a .not. operator, we need to pass it on to the math_expression visitor
709 if len(node.children) < 3:
710 return self.visit_math_expression(node)
712 literal_value_false = CASTLiteralValue("Boolean", False)
713 literal_value_true = CASTLiteralValue("Boolean", True)
715 # AND: Right side goes in body if, left side in condition
716 # OR: Right side goes in body else, left side in condition
717 left, operator, right = node.children
719 # First we need to check if this is logical and or a logical or
720 # The tehcnical types for these are \.or\. and \.and\. so to simplify things we can use the in keyword
721 is_or = "or" in operator.type
723 top_if = ModelIf()
724 top_if_expr = self.visit(left)
725 top_if.expr = top_if_expr
727 bottom_if_expr = self.visit(right)
728 if is_or:
729 top_if.orelse = [bottom_if_expr]
730 top_if.body = [literal_value_true]
731 else:
732 top_if.orelse = [literal_value_false]
733 top_if.body = [bottom_if_expr]
735 return top_if
737 def visit_assignment_statement(self, node):
738 left, _, right = node.children
740 # We need to check if the left side is a multidimensional array,
741 # Since tree-sitter incorrectly shows this assignment as a call_expression
742 if left.type == "call_expression":
743 return self._visit_set(node)
745 return Assignment(
746 left=self.visit(left),
747 right=self.visit(right),
748 source_refs=[self.node_helper.get_source_ref(node)],
749 )
751 def visit_literal(self, node) -> CASTLiteralValue:
752 """Visitor for literals. Returns a CASTLiteralValue"""
753 literal_type = node.type
754 literal_value = self.node_helper.get_identifier(node)
755 literal_source_ref = self.node_helper.get_source_ref(node)
757 if literal_type == "number_literal":
758 # Check if this is a real value, or an Integer
759 if "e" in literal_value.lower() or "." in literal_value:
760 return CASTLiteralValue(
761 value_type="AbstractFloat",
762 value=literal_value,
763 source_code_data_type=["Fortran", "Fortran95", "real"],
764 source_refs=[literal_source_ref],
765 )
766 else:
767 return CASTLiteralValue(
768 value_type="Integer",
769 value=literal_value,
770 source_code_data_type=["Fortran", "Fortran95", "integer"],
771 source_refs=[literal_source_ref],
772 )
774 elif literal_type == "string_literal":
775 return CASTLiteralValue(
776 value_type="Character",
777 value=literal_value,
778 source_code_data_type=["Fortran", "Fortran95", "character"],
779 source_refs=[literal_source_ref],
780 )
782 elif literal_type == "boolean_literal":
783 return CASTLiteralValue(
784 value_type="Boolean",
785 value=literal_value,
786 source_code_data_type=["Fortran", "Fortran95", "logical"],
787 source_refs=[literal_source_ref],
788 )
790 elif literal_type == "array_literal":
791 # There are a multiple ways to create an array literal. This visitor is for the traditional explicit creation (/ 1,2,3 /)
792 # For the do loop based version, we pass it off to another visitor
793 implied_do_loop_expression_node = get_first_child_by_type(
794 node, "implied_do_loop_expression"
795 )
796 if implied_do_loop_expression_node:
797 return self._visit_implied_do_loop(implied_do_loop_expression_node)
799 return CASTLiteralValue(
800 value_type="List",
801 value=[
802 self.visit(element) for element in get_non_control_children(node)
803 ],
804 source_code_data_type=["Fortran", "Fortran95", "dimension"],
805 source_refs=[literal_source_ref],
806 )
808 def visit_identifier(self, node):
809 # By default, this is unknown, but can be updated by other visitors
810 identifier = self.node_helper.get_identifier(node)
811 if self.variable_context.is_variable(identifier):
812 var_type = self.variable_context.get_type(identifier)
813 else:
814 var_type = "Unknown"
816 # Default value comes from Pytohn keyword arguments i.e. def foo(a, b=10)
817 # Fortran does have optional arguments introduced in F90, but these do not specify a default
818 default_value = None
820 # This is another case where we need to override the visitor to explicitly visit another node
821 value = self.visit_name(node)
823 return Var(
824 val=value,
825 type=var_type,
826 default_value=default_value,
827 source_refs=[self.node_helper.get_source_ref(node)],
828 )
830 def visit_math_expression(self, node):
831 op = self.node_helper.get_identifier(
832 get_control_children(node)[0]
833 ) # The operator will be the first control character
834 operands = []
835 for operand in get_non_control_children(node):
836 operands.append(self.visit(operand))
838 # For operators, we will only need the name node since we are not allocating space
839 if operand.type == "identifier":
840 operands[-1] = operands[-1].val
842 return Operator(
843 source_language="Fortran",
844 interpreter=None,
845 version=None,
846 op=op,
847 operands=operands,
848 source_refs=[self.node_helper.get_source_ref(node)],
849 )
851 def visit_variable_declaration(self, node) -> List:
852 """Visitor for variable declaration. Will return a List of Var and Assignment nodes."""
853 """
854 # Node structure
855 (variable_declaration)
856 (intrinsic_type)
857 (type_qualifier)
858 (qualifier)
859 (value)
860 (identifier) ...
861 (assignment_statement) ...
863 (variable_declaration)
864 (derived_type)
865 (type_name)
866 """
867 # A variable can be declared with an intrinsic_type if its built-in, or a derived_type if it is user defined.
868 intrinsic_type_node = get_first_child_by_type(node, "intrinsic_type")
869 derived_type_node = get_first_child_by_type(node, "derived_type")
871 variable_type = ""
872 variable_intent = ""
874 if intrinsic_type_node:
875 type_map = {
876 "integer": "Integer",
877 "real": "AbstractFloat",
878 "double precision": "AbstractFloat",
879 "complex": "Tuple", # Complex is a Tuple (rational,irrational),
880 "logical": "Boolean",
881 "character": "String",
882 }
883 # NOTE: Identifiers are case sensitive, so we always need to make sure we are comparing to the lower() version
884 variable_type = type_map[
885 self.node_helper.get_identifier(intrinsic_type_node).lower()
886 ]
887 elif derived_type_node:
888 variable_type = self.node_helper.get_identifier(
889 get_first_child_by_type(derived_type_node, "type_name", recurse=True),
890 )
892 # There are multiple type qualifiers that change the way we generate a variable
893 # For example, we need to determine if we are creating an array (dimension) or a single variable
894 type_qualifiers = get_children_by_types(node, ["type_qualifier"])
895 for qualifier in type_qualifiers:
896 field = self.node_helper.get_identifier(qualifier.children[0])
898 if field == "dimension":
899 variable_type = "List"
900 elif field == "intent":
901 variable_intent = self.node_helper.get_identifier(qualifier.children[1])
903 # You can declare multiple variables of the same type in a single statement, so we need to create a Var or Assignment node for each instance
904 definied_variables = get_children_by_types(
905 node,
906 [
907 "identifier", # Variable declaration
908 "assignment_statement", # Variable assignment
909 "call_expression", # Dimension without intent
910 ],
911 )
912 vars = []
913 for variable in definied_variables:
914 if variable.type == "assignment_statement":
915 if variable.children[0].type == "call_expression":
916 vars.append(
917 Assignment(
918 left=self.visit(
919 get_first_child_by_type(
920 variable.children[0], "identifier"
921 )
922 ),
923 right=self.visit(variable.children[2]),
924 source_refs=[self.node_helper.get_source_ref(variable)],
925 )
926 )
927 vars[-1].left.type = "List"
928 self.variable_context.update_type(vars[-1].left.val.name, "List")
929 else:
930 # If its a regular assignment, we can update the type normally
931 vars.append(self.visit(variable))
932 vars[-1].left.type = variable_type
933 self.variable_context.update_type(
934 vars[-1].left.val.name, variable_type
935 )
937 elif variable.type == "identifier":
938 # A basic variable declaration, we visit the identifier and then update the type
939 vars.append(self.visit(variable))
940 vars[-1].type = variable_type
941 self.variable_context.update_type(vars[-1].val.name, variable_type)
942 elif variable.type == "call_expression":
943 # Declaring a dimension variable using the x(1:5) format. It will look like a call expression in tree-sitter.
944 # We treat it like an identifier by visiting its identifier node. Then the type gets overridden by "dimension"
945 vars.append(self.visit(get_first_child_by_type(variable, "identifier")))
946 vars[-1].type = "List"
947 self.variable_context.update_type(vars[-1].val.name, "List")
949 # By default, all variables are added to a function's list of return values
950 # If the intent is actually in, then we need to remove them from the list
951 if variable_intent == "in":
952 for var in vars:
953 self.variable_context.remove_return_value(var.val.name)
955 return vars
957 def visit_extent_specifier(self, node):
958 # Node structure
959 # (extent_specifier)
960 # (identifier)
961 # (identifier)
963 # The extent specifier is the same as a slice, it can have a start, stop, and step
964 # We can determine these by looking at the number of control characters in this node.
965 # Fortran uses the character ':' to differentiate these values
966 argument_pointer = 0
967 arguments = [
968 CASTLiteralValue("None", "None"),
969 CASTLiteralValue("None", "None"),
970 CASTLiteralValue("None", "None"),
971 ]
972 for child in node.children:
973 if child.type == ":":
974 argument_pointer += 1
975 else:
976 arguments[argument_pointer] = self.visit(child)
978 return Call(
979 func=self.get_gromet_function_node("slice"),
980 source_language="Fortran",
981 source_language_version="Fortran95",
982 arguments=arguments,
983 source_refs=[self.node_helper.get_source_ref(node)],
984 )
986 def visit_derived_type(self, node: Node) -> RecordDef:
987 """Visitor function for derived type definition. Will return a RecordDef"""
988 """Node Structure:
989 (derived_type_definition)
990 (derived_type_statement)
991 (base)
992 (base_type_specifier)
993 (identifier)
994 (type_name)
995 (BODY_NODES)
996 ...
997 """
999 record_name = self.node_helper.get_identifier(
1000 get_first_child_by_type(node, "type_name", recurse=True)
1001 )
1003 # There is no multiple inheritance in Fortran, so a type may only extend 1 other type
1004 bases = []
1005 derived_type_statement_node = get_first_child_by_type(
1006 node, "derived_type_statement"
1007 )
1008 base_node = get_first_child_by_type(
1009 derived_type_statement_node, "identifier", recurse=True
1010 )
1011 if base_node:
1012 bases.append([self.node_helper.get_identifier(base_node)])
1014 # A derived type can contain symbols with the same name as those already in the main program body.
1015 # If we tell the variable context we are in a record definition, it will append the type name as a prefix to all defined variables.
1016 self.variable_context.enter_record_definition(record_name)
1018 # Note: In derived type declarations, functions are only declared. The actual definition will be in the outer module.
1019 funcs = []
1020 if derived_type_procedures_node := get_first_child_by_type(
1021 node, "derived_type_procedures"
1022 ):
1023 for procedure_node in get_children_by_types(
1024 derived_type_procedures_node, ["procedure_statement"]
1025 ):
1026 function_name = self.node_helper.get_identifier(get_first_child_by_type(procedure_node, "method_name", recurse=True))
1027 funcs.append(self.variable_context.register_module_function(function_name))
1030 # A derived type can only have variable declarations in its body.
1031 fields = []
1032 variable_declarations = [
1033 self.visit(variable_declaration)
1034 for variable_declaration in get_children_by_types(
1035 node, ["variable_declaration"]
1036 )
1037 ]
1038 for declaration in variable_declarations:
1039 # Variable declarations always returns a list of Var or Assignment, even when only one var is being created
1040 for var in declaration:
1041 if isinstance(var, Var):
1042 fields.append(var)
1043 elif isinstance(var, Assignment):
1044 # Since this is a record definition, an assignment is actually equivalent to setting the default value
1045 var.left.default_value = var.right
1046 fields.append(var.left)
1047 # TODO: Handle dimension type (Call type)
1048 elif isinstance(var, Call):
1049 pass
1050 # Leaving the record definition sets the prefix back to an empty string
1051 self.variable_context.exit_record_definition()
1053 return RecordDef(
1054 name=record_name,
1055 bases=bases,
1056 funcs=funcs,
1057 fields=fields,
1058 source_refs=[self.node_helper.get_source_ref(node)],
1059 )
1061 def visit_derived_type_member_expression(self, node) -> Attribute:
1062 """Visitor function for derived type access. Returns an Attribute object"""
1063 """ Node Structure
1064 Scalar Access
1065 (derived_type_member_expression)
1066 (identifier)
1067 (type_member)
1069 Dimensional Access
1070 (derived_type_member_expression)
1071 (call_expression)
1072 (identifier)
1073 (argument_list)
1074 (type_member)
1075 """
1077 # If we are accessing an attribute of a scalar type, we can simply pull the name node from the variable context.
1078 # However, if this is a dimensional type, we must convert it to a call to _get.
1079 call_expression_node = get_first_child_by_type(node, "call_expression")
1080 if call_expression_node:
1081 value = self._visit_get(call_expression_node)
1082 else:
1083 # We shouldn't be accessing get_node directly, since it may not exist in the case of an import.
1084 # Instead, we should visit the identifier node which will add it to the variable context automatically if it doesn't exist.
1085 value = self.visit(
1086 get_first_child_by_type(node, "identifier", recurse=True)
1087 )
1089 # NOTE: Attribue should be a Name node, NOT a string or Var node
1090 # attr = self.node_helper.get_identifier(
1091 # get_first_child_by_type(node, "type_member", recurse=True)
1092 # )
1093 #print(self.node_helper.get_identifier(get_first_child_by_type(node, "type_member", recurse=True)))
1094 attr = self.visit_name(get_first_child_by_type(node, "type_member"))
1096 return Attribute(
1097 value=value,
1098 attr=attr,
1099 source_refs=[self.node_helper.get_source_ref(node)],
1100 )
1102 # NOTE: This function starts with _ because it will never be dispatched to directly. There is not a get node in the tree-sitter parse tree.
1103 # From context, we will determine when we are accessing an element of a List, and call this function,
1104 def _visit_get(self, node):
1105 # Node structure
1106 # (call_expression)
1107 # (identifier)
1108 # (argument_list)
1110 identifier_node = node.children[0]
1111 argument_nodes = get_non_control_children(node.children[1])
1113 # First argument to get is the List itself. We can get this by passing the identifier to the identifier visitor
1114 arguments = []
1115 arguments.append(self.visit(identifier_node))
1117 # If there are more than one arguments, then this is a multi dimensional array and we need to use an extended slice
1118 if len(argument_nodes) > 1:
1119 dimension_list = CASTLiteralValue()
1120 dimension_list.value_type = "List"
1121 dimension_list.value = []
1122 for argument in argument_nodes:
1123 dimension_list.value.append(self.visit(argument))
1125 extslice_call = Call()
1126 extslice_call.func = self.get_gromet_function_node("ext_slice")
1127 extslice_call.arguments = []
1128 extslice_call.arguments.append(dimension_list)
1130 arguments.append(extslice_call)
1131 else:
1132 arguments.append(self.visit(argument_nodes[0]))
1134 return Call(
1135 func=self.get_gromet_function_node("get"),
1136 source_language="Fortran",
1137 source_language_version="Fortran95",
1138 arguments=arguments,
1139 source_refs=[self.node_helper.get_source_ref(node)],
1140 )
1142 def _visit_set(self, node):
1143 # Node structure
1144 # (assignment_statement)
1145 # (call_expression)
1146 # (right side)
1148 left, _, right = node.children
1150 # The left side is equivilent to a call gromet "get", so we will first pass the left side to the get visitor
1151 # Then we can easily convert this to a "set" call by modifying the fields and then appending the right side to the function arguments
1152 cast_call = self._visit_get(left)
1153 cast_call.func = self.get_gromet_function_node("set")
1154 cast_call.arguments.append(self.visit(right))
1156 return cast_call
1158 def _visit_while(self, node) -> Loop:
1159 """Custom visitor for while loop. Returns a Loop object"""
1160 """
1161 Node structure
1162 Do while
1163 (do_loop_statement)
1164 (while_statement)
1165 (parenthesized_expression)
1166 (...) ...
1167 (body) ...
1168 """
1169 while_statement_node = get_first_child_by_type(node, "while_statement")
1171 # Fortran has certain while(True) constructs that won't contain a while_statement node
1172 if not while_statement_node:
1173 body_start_index = 0
1174 expr = CASTLiteralValue(
1175 value_type="Boolean",
1176 value="True",
1177 )
1178 else:
1179 body_start_index = 1 + get_first_child_index(node, "while_statement")
1180 # We don't have explicit handling for parenthesized_expression, but the passthrough handler will make sure that we visit the expression correctly.
1181 expr = self.visit(
1182 get_first_child_by_type(
1183 while_statement_node, "parenthesized_expression"
1184 )
1185 )
1187 # The first body node will be the node after the while_statement
1188 body = self.generate_cast_body(node.children[body_start_index:])
1190 return Loop(
1191 pre=[],
1192 expr=expr,
1193 body=body,
1194 post=[],
1195 source_refs=[self.node_helper.get_source_ref(node)],
1196 )
1198 def _visit_implied_do_loop(self, node) -> Call:
1199 """Custom visitor for implied_do_loop array literal. This form gets converted to a call to range"""
1200 # TODO: This loop_control is the same as the do loop. Can we turn this into one visitor?
1201 loop_control_node = get_first_child_by_type(
1202 node, "loop_control_expression", recurse=True
1203 )
1204 loop_control_children = get_non_control_children(loop_control_node)
1205 if len(loop_control_children) == 3:
1206 itterator, start, stop = [
1207 self.visit(child) for child in loop_control_children
1208 ]
1209 step = CASTLiteralValue("Integer", "1")
1210 elif len(loop_control_children) == 4:
1211 itterator, start, stop, step = [
1212 self.visit(child) for child in loop_control_children
1213 ]
1214 else:
1215 itterator = None
1216 start = None
1217 stop = None
1218 step = None
1220 return Call(
1221 func=self.get_gromet_function_node("range"),
1222 source_language=None,
1223 source_language_version=None,
1224 arguments=[start, stop, step],
1225 source_refs=[self.node_helper.get_source_ref(node)],
1226 )
1228 def _visit_passthrough(self, node):
1229 if len(node.children) == 0:
1230 return None
1232 for child in node.children:
1233 child_cast = self.visit(child)
1234 if child_cast:
1235 return child_cast
1237 def _visit_no_op(self, node):
1238 """For unsupported idioms, we can generate a no op instruction so that the body is not empty"""
1239 return Call(
1240 func=self.get_gromet_function_node("no_op"),
1241 source_language=None,
1242 source_language_version=None,
1243 arguments=[],
1244 )
1246 def get_gromet_function_node(self, func_name: str) -> Name:
1247 # Idealy, we would be able to create a dummy node and just call the name visitor.
1248 # However, tree-sitter does not allow you to create or modify nodes, so we have to recreate the logic here.
1249 if self.variable_context.is_variable(func_name):
1250 return self.variable_context.get_node(func_name)
1252 return self.variable_context.add_variable(func_name, "function", None)
1254 def generate_cast_body(self, body_nodes: List):
1255 body = []
1257 for node in body_nodes:
1258 cast = self.visit(node)
1260 if isinstance(cast, AstNode):
1261 body.append(cast)
1262 elif isinstance(cast, List):
1263 body.extend([element for element in cast if element is not None])
1265 # Gromet doesn't support empty bodies, so we should create a no_op instead
1266 if len(body) == 0:
1267 body.append(self._visit_no_op(None))
1269 # TODO: How to add more support for source references
1270 return body