Coverage for skema/program_analysis/CAST/fortran/node_helper.py: 86%

63 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import itertools 

2from typing import List, Dict 

3 

4from tree_sitter import Node 

5 

6from skema.program_analysis.CAST2FN.model.cast import SourceRef 

7 

8CONTROL_CHARACTERS = [ 

9 ",", 

10 "=", 

11 "==", 

12 "(", 

13 ")", 

14 "(/", 

15 "/)", 

16 ":", 

17 "::", 

18 "+", 

19 "-", 

20 "*", 

21 "**", 

22 "/", 

23 "/=", 

24 ">", 

25 "<", 

26 "<=", 

27 ">=", 

28 "only", 

29 "\.not\.", 

30 "\.gt\.", 

31 "\.ge\.", 

32 "\.lt\.", 

33 "\.le\.", 

34 "\.eq\.", 

35 "\.ne\.", 

36] 

37 

38class NodeHelper(): 

39 def __init__(self, source: str, source_file_name: str): 

40 self.source = source 

41 self.source_file_name = source_file_name 

42 

43 # get_identifier optimization variables 

44 self.source_lines = source.splitlines(keepends=True) 

45 self.line_lengths = [len(line) for line in self.source_lines] 

46 self.line_length_sums = list(itertools.accumulate(self.line_lengths))#[sum(self.line_lengths[:i+1]) for i in range(len(self.source_lines))] 

47 

48 def get_source_ref(self, node: Node) -> SourceRef: 

49 """Given a node and file name, return a CAST SourceRef object.""" 

50 row_start, col_start = node.start_point 

51 row_end, col_end = node.end_point 

52 return SourceRef(self.source_file_name, col_start, col_end, row_start, row_end) 

53 

54 

55 def get_identifier(self, node: Node) -> str: 

56 """Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point""" 

57 start_line, start_column = node.start_point 

58 end_line, end_column = node.end_point 

59 

60 start_index = self.line_length_sums[start_line-1] + start_column 

61 if start_line == end_line: 

62 end_index = start_index + (end_column-start_column) 

63 else: 

64 end_index = self.line_length_sums[end_line] + end_column 

65 

66 return self.source[start_index:end_index] 

67 

68def remove_comments(node: Node): 

69 """Remove comment nodes from tree-sitter parse tree""" 

70 # NOTE: tree-sitter Node objects are read-only, so we have to be careful about how we remove comments 

71 # The below has been carefully designed to work around this restriction. 

72 to_remove = sorted([index for index,child in enumerate(node.children) if child.type == "comment"], reverse=True) 

73 for index in to_remove: 

74 del node.children[index] 

75 

76 for i in range(len(node.children)): 

77 node.children[i] = remove_comments(node.children[i]) 

78 

79 return node 

80 

81def get_first_child_by_type(node: Node, type: str, recurse=False): 

82 """Takes in a node and a type string as inputs and returns the first child matching that type. Otherwise, return None 

83 When the recurse argument is set, it will also recursivly search children nodes as well. 

84 """ 

85 for child in node.children: 

86 if child.type == type: 

87 return child 

88 

89 if recurse: 

90 for child in node.children: 

91 out = get_first_child_by_type(child, type, True) 

92 if out: 

93 return out 

94 return None 

95 

96 

97def get_children_by_types(node: Node, types: List): 

98 """Takes in a node and a list of types as inputs and returns all children matching those types. Otherwise, return an empty list""" 

99 return [child for child in node.children if child.type in types] 

100 

101def get_children_except_types(node: Node, types: List): 

102 """Takes in a node and a list of types as inputs and returns all children not matching those types. Otherwise, return an empty list""" 

103 return [child for child in node.children if child.type not in types] 

104 

105def get_first_child_index(node, type: str): 

106 """Get the index of the first child of node with type type.""" 

107 for i, child in enumerate(node.children): 

108 if child.type == type: 

109 return i 

110 

111 

112def get_last_child_index(node, type: str): 

113 """Get the index of the last child of node with type type.""" 

114 last = None 

115 for i, child in enumerate(node.children): 

116 if child.type == type: 

117 last = child 

118 return last 

119 

120 

121def get_control_children(node: Node): 

122 return get_children_by_types(node, CONTROL_CHARACTERS) 

123 

124 

125def get_non_control_children(node: Node): 

126 children = [] 

127 for child in node.children: 

128 if child.type not in CONTROL_CHARACTERS: 

129 children.append(child) 

130 

131 return children