Coverage for skema/program_analysis/CAST2FN/ann_cast/lambda_expression_pass.py: 71%

272 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import typing 

2from functools import singledispatchmethod 

3 

4from skema.program_analysis.CAST2FN.ann_cast.ann_cast_helpers import ( 

5 ELSEBODY, 

6 IFBODY, 

7 GrfnAssignment, 

8 ann_cast_name_to_fullid, 

9 cast_op_to_str, 

10 lambda_var_from_fullid, 

11) 

12from skema.program_analysis.CAST2FN.ann_cast.annotated_cast import * 

13from skema.program_analysis.CAST2FN.model.cast import ( 

14 ScalarType, 

15 StructureType, 

16 ValueConstructor, 

17) 

18 

19 

20def lambda_for_grfn_assignment( 

21 grfn_assignment: GrfnAssignment, lambda_body: str 

22) -> str: 

23 var_names = map(lambda_var_from_fullid, grfn_assignment.inputs.keys()) 

24 

25 param_str = ", ".join(var_names) 

26 lambda_expr = f"lambda {param_str}: {lambda_body}" 

27 

28 return lambda_expr 

29 

30 

31def lambda_for_condition(condition_in: typing.Dict, lambda_body: str) -> str: 

32 var_names = map(lambda_var_from_fullid, condition_in.values()) 

33 

34 param_str = ", ".join(var_names) 

35 lambda_expr = f"lambda {param_str}: {lambda_body}" 

36 

37 return lambda_expr 

38 

39 

40def lambda_for_decision( 

41 condition_fullid: str, decision_in: typing.Dict 

42) -> str: 

43 """ 

44 Lambdas for decision nodes chooses betweeen IFBODY and ELSEBODY variables from 

45 interface_in based on condition_in 

46 

47 The lambda has for the form: 

48 lambda COND, x_if, y_if, x_else, y_else: (x_if, y_if) if COND else (x_else, y_else) 

49 """ 

50 if len(decision_in) == 0: 

51 return f"lambda: None" 

52 cond_name = lambda_var_from_fullid(condition_fullid) 

53 

54 lambda_body = "" 

55 

56 if_names = [] 

57 else_names = [] 

58 for dec in decision_in.values(): 

59 if_fullid = dec[IFBODY] 

60 if_names.append(lambda_var_from_fullid(if_fullid) + "_if") 

61 else_fullid = dec[ELSEBODY] 

62 else_names.append(lambda_var_from_fullid(else_fullid) + "_else") 

63 

64 if_names_str = ", ".join(if_names) 

65 else_names_str = ", ".join(else_names) 

66 

67 lambda_body = f"({if_names_str}) if {cond_name} else ({else_names_str})" 

68 

69 lambda_expr = ( 

70 f"lambda {cond_name}, {if_names_str}, {else_names_str}: {lambda_body}" 

71 ) 

72 

73 return lambda_expr 

74 

75 

76def lambda_for_interface(interface_in: typing.Dict) -> str: 

77 """ 

78 Lambdas for plain interface nodes are simply multi-parameter identity functions 

79 """ 

80 if len(interface_in) == 0: 

81 return "lambda: None" 

82 

83 var_names = map(lambda_var_from_fullid, interface_in.values()) 

84 param_str = ", ".join(var_names) 

85 

86 lambda_expr = f"lambda {param_str}: ({param_str})" 

87 

88 return lambda_expr 

89 

90 

91def lambda_for_loop_top_interface( 

92 top_interface_initial: typing.Dict, top_interface_updated: typing.Dict 

93) -> str: 

94 """ 

95 Lambda for loop top interface chooses between initial and updated version 

96 of variables 

97 

98 LoopTopInterfaces are special LambdaNode's which store state on whether we have executed the 

99 body of the loop at least once. 

100 The returned lambda str has the form 

101 lambda use_initial, x_init, y_init, x_update, y_update: (x_init, y_init) if use_initial else (x_update, y_update) 

102 The `use_initial` value comes from the internal state of the LoopTopInterface during execution. 

103 """ 

104 

105 init_name = lambda fullid: lambda_var_from_fullid(fullid) + "_init" 

106 init_names = map(init_name, top_interface_initial.values()) 

107 updt_name = lambda fullid: lambda_var_from_fullid(fullid) + "_update" 

108 updt_names = map(updt_name, top_interface_updated.values()) 

109 

110 # NOTE: the lengths of top_interface_initial and top_interface_updated may not be the same 

111 # in some loops, you always use the initial value of a variable because it is never modified 

112 # to model this, for those variables which have no updated version, 

113 # we add the "init" variable to the "update" variable group of the lambda expression 

114 non_updated_keys = set(top_interface_initial.keys()).difference( 

115 top_interface_updated.keys() 

116 ) 

117 non_updated_vars = {k: top_interface_initial[k] for k in non_updated_keys} 

118 

119 # use "init" var names for non updates variables 

120 non_updt_names = map(init_name, non_updated_vars.values()) 

121 # extend returned updated names to include non updated variables 

122 updt_names = list(updt_names) 

123 return_updt_names = updt_names + list(non_updt_names) 

124 

125 # now, the lengths of init group and update group should match 

126 assert len(return_updt_names) == len(top_interface_initial) 

127 

128 use_initial_str = "use_initial" 

129 init_names_str = ", ".join(init_names) 

130 updt_names_str = ", ".join(updt_names) 

131 return_updt_names_str = ", ".join(return_updt_names) 

132 

133 lambda_body = f"({init_names_str}) if {use_initial_str} else ({return_updt_names_str})" 

134 

135 lambda_expr = f"lambda {use_initial_str}, {init_names_str}, {updt_names_str}: {lambda_body}" 

136 

137 return lambda_expr 

138 

139 

140def lambda_for_loop_condition(condition_in, lambda_body): 

141 var_names = map(lambda_var_from_fullid, condition_in.values()) 

142 

143 param_str = ", ".join(var_names) 

144 lambda_expr = f"lambda {param_str}: {lambda_body}" 

145 

146 return lambda_expr 

147 

148 

149class LambdaExpressionPass: 

150 def __init__(self, pipeline_state: PipelineState): 

151 self.pipeline_state = pipeline_state 

152 self.nodes = self.pipeline_state.nodes 

153 # Any other state variables that are needed during 

154 # the pass 

155 for node in self.pipeline_state.nodes: 

156 self.visit(node) 

157 

158 def visit(self, node: AnnCastNode) -> str: 

159 """ 

160 External visit that calls the internal visit 

161 Useful for debugging/development. For example, 

162 printing the nodes that are visited 

163 """ 

164 # print current node being visited. 

165 # this can be useful for debugging 

166 # class_name = node.__class__.__name__ 

167 # print(f"\nProcessing node type {class_name}") 

168 

169 # call internal visit 

170 return self._visit(node) 

171 

172 def visit_node_list( 

173 self, node_list: typing.List[AnnCastNode] 

174 ) -> typing.List[str]: 

175 return [self.visit(node) for node in node_list] 

176 

177 @singledispatchmethod 

178 def _visit(self, node: AnnCastNode) -> str: 

179 """ 

180 Internal visit 

181 """ 

182 raise NameError(f"Unrecognized node type: {type(node)}") 

183 

184 @_visit.register 

185 def visit_assignment(self, node: AnnCastAssignment) -> str: 

186 right = self.visit(node.right) 

187 # build the lambda expression for the assignment 

188 # and store in GrfnAssignment 

189 lambda_expr = lambda_for_grfn_assignment(node.grfn_assignment, right) 

190 node.grfn_assignment.lambda_expr = lambda_expr 

191 node.expr_str = lambda_expr 

192 

193 return node.expr_str 

194 

195 @_visit.register 

196 def visit_attribute(self, node: AnnCastAttribute) -> str: 

197 return node.expr_str 

198 

199 def visit_call_grfn_2_2(self, node: AnnCastCall): 

200 # example for argument lambda expression 

201 # Call: func(x + 3, y * 2) 

202 # GrfnAssignment with index 0 corresponds to the assignment arg_0 = x + 3 

203 # the lambda for this assigment looks like 

204 # lambda x : x + 3 

205 # for the lambda body, we need to visit the Call nodes arguments 

206 for i, grfn_assignment in node.arg_assignments.items(): 

207 lambda_body = self.visit(node.arguments[i]) 

208 grfn_assignment.lambda_expr = lambda_for_grfn_assignment( 

209 grfn_assignment, lambda_body 

210 ) 

211 

212 # top interface lambda 

213 node.top_interface_lambda = lambda_for_interface(node.top_interface_in) 

214 

215 # build lamba expressions for function def copy body 

216 body_expr = self.visit_function_def_copy(node.func_def_copy) 

217 

218 # bot interface lambda 

219 node.bot_interface_lambda = lambda_for_interface(node.bot_interface_in) 

220 

221 # DEBUG printing 

222 if self.pipeline_state.PRINT_DEBUGGING_INFO: 

223 print(f"Call GrFN 2.2 {node.func.name}") 

224 print(f"\t Args Expressions:") 

225 for arg in node.arg_assignments.values(): 

226 print(f"\t\t{arg.lambda_expr}") 

227 print(f"\t Top Interface:") 

228 print(f"\t\t{node.top_interface_lambda}") 

229 print(f"FunctionDefCopy {node.func_def_copy.name.name}") 

230 print(f"\t Body Expressions:") 

231 for e in body_expr: 

232 print(f"\t\t{e}") 

233 print(f"\t Bot Interface:") 

234 print(f"\t\t{node.bot_interface_lambda}") 

235 

236 def visit_call_without_func_copy(self, node: AnnCastCall): 

237 # example for argument lambda expression 

238 # Call: func(x + 3, y * 2) 

239 # GrfnAssignment with index 0 corresponds to the assignment arg_0 = x + 3 

240 # the lambda for this assigment looks like 

241 # lambda x : x + 3 

242 # for the lambda body, we need to visit the Call nodes arguments 

243 for i, grfn_assignment in node.arg_assignments.items(): 

244 lambda_body = self.visit(node.arguments[i]) 

245 grfn_assignment.lambda_expr = lambda_for_grfn_assignment( 

246 grfn_assignment, lambda_body 

247 ) 

248 

249 # top interface lambda 

250 node.top_interface_lambda = lambda_for_interface(node.top_interface_in) 

251 

252 # bot interface lambda 

253 node.bot_interface_lambda = lambda_for_interface(node.bot_interface_in) 

254 

255 # DEBUG printing 

256 if self.pipeline_state.PRINT_DEBUGGING_INFO: 

257 print(f"Call No FuncDef{node.func.name}") 

258 print(f"\t Args Expressions:") 

259 for arg in node.arg_assignments.values(): 

260 print(f"\t\t{arg.lambda_expr}") 

261 print(f"\t Top Interface:") 

262 print(f"\t\t{node.top_interface_lambda}") 

263 print(f"\t Bot Interface:") 

264 print(f"\t\t{node.bot_interface_lambda}") 

265 

266 @_visit.register 

267 def visit_call(self, node: AnnCastCall) -> str: 

268 if node.is_grfn_2_2: 

269 self.visit_call_grfn_2_2(node) 

270 # in the case of GrFN 2.3 style Call or 

271 # if this Call does not have FunctionDef 

272 # the Call node lambda expression has the same form 

273 else: 

274 self.visit_call_without_func_copy(node) 

275 if node.has_ret_val: 

276 assert len(node.out_ret_val) == 1 

277 ret_val_fullid = list(node.out_ret_val.values())[0] 

278 node.expr_str = lambda_var_from_fullid(ret_val_fullid) 

279 

280 return node.expr_str 

281 

282 @_visit.register 

283 def visit_record_def(self, node: AnnCastRecordDef) -> str: 

284 return node.expr_str 

285 

286 def visit_function_def_copy(self, node: AnnCastFunctionDef) -> typing.List: 

287 body_expr = self.visit_node_list(node.body) 

288 return body_expr 

289 

290 @_visit.register 

291 def visit_function_def(self, node: AnnCastFunctionDef) -> str: 

292 node.top_interface_lambda = lambda_for_interface(node.top_interface_in) 

293 # NOTE: we do not visit node.func_args because those parameters are 

294 # included in the outputs of the top interface lambda 

295 body_expr = self.visit_node_list(node.body) 

296 node.bot_interface_lambda = lambda_for_interface(node.bot_interface_in) 

297 

298 # DEBUG printing 

299 if self.pipeline_state.PRINT_DEBUGGING_INFO: 

300 print(f"FunctionDef {node.name.name}") 

301 print(f"\t Top Interface:") 

302 print(f"\t\t{node.top_interface_lambda}") 

303 print(f"\t Body Expressions:") 

304 for e in body_expr: 

305 print(f"\t\t{e}") 

306 print(f"\t Bot Interface:") 

307 print(f"\t\t{node.bot_interface_lambda}") 

308 

309 return node.expr_str 

310 

311 @_visit.register 

312 def visit_goto(self, node: AnnCastGoto): 

313 # self.visit(node.expr) 

314 # self.visit(node.label, at_module_scope) 

315 return "" 

316 

317 @_visit.register 

318 def visit_label(self, node: AnnCastLabel): 

319 # self.visit(node.label, at_module_scope) 

320 return "" 

321 

322 @_visit.register 

323 def visit_literal_value(self, node: AnnCastLiteralValue) -> str: 

324 if node.value_type == "List[Any]": 

325 # val has 

326 # operator - string 

327 # size - Var node or a LiteralValue node (for number) 

328 # initial_value - dictionary holding a literal_value (or perhaps a Var) 

329 val = node.value 

330 

331 # visit size's anncast name node 

332 size_str = self.visit(val.size) 

333 init_val = self.visit(val.initial_value) 

334 op = val.operator 

335 

336 to_ret = f"[{init_val}] {op} {size_str}" 

337 # print(to_ret) # NOTE: remove when not needed 

338 node.expr_str = to_ret 

339 return node.expr_str 

340 elif node.value_type == StructureType.TUPLE: # or node.value_type == StructureType.LIST: 

341 return "" 

342 elif node.value_type == ScalarType.INTEGER: 

343 node.expr_str = str(node.value) 

344 return node.expr_str 

345 elif node.value_type == ScalarType.ABSTRACTFLOAT: 

346 node.expr_str = str(node.value) 

347 return node.expr_str 

348 elif node.value_type == ScalarType.BOOLEAN: 

349 node.expr_str = str(node.value) 

350 return node.expr_str 

351 return node.expr_str 

352 

353 @_visit.register 

354 def visit_loop(self, node: AnnCastLoop) -> str: 

355 # top interface lambda 

356 node.top_interface_lambda = lambda_for_loop_top_interface( 

357 node.top_interface_initial, node.top_interface_updated 

358 ) 

359 # init lambda 

360 if len(node.pre) > 0: 

361 loop_pre = self.visit_node_list(node.pre) 

362 

363 # condition lambda 

364 loop_expr = self.visit(node.expr) 

365 node.condition_lambda = lambda_for_loop_condition( 

366 node.condition_in, loop_expr 

367 ) 

368 

369 body_expr = self.visit_node_list(node.body) 

370 

371 node.bot_interface_lambda = lambda_for_interface(node.bot_interface_in) 

372 

373 # DEBUG printing 

374 if self.pipeline_state.PRINT_DEBUGGING_INFO: 

375 print(f"Loop ") 

376 print(f"\t Loop Top Interface:") 

377 print(f"\t\t{node.top_interface_lambda}") 

378 print(f"\t Loop Expression:") 

379 print(f"\t\t{node.condition_lambda}") 

380 print(f"\t Body Expressions:") 

381 for e in body_expr: 

382 print(f"\t\t{e}") 

383 print(f"\t Loop Bot Interface:") 

384 print(f"\t\t{node.bot_interface_lambda}") 

385 

386 return node.expr_str 

387 

388 @_visit.register 

389 def visit_model_break(self, node: AnnCastModelBreak) -> str: 

390 return node.expr_str 

391 

392 @_visit.register 

393 def visit_model_continue(self, node: AnnCastModelContinue) -> str: 

394 return node.expr_str 

395 

396 @_visit.register 

397 def visit_model_import(self, node: AnnCastModelImport) -> str: 

398 pass 

399 

400 @_visit.register 

401 def visit_model_if(self, node: AnnCastModelIf) -> str: 

402 # top interface lambda 

403 node.top_interface_lambda = lambda_for_interface(node.top_interface_in) 

404 

405 # make condition lambda 

406 expr_str = self.visit(node.expr) 

407 node.condition_lambda = lambda_for_condition( 

408 node.condition_in, expr_str 

409 ) 

410 

411 body_expr = self.visit_node_list(node.body) 

412 or_else_expr = self.visit_node_list(node.orelse) 

413 

414 # make decision lambda 

415 cond_fullid = list(node.condition_out.values())[0] 

416 node.decision_lambda = lambda_for_decision( 

417 cond_fullid, node.decision_in 

418 ) 

419 

420 # bot interface lambda 

421 node.bot_interface_lambda = lambda_for_interface(node.bot_interface_in) 

422 

423 # DEBUG printing 

424 if self.pipeline_state.PRINT_DEBUGGING_INFO: 

425 print(f"If ") 

426 print(f"\t If Top Interface:") 

427 print(f"\t\t{node.top_interface_lambda}") 

428 print(f"\t If Expression:") 

429 print(f"\t\t{node.condition_lambda}") 

430 print(f"\t Body Expressions:") 

431 for e in body_expr: 

432 print(f"\t\t{e}") 

433 print(f"\t OrElse Expressions:") 

434 for e in or_else_expr: 

435 print(f"\t\t{e}") 

436 print(f"\t If Decision Lambda:") 

437 print(f"\t\t{node.decision_lambda}") 

438 print(f"\t If Bot Interface:") 

439 print(f"\t\t{node.bot_interface_lambda}") 

440 

441 return node.expr_str 

442 

443 @_visit.register 

444 def visit_model_return(self, node: AnnCastModelReturn) -> str: 

445 val = self.visit(node.value) 

446 # build the lambda expression for the ret_val assignment 

447 # and store in GrfnAssignment 

448 lambda_expr = lambda_for_grfn_assignment(node.grfn_assignment, val) 

449 node.grfn_assignment.lambda_expr = lambda_expr 

450 node.expr_str = lambda_expr 

451 

452 return node.expr_str 

453 

454 @_visit.register 

455 def visit_module(self, node: AnnCastModule) -> str: 

456 body_expr = self.visit_node_list(node.body) 

457 

458 # DEBUG printing 

459 if self.pipeline_state.PRINT_DEBUGGING_INFO: 

460 print(f"Module") 

461 print(f"\t Body Expressions:") 

462 for e in body_expr: 

463 print(f"\t\t{e}") 

464 

465 return node.expr_str 

466 

467 @_visit.register 

468 def visit_name(self, node: AnnCastName) -> str: 

469 fullid = ann_cast_name_to_fullid(node) 

470 node.expr_str = lambda_var_from_fullid(fullid) 

471 return node.expr_str 

472 

473 @_visit.register 

474 def visit_operator(self, node: AnnCastOperator) -> str: 

475 # TODO 

476 # op = cast_op_to_str(node.op) 

477 # right = self.visit(node.right) 

478 # left = self.visit(node.left) 

479 # node.expr_str = f"({left} {op} {right})" 

480 

481 # visit operands 

482 # self.visit_node_list(node.operands, add_to) 

483 return "" 

484 

485 

486 @_visit.register 

487 def visit_set(self, node: AnnCastSet) -> str: 

488 return node.expr_str 

489 

490 @_visit.register 

491 def visit_tuple(self, node: AnnCastTuple) -> str: 

492 pieces = self.visit_node_list(node.values) 

493 node.expr_str = f"({', '.join(pieces)})" 

494 

495 return node.expr_str 

496 

497 @_visit.register 

498 def visit_var(self, node: AnnCastVar) -> str: 

499 node.expr_str = self.visit(node.val) 

500 return node.expr_str