Coverage for skema/program_analysis/CAST/python/node_helper.py: 73%
60 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import itertools
2from typing import List, Dict
3from skema.program_analysis.CAST2FN.model.cast import SourceRef
5from tree_sitter import Node
7CONTROL_CHARACTERS = [
8 ",",
9 "=",
10 "==",
11 "[",
12 "]",
13 "(",
14 ")",
15 ":",
16 "+",
17 "-",
18 "*",
19 "**",
20 "/",
21 "!="
22 ">",
23 "<",
24 "<=",
25 ">=",
26 "in",
27 "not"
28]
30# Whatever constructs we see in the left
31# part of the for loop construct
32# for LEFT in RIGHT:
33FOR_LOOP_LEFT_TYPES = [
34 "identifier",
35 "tuple_pattern",
36 "pattern_list",
37 "list_pattern"
38]
40# Whatever constructs we see in the right
41# part of the for loop construct
42# for LEFT in RIGHT:
43FOR_LOOP_RIGHT_TYPES = [
44 "call",
45 "identifier",
46 "list",
47 "tuple",
48 "attribute"
49]
51# Whatever constructs we see in the conditional
52# part of the while loop
53WHILE_COND_TYPES = [
54 "boolean_operator",
55 "call",
56 "comparison_operator",
57 "binary_operator"
58]
60# Whatever constructs we see in the
61# list/dict comprehensions
62COMPREHENSION_OPERATORS = [
63 "binary_operator",
64 "call",
65 "identifier",
66 "attribute",
67 "pair"
68]
71class NodeHelper():
72 def __init__(self, source: str, source_file_name: str):
73 self.source = source
74 self.source_file_name = source_file_name
76 # get_identifier optimization variables
77 self.source_lines = source.splitlines(keepends=True)
78 self.line_lengths = [len(line) for line in self.source_lines]
79 self.line_length_sums = [0] + list(itertools.accumulate(self.line_lengths))
81 def get_identifier(self, node: Node) -> str:
82 """Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point"""
83 start_line, start_column = node.start_point
84 end_line, end_column = node.end_point
86 # Edge case for when an identifier is on the very first line of the code
87 # We can't index into the line_length_sums
88 start_index = self.line_length_sums[start_line] + start_column
89 if start_line == end_line:
90 end_index = start_index + (end_column-start_column)
91 else:
92 end_index = self.line_length_sums[end_line] + end_column
94 return self.source[start_index:end_index]
96 def get_source_ref(self, node: Node) -> SourceRef:
97 """Given a node and file name, return a CAST SourceRef object."""
98 row_start, col_start = node.start_point
99 row_end, col_end = node.end_point
100 return SourceRef(self.source_file_name, col_start, col_end, row_start, row_end)
102 def get_operator(self, node: Node) -> str:
103 """Given a unary/binary operator node, return the operator it contains"""
104 return node.type
106def get_first_child_by_type(node: Node, type: str, recurse=False):
107 """Takes in a node and a type string as inputs and returns the first child matching that type. Otherwise, return None
108 When the recurse argument is set, it will also recursivly search children nodes as well.
109 """
110 for child in node.children:
111 if child.type == type:
112 return child
114 if recurse:
115 for child in node.children:
116 out = get_first_child_by_type(child, type, True)
117 if out:
118 return out
119 return None
121def get_children_by_types(node: Node, types: List):
122 """Takes in a node and a list of types as inputs and returns all children matching those types. Otherwise, return an empty list"""
123 return [child for child in node.children if child.type in types]
125def get_first_child_index(node, type: str):
126 """Get the index of the first child of node with type type."""
127 for i, child in enumerate(node.children):
128 if child.type == type:
129 return i
132def get_last_child_index(node, type: str):
133 """Get the index of the last child of node with type type."""
134 last = None
135 for i, child in enumerate(node.children):
136 if child.type == type:
137 last = child
138 return last
141def get_control_children(node: Node):
142 return get_children_by_types(node, CONTROL_CHARACTERS)
145def get_non_control_children(node: Node):
146 children = []
147 for child in node.children:
148 if child.type not in CONTROL_CHARACTERS:
149 children.append(child)
151 return children