Coverage for skema/model_assembly/structures.py: 43%
275 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1from __future__ import annotations
2from abc import ABC, abstractmethod
3from dataclasses import dataclass
4from typing import List
5import re
7from .code_types import CodeType
8from .metadata import LambdaType, TypedMetadata, CodeSpanReference
11@dataclass(repr=False, frozen=True)
12class GenericIdentifier(ABC):
13 namespace: str
14 scope: str
16 @staticmethod
17 def from_str(data: str):
18 components = data.split("::")
19 type_str = components[0]
20 if type_str == "@container":
21 if len(components) == 3:
22 (_, ns, sc) = components
23 return ContainerIdentifier(ns, sc, "--")
24 (_, ns, sc, n) = components
25 if sc != "@global":
26 n = f"{sc}.{n}"
27 return ContainerIdentifier(ns, sc, n)
28 elif type_str == "@type":
29 (_, ns, sc, n) = components
30 return TypeIdentifier(ns, sc, n)
31 elif type_str == "@variable":
32 (_, ns, sc, n, idx) = components
33 return VariableIdentifier(ns, sc, n, int(idx))
35 def is_global_scope(self):
36 return self.scope == "@global"
38 def __repr__(self):
39 return self.__str__()
41 @abstractmethod
42 def __str__(self):
43 return NotImplemented
46@dataclass(repr=False, frozen=True)
47class ContainerIdentifier(GenericIdentifier):
48 con_name: str
50 def __str__(self):
51 return f"Con -- {self.con_name} ({self.namespace}.{self.scope})"
54@dataclass(repr=False, frozen=True)
55class TypeIdentifier(GenericIdentifier):
56 type_name: str
58 def __str__(self):
59 return f"Type -- {self.type_name} ({self.namespace}.{self.scope})"
62@dataclass(repr=False, frozen=True)
63class VariableIdentifier(GenericIdentifier):
64 var_name: str
65 index: int
67 @classmethod
68 def from_str_and_con(cls, data: str, con: ContainerIdentifier):
69 split = data.split("::")
70 name = ""
71 idx = -1
72 if len(split) == 3:
73 # Identifier is depricated <id type>::<name>::<version> style
74 (_, name, idx) = split
75 return cls(con.namespace, con.con_name, name, int(idx))
76 elif len(split) == 5:
77 # Identifier is <id type>::<module>::<scope>::<name>::<version>
78 (_, ns, sc, name, idx) = split
79 return cls(ns, sc, name, int(idx))
80 else:
81 raise ValueError(f"Unrecognized variable identifier: {data}")
83 @classmethod
84 def from_str(cls, var_id: str):
85 split = var_id.split("::")
86 # We introduced a change where we now append "::<uid>" onto variable
87 # ids to create unique variable nodes for multiple calls to the same
88 # function. We should probably only have the else case, but to be safe
89 # for now, keep both around.
90 if len(split) == 4:
91 (ns, sc, vn, ix) = split
92 else:
93 (_, ns, sc, vn, ix) = split
94 return cls(ns, sc, vn, int(ix))
96 def __str__(self):
97 return f"{self.namespace}::{self.scope}::{self.var_name}::{self.index}"
99 def __print(self):
100 var_str = f"{self.var_name}::{self.index}"
101 return f"Var -- {var_str} ({self.namespace}.{self.scope})"
104@dataclass(frozen=True)
105class GenericDefinition(ABC):
106 identifier: GenericIdentifier
107 type: str
109 @staticmethod
110 def from_dict(data: dict):
111 if "domain" in data:
112 if "dimensions" in data["domain"]:
113 type_str = "type"
114 name_str = "list"
115 else:
116 name_str = data["domain"]["name"]
117 type_str = data["domain"]["type"]
118 return VariableDefinition(
119 GenericIdentifier.from_str(data["name"]),
120 type_str,
121 data["domain"]["mutable"],
122 name_str,
123 data["domain_constraint"],
124 list(data["source_refs"]),
125 )
126 else:
127 return TypeDefinition.from_data(data)
130@dataclass(frozen=True)
131class VariableDefinition(GenericDefinition):
132 is_mutable: bool
133 domain_name: str
134 domain_constraint: str
135 metadata: List[TypedMetadata]
137 @classmethod
138 def from_identifier(cls, id: VariableIdentifier):
139 return cls(
140 id,
141 "type",
142 False,
143 "None",
144 "(and (> v -infty) (< v infty))",
145 [],
146 )
148 @classmethod
149 def from_data(cls, data: dict) -> VariableDefinition:
150 var_id = VariableIdentifier.from_str(data["name"])
151 type_str = "type"
152 file_ref = data["file_uid"] if "file_uid" in data else ""
153 src_ref = data["source_refs"][0] if "source_refs" in data else ""
154 code_span_data = {
155 "source_ref": src_ref,
156 "file_uid": file_ref,
157 "code_type": "identifier",
158 }
159 code_span_metadata = [CodeSpanReference.from_air_data(code_span_data)]
160 metadata = (
161 []
162 if "metadata" not in data
163 else [TypedMetadata.from_data(mdict) for mdict in data["metadata"]]
164 + code_span_metadata
165 )
166 return cls(
167 var_id,
168 type_str,
169 data["domain"]["mutable"],
170 data["domain"]["name"],
171 data["domain_constraint"],
172 metadata,
173 )
176@dataclass(frozen=True)
177class TypeFieldDefinition:
178 name: str
179 type: str
180 metadata: List[TypedMetadata]
182 @classmethod
183 def from_air_data(cls, data: dict, file_uid: str) -> TypeFieldDefinition:
184 code_span_data = {
185 "source_ref": data["source_ref"],
186 "file_uid": file_uid,
187 "code_type": "identifier",
188 }
189 return cls(
190 data["name"],
191 data["type"],
192 [CodeSpanReference.from_air_data(code_span_data)],
193 )
195 @classmethod
196 def from_data(cls, data: dict) -> TypeFieldDefinition:
197 return cls(
198 data["name"],
199 data["type"],
200 [TypedMetadata.from_data(d) for d in data["metadata"]]
201 if "metadata" in data
202 else [],
203 )
205 def to_dict(self) -> dict:
206 return {
207 "name": self.name,
208 "type": self.type,
209 "metadata": [d.to_dict() for d in self.metadata],
210 }
213@dataclass(frozen=True)
214class TypeDefinition(GenericDefinition):
215 name: str
216 metatype: str
217 fields: List[TypeFieldDefinition]
218 metadata: List[TypedMetadata]
220 @classmethod
221 def from_air_data(cls, data: dict) -> TypeDefinition:
222 file_ref = data["file_uid"] if "file_uid" in data else ""
223 src_ref = data["source_ref"] if "source_ref" in data else ""
224 code_span_data = {
225 "source_ref": src_ref,
226 "file_uid": file_ref,
227 "code_type": "block",
228 }
229 metadata = [CodeSpanReference.from_air_data(code_span_data)]
230 return cls(
231 "",
232 "",
233 data["name"],
234 data["metatype"],
235 [
236 TypeFieldDefinition.from_air_data(d, data["file_uid"])
237 for d in data["fields"]
238 ],
239 metadata,
240 )
242 @classmethod
243 def from_data(cls, data: dict) -> TypeDefinition:
244 metadata = [TypedMetadata.from_data(d) for d in data["metadata"]]
245 return cls(
246 data["name"],
247 "",
248 data["name"],
249 data["metatype"],
250 [TypeFieldDefinition.from_data(d) for d in data["fields"]],
251 metadata,
252 )
254 def to_dict(self) -> dict:
255 return {
256 "name": self.name,
257 "metatype": self.metatype,
258 "fields": [fdef.to_dict() for fdef in self.fields],
259 "metadata": [d.to_dict() for d in self.metadata],
260 }
263@dataclass(frozen=True)
264class ObjectDefinition(GenericDefinition):
265 pass
268class GenericContainer(ABC):
269 def __init__(self, data: dict):
270 self.identifier = GenericIdentifier.from_str(data["name"])
271 file_reference = data["file_uid"] if "file_uid" in data else ""
272 self.arguments = [
273 VariableIdentifier.from_str_and_con(var_str, self.identifier)
274 for var_str in data["arguments"]
275 ]
276 self.updated = [
277 VariableIdentifier.from_str_and_con(var_str, self.identifier)
278 for var_str in data["updated"]
279 ]
280 self.returns = [
281 VariableIdentifier.from_str_and_con(var_str, self.identifier)
282 for var_str in data["return_value"]
283 ]
284 self.statements = [
285 GenericStmt.create_statement(stmt, self, file_reference)
286 for stmt in data["body"]
287 ]
288 src_ref = data["body_source_ref"] if "body_source_ref" in data else ""
289 file_ref = data["file_uid"] if "file_uid" in data else ""
290 code_span_data = {
291 "source_ref": src_ref,
292 "file_uid": file_ref,
293 "code_type": "block",
294 }
295 self.metadata = [CodeSpanReference.from_air_data(code_span_data)]
297 # NOTE: store base name as key and update index during wiring
298 self.variables = dict()
299 self.code_type = CodeType.UNKNOWN
300 self.code_stats = {
301 "num_calls": 0,
302 "max_call_depth": 0,
303 "num_math_assgs": 0,
304 "num_data_changes": 0,
305 "num_var_access": 0,
306 "num_assgs": 0,
307 "num_switches": 0,
308 "num_loops": 0,
309 "max_loop_depth": 0,
310 "num_conditionals": 0,
311 "max_conditional_depth": 0,
312 }
314 def __repr__(self):
315 return self.__str__()
317 @abstractmethod
318 def __str__(self):
319 args_str = "\n".join([f"\t{arg}" for arg in self.arguments])
320 outputs_str = "\n".join(
321 [f"\t{var}" for var in self.returns + self.updated]
322 )
323 return f"Inputs:\n{args_str}\nVariables:\n{outputs_str}"
325 @staticmethod
326 def from_dict(data: dict):
327 if "type" not in data:
328 con_type = "function"
329 else:
330 con_type = data["type"]
331 if con_type == "function":
332 return FuncContainer(data)
333 elif con_type == "loop":
334 return LoopContainer(data)
335 elif con_type == "if-block":
336 return CondContainer(data)
337 elif con_type == "select-block":
338 return CondContainer(data)
339 else:
340 raise ValueError(f"Unrecognized container type value: {con_type}")
343class CondContainer(GenericContainer):
344 def __init__(self, data: dict):
345 super().__init__(data)
347 def __repr__(self):
348 return self.__str__()
350 def __str__(self):
351 base_str = super().__str__()
352 return f"<COND Con> -- {self.identifier.con_name}\n{base_str}\n"
355class FuncContainer(GenericContainer):
356 def __init__(self, data: dict):
357 super().__init__(data)
359 def __repr__(self):
360 return self.__str__()
362 def __str__(self):
363 base_str = super().__str__()
364 return f"<FUNC Con> -- {self.identifier.con_name}\n{base_str}\n"
367class LoopContainer(GenericContainer):
368 def __init__(self, data: dict):
369 super().__init__(data)
371 def __repr__(self):
372 return self.__str__()
374 def __str__(self):
375 base_str = super().__str__()
376 return f"<LOOP Con> -- {self.identifier.con_name}\n{base_str}\n"
379class GenericStmt(ABC):
380 def __init__(self, stmt: dict, p: GenericContainer):
381 self.container = p
382 self.inputs = [
383 VariableIdentifier.from_str_and_con(i, self.container.identifier)
384 for i in stmt["input"]
385 ]
386 self.outputs = [
387 VariableIdentifier.from_str_and_con(o, self.container.identifier)
388 for o in (stmt["output"] + stmt["updated"])
389 ]
391 def __repr__(self):
392 return self.__str__()
394 @abstractmethod
395 def __str__(self):
396 inputs_str = ", ".join(
397 [f"{id.var_name} ({id.index})" for id in self.inputs]
398 )
399 outputs_str = ", ".join(
400 [f"{id.var_name} ({id.index})" for id in self.outputs]
401 )
402 return f"Inputs: {inputs_str}\nOutputs: {outputs_str}"
404 @staticmethod
405 def create_statement(
406 stmt_data: dict, container: GenericContainer, file_ref: str
407 ):
408 func_type = stmt_data["function"]["type"]
409 if func_type == "lambda":
410 return LambdaStmt(stmt_data, container, file_ref)
411 elif func_type == "container":
412 return CallStmt(stmt_data, container, file_ref)
413 elif func_type == "operator":
414 return OperatorStmt(stmt_data, container, file_ref)
415 else:
416 raise ValueError(f"Undefined statement type: {func_type}")
418 # def correct_input_list(
419 # self, alt_inputs: Dict[VariableIdentifier, VariableNode]
420 # ) -> List[VariableNode]:
421 # return [v if v.index != -1 else alt_inputs[v] for v in self.inputs]
424class CallStmt(GenericStmt):
425 def __init__(self, stmt: dict, con: GenericContainer, file_ref: str):
426 super().__init__(stmt, con)
427 self.call_id = GenericIdentifier.from_str(stmt["function"]["name"])
428 src_ref = stmt["source_ref"] if "source_ref" in stmt else ""
429 code_span_data = {
430 "source_ref": src_ref,
431 "file_uid": file_ref,
432 "code_type": "block",
433 }
434 self.metadata = [CodeSpanReference.from_air_data(code_span_data)]
436 def __repr__(self):
437 return self.__str__()
439 def __str__(self):
440 generic_str = super().__str__()
441 return f"<CallStmt>: {self.call_id}\n{generic_str}"
444class OperatorStmt(GenericStmt):
445 def __init__(self, stmt: dict, con: GenericContainer):
446 super().__init__(stmt, con)
447 self.call_id = GenericIdentifier.from_str(stmt["function"]["name"])
449 def __repr__(self):
450 return self.__str__()
452 def __str__(self):
453 generic_str = super().__str__()
454 return f"<OperatorStmt>: {self.call_id}\n{generic_str}"
457class LambdaStmt(GenericStmt):
458 def __init__(self, stmt: dict, con: GenericContainer, file_ref: str):
459 super().__init__(stmt, con)
460 # NOTE Want to use the form below eventually
461 # type_str = stmt["function"]["lambda_type"]
463 type_str = self.type_str_from_name(stmt["function"]["name"])
465 # NOTE: we shouldn't need this since we will use UUIDs
466 # self.lambda_node_name = f"{self.parent.name}::" + self.name
467 self.type = LambdaType.get_lambda_type(type_str, len(self.inputs))
468 self.func_str = stmt["function"]["code"]
469 src_ref = stmt["source_ref"] if "source_ref" in stmt else ""
470 code_span_data = {
471 "source_ref": src_ref,
472 "file_uid": file_ref,
473 "code_type": "block",
474 }
475 self.metadata = [CodeSpanReference.from_air_data(code_span_data)]
477 def __repr__(self):
478 return self.__str__()
480 def __str__(self):
481 generic_str = super().__str__()
482 return f"<LambdaStmt>: {self.type}\n{generic_str}"
484 @staticmethod
485 def type_str_from_name(name: str) -> str:
486 if re.search(r"__assign__", name) is not None:
487 return "assign"
488 elif re.search(r"__condition__", name) is not None:
489 return "condition"
490 elif re.search(r"__decision__", name) is not None:
491 return "decision"
492 elif re.search(r"__pack__", name) is not None:
493 return "pack"
494 elif re.search(r"__extract__", name) is not None:
495 return "extract"
496 else:
497 raise ValueError(
498 f"No recognized lambda type found from name string: {name}"
499 )
502class GrFNExecutionException(Exception):
503 pass