Coverage for skema/program_analysis/comment_extractor/comment_extractor.py: 99%

131 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import argparse 

2from typing import List, Dict, Tuple, Union, Optional 

3from pathlib import Path 

4 

5import yaml 

6from tree_sitter import Language, Parser, Node 

7from pydantic import BaseModel, Field 

8 

9from skema.program_analysis.comment_extractor.model import ( 

10 SingleLineComment, 

11 MultiLineComment, 

12 Docstring, 

13 SingleFileCommentRequest, 

14 SingleFileCommentResponse, 

15 MultiFileCommentRequest, 

16 MultiFileCommentResponse, 

17 SupportedLanguage, 

18 SupportedLanguageResponse, 

19) 

20from skema.program_analysis.tree_sitter_parsers.build_parsers import ( 

21 INSTALLED_LANGUAGES_FILEPATH, 

22 LANGUAGES_YAML_FILEPATH, 

23) 

24 

25QUERIES_FILEPATH = Path(__file__).parent / "queries.yaml" 

26 

27 

28def get_identifier(node: Node, source: str): 

29 """Given a tree-sitter node object, return the string representing the source between node.start_point and node.end_point""" 

30 line_num = 0 

31 column_num = 0 

32 in_identifier = False 

33 identifier = "" 

34 for i, char in enumerate(source): 

35 if line_num == node.start_point[0] and column_num == node.start_point[1]: 

36 in_identifier = True 

37 elif line_num == node.end_point[0] and column_num == node.end_point[1]: 

38 break 

39 

40 if char == "\n": 

41 line_num += 1 

42 column_num = 0 

43 else: 

44 column_num += 1 

45 

46 if in_identifier: 

47 identifier += char 

48 

49 return identifier 

50 

51 

52def node_to_single_line_comment(node: Node, source: str) -> SingleLineComment: 

53 """Converts a tree-sitter node to a SingleLineComment object""" 

54 content = get_identifier(node, source) 

55 line_number = node.start_point[0] 

56 

57 return SingleLineComment(content=content, line_number=line_number) 

58 

59 

60def node_to_multi_line_comment_partial(nodes: List[Node], source: str): 

61 """Converts a list of tree-sitter nodes to a single MultiLineCodeComment object""" 

62 content = [get_identifier(node, source) for node in nodes] 

63 start_line_number = nodes[0].start_point[0] 

64 end_line_number = nodes[-1].end_point[0] 

65 

66 return MultiLineComment( 

67 content=content, 

68 start_line_number=start_line_number, 

69 end_line_number=end_line_number, 

70 ) 

71 

72 

73def node_to_multi_line_comment(node: Node, source: str) -> Dict: 

74 """Converts a tree-sitter node to a MultiLineComment object""" 

75 content = get_identifier(node, source).split("\n") 

76 start_line_number = node.start_point[0] 

77 end_line_number = node.end_point[0] 

78 

79 return MultiLineComment( 

80 content=content, 

81 start_line_number=start_line_number, 

82 end_line_number=end_line_number, 

83 ) 

84 

85 

86def nodes_to_docstring_partial(name_node: Node, content_nodes: List[Node], source: str): 

87 """Converts a list of tree-sitter nodes to a single Docstring object""" 

88 content = [get_identifier(node, source) for node in content_nodes] 

89 function_name = get_identifier(name_node, source) 

90 start_line_number = content_nodes[0].start_point[0] 

91 end_line_number = content_nodes[-1].end_point[0] 

92 

93 return Docstring( 

94 content=content, 

95 function_name=function_name, 

96 start_line_number=start_line_number, 

97 end_line_number=end_line_number, 

98 ) 

99 

100 

101def nodes_to_docstring(name_node: Node, content_node: Node, source: str) -> Dict: 

102 """Converts a tree-sitter node to a Docstring object""" 

103 content = get_identifier(content_node, source).split("\n") 

104 function_name = get_identifier(name_node, source) 

105 start_line_number = content_node.start_point[0] 

106 end_line_number = content_node.end_point[0] 

107 

108 return Docstring( 

109 content=content, 

110 function_name=function_name, 

111 start_line_number=start_line_number, 

112 end_line_number=end_line_number, 

113 ) 

114 

115 

116def preprocess_captures(captures: List[Tuple[Node, str]]) -> List[Tuple[Node, str]]: 

117 """Preprocess list of captures generated by tree-sitter. This preprocessing includes: 

118 1. Reording captures to go docstring -> multi -> single 

119 2. Moving docstring name node after docstring body node in languages with internal function docstrings 

120 3. Removing duplicate nodes captured my multiple capture groups 

121 """ 

122 

123 def hash_node(node: Node): 

124 """Create a hashable tuple of tree-sitter node.""" 

125 return (node.type, node.start_point, node.end_point) 

126 

127 # Rearrange the captures to go Docstring -> Multi -> Single 

128 order = { 

129 "docstring_body": 1, 

130 "docstring_name": 1, 

131 "docstring_body_partial": 1, 

132 "multi": 2, 

133 "multi_partial": 2, 

134 "single": 3, 

135 } 

136 ordered = sorted(captures, key=lambda capture: (order[capture[1]])) 

137 

138 # The order of docstring_body and docstring_name can differ between languages 

139 # To standardize this, we will the docstring_name node to the back. 

140 # We do this by additionally sorting on the line number. 

141 if len(ordered) > 0 and ordered[0][1] == "docstring_name": 

142 ordered = ordered = sorted( 

143 captures, key=lambda capture: (order[capture[1]], -capture[0].start_byte) 

144 ) 

145 

146 output = [] 

147 duplicates = set() 

148 for comment in ordered: 

149 hash = hash_node(comment[0]) 

150 if hash in duplicates: 

151 continue 

152 output.append(comment) 

153 duplicates.add(hash) 

154 

155 return output 

156 

157 

158def extract_comments_single( 

159 request: SingleFileCommentRequest, 

160) -> SingleFileCommentResponse: 

161 # Get tree-sitter queries for given language 

162 queries_obj = yaml.safe_load(open(QUERIES_FILEPATH)) 

163 if request.language not in queries_obj: 

164 return None 

165 queries = queries_obj[request.language] 

166 

167 language_obj = Language(INSTALLED_LANGUAGES_FILEPATH, request.language) 

168 

169 # Parse source and run query with tree-sitter 

170 parser = Parser() 

171 parser.set_language(language_obj) 

172 tree = parser.parse(bytes(request.source, encoding="UTF-8")) 

173 captures = language_obj.query(queries_obj[request.language]).captures( 

174 tree.root_node 

175 ) 

176 captures = preprocess_captures(captures) 

177 

178 # Loop over captures, converting them to CodeComment objects. 

179 # There are currently 5 types of capture groups supported by the comment extarctor 

180 # 1. single - Single line comment (i.e. Fortran '!') 

181 # 2. multi - Multi line comment (i.e. C '/* */') 

182 # 3. multi_partial - Adjacent single line comments in languages that don't have a multi line comment token 

183 # 4. docstring_body - Docstring comment (def foo():\n""" """) 

184 # 5. docstring_body_partial - Adjacent single line docstring comments in languages that don't have a multi line comment token 

185 single, multi, docstring = ([], [], []) 

186 index = 0 

187 while index < len(captures): 

188 node, type = captures[index] 

189 if type == "single": 

190 single.append(node_to_single_line_comment(node, request.source)) 

191 elif type == "multi": 

192 multi.append(node_to_multi_line_comment(node, request.source)) 

193 elif type == "multi_partial": 

194 # For partial multi line comments, we have to determine the stopping point. 

195 # We do this by checking the the line number of each capture. If its >1 line away from the previous capture, we have hit the stopping point. 

196 multi_start_index = index 

197 multi_end_index = None 

198 for i in range(index, len(captures) - 1): 

199 current_line_number = captures[i][0].end_point[0] 

200 next_line_number = captures[i + 1][0].end_point[0] 

201 if next_line_number - current_line_number != 1: 

202 multi_end_index = i + 1 

203 break 

204 

205 multi_body_nodes = [ 

206 node for node, _ in captures[multi_start_index:multi_end_index] 

207 ] 

208 multi.append( 

209 node_to_multi_line_comment_partial(multi_body_nodes, request.source) 

210 ) 

211 index = multi_end_index 

212 continue 

213 elif type == "docstring_body": 

214 docstring_name_node = captures[index + 1][0] 

215 docstring.append( 

216 nodes_to_docstring(docstring_name_node, node, request.source) 

217 ) 

218 index += 2 

219 continue 

220 elif type == "docstring_body_partial": 

221 # Similar to partial multi line comments, partial docstrings also need a stopping index. 

222 # Due to the preprocessing earlier, the stop index will always be the index of the next docstring_name node. 

223 docstring_name_index = -1 

224 for i in range(index, len(captures)): 

225 if captures[i][1] == "docstring_name": 

226 docstring_name_index = i 

227 break 

228 docstring_name_node = captures[docstring_name_index][0] 

229 docstring_body_nodes = [ 

230 node for node, _ in captures[index:docstring_name_index] 

231 ] 

232 docstring.append( 

233 nodes_to_docstring_partial( 

234 docstring_name_node, docstring_body_nodes, request.source 

235 ) 

236 ) 

237 index = docstring_name_index + 1 

238 continue 

239 index += 1 

240 

241 return SingleFileCommentResponse(single=single, multi=multi, docstring=docstring) 

242 

243 

244def extract_comments_multi( 

245 request: MultiFileCommentRequest, 

246) -> MultiFileCommentResponse: 

247 """Wrapper for processing multiple source files at a time.""" 

248 return MultiFileCommentResponse(** 

249 { 

250 "files": { 

251 file_name: extract_comments_single(file_request) 

252 for file_name, file_request in request.files.items() 

253 } 

254 } 

255 ) 

256 

257 

258def get_supported_languages() -> SupportedLanguageResponse: 

259 """Return the languages supported for comment extraction""" 

260 

261 # To determine the supported file extensions for each language, we need to read the language.yaml file used by the tree-sitter build system. 

262 languages_obj = yaml.safe_load(LANGUAGES_YAML_FILEPATH.read_text()) 

263 

264 # The other fields can be determined by reading the queries in queries.yaml 

265 queries_obj = yaml.safe_load(QUERIES_FILEPATH.read_text()) 

266 languages = [] 

267 for language, query in queries_obj.items(): 

268 single = True if "@single" in query else False 

269 multi = True if "@multi" in query or "@multi_partial" in query else False 

270 docstring = ( 

271 True 

272 if "@docstring_body" in query or "docstring_body_partial" in query 

273 else False 

274 ) 

275 extensions = languages_obj[language]["extensions"] 

276 languages.append( 

277 SupportedLanguage( 

278 name=language, 

279 single=single, 

280 multi=multi, 

281 docstring=docstring, 

282 extensions=extensions, 

283 ) 

284 ) 

285 

286 return SupportedLanguageResponse(languages=languages)