Coverage for skema/program_analysis/CAST/fortran/node_helper.py: 86%
63 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import itertools
2from typing import List, Dict
4from tree_sitter import Node
6from skema.program_analysis.CAST2FN.model.cast import SourceRef
8CONTROL_CHARACTERS = [
9 ",",
10 "=",
11 "==",
12 "(",
13 ")",
14 "(/",
15 "/)",
16 ":",
17 "::",
18 "+",
19 "-",
20 "*",
21 "**",
22 "/",
23 "/=",
24 ">",
25 "<",
26 "<=",
27 ">=",
28 "only",
29 "\.not\.",
30 "\.gt\.",
31 "\.ge\.",
32 "\.lt\.",
33 "\.le\.",
34 "\.eq\.",
35 "\.ne\.",
36]
38class NodeHelper():
39 def __init__(self, source: str, source_file_name: str):
40 self.source = source
41 self.source_file_name = source_file_name
43 # get_identifier optimization variables
44 self.source_lines = source.splitlines(keepends=True)
45 self.line_lengths = [len(line) for line in self.source_lines]
46 self.line_length_sums = list(itertools.accumulate(self.line_lengths))#[sum(self.line_lengths[:i+1]) for i in range(len(self.source_lines))]
48 def get_source_ref(self, node: Node) -> SourceRef:
49 """Given a node and file name, return a CAST SourceRef object."""
50 row_start, col_start = node.start_point
51 row_end, col_end = node.end_point
52 return SourceRef(self.source_file_name, col_start, col_end, row_start, row_end)
55 def get_identifier(self, node: Node) -> str:
56 """Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point"""
57 start_line, start_column = node.start_point
58 end_line, end_column = node.end_point
60 start_index = self.line_length_sums[start_line-1] + start_column
61 if start_line == end_line:
62 end_index = start_index + (end_column-start_column)
63 else:
64 end_index = self.line_length_sums[end_line] + end_column
66 return self.source[start_index:end_index]
68def remove_comments(node: Node):
69 """Remove comment nodes from tree-sitter parse tree"""
70 # NOTE: tree-sitter Node objects are read-only, so we have to be careful about how we remove comments
71 # The below has been carefully designed to work around this restriction.
72 to_remove = sorted([index for index,child in enumerate(node.children) if child.type == "comment"], reverse=True)
73 for index in to_remove:
74 del node.children[index]
76 for i in range(len(node.children)):
77 node.children[i] = remove_comments(node.children[i])
79 return node
81def get_first_child_by_type(node: Node, type: str, recurse=False):
82 """Takes in a node and a type string as inputs and returns the first child matching that type. Otherwise, return None
83 When the recurse argument is set, it will also recursivly search children nodes as well.
84 """
85 for child in node.children:
86 if child.type == type:
87 return child
89 if recurse:
90 for child in node.children:
91 out = get_first_child_by_type(child, type, True)
92 if out:
93 return out
94 return None
97def get_children_by_types(node: Node, types: List):
98 """Takes in a node and a list of types as inputs and returns all children matching those types. Otherwise, return an empty list"""
99 return [child for child in node.children if child.type in types]
101def get_children_except_types(node: Node, types: List):
102 """Takes in a node and a list of types as inputs and returns all children not matching those types. Otherwise, return an empty list"""
103 return [child for child in node.children if child.type not in types]
105def get_first_child_index(node, type: str):
106 """Get the index of the first child of node with type type."""
107 for i, child in enumerate(node.children):
108 if child.type == type:
109 return i
112def get_last_child_index(node, type: str):
113 """Get the index of the last child of node with type type."""
114 last = None
115 for i, child in enumerate(node.children):
116 if child.type == type:
117 last = child
118 return last
121def get_control_children(node: Node):
122 return get_children_by_types(node, CONTROL_CHARACTERS)
125def get_non_control_children(node: Node):
126 children = []
127 for child in node.children:
128 if child.type not in CONTROL_CHARACTERS:
129 children.append(child)
131 return children