Coverage for skema/program_analysis/CAST2FN/ann_cast/grfn_assignment_pass.py: 77%

161 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import typing 

2from functools import singledispatchmethod 

3 

4from skema.model_assembly.metadata import LambdaType 

5from skema.program_analysis.CAST2FN.ann_cast.ann_cast_helpers import ( 

6 GrfnAssignment, 

7 ann_cast_name_to_fullid, 

8 create_grfn_assign_node, 

9 create_grfn_literal_node, 

10 create_grfn_pack_node, 

11 create_grfn_unpack_node, 

12 create_lambda_node_metadata, 

13 is_literal_assignment, 

14) 

15from skema.program_analysis.CAST2FN.ann_cast.annotated_cast import * 

16from skema.program_analysis.CAST2FN.model.cast import ( 

17 ScalarType, 

18 StructureType, 

19 ValueConstructor, 

20) 

21 

22 

23class GrfnAssignmentPass: 

24 def __init__(self, pipeline_state: PipelineState): 

25 self.pipeline_state = pipeline_state 

26 self.nodes = self.pipeline_state.nodes 

27 # Any other state variables that are needed during 

28 # the pass 

29 for node in self.pipeline_state.nodes: 

30 add_to = {} 

31 self.visit(node, add_to) 

32 

33 def visit(self, node: AnnCastNode, add_to: typing.Dict): 

34 """ 

35 `add_to` is either the input or outputs to an GrFN Assignment/Literal node 

36 When visiting variable nodes, we add the variable to this `add_to` dict. 

37 When visiting call nodes, we add the function return value to this `add_to` dict. 

38 """ 

39 # print current node being visited. 

40 # this can be useful for debugging 

41 # class_name = node.__class__.__name__ 

42 # print(f"\nProcessing node type {class_name}") 

43 

44 # call internal visit 

45 return self._visit(node, add_to) 

46 

47 def visit_node_list( 

48 self, node_list: typing.List[AnnCastNode], add_to: typing.Dict 

49 ): 

50 return [self.visit(node, add_to) for node in node_list] 

51 

52 @singledispatchmethod 

53 def _visit(self, node: AnnCastNode, add_to: typing.Dict): 

54 """ 

55 `add_to` is either the input or outputs to an GrFN Assignment/Literal node 

56 When visiting variable nodes, we add the variable to this `add_to` dict. 

57 When visiting call nodes, we add the function return value to this `add_to` dict. 

58 """ 

59 raise NameError(f"Unrecognized node type: {type(node)}") 

60 

61 @_visit.register 

62 def visit_assignment(self, node: AnnCastAssignment, add_to: typing.Dict): 

63 # create the LambdaNode 

64 metadata = create_lambda_node_metadata(node.source_refs) 

65 if is_literal_assignment(node.right): 

66 node.grfn_assignment = GrfnAssignment( 

67 create_grfn_literal_node(metadata), LambdaType.LITERAL 

68 ) 

69 elif isinstance(node.left, AnnCastTuple): 

70 node.grfn_assignment = GrfnAssignment( 

71 create_grfn_unpack_node(metadata), LambdaType.UNPACK 

72 ) 

73 else: 

74 node.grfn_assignment = GrfnAssignment( 

75 create_grfn_assign_node(metadata), LambdaType.ASSIGN 

76 ) 

77 

78 self.visit(node.right, node.grfn_assignment.inputs) 

79 # The AnnCastTuple is added to handle scenarios where an assignment 

80 # is made by assigning to a tuple of values, as opposed to one singular value 

81 assert ( 

82 isinstance(node.left, AnnCastVar) 

83 or (isinstance(node.left, AnnCastLiteralValue) and (node.left.value_type == StructureType.TUPLE)) 

84 or isinstance(node.left, AnnCastAttribute) 

85 ), f"container_scope: visit_assigment: node.left is not AnnCastVar or AnnCastTuple it is {type(node.left)}" 

86 self.visit(node.left, node.grfn_assignment.outputs) 

87 

88 # DEBUG printing 

89 if self.pipeline_state.PRINT_DEBUGGING_INFO: 

90 print( 

91 f"GrFN {node.grfn_assignment.assignment_type} after visiting children:" 

92 ) 

93 print( 

94 f" grfn_assignment.inputs: {node.grfn_assignment.inputs}" 

95 ) 

96 print( 

97 f" grfn_assignment.outputs: {node.grfn_assignment.outputs}" 

98 ) 

99 

100 @_visit.register 

101 def visit_attribute(self, node: AnnCastAttribute, add_to: typing.Dict): 

102 pass 

103 

104 @_visit.register 

105 def visit_call(self, node: AnnCastCall, add_to: typing.Dict): 

106 if node.is_grfn_2_2: 

107 self.visit_call_grfn_2_2(node, add_to) 

108 return 

109 

110 # add ret_val to add_to dict 

111 for id, fullid in node.out_ret_val.items(): 

112 grfn_var = self.pipeline_state.get_grfn_var(fullid) 

113 add_to[fullid] = grfn_var.uid 

114 

115 # populate `arg_assignments` attribute of node 

116 for i, n in enumerate(node.arguments): 

117 # grab GrFN variable for argument 

118 if i in node.arg_index_to_fullid.keys(): # NOTE: M7 Placeholder 

119 arg_fullid = node.arg_index_to_fullid[i] 

120 arg_grfn_var = self.pipeline_state.get_grfn_var(arg_fullid) 

121 

122 # create GrfnAssignment based on assignment type 

123 metadata = create_lambda_node_metadata(node.source_refs) 

124 if is_literal_assignment(n): 

125 arg_assignment = GrfnAssignment( 

126 create_grfn_literal_node(metadata), LambdaType.LITERAL 

127 ) 

128 else: 

129 arg_assignment = GrfnAssignment( 

130 create_grfn_assign_node(metadata), LambdaType.ASSIGN 

131 ) 

132 

133 # store argument as output to GrfnAssignment 

134 arg_assignment.outputs[arg_fullid] = arg_grfn_var.uid 

135 # populate GrfnAssignment inputs for arguments 

136 self.visit(n, arg_assignment.inputs) 

137 # store GrfnAssignment for this argument 

138 node.arg_assignments[i] = arg_assignment 

139 

140 # DEBUG printing 

141 if self.pipeline_state.PRINT_DEBUGGING_INFO: 

142 print(f"Call after processing arguments:") 

143 for pos, grfn_assignment in node.arg_assignments.items(): 

144 print(f" {pos} : {str(grfn_assignment)}") 

145 

146 def visit_call_grfn_2_2(self, node: AnnCastCall, add_to: typing.Dict): 

147 assert isinstance(node.func, AnnCastName) 

148 # add ret_val to add_to dict 

149 for id, fullid in node.out_ret_val.items(): 

150 grfn_var = self.pipeline_state.get_grfn_var(fullid) 

151 add_to[fullid] = grfn_var.uid 

152 

153 # populate `arg_assignments` attribute of node 

154 for i, n in enumerate(node.arguments): 

155 # grab GrFN variable for argument 

156 arg_fullid = node.arg_index_to_fullid[i] 

157 arg_grfn_var = self.pipeline_state.get_grfn_var(arg_fullid) 

158 

159 # create GrfnAssignment based on assignment type 

160 metadata = create_lambda_node_metadata(node.source_refs) 

161 if is_literal_assignment(n): 

162 arg_assignment = GrfnAssignment( 

163 create_grfn_literal_node(metadata), LambdaType.LITERAL 

164 ) 

165 else: 

166 arg_assignment = GrfnAssignment( 

167 create_grfn_assign_node(metadata), LambdaType.ASSIGN 

168 ) 

169 

170 # store argument as output to GrfnAssignment 

171 arg_assignment.outputs[arg_fullid] = arg_grfn_var.uid 

172 # populate GrfnAssignment inputs for arguments 

173 self.visit(n, arg_assignment.inputs) 

174 # store GrfnAssignment for this argument 

175 node.arg_assignments[i] = arg_assignment 

176 

177 self.visit_function_def(node.func_def_copy, {}) 

178 

179 # DEBUG printing 

180 if self.pipeline_state.PRINT_DEBUGGING_INFO: 

181 print(f"Call after processing arguments:") 

182 for pos, grfn_assignment in node.arg_assignments.items(): 

183 print(f" {pos} : {str(grfn_assignment)}") 

184 

185 @_visit.register 

186 def visit_record_def(self, node: AnnCastRecordDef, add_to: typing.Dict): 

187 pass 

188 

189 @_visit.register 

190 def visit_function_def( 

191 self, node: AnnCastFunctionDef, add_to: typing.Dict 

192 ): 

193 # linking function arguments to formal parameters through the top 

194 # interface is handled during VariableVersionPass, so we don't need to visit 

195 # func_args here 

196 self.visit_node_list(node.body, add_to) 

197 

198 @_visit.register 

199 def visit_goto(self, node: AnnCastGoto, add_to): 

200 if node.expr != None: 

201 self.visit(node.expr, add_to) 

202 # self.visit(node.label, add_to) 

203 

204 @_visit.register 

205 def visit_label(self, node: AnnCastLabel, add_to): 

206 # self.visit(node.label, add_to) 

207 pass 

208 

209 @_visit.register 

210 def visit_literal_value( 

211 self, node: AnnCastLiteralValue, add_to: typing.Dict 

212 ): 

213 if node.value_type == "List[Any]": 

214 # val has 

215 # operator - string 

216 # size - Var node or a LiteralValue node (for number) 

217 # initial_value - LiteralValue node 

218 val = node.value 

219 

220 # visit size's anncast name node 

221 self.visit(val.size, add_to) 

222 

223 # List literal doesn't need to add any other changes 

224 # to the anncast at this pass 

225 

226 elif node.value_type == StructureType.TUPLE: # or node.value_type == StructureType.LIST: 

227 self.visit_node_list(node.value, add_to) 

228 elif node.value_type == ScalarType.INTEGER: 

229 pass 

230 elif node.value_type == ScalarType.ABSTRACTFLOAT: 

231 pass 

232 pass 

233 

234 @_visit.register 

235 def visit_loop(self, node: AnnCastLoop, add_to: typing.Dict): 

236 self.visit_node_list(node.pre, add_to) 

237 self.visit(node.expr, add_to) 

238 self.visit_node_list(node.body, add_to) 

239 self.visit_node_list(node.post, add_to) 

240 

241 @_visit.register 

242 def visit_model_break(self, node: AnnCastModelBreak, add_to: typing.Dict): 

243 pass 

244 

245 @_visit.register 

246 def visit_model_continue( 

247 self, node: AnnCastModelContinue, add_to: typing.Dict 

248 ): 

249 pass 

250 

251 @_visit.register 

252 def visit_model_import( 

253 self, node: AnnCastModelImport, add_to: typing.Dict 

254 ): 

255 pass 

256 

257 @_visit.register 

258 def visit_model_if(self, node: AnnCastModelIf, add_to: typing.Dict): 

259 self.visit(node.expr, add_to) 

260 self.visit_node_list(node.body, add_to) 

261 self.visit_node_list(node.orelse, add_to) 

262 

263 @_visit.register 

264 def visit_model_return( 

265 self, node: AnnCastModelReturn, add_to: typing.Dict 

266 ): 

267 # create the assignment LambdaNode for this return statement 

268 metadata = create_lambda_node_metadata(node.source_refs) 

269 if is_literal_assignment(node.value): 

270 node.grfn_assignment = GrfnAssignment( 

271 create_grfn_literal_node(metadata), LambdaType.LITERAL 

272 ) 

273 elif isinstance(node.value, AnnCastTuple): 

274 node.grfn_assignment = GrfnAssignment( 

275 create_grfn_pack_node(metadata), LambdaType.PACK 

276 ) 

277 else: 

278 node.grfn_assignment = GrfnAssignment( 

279 create_grfn_assign_node(metadata), LambdaType.ASSIGN 

280 ) 

281 

282 self.visit(node.value, node.grfn_assignment.inputs) 

283 

284 for id, fullid in node.owning_func_def.in_ret_val.items(): 

285 grfn_var = self.pipeline_state.get_grfn_var(fullid) 

286 node.grfn_assignment.outputs[fullid] = grfn_var.uid 

287 

288 # DEBUG printing 

289 if self.pipeline_state.PRINT_DEBUGGING_INFO: 

290 print( 

291 f"GrFN RETURN with type {node.grfn_assignment.assignment_type} after visiting children:" 

292 ) 

293 print( 

294 f" grfn_assignment.inputs: {node.grfn_assignment.inputs}" 

295 ) 

296 print( 

297 f" grfn_assignment.outputs: {node.grfn_assignment.outputs}" 

298 ) 

299 

300 @_visit.register 

301 def visit_module(self, node: AnnCastModule, add_to: typing.Dict): 

302 add_to = {} 

303 self.visit_node_list(node.body, add_to) 

304 

305 @_visit.register 

306 def visit_name(self, node: AnnCastName, add_to: typing.Dict): 

307 fullid = ann_cast_name_to_fullid(node) 

308 

309 # store fullid/grfn id in add_to 

310 add_to[fullid] = node.grfn_id 

311 

312 @_visit.register 

313 def visit_operator(self, node: AnnCastOperator, add_to: typing.Dict): 

314 # visit operands 

315 self.visit_node_list(node.operands, add_to) 

316 

317 @_visit.register 

318 def visit_set(self, node: AnnCastSet, add_to: typing.Dict): 

319 pass 

320 

321 @_visit.register 

322 def visit_tuple(self, node: AnnCastTuple, add_to: typing.Dict): 

323 self.visit_node_list(node.values, add_to) 

324 

325 @_visit.register 

326 def visit_var(self, node: AnnCastVar, add_to: typing.Dict): 

327 self.visit(node.val, add_to)