Coverage for skema/program_analysis/CAST/python/node_helper.py: 73%

60 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import itertools 

2from typing import List, Dict 

3from skema.program_analysis.CAST2FN.model.cast import SourceRef 

4 

5from tree_sitter import Node 

6 

7CONTROL_CHARACTERS = [ 

8 ",", 

9 "=", 

10 "==", 

11 "[", 

12 "]", 

13 "(", 

14 ")", 

15 ":", 

16 "+", 

17 "-", 

18 "*", 

19 "**", 

20 "/", 

21 "!=" 

22 ">", 

23 "<", 

24 "<=", 

25 ">=", 

26 "in", 

27 "not" 

28] 

29 

30# Whatever constructs we see in the left 

31# part of the for loop construct 

32# for LEFT in RIGHT:  

33FOR_LOOP_LEFT_TYPES = [ 

34 "identifier", 

35 "tuple_pattern", 

36 "pattern_list", 

37 "list_pattern" 

38] 

39 

40# Whatever constructs we see in the right 

41# part of the for loop construct 

42# for LEFT in RIGHT:  

43FOR_LOOP_RIGHT_TYPES = [ 

44 "call", 

45 "identifier", 

46 "list", 

47 "tuple", 

48 "attribute" 

49] 

50 

51# Whatever constructs we see in the conditional 

52# part of the while loop 

53WHILE_COND_TYPES = [ 

54 "boolean_operator", 

55 "call", 

56 "comparison_operator", 

57 "binary_operator" 

58] 

59 

60# Whatever constructs we see in the  

61# list/dict comprehensions 

62COMPREHENSION_OPERATORS = [ 

63 "binary_operator", 

64 "call", 

65 "identifier", 

66 "attribute", 

67 "pair" 

68] 

69 

70 

71class NodeHelper(): 

72 def __init__(self, source: str, source_file_name: str): 

73 self.source = source 

74 self.source_file_name = source_file_name 

75 

76 # get_identifier optimization variables 

77 self.source_lines = source.splitlines(keepends=True) 

78 self.line_lengths = [len(line) for line in self.source_lines] 

79 self.line_length_sums = [0] + list(itertools.accumulate(self.line_lengths)) 

80 

81 def get_identifier(self, node: Node) -> str: 

82 """Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point""" 

83 start_line, start_column = node.start_point 

84 end_line, end_column = node.end_point 

85 

86 # Edge case for when an identifier is on the very first line of the code 

87 # We can't index into the line_length_sums 

88 start_index = self.line_length_sums[start_line] + start_column 

89 if start_line == end_line: 

90 end_index = start_index + (end_column-start_column) 

91 else: 

92 end_index = self.line_length_sums[end_line] + end_column 

93 

94 return self.source[start_index:end_index] 

95 

96 def get_source_ref(self, node: Node) -> SourceRef: 

97 """Given a node and file name, return a CAST SourceRef object.""" 

98 row_start, col_start = node.start_point 

99 row_end, col_end = node.end_point 

100 return SourceRef(self.source_file_name, col_start, col_end, row_start, row_end) 

101 

102 def get_operator(self, node: Node) -> str: 

103 """Given a unary/binary operator node, return the operator it contains""" 

104 return node.type 

105 

106def get_first_child_by_type(node: Node, type: str, recurse=False): 

107 """Takes in a node and a type string as inputs and returns the first child matching that type. Otherwise, return None 

108 When the recurse argument is set, it will also recursivly search children nodes as well. 

109 """ 

110 for child in node.children: 

111 if child.type == type: 

112 return child 

113 

114 if recurse: 

115 for child in node.children: 

116 out = get_first_child_by_type(child, type, True) 

117 if out: 

118 return out 

119 return None 

120 

121def get_children_by_types(node: Node, types: List): 

122 """Takes in a node and a list of types as inputs and returns all children matching those types. Otherwise, return an empty list""" 

123 return [child for child in node.children if child.type in types] 

124 

125def get_first_child_index(node, type: str): 

126 """Get the index of the first child of node with type type.""" 

127 for i, child in enumerate(node.children): 

128 if child.type == type: 

129 return i 

130 

131 

132def get_last_child_index(node, type: str): 

133 """Get the index of the last child of node with type type.""" 

134 last = None 

135 for i, child in enumerate(node.children): 

136 if child.type == type: 

137 last = child 

138 return last 

139 

140 

141def get_control_children(node: Node): 

142 return get_children_by_types(node, CONTROL_CHARACTERS) 

143 

144 

145def get_non_control_children(node: Node): 

146 children = [] 

147 for child in node.children: 

148 if child.type not in CONTROL_CHARACTERS: 

149 children.append(child) 

150 

151 return children 

152