Coverage for skema/program_analysis/CAST2FN/ann_cast/id_collapse_pass.py: 93%

158 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1from re import A 

2import typing 

3from collections import defaultdict 

4from functools import singledispatchmethod 

5 

6from skema.program_analysis.CAST2FN.ann_cast.ann_cast_helpers import ( 

7 call_container_name, 

8) 

9from skema.program_analysis.CAST2FN.ann_cast.annotated_cast import * 

10from skema.program_analysis.CAST2FN.model.cast import ( 

11 ScalarType, 

12 StructureType, 

13 ValueConstructor, 

14) 

15 

16 

17class IdCollapsePass: 

18 def __init__(self, pipeline_state: PipelineState): 

19 self.pipeline_state = pipeline_state 

20 # cache Call nodes so after visiting we can determine which Call's have associated 

21 # FunctionDefs 

22 # this dict maps call container name to the AnnCastCall node 

23 self.cached_call_nodes: typing.Dict[str, AnnCastCall] = {} 

24 # during the pass, we collpase Name ids to a range starting from zero 

25 self.old_id_to_collapsed_id = {} 

26 # this tracks what collapsed ids we have used so far 

27 self.collapsed_id_counter = 0 

28 # dict mapping collapsed function id to number of invocations 

29 # used to populate `invocation_index` of AnnCastCall nodes 

30 self.func_invocation_counter = defaultdict(int) 

31 for node in self.pipeline_state.nodes: 

32 at_module_scope = False 

33 self.visit(node, at_module_scope) 

34 self.nodes = self.pipeline_state.nodes 

35 self.determine_function_defs_for_calls() 

36 self.store_highest_id() 

37 

38 def store_highest_id(self): 

39 self.pipeline_state.collapsed_id_counter = self.collapsed_id_counter 

40 

41 def collapse_id(self, id: int) -> int: 

42 """ 

43 Returns the collapsed id for id if it already exists, 

44 otherwise creates a collapsed id for it 

45 """ 

46 if id not in self.old_id_to_collapsed_id: 

47 self.old_id_to_collapsed_id[id] = self.collapsed_id_counter 

48 self.collapsed_id_counter += 1 

49 

50 return self.old_id_to_collapsed_id[id] 

51 

52 def next_function_invocation(self, coll_func_id: int) -> int: 

53 """ 

54 Returns the next invocation index for function with collapsed id `coll_func_id` 

55 """ 

56 index = self.func_invocation_counter[coll_func_id] 

57 self.func_invocation_counter[coll_func_id] += 1 

58 

59 return index 

60 

61 def determine_function_defs_for_calls(self): 

62 for call_name, call in self.cached_call_nodes.items(): 

63 if isinstance(call.func, AnnCastAttribute): 

64 func_id = call.func.attr.id 

65 else: 

66 func_id = call.func.id 

67 call.has_func_def = self.pipeline_state.func_def_exists(func_id) 

68 

69 # DEBUG printing 

70 if self.pipeline_state.PRINT_DEBUGGING_INFO: 

71 print(f"{call_name} has FunctionDef: {call.has_func_def}") 

72 

73 def visit(self, node: AnnCastNode, at_module_scope): 

74 # print current node being visited. 

75 # this can be useful for debugging 

76 # class_name = node.__class__.__name__ 

77 # print(f"\nProcessing node type {class_name}") 

78 try: 

79 return self._visit(node, at_module_scope) 

80 except Exception as e: 

81 print( 

82 f"id_collapse_pass.py: Error for {type(node)} which has source ref information {node.source_refs}" 

83 ) 

84 raise e 

85 

86 def visit_node_list( 

87 self, node_list: typing.List[AnnCastNode], at_module_scope 

88 ): 

89 return [self.visit(node, at_module_scope) for node in node_list] 

90 

91 @singledispatchmethod 

92 def _visit(self, node: AnnCastNode, at_module_scope): 

93 """ 

94 Visit each AnnCastNode, collapsing AnnCastName ids along the way 

95 """ 

96 print(node.source_refs[0]) 

97 raise Exception(f"Unimplemented AST node of type: {type(node)}") 

98 

99 @_visit.register 

100 def visit_assignment(self, node: AnnCastAssignment, at_module_scope): 

101 self.visit(node.right, at_module_scope) 

102 # The AnnCastTuple is added to handle scenarios where an assignment 

103 # is made by assigning to a tuple of values, as opposed to one singular value 

104 assert ( 

105 isinstance(node.left, AnnCastVar) 

106 or (isinstance(node.left, AnnCastLiteralValue) and (node.left.value_type == StructureType.TUPLE)) 

107 or isinstance(node.left, AnnCastAttribute) or isinstance(node.left, AnnCastCall) 

108 ), f"id_collapse: visit_assigment: node.left is {type(node.left)}" 

109 self.visit(node.left, at_module_scope) 

110 

111 @_visit.register 

112 def visit_attribute(self, node: AnnCastAttribute, at_module_scope): 

113 value = self.visit(node.value, at_module_scope) 

114 attr = self.visit(node.attr, at_module_scope) 

115 

116 @_visit.register 

117 def visit_call(self, node: AnnCastCall, at_module_scope): 

118 if isinstance(node.func, AnnCastLiteralValue): 

119 return 

120 

121 assert isinstance(node.func, AnnCastName) or isinstance( 

122 node.func, AnnCastAttribute 

123 ), f"node.func is type f{type(node.func)}" 

124 if isinstance(node.func, AnnCastName): 

125 node.func.id = self.collapse_id(node.func.id) 

126 node.invocation_index = self.next_function_invocation(node.func.id) 

127 else: 

128 if isinstance(node.func.value, AnnCastCall): 

129 self.visit(node.func.value, at_module_scope) 

130 elif isinstance(node.func.value, AnnCastAttribute): 

131 self.visit(node.func.value, at_module_scope) 

132 #elif isinstance(node.func.value, AnnCastSubscript): 

133 # self.visit(node.func.value, at_module_scope) 

134 elif isinstance(node.func.value, AnnCastOperator): 

135 self.visit(node.func.value, at_module_scope) 

136 elif isinstance(node.func.value, AnnCastAssignment): 

137 self.visit(node.func.value, at_module_scope) 

138 else: 

139 if not isinstance(node.func.value, AnnCastLiteralValue): 

140 node.func.value.id = self.collapse_id(node.func.value.id) 

141 node.func.attr.id = self.collapse_id(node.func.attr.id) 

142 node.invocation_index = self.next_function_invocation( 

143 node.func.attr.id 

144 ) 

145 

146 # cache Call node to later determine if this Call has a FunctionDef 

147 call_name = call_container_name(node) 

148 self.cached_call_nodes[call_name] = node 

149 

150 self.visit_node_list(node.arguments, at_module_scope) 

151 

152 @_visit.register 

153 def visit_record_def(self, node: AnnCastRecordDef, at_module_scope): 

154 at_module_scope = False 

155 

156 # Each base should be an AnnCastName node 

157 self.visit_node_list(node.bases, at_module_scope) 

158 

159 # Each func is an AnnCastFuncDef node 

160 self.visit_node_list(node.funcs, at_module_scope) 

161 

162 # Each field (attribute) is an AnnCastVar node 

163 self.visit_node_list(node.fields, at_module_scope) 

164 

165 @_visit.register 

166 def visit_function_def(self, node: AnnCastFunctionDef, at_module_scope): 

167 # collapse the function id 

168 node.name.id = self.collapse_id(node.name.id) 

169 self.pipeline_state.func_id_to_def[node.name.id] = node 

170 

171 at_module_scope = False 

172 self.visit_node_list(node.func_args, at_module_scope) 

173 self.visit_node_list(node.body, at_module_scope) 

174 

175 @_visit.register 

176 def visit_goto(self, node: AnnCastGoto, at_module_scope): 

177 if node.expr != None: 

178 self.visit(node.expr, at_module_scope) 

179 # self.visit(node.label, at_module_scope) 

180 

181 @_visit.register 

182 def visit_label(self, node: AnnCastLabel, at_module_scope): 

183 # self.visit(node.label, at_module_scope) 

184 pass 

185 

186 @_visit.register 

187 def visit_literal_value(self, node: AnnCastLiteralValue, at_module_scope): 

188 if node.value_type == "List[Any]": 

189 # operator - string 

190 # size - Var node or a LiteralValue node (for number) 

191 # initial_value - LiteralValue node 

192 val = node.value 

193 self.visit(val.size, at_module_scope) 

194 

195 # List literal doesn't need to add any other changes 

196 # to the anncast at this pass 

197 

198 elif node.value_type == StructureType.TUPLE: # or node.value_type == StructureType.LIST: 

199 self.visit_node_list(node.value, at_module_scope) 

200 elif node.value_type == ScalarType.INTEGER: 

201 pass 

202 elif node.value_type == ScalarType.ABSTRACTFLOAT: 

203 pass 

204 pass 

205 

206 @_visit.register 

207 def visit_loop(self, node: AnnCastLoop, at_module_scope): 

208 self.visit_node_list(node.pre, at_module_scope) 

209 self.visit(node.expr, at_module_scope) 

210 self.visit_node_list(node.body, at_module_scope) 

211 self.visit_node_list(node.post, at_module_scope) 

212 

213 @_visit.register 

214 def visit_model_break(self, node: AnnCastModelBreak, at_module_scope): 

215 pass 

216 

217 @_visit.register 

218 def visit_model_continue( 

219 self, node: AnnCastModelContinue, at_module_scope 

220 ): 

221 pass 

222 

223 @_visit.register 

224 def visit_model_if(self, node: AnnCastModelIf, at_module_scope): 

225 self.visit(node.expr, at_module_scope) 

226 self.visit_node_list(node.body, at_module_scope) 

227 self.visit_node_list(node.orelse, at_module_scope) 

228 

229 @_visit.register 

230 def visit_return(self, node: AnnCastModelReturn, at_module_scope): 

231 self.visit(node.value, at_module_scope) 

232 

233 @_visit.register 

234 def visit_model_import(self, node: AnnCastModelImport, at_module_scope): 

235 pass 

236 

237 @_visit.register 

238 def visit_module(self, node: AnnCastModule, at_module_scope): 

239 # we cache the module node in the AnnCast object 

240 self.pipeline_state.module_node = node 

241 at_module_scope = True 

242 self.visit_node_list(node.body, at_module_scope) 

243 

244 @_visit.register 

245 def visit_name(self, node: AnnCastName, at_module_scope): 

246 node.id = self.collapse_id(node.id) 

247 

248 # we consider name nodes at the module scope to be globals 

249 # and store them in the `used_vars` attribute of the module_node 

250 if at_module_scope: 

251 self.pipeline_state.module_node.used_vars[node.id] = node.name 

252 

253 @_visit.register 

254 def visit_operator(self, node: AnnCastOperator, at_module_scope): 

255 # visit operands 

256 self.visit_node_list(node.operands, at_module_scope) 

257 

258 @_visit.register 

259 def visit_var(self, node: AnnCastVar, at_module_scope): 

260 self.visit(node.val, at_module_scope) 

261 if node.default_value != None: 

262 self.visit(node.default_value, at_module_scope) 

263 

264 @_visit.register 

265 def visit_tuple(self, node: AnnCastTuple, at_module_scope): 

266 # Tuple of vars: Visit them all to collapse IDs, nothing else to be done I think 

267 self.visit_node_list(node.values, at_module_scope)