Coverage for skema/model_assembly/structures.py: 43%

275 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1from __future__ import annotations 

2from abc import ABC, abstractmethod 

3from dataclasses import dataclass 

4from typing import List 

5import re 

6 

7from .code_types import CodeType 

8from .metadata import LambdaType, TypedMetadata, CodeSpanReference 

9 

10 

11@dataclass(repr=False, frozen=True) 

12class GenericIdentifier(ABC): 

13 namespace: str 

14 scope: str 

15 

16 @staticmethod 

17 def from_str(data: str): 

18 components = data.split("::") 

19 type_str = components[0] 

20 if type_str == "@container": 

21 if len(components) == 3: 

22 (_, ns, sc) = components 

23 return ContainerIdentifier(ns, sc, "--") 

24 (_, ns, sc, n) = components 

25 if sc != "@global": 

26 n = f"{sc}.{n}" 

27 return ContainerIdentifier(ns, sc, n) 

28 elif type_str == "@type": 

29 (_, ns, sc, n) = components 

30 return TypeIdentifier(ns, sc, n) 

31 elif type_str == "@variable": 

32 (_, ns, sc, n, idx) = components 

33 return VariableIdentifier(ns, sc, n, int(idx)) 

34 

35 def is_global_scope(self): 

36 return self.scope == "@global" 

37 

38 def __repr__(self): 

39 return self.__str__() 

40 

41 @abstractmethod 

42 def __str__(self): 

43 return NotImplemented 

44 

45 

46@dataclass(repr=False, frozen=True) 

47class ContainerIdentifier(GenericIdentifier): 

48 con_name: str 

49 

50 def __str__(self): 

51 return f"Con -- {self.con_name} ({self.namespace}.{self.scope})" 

52 

53 

54@dataclass(repr=False, frozen=True) 

55class TypeIdentifier(GenericIdentifier): 

56 type_name: str 

57 

58 def __str__(self): 

59 return f"Type -- {self.type_name} ({self.namespace}.{self.scope})" 

60 

61 

62@dataclass(repr=False, frozen=True) 

63class VariableIdentifier(GenericIdentifier): 

64 var_name: str 

65 index: int 

66 

67 @classmethod 

68 def from_str_and_con(cls, data: str, con: ContainerIdentifier): 

69 split = data.split("::") 

70 name = "" 

71 idx = -1 

72 if len(split) == 3: 

73 # Identifier is depricated <id type>::<name>::<version> style 

74 (_, name, idx) = split 

75 return cls(con.namespace, con.con_name, name, int(idx)) 

76 elif len(split) == 5: 

77 # Identifier is <id type>::<module>::<scope>::<name>::<version> 

78 (_, ns, sc, name, idx) = split 

79 return cls(ns, sc, name, int(idx)) 

80 else: 

81 raise ValueError(f"Unrecognized variable identifier: {data}") 

82 

83 @classmethod 

84 def from_str(cls, var_id: str): 

85 split = var_id.split("::") 

86 # We introduced a change where we now append "::<uid>" onto variable 

87 # ids to create unique variable nodes for multiple calls to the same 

88 # function. We should probably only have the else case, but to be safe 

89 # for now, keep both around. 

90 if len(split) == 4: 

91 (ns, sc, vn, ix) = split 

92 else: 

93 (_, ns, sc, vn, ix) = split 

94 return cls(ns, sc, vn, int(ix)) 

95 

96 def __str__(self): 

97 return f"{self.namespace}::{self.scope}::{self.var_name}::{self.index}" 

98 

99 def __print(self): 

100 var_str = f"{self.var_name}::{self.index}" 

101 return f"Var -- {var_str} ({self.namespace}.{self.scope})" 

102 

103 

104@dataclass(frozen=True) 

105class GenericDefinition(ABC): 

106 identifier: GenericIdentifier 

107 type: str 

108 

109 @staticmethod 

110 def from_dict(data: dict): 

111 if "domain" in data: 

112 if "dimensions" in data["domain"]: 

113 type_str = "type" 

114 name_str = "list" 

115 else: 

116 name_str = data["domain"]["name"] 

117 type_str = data["domain"]["type"] 

118 return VariableDefinition( 

119 GenericIdentifier.from_str(data["name"]), 

120 type_str, 

121 data["domain"]["mutable"], 

122 name_str, 

123 data["domain_constraint"], 

124 list(data["source_refs"]), 

125 ) 

126 else: 

127 return TypeDefinition.from_data(data) 

128 

129 

130@dataclass(frozen=True) 

131class VariableDefinition(GenericDefinition): 

132 is_mutable: bool 

133 domain_name: str 

134 domain_constraint: str 

135 metadata: List[TypedMetadata] 

136 

137 @classmethod 

138 def from_identifier(cls, id: VariableIdentifier): 

139 return cls( 

140 id, 

141 "type", 

142 False, 

143 "None", 

144 "(and (> v -infty) (< v infty))", 

145 [], 

146 ) 

147 

148 @classmethod 

149 def from_data(cls, data: dict) -> VariableDefinition: 

150 var_id = VariableIdentifier.from_str(data["name"]) 

151 type_str = "type" 

152 file_ref = data["file_uid"] if "file_uid" in data else "" 

153 src_ref = data["source_refs"][0] if "source_refs" in data else "" 

154 code_span_data = { 

155 "source_ref": src_ref, 

156 "file_uid": file_ref, 

157 "code_type": "identifier", 

158 } 

159 code_span_metadata = [CodeSpanReference.from_air_data(code_span_data)] 

160 metadata = ( 

161 [] 

162 if "metadata" not in data 

163 else [TypedMetadata.from_data(mdict) for mdict in data["metadata"]] 

164 + code_span_metadata 

165 ) 

166 return cls( 

167 var_id, 

168 type_str, 

169 data["domain"]["mutable"], 

170 data["domain"]["name"], 

171 data["domain_constraint"], 

172 metadata, 

173 ) 

174 

175 

176@dataclass(frozen=True) 

177class TypeFieldDefinition: 

178 name: str 

179 type: str 

180 metadata: List[TypedMetadata] 

181 

182 @classmethod 

183 def from_air_data(cls, data: dict, file_uid: str) -> TypeFieldDefinition: 

184 code_span_data = { 

185 "source_ref": data["source_ref"], 

186 "file_uid": file_uid, 

187 "code_type": "identifier", 

188 } 

189 return cls( 

190 data["name"], 

191 data["type"], 

192 [CodeSpanReference.from_air_data(code_span_data)], 

193 ) 

194 

195 @classmethod 

196 def from_data(cls, data: dict) -> TypeFieldDefinition: 

197 return cls( 

198 data["name"], 

199 data["type"], 

200 [TypedMetadata.from_data(d) for d in data["metadata"]] 

201 if "metadata" in data 

202 else [], 

203 ) 

204 

205 def to_dict(self) -> dict: 

206 return { 

207 "name": self.name, 

208 "type": self.type, 

209 "metadata": [d.to_dict() for d in self.metadata], 

210 } 

211 

212 

213@dataclass(frozen=True) 

214class TypeDefinition(GenericDefinition): 

215 name: str 

216 metatype: str 

217 fields: List[TypeFieldDefinition] 

218 metadata: List[TypedMetadata] 

219 

220 @classmethod 

221 def from_air_data(cls, data: dict) -> TypeDefinition: 

222 file_ref = data["file_uid"] if "file_uid" in data else "" 

223 src_ref = data["source_ref"] if "source_ref" in data else "" 

224 code_span_data = { 

225 "source_ref": src_ref, 

226 "file_uid": file_ref, 

227 "code_type": "block", 

228 } 

229 metadata = [CodeSpanReference.from_air_data(code_span_data)] 

230 return cls( 

231 "", 

232 "", 

233 data["name"], 

234 data["metatype"], 

235 [ 

236 TypeFieldDefinition.from_air_data(d, data["file_uid"]) 

237 for d in data["fields"] 

238 ], 

239 metadata, 

240 ) 

241 

242 @classmethod 

243 def from_data(cls, data: dict) -> TypeDefinition: 

244 metadata = [TypedMetadata.from_data(d) for d in data["metadata"]] 

245 return cls( 

246 data["name"], 

247 "", 

248 data["name"], 

249 data["metatype"], 

250 [TypeFieldDefinition.from_data(d) for d in data["fields"]], 

251 metadata, 

252 ) 

253 

254 def to_dict(self) -> dict: 

255 return { 

256 "name": self.name, 

257 "metatype": self.metatype, 

258 "fields": [fdef.to_dict() for fdef in self.fields], 

259 "metadata": [d.to_dict() for d in self.metadata], 

260 } 

261 

262 

263@dataclass(frozen=True) 

264class ObjectDefinition(GenericDefinition): 

265 pass 

266 

267 

268class GenericContainer(ABC): 

269 def __init__(self, data: dict): 

270 self.identifier = GenericIdentifier.from_str(data["name"]) 

271 file_reference = data["file_uid"] if "file_uid" in data else "" 

272 self.arguments = [ 

273 VariableIdentifier.from_str_and_con(var_str, self.identifier) 

274 for var_str in data["arguments"] 

275 ] 

276 self.updated = [ 

277 VariableIdentifier.from_str_and_con(var_str, self.identifier) 

278 for var_str in data["updated"] 

279 ] 

280 self.returns = [ 

281 VariableIdentifier.from_str_and_con(var_str, self.identifier) 

282 for var_str in data["return_value"] 

283 ] 

284 self.statements = [ 

285 GenericStmt.create_statement(stmt, self, file_reference) 

286 for stmt in data["body"] 

287 ] 

288 src_ref = data["body_source_ref"] if "body_source_ref" in data else "" 

289 file_ref = data["file_uid"] if "file_uid" in data else "" 

290 code_span_data = { 

291 "source_ref": src_ref, 

292 "file_uid": file_ref, 

293 "code_type": "block", 

294 } 

295 self.metadata = [CodeSpanReference.from_air_data(code_span_data)] 

296 

297 # NOTE: store base name as key and update index during wiring 

298 self.variables = dict() 

299 self.code_type = CodeType.UNKNOWN 

300 self.code_stats = { 

301 "num_calls": 0, 

302 "max_call_depth": 0, 

303 "num_math_assgs": 0, 

304 "num_data_changes": 0, 

305 "num_var_access": 0, 

306 "num_assgs": 0, 

307 "num_switches": 0, 

308 "num_loops": 0, 

309 "max_loop_depth": 0, 

310 "num_conditionals": 0, 

311 "max_conditional_depth": 0, 

312 } 

313 

314 def __repr__(self): 

315 return self.__str__() 

316 

317 @abstractmethod 

318 def __str__(self): 

319 args_str = "\n".join([f"\t{arg}" for arg in self.arguments]) 

320 outputs_str = "\n".join( 

321 [f"\t{var}" for var in self.returns + self.updated] 

322 ) 

323 return f"Inputs:\n{args_str}\nVariables:\n{outputs_str}" 

324 

325 @staticmethod 

326 def from_dict(data: dict): 

327 if "type" not in data: 

328 con_type = "function" 

329 else: 

330 con_type = data["type"] 

331 if con_type == "function": 

332 return FuncContainer(data) 

333 elif con_type == "loop": 

334 return LoopContainer(data) 

335 elif con_type == "if-block": 

336 return CondContainer(data) 

337 elif con_type == "select-block": 

338 return CondContainer(data) 

339 else: 

340 raise ValueError(f"Unrecognized container type value: {con_type}") 

341 

342 

343class CondContainer(GenericContainer): 

344 def __init__(self, data: dict): 

345 super().__init__(data) 

346 

347 def __repr__(self): 

348 return self.__str__() 

349 

350 def __str__(self): 

351 base_str = super().__str__() 

352 return f"<COND Con> -- {self.identifier.con_name}\n{base_str}\n" 

353 

354 

355class FuncContainer(GenericContainer): 

356 def __init__(self, data: dict): 

357 super().__init__(data) 

358 

359 def __repr__(self): 

360 return self.__str__() 

361 

362 def __str__(self): 

363 base_str = super().__str__() 

364 return f"<FUNC Con> -- {self.identifier.con_name}\n{base_str}\n" 

365 

366 

367class LoopContainer(GenericContainer): 

368 def __init__(self, data: dict): 

369 super().__init__(data) 

370 

371 def __repr__(self): 

372 return self.__str__() 

373 

374 def __str__(self): 

375 base_str = super().__str__() 

376 return f"<LOOP Con> -- {self.identifier.con_name}\n{base_str}\n" 

377 

378 

379class GenericStmt(ABC): 

380 def __init__(self, stmt: dict, p: GenericContainer): 

381 self.container = p 

382 self.inputs = [ 

383 VariableIdentifier.from_str_and_con(i, self.container.identifier) 

384 for i in stmt["input"] 

385 ] 

386 self.outputs = [ 

387 VariableIdentifier.from_str_and_con(o, self.container.identifier) 

388 for o in (stmt["output"] + stmt["updated"]) 

389 ] 

390 

391 def __repr__(self): 

392 return self.__str__() 

393 

394 @abstractmethod 

395 def __str__(self): 

396 inputs_str = ", ".join( 

397 [f"{id.var_name} ({id.index})" for id in self.inputs] 

398 ) 

399 outputs_str = ", ".join( 

400 [f"{id.var_name} ({id.index})" for id in self.outputs] 

401 ) 

402 return f"Inputs: {inputs_str}\nOutputs: {outputs_str}" 

403 

404 @staticmethod 

405 def create_statement( 

406 stmt_data: dict, container: GenericContainer, file_ref: str 

407 ): 

408 func_type = stmt_data["function"]["type"] 

409 if func_type == "lambda": 

410 return LambdaStmt(stmt_data, container, file_ref) 

411 elif func_type == "container": 

412 return CallStmt(stmt_data, container, file_ref) 

413 elif func_type == "operator": 

414 return OperatorStmt(stmt_data, container, file_ref) 

415 else: 

416 raise ValueError(f"Undefined statement type: {func_type}") 

417 

418 # def correct_input_list( 

419 # self, alt_inputs: Dict[VariableIdentifier, VariableNode] 

420 # ) -> List[VariableNode]: 

421 # return [v if v.index != -1 else alt_inputs[v] for v in self.inputs] 

422 

423 

424class CallStmt(GenericStmt): 

425 def __init__(self, stmt: dict, con: GenericContainer, file_ref: str): 

426 super().__init__(stmt, con) 

427 self.call_id = GenericIdentifier.from_str(stmt["function"]["name"]) 

428 src_ref = stmt["source_ref"] if "source_ref" in stmt else "" 

429 code_span_data = { 

430 "source_ref": src_ref, 

431 "file_uid": file_ref, 

432 "code_type": "block", 

433 } 

434 self.metadata = [CodeSpanReference.from_air_data(code_span_data)] 

435 

436 def __repr__(self): 

437 return self.__str__() 

438 

439 def __str__(self): 

440 generic_str = super().__str__() 

441 return f"<CallStmt>: {self.call_id}\n{generic_str}" 

442 

443 

444class OperatorStmt(GenericStmt): 

445 def __init__(self, stmt: dict, con: GenericContainer): 

446 super().__init__(stmt, con) 

447 self.call_id = GenericIdentifier.from_str(stmt["function"]["name"]) 

448 

449 def __repr__(self): 

450 return self.__str__() 

451 

452 def __str__(self): 

453 generic_str = super().__str__() 

454 return f"<OperatorStmt>: {self.call_id}\n{generic_str}" 

455 

456 

457class LambdaStmt(GenericStmt): 

458 def __init__(self, stmt: dict, con: GenericContainer, file_ref: str): 

459 super().__init__(stmt, con) 

460 # NOTE Want to use the form below eventually 

461 # type_str = stmt["function"]["lambda_type"] 

462 

463 type_str = self.type_str_from_name(stmt["function"]["name"]) 

464 

465 # NOTE: we shouldn't need this since we will use UUIDs 

466 # self.lambda_node_name = f"{self.parent.name}::" + self.name 

467 self.type = LambdaType.get_lambda_type(type_str, len(self.inputs)) 

468 self.func_str = stmt["function"]["code"] 

469 src_ref = stmt["source_ref"] if "source_ref" in stmt else "" 

470 code_span_data = { 

471 "source_ref": src_ref, 

472 "file_uid": file_ref, 

473 "code_type": "block", 

474 } 

475 self.metadata = [CodeSpanReference.from_air_data(code_span_data)] 

476 

477 def __repr__(self): 

478 return self.__str__() 

479 

480 def __str__(self): 

481 generic_str = super().__str__() 

482 return f"<LambdaStmt>: {self.type}\n{generic_str}" 

483 

484 @staticmethod 

485 def type_str_from_name(name: str) -> str: 

486 if re.search(r"__assign__", name) is not None: 

487 return "assign" 

488 elif re.search(r"__condition__", name) is not None: 

489 return "condition" 

490 elif re.search(r"__decision__", name) is not None: 

491 return "decision" 

492 elif re.search(r"__pack__", name) is not None: 

493 return "pack" 

494 elif re.search(r"__extract__", name) is not None: 

495 return "extract" 

496 else: 

497 raise ValueError( 

498 f"No recognized lambda type found from name string: {name}" 

499 ) 

500 

501 

502class GrFNExecutionException(Exception): 

503 pass