Coverage for skema/model_assembly/networks.py: 18%

939 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1from __future__ import annotations 

2from typing import List, Dict, Iterable, Set, Any, Tuple, NoReturn, Optional 

3from abc import ABC, abstractmethod 

4from functools import singledispatch 

5from dataclasses import dataclass 

6from itertools import product 

7from copy import deepcopy 

8 

9import datetime 

10import json 

11import re 

12import os 

13 

14import networkx as nx 

15import numpy as np 

16from networkx.algorithms.simple_paths import all_simple_paths 

17 

18from .sandbox import load_lambda_function 

19from .air import AutoMATES_IR 

20from .structures import ( 

21 GenericContainer, 

22 LoopContainer, 

23 GenericStmt, 

24 CallStmt, 

25 OperatorStmt, 

26 LambdaStmt, 

27 GenericIdentifier, 

28 ContainerIdentifier, 

29 VariableIdentifier, 

30 TypeIdentifier, 

31 ObjectDefinition, 

32 VariableDefinition, 

33 TypeDefinition, 

34 GrFNExecutionException, 

35) 

36from .metadata import ( 

37 TypedMetadata, 

38 ProvenanceData, 

39 MeasurementType, 

40 LambdaType, 

41 DataType, 

42 DomainSet, 

43 DomainInterval, 

44 SuperSet, 

45 MetadataType, 

46 MetadataMethod, 

47 Domain, 

48) 

49from ..utils.misc import choose_font, uuid 

50 

51 

52FONT = choose_font() 

53 

54dodgerblue3 = "#1874CD" 

55forestgreen = "#228b22" 

56 

57 

58@dataclass(repr=False, frozen=False) 

59class GenericNode(ABC): 

60 uid: str 

61 

62 def __repr__(self): 

63 return self.__str__() 

64 

65 def __str__(self): 

66 return self.uid 

67 

68 @staticmethod 

69 def create_node_id() -> str: 

70 return str(uuid.uuid4()) 

71 

72 @abstractmethod 

73 def get_kwargs(self): 

74 return NotImplemented 

75 

76 @abstractmethod 

77 def get_label(self): 

78 return NotImplemented 

79 

80 

81@dataclass(repr=False, frozen=False) 

82class VariableNode(GenericNode): 

83 identifier: VariableIdentifier 

84 metadata: List[TypedMetadata] 

85 object_ref: str = None 

86 value: Any = None 

87 input_value: Any = None 

88 is_exit: bool = False 

89 

90 def __hash__(self): 

91 return hash(self.uid) 

92 

93 def __eq__(self, other) -> bool: 

94 return self.uid == other.uid 

95 

96 def __str__(self): 

97 return f"{str(self.identifier)}::{str(self.uid)}" 

98 

99 @classmethod 

100 def from_id(cls, idt: VariableIdentifier, data: VariableDefinition): 

101 # TODO: use domain constraint information in the future 

102 d_type = DataType.from_str(data.domain_name) 

103 m_type = MeasurementType.from_name(data.domain_name) 

104 

105 def create_domain_elements(): 

106 if MeasurementType.isa_categorical(m_type): 

107 set_type = SuperSet.from_data_type(d_type) 

108 return [ 

109 DomainSet( 

110 d_type, 

111 set_type, 

112 "lambda x: SuperSet.ismember(x, set_type)", 

113 ) 

114 ] 

115 elif MeasurementType.isa_numerical(m_type): 

116 return [ 

117 DomainInterval(-float("inf"), float("inf"), False, False) 

118 ] 

119 else: 

120 return [] 

121 

122 dom = Domain( 

123 MetadataType.DOMAIN, 

124 ProvenanceData( 

125 MetadataMethod.PROGRAM_ANALYSIS_PIPELINE, 

126 ProvenanceData.get_dt_timestamp(), 

127 ), 

128 d_type, 

129 m_type, 

130 create_domain_elements(), 

131 ) 

132 

133 metadata = [dom] + data.metadata 

134 return cls(GenericNode.create_node_id(), idt, metadata) 

135 

136 def get_fullname(self): 

137 return f"{self.name}\n({self.index})" 

138 

139 def get_name(self): 

140 return str(self.identifier) 

141 

142 def get_kwargs(self): 

143 return { 

144 "color": "crimson", 

145 "fontcolor": "white" if self.is_exit else "black", 

146 "fillcolor": "crimson" if self.is_exit else "white", 

147 "style": "filled" if self.is_exit else "", 

148 "padding": 15, 

149 "label": self.get_label(), 

150 } 

151 

152 @staticmethod 

153 def get_node_label(base_name): 

154 if "_" in base_name: 

155 if base_name.startswith("IF_") or base_name.startswith("COND_"): 

156 snake_case_tokens = [base_name] 

157 else: 

158 snake_case_tokens = base_name.split("_") 

159 else: 

160 snake_case_tokens = [base_name] 

161 

162 # If the identifier is an all uppercase acronym (like `WMA`) then we 

163 # need to capture that. But we also the case where the first two 

164 # letters are capitals but the first letter does not belong with 

165 # the second (like `AVariable`). We also need to capture the case of 

166 # an acronym followed by a capital for the next word (like `AWSVars`). 

167 camel_case_tokens = list() 

168 for token in snake_case_tokens: 

169 if token.islower() or token.isupper(): 

170 camel_case_tokens.append(token) 

171 else: 

172 # NOTE: special case rule for derivatives 

173 if re.match(r"^d[A-Z]", token) is not None: 

174 camel_case_tokens.append(token) 

175 else: 

176 camel_split = re.split( 

177 r"([A-Z]+|[A-Z]?[a-z]+)(?=[A-Z]|\b)", token 

178 ) 

179 camel_case_tokens.extend(camel_split) 

180 

181 clean_tokens = [t for t in camel_case_tokens if t != ""] 

182 label = "" 

183 cur_length = 0 

184 for token in clean_tokens: 

185 tok_len = len(token) 

186 if cur_length == 0: 

187 label += token + " " 

188 cur_length += tok_len + 1 

189 continue 

190 

191 if cur_length >= 8: 

192 label += "\n" 

193 cur_length = 0 

194 

195 if cur_length + tok_len < 8: 

196 label += token + " " 

197 cur_length += tok_len + 1 

198 else: 

199 label += token 

200 cur_length += tok_len 

201 

202 return label 

203 

204 def get_label(self): 

205 node_label = self.get_node_label(self.identifier.var_name) 

206 return node_label 

207 

208 @classmethod 

209 def from_dict(cls, data: dict): 

210 return cls( 

211 data["uid"], 

212 VariableIdentifier.from_str(data["identifier"]), 

213 [TypedMetadata.from_data(mdict) for mdict in data["metadata"]] 

214 if "metadata" in data 

215 else [], 

216 data["object_ref"] if "object_ref" in data else "", 

217 ) 

218 

219 def to_dict(self) -> dict: 

220 return { 

221 "uid": self.uid, 

222 "identifier": str(self.identifier), 

223 "object_ref": self.object_ref, 

224 "metadata": [m.to_dict() for m in self.metadata], 

225 } 

226 

227 

228@dataclass(repr=False, frozen=False) 

229class LambdaNode(GenericNode): 

230 func_type: LambdaType 

231 func_str: str 

232 function: callable 

233 metadata: List[TypedMetadata] 

234 

235 def __hash__(self): 

236 return hash(self.uid) 

237 

238 def __eq__(self, other) -> bool: 

239 return self.uid == other.uid 

240 

241 def __str__(self): 

242 return f"{self.get_label()}: {self.uid}" 

243 

244 def __call__(self, *values) -> Iterable[np.ndarray]: 

245 expected_num_args = len(self.get_signature()) 

246 input_num_args = len(values) 

247 if expected_num_args != input_num_args: 

248 raise RuntimeError( 

249 f"""Incorrect number of inputs 

250 (expected {expected_num_args} found {input_num_args}) 

251 for lambda:\n{self.func_str}""" 

252 ) 

253 

254 try: 

255 if len(values) != 0: 

256 # In vectorized execution, we would have a values list that looks like: 

257 # [ [x_1, x_2, ... x_N] [y_1, y_2, ... y_N] [z_1, z_2, ... z_N]] 

258 # where the root lists length is the # of inputs to the lambda function 

259 # (in this case 3). We want to turn this into a list of length N where 

260 # each sub list is length of inputs (3 in this case) with the corresponding 

261 # x/y/z variables. I.e. it should look like: 

262 # [ [x_1, y_1, z_1] [x_2, y_2, z_2] ... [x_N, y_N, z_N]] 

263 res = [self.function(*inputs) for inputs in zip(*values)] 

264 else: 

265 res = self.function() 

266 return self.parse_result(values, res) 

267 except Exception as e: 

268 raise GrFNExecutionException(e) 

269 

270 def parse_result(self, values, res): 

271 if ( 

272 self.func_type == LambdaType.INTERFACE 

273 or self.func_type == LambdaType.DECISION 

274 or self.func_type == LambdaType.EXTRACT 

275 ): 

276 # Interfaces and decision nodes should output a tuple of the 

277 # correct variables. However, if there is only one var in the 

278 # tuple it is outputting, python collapses this to a single 

279 # var, so handle this scenario 

280 if not isinstance(res[0], tuple): 

281 # return [[r] for r in res] 

282 return [res] 

283 res = [list(v) for v in res] 

284 return [np.array(v) for v in list(map(list, zip(*res)))] 

285 else: 

286 if isinstance(res, dict): 

287 res = {k: self.parse_result(values, v) for k, v in res.items()} 

288 elif len(values) == 0: 

289 if isinstance(res, int): 

290 res = np.full(self.np_shape, res, dtype=np.int64) 

291 elif isinstance(res, float): 

292 res = np.full(self.np_shape, res, dtype=np.float64) 

293 elif isinstance(res, bool): 

294 res = np.full(self.np_shape, res, dtype=bool) 

295 elif isinstance(res, list): 

296 res = np.array([res] * self.np_shape[0]) 

297 else: 

298 res = np.full(self.np_shape, res) 

299 

300 return [res] 

301 

302 def get_kwargs(self): 

303 return {"shape": "rectangle", "padding": 10, "label": self.get_label()} 

304 

305 def get_label(self): 

306 return self.func_type.shortname() 

307 

308 def get_signature(self): 

309 return self.function.__code__.co_varnames 

310 

311 @classmethod 

312 def from_AIR( 

313 cls, lm_id: str, lm_type: str, lm_str: str, mdata: List[TypedMetadata] 

314 ): 

315 lambda_fn = load_lambda_function(lm_str) 

316 if mdata is None: 

317 mdata = list() 

318 return cls(lm_id, lm_type, lm_str, lambda_fn, mdata) 

319 

320 @classmethod 

321 def from_dict(cls, data: dict): 

322 lambda_fn = load_lambda_function(data["lambda"]) 

323 lambda_type = LambdaType.from_str(data["type"]) 

324 if "metadata" in data: 

325 metadata = [TypedMetadata.from_data(d) for d in data["metadata"]] 

326 else: 

327 metadata = [] 

328 return cls( 

329 data["uid"], 

330 lambda_type, 

331 data["lambda"], 

332 lambda_fn, 

333 metadata, 

334 ) 

335 

336 def to_dict(self) -> dict: 

337 return { 

338 "uid": self.uid, 

339 "type": str(self.func_type), 

340 "lambda": self.func_str, 

341 "metadata": [m.to_dict() for m in self.metadata], 

342 } 

343 

344 

345@dataclass(eq=False) 

346class LoopTopInterface(LambdaNode): 

347 use_initial: bool = False 

348 

349 def parse_result(self, values, res): 

350 # The top interfaces node (LTI) should output a tuple of the 

351 # correct variables. However, if there is only one var in the 

352 # tuple it is outputting, python collapses this to a single 

353 # var, so handle this scenario 

354 if not isinstance(res[0], tuple): 

355 # return [[r] for r in res] 

356 return [res] 

357 res = [list(v) for v in res] 

358 return [np.array(v) for v in list(map(list, zip(*res)))] 

359 

360 @classmethod 

361 def from_dict(cls, data: dict): 

362 lambda_fn = load_lambda_function(data["lambda"]) 

363 lambda_type = LambdaType.from_str(data["type"]) 

364 if "metadata" in data: 

365 metadata = [TypedMetadata.from_data(d) for d in data["metadata"]] 

366 else: 

367 metadata = [] 

368 return cls( 

369 data["uid"], 

370 lambda_type, 

371 data["lambda"], 

372 lambda_fn, 

373 metadata, 

374 data["use_initial"], 

375 ) 

376 

377 def to_dict(self) -> dict: 

378 return { 

379 "uid": self.uid, 

380 "type": str(self.func_type), 

381 "lambda": self.func_str, 

382 "metadata": [m.to_dict() for m in self.metadata], 

383 "use_initial": self.use_initial, 

384 } 

385 

386 

387@dataclass(eq=False) 

388class UnpackNode(LambdaNode): 

389 """An UnpackNode is used to represent the process of 'unpacking' a 

390 sequence of variables in an assignment, for example 

391 x,y,z,w = foo(1,2,3) 

392 The return value of foo is unpacked into variables x,y,z, and w 

393 This is a new operation that the GrFN execution handles differently 

394 An UnpackNode does not contain a lambda expression 

395 """ 

396 

397 # input: A single tuple string name to unpack 

398 inputs: str = "" 

399 # output: A string holding a list of variable names 

400 output: str = "" 

401 

402 @classmethod 

403 def from_dict(cls, data: Dict): 

404 return {} 

405 

406 def to_dict(self) -> dict: 

407 return { 

408 "uid": self.uid, 

409 "type": str(self.func_type), 

410 "inputs": self.inputs, 

411 "output": self.output, 

412 "metadata": [m.to_dict() for m in self.metadata], 

413 } 

414 

415 

416@dataclass(eq=False) 

417class PackNode(LambdaNode): 

418 # input: A single tuple string name to unpack 

419 inputs: str = "" 

420 # output: A string holding a list of variable names 

421 output: str = "" 

422 

423 @classmethod 

424 def from_dict(cls, data: Dict): 

425 return {} 

426 

427 def to_dict(self) -> dict: 

428 return { 

429 "uid": self.uid, 

430 "type": str(self.func_type), 

431 "inputs": self.inputs, 

432 "output": self.output, 

433 "metadata": [m.to_dict() for m in self.metadata], 

434 } 

435 

436 

437@dataclass 

438class HyperEdge: 

439 inputs: Iterable[VariableNode] 

440 lambda_fn: LambdaNode 

441 outputs: Iterable[VariableNode] 

442 

443 def __call__(self): 

444 inputs = [ 

445 var.value 

446 if var.value is not None 

447 else var.input_value 

448 if var.input_value is not None 

449 else [None] 

450 for var in self.inputs 

451 ] 

452 result = self.lambda_fn(*inputs) 

453 # If we are in the exit decision hyper edge and in vectorized execution 

454 if ( 

455 self.lambda_fn.func_type == LambdaType.DECISION 

456 and any([o.identifier.var_name == "EXIT" for o in self.inputs]) 

457 and self.lambda_fn.np_shape != (1,) 

458 ): 

459 # Initialize seen exits to an array of False if it does not exist 

460 if not hasattr(self, "seen_exits"): 

461 self.seen_exits = np.full( 

462 self.lambda_fn.np_shape, False, dtype=np.bool 

463 ) 

464 

465 # Gather the exit conditions for this execution 

466 exit_var_values = [ 

467 o for o in self.inputs if o.identifier.var_name == "EXIT" 

468 ][0].value 

469 

470 # For each output value, update output nodes with new value that 

471 # just exited, otherwise keep existing value 

472 for res_index, out_val in enumerate(result): 

473 if self.outputs[res_index].value is None: 

474 self.outputs[res_index].value = np.full( 

475 out_val.shape, np.NaN 

476 ) 

477 # If we have seen an exit before at a given position, keep the 

478 # existing value, otherwise update. 

479 for j, _ in enumerate(self.outputs[res_index].value): 

480 if self.seen_exits[j]: 

481 self.outputs[res_index].value[j] = out_val[j] 

482 

483 # Update seen_exits with any vectorized positions that may have 

484 # exited during this execution 

485 self.seen_exits = np.copy(self.seen_exits) | exit_var_values 

486 

487 else: 

488 for i, out_val in enumerate(result): 

489 variable = self.outputs[i] 

490 if ( 

491 self.lambda_fn.func_type == LambdaType.LITERAL 

492 and variable.input_value is not None 

493 ): 

494 variable.value = variable.input_value 

495 else: 

496 variable.value = out_val 

497 

498 def __eq__(self, other) -> bool: 

499 return ( 

500 self.lambda_fn == other.lambda_fn 

501 and all([i1 == i2 for i1, i2 in zip(self.inputs, other.inputs)]) 

502 and all([o1 == o2 for o1, o2 in zip(self.outputs, other.outputs)]) 

503 ) 

504 

505 def __hash__(self): 

506 return hash( 

507 ( 

508 self.lambda_fn.uid, 

509 tuple([inp.uid for inp in self.inputs]), 

510 tuple([out.uid for out in self.outputs]), 

511 ) 

512 ) 

513 

514 @classmethod 

515 def from_dict(cls, data: dict, all_nodes: Dict[str, GenericNode]): 

516 return cls( 

517 [all_nodes[n_id] for n_id in data["inputs"]], 

518 all_nodes[data["function"]], 

519 [all_nodes[n_id] for n_id in data["outputs"]], 

520 ) 

521 

522 def to_dict(self) -> dict: 

523 return { 

524 "inputs": [n.uid for n in self.inputs], 

525 "function": self.lambda_fn.uid, 

526 "outputs": [n.uid for n in self.outputs], 

527 } 

528 

529 

530@dataclass(repr=False) 

531class GrFNSubgraph: 

532 uid: str 

533 namespace: str 

534 scope: str 

535 basename: str 

536 basename_id: int 

537 occurrence_num: int 

538 parent: str 

539 # TODO: maybe uncomment 

540 # parent: Optional[GrFNSubgraph] 

541 type: str 

542 border_color: str 

543 nodes: Iterable[GenericNode] 

544 metadata: List[TypedMetadata] 

545 

546 def __hash__(self): 

547 return hash(self.__str__()) 

548 

549 def __repr__(self): 

550 return self.__str__() 

551 

552 def __str__(self): 

553 context = f"{self.namespace}.{self.scope}" 

554 return f"{self.basename}::{self.occurrence_num} ({context})" 

555 

556 def __eq__(self, other) -> bool: 

557 return ( 

558 self.occurrence_num == other.occurrence_num 

559 and self.border_color == other.border_color 

560 and set([n.uid for n in self.nodes]) 

561 == set([n.uid for n in other.nodes]) 

562 ) 

563 

564 def __call__( 

565 self, 

566 grfn: GroundedFunctionNetwork, 

567 subgraphs_to_hyper_edges: Dict[GrFNSubgraph, List[HyperEdge]], 

568 node_to_subgraph: Dict[LambdaNode, GrFNSubgraph], 

569 all_nodes_visited: Set[VariableNode], 

570 vars_to_compute=[], 

571 ) -> List[GenericNode]: 

572 """ 

573 Handles the execution of the lambda functions within a subgraph of 

574 GrFN. We place the logic in this function versus directly in __call__ 

575 so the logic can be shared in the loop subgraph type. 

576 

577 Args: 

578 grfn (GroundedFucntioNetwork): 

579 The GrFN we are operating on. Used to find successors of nodes. 

580 subgraphs_to_hyper_edges (Dict[GrFNSubgraph, List[HyperEdge]]): 

581 A list of a subgraph to the hyper edges with nodes in the 

582 subgraph. 

583 node_to_subgraph (Dict[LambdaNode, GrFNSubgraph]): 

584 nodes to the subgraph they are contained in. 

585 all_nodes_visited (Set[VariableNode]): 

586 Holds the set of all variable nodes that have been visited 

587 vars_to_compute (List[VariableNode]): 

588 List of nodes to compute the values of. If passed in, the 

589 node_execute_queue will be set to this list. 

590 

591 Raises: 

592 Exception: Raised when we find multiple input interface nodes. 

593 

594 Returns: 

595 List[GenericNode]: The final list of nodes that we update/output 

596 to in the parent container. 

597 """ 

598 # Grab all hyper edges in this subgraph 

599 hyper_edges = subgraphs_to_hyper_edges[self] 

600 nodes_to_hyper_edge = {e.lambda_fn: e for e in hyper_edges} 

601 

602 # There should be only one lambda node of type interface with outputs 

603 # all in the same subgraph. Identify this node as the entry point of 

604 # execution within this subgraph. Will be none if no input. 

605 input_interface_hyper_edge_node = self.get_input_interface_hyper_edge( 

606 hyper_edges 

607 ) 

608 output_interface_hyper_edge_node = self.get_output_interface_node( 

609 hyper_edges 

610 ) 

611 

612 # Add nodes that must be configured via user input as they have no 

613 # input edge 

614 standalone_vars = [ 

615 n 

616 for n in self.nodes 

617 if isinstance(n, VariableNode) and grfn.in_degree(n) == 0 

618 ] 

619 all_nodes_visited.update(standalone_vars) 

620 

621 reverse_path_execution = len(vars_to_compute) > 0 

622 if not reverse_path_execution: 

623 # Find the hyper edge nodes with no input to initialize the execution 

624 # queue and var nodes with no incoming edges 

625 node_execute_queue = [ 

626 e.lambda_fn for e in hyper_edges if len(e.inputs) == 0 

627 ] 

628 node_execute_queue.extend( 

629 [s for n in standalone_vars for s in grfn.successors(n)] 

630 ) 

631 

632 if input_interface_hyper_edge_node: 

633 node_execute_queue.insert( 

634 0, input_interface_hyper_edge_node.lambda_fn 

635 ) 

636 

637 # Need to recurse to a different subgraph if no nodes to execute here 

638 if len(node_execute_queue) == 0: 

639 global_literal_nodes = [ 

640 n 

641 for n in grfn.nodes 

642 if isinstance(n, LambdaNode) 

643 and grfn.in_degree(n) == 0 

644 and n.func_type == LambdaType.LITERAL 

645 ] 

646 global_output_vars = [ 

647 n 

648 for n in grfn.nodes 

649 if isinstance(n, VariableNode) and grfn.out_degree(n) == 0 

650 ] 

651 

652 # Choose a literal node with maximum distance to the output 

653 # to begin recursing. 

654 lit_node_to_max_dist = dict() 

655 for (l_node, o_node) in product( 

656 global_literal_nodes, global_output_vars 

657 ): 

658 max_dist = max( 

659 [ 

660 len(path) 

661 for path in all_simple_paths(grfn, l_node, o_node) 

662 ] 

663 ) 

664 lit_node_to_max_dist[l_node] = max_dist 

665 lits_by_dist = sorted( 

666 list(lit_node_to_max_dist.items()), 

667 key=lambda t: t[1], 

668 reverse=True, 

669 ) 

670 (L_node, _) = lits_by_dist[0] 

671 subgraph = node_to_subgraph[L_node] 

672 subgraph_hyper_edges = subgraphs_to_hyper_edges[subgraph] 

673 subgraph_input_interface = ( 

674 subgraph.get_input_interface_hyper_edge( 

675 subgraph_hyper_edges 

676 ) 

677 ) 

678 subgraph_outputs = subgraph( 

679 grfn, 

680 subgraphs_to_hyper_edges, 

681 node_to_subgraph, 

682 all_nodes_visited, 

683 ) 

684 

685 node_execute_queue.extend( 

686 set( 

687 f_node 

688 for o_node in subgraph_outputs 

689 for f_node in grfn.successors(o_node) 

690 if f_node not in all_nodes_visited 

691 ) 

692 ) 

693 else: 

694 node_execute_queue = [ 

695 succ 

696 for var in vars_to_compute 

697 for succ in grfn.predecessors(var) 

698 # if (succ in self.nodes and succ not in all_nodes_visited) 

699 # or (var in self.nodes and succ.func_type == LambdaType.INTERFACE) 

700 ] 

701 

702 while node_execute_queue: 

703 executed = True 

704 executed_visited_variables = set() 

705 node_to_execute = node_execute_queue.pop(0) 

706 # TODO remove? 

707 if node_to_execute in all_nodes_visited: 

708 continue 

709 

710 if node_to_execute not in nodes_to_hyper_edge: 

711 # Node is not in current subgraph 

712 if node_to_execute.func_type == LambdaType.INTERFACE: 

713 subgraph = node_to_subgraph[node_to_execute] 

714 subgraph_hyper_edges = subgraphs_to_hyper_edges[subgraph] 

715 subgraph_input_interface = ( 

716 subgraph.get_input_interface_hyper_edge( 

717 subgraph_hyper_edges 

718 ) 

719 ) 

720 # Either the subgraph has no input interface or all the 

721 # inputs must be set. 

722 if subgraph_input_interface is None or all( 

723 [ 

724 n in all_nodes_visited 

725 for n in subgraph_input_interface.inputs 

726 ] 

727 ): 

728 # We need to recurse into a new subgraph as the 

729 # next node is an interface thats not in the 

730 # current subgraph 

731 

732 # subgraph execution returns the updated output nodes 

733 # so we can mark them as visited here in the parent 

734 # in order to continue execution 

735 sugraph_execution_result = subgraph( 

736 grfn, 

737 subgraphs_to_hyper_edges, 

738 node_to_subgraph, 

739 all_nodes_visited, 

740 ) 

741 executed_visited_variables.update( 

742 sugraph_execution_result 

743 ) 

744 else: 

745 node_to_execute = subgraph_input_interface.lambda_fn 

746 executed = False 

747 else: 

748 raise GrFNExecutionException( 

749 "Error: Attempting to execute non-interface node" 

750 + f" {node_to_execute} found in another subgraph." 

751 ) 

752 elif all( 

753 [ 

754 n in all_nodes_visited 

755 for n in nodes_to_hyper_edge[node_to_execute].inputs 

756 ] 

757 ): 

758 # All of the input nodes have been visited, so the input values 

759 # are initialized and we can execute. In the case of literal 

760 # nodes, inputs is empty and all() will default to True. 

761 to_execute = nodes_to_hyper_edge[node_to_execute] 

762 to_execute() 

763 executed_visited_variables.update(to_execute.outputs) 

764 else: 

765 # We still are waiting on input values to be computed, push to 

766 # the back of the queue 

767 executed = False 

768 

769 if executed: 

770 all_nodes_visited.update(executed_visited_variables) 

771 all_nodes_visited.add(node_to_execute) 

772 if not reverse_path_execution: 

773 node_execute_queue.extend( 

774 [ 

775 succ 

776 for var in executed_visited_variables 

777 for succ in grfn.successors(var) 

778 if ( 

779 succ in self.nodes 

780 and succ not in all_nodes_visited 

781 ) 

782 or ( 

783 var in self.nodes 

784 and succ.func_type == LambdaType.INTERFACE 

785 ) 

786 ] 

787 ) 

788 else: 

789 node_execute_queue.extend( 

790 [ 

791 lambda_pred 

792 for var_pred in grfn.predecessors(node_to_execute) 

793 for lambda_pred in grfn.predecessors(var_pred) 

794 if ( 

795 lambda_pred in self.nodes 

796 and lambda_pred not in all_nodes_visited 

797 ) 

798 or lambda_pred.func_type == LambdaType.INTERFACE 

799 ] 

800 ) 

801 node_execute_queue.append(node_to_execute) 

802 

803 return ( 

804 {} 

805 if not output_interface_hyper_edge_node 

806 else {n for n in output_interface_hyper_edge_node.outputs} 

807 ) 

808 

809 @classmethod 

810 def from_container( 

811 cls, con: GenericContainer, occ: int, parent_subgraph: GrFNSubgraph 

812 ): 

813 id = con.identifier 

814 

815 class_to_create = cls 

816 if isinstance(con, LoopContainer): 

817 class_to_create = GrFNLoopSubgraph 

818 

819 return class_to_create( 

820 str(uuid.uuid4()), 

821 id.namespace, 

822 id.scope, 

823 id.con_name, 

824 occ, 

825 None if parent_subgraph is None else parent_subgraph.uid, 

826 con.__class__.__name__, 

827 cls.get_border_color(con.__class__.__name__), 

828 [], 

829 con.metadata, 

830 ) 

831 

832 def get_input_interface_hyper_edge(self, hyper_edges): 

833 """ 

834 Get the interface node for input in this subgraph 

835 

836 Args: 

837 hyper_edges (List[HyperEdge]): All hyper edges with nodes in this 

838 subgraph. 

839 

840 Returns: 

841 LambdaNode: The lambda node for the input interface. None if there 

842 is no input for this subgraph. 

843 """ 

844 input_interfaces = [ 

845 e 

846 for e in hyper_edges 

847 if e.lambda_fn.func_type == LambdaType.INTERFACE 

848 and all([o in self.nodes for o in e.outputs]) 

849 ] 

850 

851 if len(input_interfaces) > 1 and self.parent: 

852 raise GrFNExecutionException( 

853 "Found multiple input interface nodes" 

854 + " in subgraph during execution." 

855 + f" Expected 1 but {len(input_interfaces)} were found." 

856 ) 

857 

858 if len(input_interfaces) == 0: 

859 return None 

860 

861 return input_interfaces[0] 

862 

863 def get_output_interface_node(self, hyper_edges): 

864 """ 

865 Get the interface node for output in this subgraph 

866 

867 Args: 

868 hyper_edges (List[HyperEdge]): All hyper edges with nodes in this 

869 subgraph. 

870 

871 Returns: 

872 LambdaNode: The lambda node for the output interface. 

873 """ 

874 output_interfaces = [ 

875 e 

876 for e in hyper_edges 

877 if e.lambda_fn.func_type == LambdaType.INTERFACE 

878 and all([o in self.nodes for o in e.inputs]) 

879 ] 

880 

881 if not self.parent: 

882 # The root subgraph has no output interface 

883 return None 

884 elif len(output_interfaces) != 1: 

885 raise GrFNExecutionException( 

886 "Found multiple output interface nodes" 

887 + " in subgraph during execution." 

888 + f" Expected 1 but {len(output_interfaces)} were found." 

889 ) 

890 return output_interfaces[0] 

891 

892 @staticmethod 

893 def get_border_color(type_str): 

894 if type_str == "CondContainer": 

895 return "orange" 

896 elif type_str == "FuncContainer": 

897 return "forestgreen" 

898 elif type_str == "LoopContainer": 

899 return "navyblue" 

900 elif type_str == "CallContainer": 

901 return "purple" 

902 elif type_str == "ModuleContainer": 

903 return "grey" 

904 else: 

905 raise TypeError(f"Unrecognized subgraph type: {type_str}") 

906 

907 @classmethod 

908 def from_dict(cls, data: dict, all_nodes: Dict[str, GenericNode]): 

909 subgraph_nodes = [all_nodes[n_id] for n_id in data["nodes"]] 

910 type_str = data["type"] 

911 

912 class_to_create = cls 

913 if type_str == "LoopContainer": 

914 class_to_create = GrFNLoopSubgraph 

915 

916 return class_to_create( 

917 data["uid"], 

918 data["namespace"], 

919 data["scope"], 

920 data["basename"], 

921 data["basename_id"], 

922 data["occurrence_num"], 

923 data["parent"], 

924 type_str, 

925 cls.get_border_color(type_str), 

926 subgraph_nodes, 

927 [TypedMetadata.from_data(d) for d in data["metadata"]] 

928 if "metadata" in data 

929 else [], 

930 ) 

931 

932 def to_dict(self): 

933 return { 

934 "uid": self.uid, 

935 "namespace": self.namespace, 

936 "scope": self.scope, 

937 "basename": self.basename, 

938 "basename_id": self.basename_id, 

939 "occurrence_num": self.occurrence_num, 

940 "parent": self.parent, 

941 "type": self.type, 

942 "border_color": self.border_color, 

943 "nodes": [n.uid for n in self.nodes], 

944 "metadata": [m.to_dict() for m in self.metadata], 

945 } 

946 

947 

948@dataclass(repr=False, eq=False) 

949class GrFNLoopSubgraph(GrFNSubgraph): 

950 def __call__( 

951 self, 

952 grfn: GroundedFunctionNetwork, 

953 subgraphs_to_hyper_edges: Dict[GrFNSubgraph, List[HyperEdge]], 

954 node_to_subgraph: Dict[LambdaNode, GrFNSubgraph], 

955 all_nodes_visited: Set[VariableNode], 

956 ): 

957 """ 

958 Handle a call statement on an object of type GrFNSubgraph 

959 

960 Args: 

961 grfn (GroundedFucntioNetwork): 

962 The GrFN we are operating on. Used to find successors of nodes. 

963 subgraphs_to_hyper_edges (Dict[GrFNSubgraph, List[HyperEdge]]): 

964 A list of a subgraph to the hyper edges with nodes in the 

965 subgraph. 

966 node_to_subgraph (Dict[LambdaNode, GrFNSubgraph]): 

967 nodes to the subgraph they are contained in. 

968 all_nodes_visited (Set[VariableNode]): 

969 Holds the set of all variable nodes that have been visited 

970 """ 

971 

972 # First, find exit node within the subgraph 

973 exit_var_nodes = [ 

974 n 

975 for n in self.nodes 

976 if isinstance(n, VariableNode) and n.identifier.var_name == "EXIT" 

977 ] 

978 if len(exit_var_nodes) != 1: 

979 raise GrFNExecutionException( 

980 "Found incorrect number of exit var nodes in" 

981 + " loop subgraph during execution." 

982 + f" Expected 1 but {len(exit_var_nodes)} were found." 

983 ) 

984 exit_var_node = exit_var_nodes[0] 

985 

986 # Find the first decision node and mark its input variables as 

987 # visited so we can execute the cyclic portion of the loop 

988 input_interface = self.get_input_interface_hyper_edge( 

989 subgraphs_to_hyper_edges[self] 

990 ) 

991 

992 output_interface = self.get_output_interface_node( 

993 subgraphs_to_hyper_edges[self] 

994 ) 

995 

996 output_decision = [ 

997 n 

998 for v in output_interface.inputs 

999 for n in grfn.predecessors(v) 

1000 if n.func_type == LambdaType.DECISION 

1001 ][0] 

1002 output_decision_edge = [ 

1003 e 

1004 for e in subgraphs_to_hyper_edges[self] 

1005 if e.lambda_fn == output_decision 

1006 ][0] 

1007 

1008 initial_decision = list( 

1009 { 

1010 n 

1011 for v in input_interface.outputs 

1012 for n in grfn.successors(v) 

1013 if n.func_type == LambdaType.DECISION 

1014 } 

1015 ) 

1016 

1017 first_decision_vars = { 

1018 v 

1019 for lm_node in initial_decision 

1020 for v in grfn.predecessors(lm_node) 

1021 if isinstance(v, VariableNode) 

1022 } 

1023 

1024 updated_decision_input_vars_map = {} 

1025 for v in first_decision_vars: 

1026 name = v.identifier.var_name 

1027 ver = v.identifier.index 

1028 if ( 

1029 name not in updated_decision_input_vars_map 

1030 or updated_decision_input_vars_map[name].identifier.index < ver 

1031 ): 

1032 updated_decision_input_vars_map[name] = v 

1033 

1034 updated_decision_input_vars = updated_decision_input_vars_map.values() 

1035 for v in updated_decision_input_vars: 

1036 if v.value is None: 

1037 v.value = [None] * grfn.np_shape[0] 

1038 

1039 var_results = set() 

1040 initial_visited_nodes = set() 

1041 prev_all_nodes_visited = all_nodes_visited 

1042 iterations = 0 

1043 # Loop until the exit value becomes true 

1044 while True: 

1045 initial_visited_nodes = all_nodes_visited.copy() 

1046 initial_visited_nodes.update(updated_decision_input_vars) 

1047 

1048 # Compute JUST the path to the exit variable so we can prevent 

1049 # computing all paths on the n+1 step 

1050 super().__call__( 

1051 grfn, 

1052 subgraphs_to_hyper_edges, 

1053 node_to_subgraph, 

1054 initial_visited_nodes, 

1055 vars_to_compute=input_interface.outputs + [exit_var_node], 

1056 ) 

1057 

1058 if ( 

1059 isinstance(exit_var_node.value, bool) and exit_var_node.value 

1060 ) or ( 

1061 isinstance(exit_var_node.value, (np.ndarray, list)) 

1062 and all(exit_var_node.value) 

1063 ): 

1064 output_decision_edge.seen_exits = np.full( 

1065 grfn.np_shape, True, dtype=np.bool 

1066 ) 

1067 output_decision_edge() 

1068 output_interface() 

1069 break 

1070 

1071 iterations += 1 

1072 initial_visited_nodes = all_nodes_visited.copy() 

1073 initial_visited_nodes.update(updated_decision_input_vars) 

1074 

1075 var_results = super().__call__( 

1076 grfn, 

1077 subgraphs_to_hyper_edges, 

1078 node_to_subgraph, 

1079 initial_visited_nodes, 

1080 ) 

1081 

1082 prev_all_nodes_visited = initial_visited_nodes 

1083 

1084 # Initialize all of the post loop output variables in case there are 

1085 # no iterations 

1086 if iterations == 0: 

1087 output_var_successors = grfn.successors(output_interface.lambda_fn) 

1088 for in_var in grfn.predecessors(input_interface.lambda_fn): 

1089 for out_var in output_var_successors: 

1090 if ( 

1091 in_var.identifier.var_name 

1092 == out_var.identifier.var_name 

1093 ): 

1094 out_var.value = in_var.value 

1095 var_results.update(output_var_successors) 

1096 all_nodes_visited.add(output_interface.lambda_fn) 

1097 

1098 all_nodes_visited.update(prev_all_nodes_visited - all_nodes_visited) 

1099 return var_results 

1100 

1101 

1102class GrFNType: 

1103 name: str 

1104 fields: List[Tuple[str, str]] 

1105 

1106 def __init__(self, name, fields): 

1107 self.name = name 

1108 self.fields = fields 

1109 

1110 def get_initial_dict(self): 

1111 d = {} 

1112 for field in self.fields: 

1113 d[field] = None 

1114 return d 

1115 

1116 

1117class GroundedFunctionNetwork(nx.DiGraph): 

1118 def __init__( 

1119 self, 

1120 uid: str, 

1121 id: ContainerIdentifier, 

1122 timestamp: str, 

1123 G: nx.DiGraph, 

1124 H: List[HyperEdge], 

1125 S: nx.DiGraph, 

1126 T: List[TypeDefinition], 

1127 M: List[TypedMetadata], 

1128 ): 

1129 super().__init__(G) 

1130 self.hyper_edges = H 

1131 self.subgraphs = S 

1132 

1133 self.uid = uid 

1134 self.timestamp = timestamp 

1135 self.namespace = id.namespace 

1136 self.scope = id.scope 

1137 self.name = id.con_name 

1138 self.label = f"{self.name} ({self.namespace}.{self.scope})" 

1139 self.metadata = M 

1140 

1141 self.variables = [n for n in self.nodes if isinstance(n, VariableNode)] 

1142 self.lambdas = [n for n in self.nodes if isinstance(n, LambdaNode)] 

1143 self.types = T 

1144 

1145 # NOTE: removing detached variables from GrFN 

1146 del_indices = list() 

1147 for idx, var_node in enumerate(self.variables): 

1148 found_var = False 

1149 for edge in self.hyper_edges: 

1150 if var_node in edge.inputs or var_node in edge.outputs: 

1151 found_var = True 

1152 break 

1153 if not found_var: 

1154 self.remove_node(var_node) 

1155 del_indices.append(idx) 

1156 

1157 for idx, del_idx in enumerate(del_indices): 

1158 del self.variables[del_idx - idx] 

1159 

1160 root_subgraphs = [s for s in self.subgraphs if not s.parent] 

1161 if len(root_subgraphs) != 1: 

1162 raise Exception( 

1163 "Error: Incorrect number of root subgraphs found in GrFN." 

1164 + f"Should be 1 and found {len(root_subgraphs)}." 

1165 ) 

1166 self.root_subgraph = root_subgraphs[0] 

1167 

1168 for lambda_node in self.lambdas: 

1169 lambda_node.v_function = np.vectorize(lambda_node.function) 

1170 

1171 # TODO update inputs/literal_vars to be required_inputs and 

1172 # configurable_inputs 

1173 

1174 # TODO decide how we detect configurable inputs for execution 

1175 # Configurable inputs are all variables assigned to a literal in the 

1176 # root level subgraph AND input args to the root level subgraph 

1177 self.inputs = [ 

1178 n 

1179 for e in self.hyper_edges 

1180 for n in e.outputs 

1181 if ( 

1182 n in self.root_subgraph.nodes 

1183 and e.lambda_fn.func_type == LambdaType.LITERAL 

1184 ) 

1185 ] 

1186 self.inputs.extend( 

1187 [ 

1188 n 

1189 for n, d in self.in_degree() 

1190 if d == 0 and isinstance(n, VariableNode) 

1191 ] 

1192 ) 

1193 self.literal_vars = list() 

1194 for var_node in self.variables: 

1195 preds = list(self.predecessors(var_node)) 

1196 if len(preds) > 0 and preds[0].func_type == LambdaType.LITERAL: 

1197 self.literal_vars.append(var_node) 

1198 self.outputs = [ 

1199 n 

1200 for n, d in self.out_degree() 

1201 if d == 0 

1202 and isinstance(n, VariableNode) 

1203 and n in self.root_subgraph.nodes 

1204 ] 

1205 

1206 self.uid2varnode = {v.uid: v for v in self.variables} 

1207 

1208 self.input_names = [var_node.identifier for var_node in self.inputs] 

1209 

1210 self.output_names = [var_node.identifier for var_node in self.outputs] 

1211 

1212 self.input_name_map = { 

1213 var_node.identifier.var_name: var_node for var_node in self.inputs 

1214 } 

1215 self.input_identifier_map = { 

1216 var_node.identifier: var_node for var_node in self.inputs 

1217 } 

1218 self.literal_identifier_map = { 

1219 var_node.identifier: var_node for var_node in self.literal_vars 

1220 } 

1221 

1222 self.output_name_map = { 

1223 var_node.identifier.var_name: var_node for var_node in self.outputs 

1224 } 

1225 self.FCG = self.to_FCG() 

1226 self.function_sets = self.build_function_sets() 

1227 

1228 def __repr__(self): 

1229 return self.__str__() 

1230 

1231 def __eq__(self, other) -> bool: 

1232 # FUTURE: for debugging and testing purposes 

1233 # it might be good to convert to loop e.g. 

1234 # for edge in self.hyper_edges: 

1235 # if edge not in other.hyper_edges: 

1236 # print(f"\nSelf HEs:") 

1237 # for e in self.hyper_edges: 

1238 # print(f"{e}") 

1239 # print(f"Input uids: {[v.uid for v in e.inputs]}") 

1240 # print(f"Output uids: {[v.uid for v in e.outputs]}") 

1241 # print(f"\n\nOther HEs:") 

1242 # for e in other.hyper_edges: 

1243 # print(f"{e}") 

1244 # print(f"Input uids: {[v.uid for v in e.inputs]}") 

1245 # print(f"Output uids: {[v.uid for v in e.outputs]}") 

1246 # return False 

1247 # 

1248 # return True 

1249 

1250 return ( 

1251 self.hyper_edges == other.hyper_edges 

1252 and list(self.subgraphs) == list(other.subgraphs) 

1253 and self.inputs == other.inputs 

1254 and self.outputs == other.outputs 

1255 ) 

1256 

1257 def __str__(self): 

1258 L_sz = str(len(self.lambdas)) 

1259 V_sz = str(len(self.variables)) 

1260 I_sz = str(len(self.inputs)) 

1261 O_sz = str(len(self.outputs)) 

1262 size_str = f"< |L|: {L_sz}, |V|: {V_sz}, |I|: {I_sz}, |O|: {O_sz} >" 

1263 return f"{self.label}\n{size_str}" 

1264 

1265 def __call__( 

1266 self, 

1267 inputs: Dict[str, Any], 

1268 literals: Dict[str, Any] = None, 

1269 desired_outputs: List[str] = None, 

1270 ) -> Iterable[Any]: 

1271 """Executes the GrFN over a particular set of inputs and returns the 

1272 result. 

1273 

1274 Args: 

1275 inputs: Input set where keys are the identifier strings of input 

1276 nodes in the GrFN and each key points to a set of input values 

1277 (or just one) 

1278 literals: Input set where keys are the identifier strings of 

1279 variable nodes in the GrFN that inherit directly from a literal 

1280 node and each key points to a set of input values (or just one) 

1281 desired_outputs: A list of variable names to customize the 

1282 desired variable nodes whose values we should output after 

1283 execution. Will find the max version var node in the root container 

1284 for each name given and return their values. 

1285 

1286 Returns: 

1287 A set of outputs from executing the GrFN, one for every set of 

1288 inputs. 

1289 """ 

1290 self.np_shape = (1,) 

1291 # TODO: update this function to work with new GrFN object 

1292 full_inputs = { 

1293 self.input_identifier_map[VariableIdentifier.from_str(n)]: v 

1294 for n, v in inputs.items() 

1295 } 

1296 

1297 # Check if vectorized input is given and configure the numpy shape 

1298 for input_node in [n for n in self.inputs if n in full_inputs]: 

1299 value = full_inputs[input_node] 

1300 if isinstance(value, np.ndarray): 

1301 if self.np_shape != value.shape and self.np_shape != (1,): 

1302 raise GrFNExecutionException( 

1303 f"Error: Given two vectorized inputs with different shapes: '{value.shape}' and '{self.np_shape}'" 

1304 ) 

1305 self.np_shape = value.shape 

1306 

1307 # Set the values of input var nodes given in the inputs dict 

1308 for input_node in [n for n in self.inputs if n in full_inputs]: 

1309 value = full_inputs[input_node] 

1310 # TODO: need to find a way to incorporate a 32/64 bit check here 

1311 if isinstance(value, (float, np.float64)): 

1312 value = np.full(self.np_shape, value, dtype=np.float64) 

1313 if isinstance(value, (int, np.int64)): 

1314 value = np.full(self.np_shape, value, dtype=np.int64) 

1315 elif isinstance(value, (dict, list)): 

1316 value = np.array([value] * self.np_shape[0]) 

1317 

1318 input_node.input_value = value 

1319 

1320 if literals is not None: 

1321 literal_ids = set( 

1322 [ 

1323 VariableIdentifier.from_str(var_id) 

1324 for var_id in literals.keys() 

1325 ] 

1326 ) 

1327 lit_id2val = { 

1328 lit_id: literals[str(lit_id)] for lit_id in literal_ids 

1329 } 

1330 literal_overrides = [ 

1331 (var_node, lit_id2val[identifier]) 

1332 for identifier, var_node in self.literal_identifier_map.items() 

1333 if identifier in literal_ids 

1334 ] 

1335 for input_node, value in literal_overrides: 

1336 # TODO: need to find a way to incorporate a 32/64 bit check here 

1337 if isinstance(value, float): 

1338 value = np.array([value], dtype=np.float64) 

1339 if isinstance(value, int): 

1340 value = np.array([value], dtype=np.int64) 

1341 elif isinstance(value, list): 

1342 value = np.array(value) 

1343 self.np_shape = value.shape 

1344 elif isinstance(value, np.ndarray): 

1345 self.np_shape = value.shape 

1346 

1347 input_node.input_value = value 

1348 

1349 # Configure the np array shape for all lambda nodes 

1350 for n in self.lambdas: 

1351 n.np_shape = self.np_shape 

1352 

1353 subgraph_to_hyper_edges = { 

1354 s: [h for h in self.hyper_edges if h.lambda_fn in s.nodes] 

1355 for s in self.subgraphs 

1356 } 

1357 node_to_subgraph = {n: s for s in self.subgraphs for n in s.nodes} 

1358 self.root_subgraph( 

1359 self, subgraph_to_hyper_edges, node_to_subgraph, set() 

1360 ) 

1361 # Return the output 

1362 if desired_outputs is not None and len(desired_outputs) > 0: 

1363 root_var_nodes = [ 

1364 n 

1365 for n, _ in self.out_degree() 

1366 if isinstance(n, VariableNode) 

1367 and n in self.root_subgraph.nodes 

1368 ] 

1369 

1370 desired_output_values = {} 

1371 for n in root_var_nodes: 

1372 n_name = n.identifier.var_name 

1373 if n_name in set(desired_outputs) and ( 

1374 n_name not in desired_output_values 

1375 or desired_output_values[n_name].identifier.index 

1376 < n.identifier.index 

1377 ): 

1378 desired_output_values[n_name] = n 

1379 return { 

1380 k: np.array(v.value) for k, v in desired_output_values.items() 

1381 } 

1382 

1383 return { 

1384 output.identifier.var_name: output.value for output in self.outputs 

1385 } 

1386 

1387 @classmethod 

1388 def from_AIR(cls, air: AutoMATES_IR): 

1389 network = nx.DiGraph() 

1390 hyper_edges = list() 

1391 Occs = dict() 

1392 subgraphs = nx.DiGraph() 

1393 

1394 def add_variable_node( 

1395 v_id: VariableIdentifier, v_data: VariableDefinition 

1396 ) -> VariableNode: 

1397 node = VariableNode.from_id(v_id, v_data) 

1398 network.add_node(node, **(node.get_kwargs())) 

1399 return node 

1400 

1401 variable_nodes = dict() 

1402 

1403 def add_lambda_node( 

1404 lambda_type: LambdaType, 

1405 lambda_str: str, 

1406 metadata: List[TypedMetadata] = None, 

1407 ) -> LambdaNode: 

1408 lambda_id = GenericNode.create_node_id() 

1409 node = LambdaNode.from_AIR( 

1410 lambda_id, lambda_type, lambda_str, metadata 

1411 ) 

1412 network.add_node(node, **(node.get_kwargs())) 

1413 return node 

1414 

1415 def add_hyper_edge( 

1416 inputs: Iterable[VariableNode], 

1417 lambda_node: LambdaNode, 

1418 outputs: Iterable[VariableNode], 

1419 ) -> None: 

1420 network.add_edges_from( 

1421 [(in_node, lambda_node) for in_node in inputs] 

1422 ) 

1423 network.add_edges_from( 

1424 [(lambda_node, out_node) for out_node in outputs] 

1425 ) 

1426 hyper_edges.append(HyperEdge(inputs, lambda_node, outputs)) 

1427 

1428 def translate_container( 

1429 con: GenericContainer, 

1430 inputs: Iterable[VariableNode], 

1431 parent: GrFNSubgraph = None, 

1432 ) -> Iterable[VariableNode]: 

1433 con_name = con.identifier 

1434 if con_name not in Occs: 

1435 Occs[con_name] = 0 

1436 else: 

1437 Occs[con_name] += 1 

1438 

1439 for k, v in air.variables.items(): 

1440 if con.identifier.con_name == k.scope: 

1441 variable_nodes[k] = add_variable_node(k, v) 

1442 

1443 con_subgraph = GrFNSubgraph.from_container( 

1444 con, Occs[con_name], parent 

1445 ) 

1446 live_variables = dict() 

1447 if len(inputs) > 0: 

1448 in_var_names = [n.identifier.var_name for n in inputs] 

1449 

1450 seen_input_vars = {} 

1451 in_var_name_with_occurence = [] 

1452 for v in in_var_names: 

1453 if v not in seen_input_vars: 

1454 seen_input_vars[v] = 0 

1455 in_var_name_with_occurence.append( 

1456 f"{v}_{seen_input_vars[v]}" 

1457 ) 

1458 seen_input_vars[v] += 1 

1459 in_var_str = ",".join(in_var_name_with_occurence) 

1460 

1461 interface_func_str = f"lambda {in_var_str}:({in_var_str})" 

1462 func = add_lambda_node( 

1463 LambdaType.INTERFACE, interface_func_str 

1464 ) 

1465 out_nodes = [variable_nodes[v_id] for v_id in con.arguments] 

1466 add_hyper_edge(inputs, func, out_nodes) 

1467 con_subgraph.nodes.append(func) 

1468 

1469 live_variables.update( 

1470 {id: node for id, node in zip(con.arguments, out_nodes)} 

1471 ) 

1472 else: 

1473 live_variables.update( 

1474 {v_id: variable_nodes[v_id] for v_id in con.arguments} 

1475 ) 

1476 

1477 con_subgraph.nodes.extend(list(live_variables.values())) 

1478 

1479 revisit = [] 

1480 for stmt in con.statements: 

1481 if translate_stmt(stmt, live_variables, con_subgraph): 

1482 revisit.append(stmt) 

1483 for stmt in revisit: 

1484 translate_stmt(stmt, live_variables, con_subgraph) 

1485 

1486 subgraphs.add_node(con_subgraph) 

1487 

1488 if parent is not None: 

1489 subgraphs.add_edge(parent, con_subgraph) 

1490 

1491 # If this container identifier is not the root con_id passed into from_AIR 

1492 if con.identifier != air.entrypoint: 

1493 # Do this only if this is not the starting container 

1494 returned_vars = [variable_nodes[v_id] for v_id in con.returns] 

1495 update_vars = [variable_nodes[v_id] for v_id in con.updated] 

1496 output_vars = returned_vars + update_vars 

1497 

1498 out_var_names = [n.identifier.var_name for n in output_vars] 

1499 out_var_str = ",".join(out_var_names) 

1500 interface_func_str = f"lambda {out_var_str}:({out_var_str})" 

1501 func = add_lambda_node( 

1502 LambdaType.INTERFACE, interface_func_str 

1503 ) 

1504 con_subgraph.nodes.append(func) 

1505 return (output_vars, func) 

1506 

1507 @singledispatch 

1508 def translate_stmt( 

1509 stmt: GenericStmt, 

1510 live_variables: Dict[VariableIdentifier, VariableNode], 

1511 parent: GrFNSubgraph, 

1512 ) -> None: 

1513 raise ValueError(f"Unsupported statement type: {type(stmt)}") 

1514 

1515 @translate_stmt.register 

1516 def _( 

1517 stmt: OperatorStmt, 

1518 live_variables: Dict[VariableIdentifier, VariableNode], 

1519 subgraph: GrFNSubgraph, 

1520 ) -> None: 

1521 # TODO lambda 

1522 func = add_lambda_node(LambdaType.OPERATOR, "lambda : None") 

1523 subgraph.nodes.append(func) 

1524 

1525 if stmt.call_id not in Occs: 

1526 Occs[stmt.call_id] = 0 

1527 inputs = [live_variables[id] for id in stmt.inputs] 

1528 Occs[stmt.call_id] += 1 

1529 

1530 out_nodes = [add_variable_node(var) for var in stmt.outputs] 

1531 subgraph.nodes.extend(out_nodes) 

1532 add_hyper_edge(inputs, func, out_nodes) 

1533 for output_node in out_nodes: 

1534 var_id = output_node.identifier 

1535 live_variables[var_id] = output_node 

1536 

1537 @translate_stmt.register 

1538 def _( 

1539 stmt: CallStmt, 

1540 live_variables: Dict[VariableIdentifier, VariableNode], 

1541 subgraph: GrFNSubgraph, 

1542 ) -> None: 

1543 new_con = air.containers[stmt.call_id] 

1544 if stmt.call_id not in Occs: 

1545 Occs[stmt.call_id] = 0 

1546 

1547 inputs = [variable_nodes[v_id] for v_id in stmt.inputs] 

1548 (con_outputs, interface_func) = translate_container( 

1549 new_con, inputs, subgraph 

1550 ) 

1551 Occs[stmt.call_id] += 1 

1552 out_nodes = [variable_nodes[v_id] for v_id in stmt.outputs] 

1553 subgraph.nodes.extend(out_nodes) 

1554 add_hyper_edge(con_outputs, interface_func, out_nodes) 

1555 for output_node in out_nodes: 

1556 var_id = output_node.identifier 

1557 live_variables[var_id] = output_node 

1558 

1559 def add_live_variables( 

1560 live_variables: Dict[VariableIdentifier, VariableNode], 

1561 subgraph: GrFNSubgraph, 

1562 vars_to_add: List, 

1563 ): 

1564 for output_node in [variable_nodes[v_id] for v_id in vars_to_add]: 

1565 var_id = output_node.identifier 

1566 live_variables[var_id] = output_node 

1567 

1568 @translate_stmt.register 

1569 def _( 

1570 stmt: LambdaStmt, 

1571 live_variables: Dict[VariableIdentifier, VariableNode], 

1572 subgraph: GrFNSubgraph, 

1573 ) -> None: 

1574 # The var inputs into this decision node defined inside the loop 

1575 # may not be defined yet, so guard against that 

1576 if ( 

1577 subgraph.type == "LoopContainer" 

1578 and stmt.type == LambdaType.DECISION 

1579 ): 

1580 if stmt not in Occs: 

1581 # We will add the live variables if this is the first pass on 

1582 # a decision node that needs two passes OR if it is the only pass 

1583 add_live_variables(live_variables, subgraph, stmt.outputs) 

1584 

1585 if not all([id in live_variables for id in stmt.inputs]): 

1586 if stmt in Occs: 

1587 # We have already visited this node and all of the inputs 

1588 # are still not found. 

1589 # TODO custom exception 

1590 raise Exception( 

1591 f"Unable to find inputs required for loop decision node {stmt}" 

1592 ) 

1593 Occs[stmt] = stmt 

1594 return True 

1595 elif stmt in Occs: 

1596 del Occs[stmt] 

1597 else: 

1598 add_live_variables(live_variables, subgraph, stmt.outputs) 

1599 

1600 inputs = [variable_nodes[v_id] for v_id in stmt.inputs] 

1601 out_nodes = [variable_nodes[v_id] for v_id in stmt.outputs] 

1602 func = add_lambda_node(stmt.type, stmt.func_str, stmt.metadata) 

1603 

1604 subgraph.nodes.extend(out_nodes) 

1605 subgraph.nodes.append(func) 

1606 

1607 inputs = [live_variables[id] for id in stmt.inputs] 

1608 add_hyper_edge(inputs, func, out_nodes) 

1609 

1610 start_container = air.containers[air.entrypoint] 

1611 Occs[air.entrypoint] = 0 

1612 translate_container(start_container, []) 

1613 grfn_uid = str(uuid.uuid4()) 

1614 date_created = datetime.datetime.now().strftime("%Y-%m-%d") 

1615 return cls( 

1616 grfn_uid, 

1617 air.entrypoint, 

1618 date_created, 

1619 network, 

1620 hyper_edges, 

1621 subgraphs, 

1622 air.type_definitions, 

1623 air.metadata, 

1624 ) 

1625 

1626 def to_FCG(self): 

1627 G = nx.DiGraph() 

1628 func_to_func_edges = [] 

1629 for node in self.nodes: 

1630 if isinstance(node, LambdaNode): 

1631 preds = list(self.predecessors(node)) 

1632 # DEBUGGING 

1633 # print(f"node {node} has predecessors {preds}") 

1634 for var_node in self.predecessors(node): 

1635 preds = list(self.predecessors(var_node)) 

1636 # DEBUGGING 

1637 # print(f"node {var_node} has predecessors {preds}") 

1638 for func_node in self.predecessors(var_node): 

1639 func_to_func_edges.append((func_node, node)) 

1640 

1641 G.add_edges_from(func_to_func_edges) 

1642 return G 

1643 

1644 def build_function_sets(self): 

1645 subgraphs_to_func_sets = {s.uid: list() for s in self.subgraphs} 

1646 

1647 initial_funcs = [n for n, d in self.FCG.in_degree() if d == 0] 

1648 func2container = {f: s.uid for s in self.subgraphs for f in s.nodes} 

1649 initial_funcs_to_subgraph = { 

1650 n: func2container[n] for n in initial_funcs 

1651 } 

1652 containers_to_initial_funcs = {s.uid: list() for s in self.subgraphs} 

1653 for k, v in initial_funcs_to_subgraph.items(): 

1654 containers_to_initial_funcs[v].append(k) 

1655 

1656 def build_function_set_for_container( 

1657 container, container_initial_funcs 

1658 ): 

1659 all_successors = list() 

1660 distances = dict() 

1661 visited_funcs = set() 

1662 

1663 def find_distances(func, dist, path=[]): 

1664 if func.func_type == LambdaType.OPERATOR: 

1665 return 

1666 new_successors = list() 

1667 func_container = func2container[func] 

1668 if func_container == container: 

1669 distances[func] = ( 

1670 max(dist, distances[func]) 

1671 if func in distances and func not in path 

1672 else dist 

1673 ) 

1674 if func not in visited_funcs: 

1675 visited_funcs.add(func) 

1676 # add new successors if func is in FCG 

1677 if func in self.FCG: 

1678 new_successors.extend(self.FCG.successors(func)) 

1679 

1680 if len(new_successors) > 0: 

1681 all_successors.extend(new_successors) 

1682 for f in new_successors: 

1683 find_distances(f, dist + 1, path=(path + [func])) 

1684 

1685 for f in container_initial_funcs: 

1686 find_distances(f, 0) 

1687 

1688 call_sets = dict() 

1689 

1690 for func_node, call_dist in distances.items(): 

1691 if call_dist in call_sets: 

1692 call_sets[call_dist].add(func_node) 

1693 else: 

1694 call_sets[call_dist] = {func_node} 

1695 

1696 function_set_dists = sorted( 

1697 call_sets.items(), key=lambda t: (t[0], len(t[1])) 

1698 ) 

1699 

1700 subgraphs_to_func_sets[container] = [ 

1701 func_set for _, func_set in function_set_dists 

1702 ] 

1703 

1704 for container in self.subgraphs: 

1705 input_interface_funcs = [ 

1706 n 

1707 for n in container.nodes 

1708 if isinstance(n, LambdaNode) 

1709 and n.func_type == LambdaType.INTERFACE 

1710 and all( 

1711 [ 

1712 var_node in container.nodes 

1713 for var_node in self.successors(n) 

1714 ] 

1715 ) 

1716 ] 

1717 build_function_set_for_container( 

1718 container.uid, 

1719 input_interface_funcs 

1720 + containers_to_initial_funcs[container.uid], 

1721 ) 

1722 

1723 return subgraphs_to_func_sets 

1724 

1725 def to_AGraph(self): 

1726 """Export to a PyGraphviz AGraph object.""" 

1727 

1728 from pygraphviz import AGraph 

1729 var_nodes = [n for n in self.nodes if isinstance(n, VariableNode)] 

1730 input_nodes = [] 

1731 for v in var_nodes: 

1732 if self.in_degree(v) == 0 or ( 

1733 len(list(self.predecessors(v))) == 1 

1734 and list(self.predecessors(v))[0].func_type 

1735 == LambdaType.LITERAL 

1736 ): 

1737 input_nodes.append(v) 

1738 output_nodes = set([v for v in var_nodes if self.out_degree(v) == 0]) 

1739 

1740 A = nx.nx_agraph.to_agraph(self) 

1741 A.graph_attr.update( 

1742 {"dpi": 227, "fontsize": 20, "fontname": "Menlo", "rankdir": "TB"} 

1743 ) 

1744 A.node_attr.update({"fontname": "Menlo"}) 

1745 

1746 # A.add_subgraph(input_nodes, rank="same") 

1747 # A.add_subgraph(output_nodes, rank="same") 

1748 

1749 def get_subgraph_nodes(subgraph: GrFNSubgraph): 

1750 return subgraph.nodes + [ 

1751 node 

1752 for child_graph in self.subgraphs.successors(subgraph) 

1753 for node in get_subgraph_nodes(child_graph) 

1754 ] 

1755 

1756 def populate_subgraph(subgraph: GrFNSubgraph, parent: AGraph): 

1757 all_sub_nodes = get_subgraph_nodes(subgraph) 

1758 container_subgraph = parent.add_subgraph( 

1759 all_sub_nodes, 

1760 name=f"cluster_{str(subgraph)}", 

1761 label=subgraph.basename, 

1762 style="bold, rounded", 

1763 rankdir="TB", 

1764 color=subgraph.border_color, 

1765 ) 

1766 

1767 input_var_nodes = set(input_nodes).intersection(subgraph.nodes) 

1768 # container_subgraph.add_subgraph(list(input_var_nodes), rank="same") 

1769 container_subgraph.add_subgraph( 

1770 [v.uid for v in input_var_nodes], rank="same" 

1771 ) 

1772 

1773 for new_subgraph in self.subgraphs.successors(subgraph): 

1774 populate_subgraph(new_subgraph, container_subgraph) 

1775 

1776 for _, func_sets in self.function_sets.items(): 

1777 for func_set in func_sets: 

1778 func_set = list(func_set.intersection(set(subgraph.nodes))) 

1779 

1780 container_subgraph.add_subgraph( 

1781 func_set, 

1782 ) 

1783 output_var_nodes = list() 

1784 for func_node in func_set: 

1785 succs = list(self.successors(func_node)) 

1786 output_var_nodes.extend(succs) 

1787 output_var_nodes = set(output_var_nodes) - output_nodes 

1788 var_nodes = output_var_nodes.intersection(subgraph.nodes) 

1789 

1790 container_subgraph.add_subgraph( 

1791 [v.uid for v in var_nodes], 

1792 ) 

1793 

1794 root_subgraph = [n for n, d in self.subgraphs.in_degree() if d == 0][0] 

1795 populate_subgraph(root_subgraph, A) 

1796 

1797 # TODO this code helps with the layout of the graph. However, it assumes 

1798 # all var nodes start at -1 and are consecutive. This is currently not 

1799 # the case, so it creates random hanging var nodes if run. Fix this. 

1800 

1801 # unique_var_names = { 

1802 # "::".join(n.name.split("::")[:-1]) 

1803 # for n in A.nodes() 

1804 # if len(n.name.split("::")) > 2 

1805 # } 

1806 # for name in unique_var_names: 

1807 # max_var_version = max( 

1808 # [ 

1809 # int(n.name.split("::")[-1]) 

1810 # for n in A.nodes() 

1811 # if n.name.startswith(name) 

1812 # ] 

1813 # ) 

1814 # min_var_version = min( 

1815 # [ 

1816 # int(n.name.split("::")[-1]) 

1817 # for n in A.nodes() 

1818 # if n.name.startswith(name) 

1819 # ] 

1820 # ) 

1821 # for i in range(min_var_version, max_var_version): 

1822 # e = A.add_edge(f"{name}::{i}", f"{name}::{i + 1}") 

1823 # e = A.get_edge(f"{name}::{i}", f"{name}::{i + 1}") 

1824 # e.attr["color"] = "invis" 

1825 

1826 # for agraph_node in [ 

1827 # a for (a, b) in product(A.nodes(), self.output_names) if a.name == str(b) 

1828 # ]: 

1829 # agraph_node.attr["rank"] = "max" 

1830 

1831 return A 

1832 

1833 def to_FIB(self, G2): 

1834 """Creates a ForwardInfluenceBlanket object representing the 

1835 intersection of this model with the other input model. 

1836 

1837 Args: 

1838 G1: The GrFN model to use as the basis for this FIB 

1839 G2: The GroundedFunctionNetwork object to compare this model to. 

1840 

1841 Returns: 

1842 A ForwardInfluenceBlanket object to use for model comparison. 

1843 """ 

1844 # TODO: Finish inpsection and testing of this function 

1845 

1846 if not isinstance(G2, GroundedFunctionNetwork): 

1847 raise TypeError(f"Expected a second GrFN but got: {type(G2)}") 

1848 

1849 def shortname(var): 

1850 return var[var.find("::") + 2 : var.rfind("_")] 

1851 

1852 def shortname_vars(graph, shortname): 

1853 return [v for v in graph.nodes() if shortname in v] 

1854 

1855 g1_var_nodes = { 

1856 shortname(n) 

1857 for (n, d) in self.nodes(data=True) 

1858 if d["type"] == "variable" 

1859 } 

1860 g2_var_nodes = { 

1861 shortname(n) 

1862 for (n, d) in G2.nodes(data=True) 

1863 if d["type"] == "variable" 

1864 } 

1865 

1866 shared_nodes = { 

1867 full_var 

1868 for shared_var in g1_var_nodes.intersection(g2_var_nodes) 

1869 for full_var in shortname_vars(self, shared_var) 

1870 } 

1871 

1872 outputs = self.outputs 

1873 inputs = set(self.inputs).intersection(shared_nodes) 

1874 

1875 # Get all paths from shared inputs to shared outputs 

1876 path_inputs = shared_nodes - set(outputs) 

1877 io_pairs = [(inp, self.output_node) for inp in path_inputs] 

1878 paths = [ 

1879 p for (i, o) in io_pairs for p in all_simple_paths(self, i, o) 

1880 ] 

1881 

1882 # Get all edges needed to blanket the included nodes 

1883 main_nodes = {node for path in paths for node in path} 

1884 main_edges = { 

1885 (n1, n2) for path in paths for n1, n2 in zip(path, path[1:]) 

1886 } 

1887 blanket_nodes = set() 

1888 add_nodes, add_edges = list(), list() 

1889 

1890 def place_var_node(var_node): 

1891 prev_funcs = list(self.predecessors(var_node)) 

1892 if ( 

1893 len(prev_funcs) > 0 

1894 and self.nodes[prev_funcs[0]]["label"] == "L" 

1895 ): 

1896 prev_func = prev_funcs[0] 

1897 add_nodes.extend([var_node, prev_func]) 

1898 add_edges.append((prev_func, var_node)) 

1899 else: 

1900 blanket_nodes.add(var_node) 

1901 

1902 for node in main_nodes: 

1903 if self.nodes[node]["type"] == "function": 

1904 for var_node in self.predecessors(node): 

1905 if var_node not in main_nodes: 

1906 add_edges.append((var_node, node)) 

1907 if "::IF_" in var_node: 

1908 if_func = list(self.predecessors(var_node))[0] 

1909 add_nodes.extend([if_func, var_node]) 

1910 add_edges.append((if_func, var_node)) 

1911 for new_var_node in self.predecessors(if_func): 

1912 add_edges.append((new_var_node, if_func)) 

1913 place_var_node(new_var_node) 

1914 else: 

1915 place_var_node(var_node) 

1916 

1917 main_nodes |= set(add_nodes) 

1918 main_edges |= set(add_edges) 

1919 main_nodes = main_nodes - inputs - set(outputs) 

1920 

1921 orig_nodes = self.nodes(data=True) 

1922 

1923 F = nx.DiGraph() 

1924 

1925 F.add_nodes_from([(n, d) for n, d in orig_nodes if n in inputs]) 

1926 for node in inputs: 

1927 F.nodes[node]["color"] = dodgerblue3 

1928 F.nodes[node]["fontcolor"] = dodgerblue3 

1929 F.nodes[node]["penwidth"] = 3.0 

1930 F.nodes[node]["fontname"] = FONT 

1931 

1932 F.inputs = list(F.inputs) 

1933 

1934 F.add_nodes_from([(n, d) for n, d in orig_nodes if n in blanket_nodes]) 

1935 for node in blanket_nodes: 

1936 F.nodes[node]["fontname"] = FONT 

1937 F.nodes[node]["color"] = forestgreen 

1938 F.nodes[node]["fontcolor"] = forestgreen 

1939 

1940 F.add_nodes_from([(n, d) for n, d in orig_nodes if n in main_nodes]) 

1941 for node in main_nodes: 

1942 F.nodes[node]["fontname"] = FONT 

1943 

1944 for out_var_node in outputs: 

1945 F.add_node(out_var_node, **self.nodes[out_var_node]) 

1946 F.nodes[out_var_node]["color"] = dodgerblue3 

1947 F.nodes[out_var_node]["fontcolor"] = dodgerblue3 

1948 

1949 F.add_edges_from(main_edges) 

1950 return F 

1951 

1952 def to_dict(self) -> Dict: 

1953 """Outputs the contents of this GrFN to a dict object. 

1954 

1955 :return: Description of returned object. 

1956 :rtype: type 

1957 :raises ExceptionName: Why the exception is raised. 

1958 """ 

1959 return { 

1960 "uid": self.uid, 

1961 "entry_point": "::".join( 

1962 ["@container", self.namespace, self.scope, self.name] 

1963 ), 

1964 "timestamp": self.timestamp, 

1965 "hyper_edges": [edge.to_dict() for edge in self.hyper_edges], 

1966 "variables": [var.to_dict() for var in self.variables], 

1967 "functions": [func.to_dict() for func in self.lambdas], 

1968 "subgraphs": [sgraph.to_dict() for sgraph in self.subgraphs], 

1969 # TODO fix this 

1970 "types": [ 

1971 t_def.to_dict() 

1972 for t_def in ( 

1973 self.types 

1974 if isinstance(self.types, list) 

1975 else self.types.values() 

1976 ) 

1977 ], 

1978 "metadata": [m.to_dict() for m in self.metadata], 

1979 } 

1980 

1981 def to_json(self) -> str: 

1982 """Outputs the contents of this GrFN to a JSON object string. 

1983 

1984 :return: Description of returned object. 

1985 :rtype: type 

1986 :raises ExceptionName: Why the exception is raised. 

1987 """ 

1988 data = self.to_dict() 

1989 return json.dumps(data) 

1990 

1991 def to_json_file(self, json_path) -> None: 

1992 with open(json_path, "w") as outfile: 

1993 outfile.write(self.to_json()) 

1994 

1995 @classmethod 

1996 def from_dict(cls, data): 

1997 # Re-create variable and function nodes from their JSON descriptions 

1998 V = {v["uid"]: VariableNode.from_dict(v) for v in data["variables"]} 

1999 F = {f["uid"]: LambdaNode.from_dict(f) for f in data["functions"]} 

2000 

2001 # Add all of the function and variable nodes to a new DiGraph 

2002 G = nx.DiGraph() 

2003 ALL_NODES = {**V, **F} 

2004 for grfn_node in ALL_NODES.values(): 

2005 G.add_node(grfn_node, **(grfn_node.get_kwargs())) 

2006 

2007 # Re-create the hyper-edges/subgraphs using the node lookup list 

2008 S = nx.DiGraph() 

2009 

2010 subgraphs = [ 

2011 GrFNSubgraph.from_dict(s, ALL_NODES) for s in data["subgraphs"] 

2012 ] 

2013 subgraph_dict = {s.uid: s for s in subgraphs} 

2014 subgraph_edges = [ 

2015 (subgraph_dict[s.parent], subgraph_dict[s.uid]) 

2016 for s in subgraphs 

2017 if s.parent is not None 

2018 ] 

2019 S.add_nodes_from(subgraphs) 

2020 S.add_edges_from(subgraph_edges) 

2021 

2022 H = [HyperEdge.from_dict(h, ALL_NODES) for h in data["hyper_edges"]] 

2023 

2024 T = ( 

2025 [TypeDefinition.from_data(t) for t in data["types"]] 

2026 if "types" in data 

2027 else [] 

2028 ) 

2029 

2030 M = ( 

2031 [TypedMetadata.from_data(d) for d in data["metadata"]] 

2032 if "metadata" in data 

2033 else [] 

2034 ) 

2035 

2036 # Add edges to the new DiGraph using the re-created hyper-edge objects 

2037 for edge in H: 

2038 G.add_edges_from([(var, edge.lambda_fn) for var in edge.inputs]) 

2039 G.add_edges_from([(edge.lambda_fn, var) for var in edge.outputs]) 

2040 

2041 if "entry_point" in data: 

2042 entry_point = data["entry_point"] 

2043 elif "identifier" in data: 

2044 entry_point = data["identifier"] 

2045 else: 

2046 entry_point = "" 

2047 identifier = GenericIdentifier.from_str(entry_point) 

2048 return cls(data["uid"], identifier, data["timestamp"], G, H, S, T, M) 

2049 

2050 @classmethod 

2051 def from_json(cls, json_path): 

2052 """Short summary. 

2053 

2054 :param type cls: Description of parameter `cls`. 

2055 :param type json_path: Description of parameter `json_path`. 

2056 :return: Description of returned object. 

2057 :rtype: type 

2058 :raises ExceptionName: Why the exception is raised. 

2059 

2060 """ 

2061 data = json.load(open(json_path, "r")) 

2062 return cls.from_dict(data) 

2063 

2064 

2065class CausalAnalysisGraph(nx.DiGraph): 

2066 def __init__(self, G, S, uid, date, ns, sc, nm): 

2067 super().__init__(G) 

2068 self.subgraphs = S 

2069 self.uid = uid 

2070 self.timestamp = date 

2071 self.namespace = ns 

2072 self.scope = sc 

2073 self.name = nm 

2074 

2075 @classmethod 

2076 def from_GrFN(cls, GrFN: GroundedFunctionNetwork): 

2077 """Export to a Causal Analysis Graph (CAG) object. 

2078 The CAG shows the influence relationships between the variables and 

2079 elides the function nodes.""" 

2080 

2081 G = nx.DiGraph() 

2082 for var_node in GrFN.variables: 

2083 G.add_node(var_node, **var_node.get_kwargs()) 

2084 for edge in GrFN.hyper_edges: 

2085 if edge.lambda_fn.func_type == LambdaType.INTERFACE: 

2086 G.add_edges_from(list(zip(edge.inputs, edge.outputs))) 

2087 else: 

2088 G.add_edges_from(list(product(edge.inputs, edge.outputs))) 

2089 

2090 def delete_paths_at_level(nodes: list): 

2091 orig_nodes = deepcopy(nodes) 

2092 

2093 while len(nodes) > 0: 

2094 updated_nodes = list() 

2095 for node in nodes: 

2096 node_var_name = node.identifier.var_name 

2097 succs = list(G.successors(node)) 

2098 for next_node in succs: 

2099 if ( 

2100 next_node.identifier.var_name == node_var_name 

2101 and len(list(G.predecessors(next_node))) == 1 

2102 ): 

2103 next_succs = list(G.successors(next_node)) 

2104 G.remove_node(next_node) 

2105 updated_nodes.append(node) 

2106 for succ_node in next_succs: 

2107 G.add_edge(node, succ_node) 

2108 nodes = updated_nodes 

2109 

2110 next_level_nodes = list() 

2111 for node in orig_nodes: 

2112 next_level_nodes.extend(list(G.successors(node))) 

2113 next_level_nodes = list(set(next_level_nodes)) 

2114 

2115 if len(next_level_nodes) > 0: 

2116 delete_paths_at_level(next_level_nodes) 

2117 

2118 def correct_subgraph_nodes(subgraph: GrFNSubgraph): 

2119 cag_subgraph_nodes = list( 

2120 set(G.nodes).intersection(set(subgraph.nodes)) 

2121 ) 

2122 subgraph.nodes = cag_subgraph_nodes 

2123 

2124 for new_subgraph in GrFN.subgraphs.successors(subgraph): 

2125 correct_subgraph_nodes(new_subgraph) 

2126 

2127 input_nodes = [n for n in G.nodes if G.in_degree(n) == 0] 

2128 delete_paths_at_level(input_nodes) 

2129 root_subgraph = [n for n, d in GrFN.subgraphs.in_degree() if d == 0][0] 

2130 correct_subgraph_nodes(root_subgraph) 

2131 return cls( 

2132 G, 

2133 GrFN.subgraphs, 

2134 GrFN.uid, 

2135 GrFN.timestamp, 

2136 GrFN.namespace, 

2137 GrFN.scope, 

2138 GrFN.name, 

2139 ) 

2140 

2141 def to_AGraph(self): 

2142 """Returns a variable-only view of the GrFN in the form of an AGraph. 

2143 

2144 Returns: 

2145 type: A CAG constructed via variable influence in the GrFN object. 

2146 

2147 """ 

2148 A = nx.nx_agraph.to_agraph(self) 

2149 A.graph_attr.update( 

2150 { 

2151 "dpi": 227, 

2152 "fontsize": 20, 

2153 "fontcolor": "black", 

2154 "fontname": "Menlo", 

2155 "rankdir": "TB", 

2156 } 

2157 ) 

2158 A.node_attr.update( 

2159 shape="rectangle", 

2160 color="#650021", 

2161 fontname="Menlo", 

2162 fontcolor="black", 

2163 ) 

2164 for node in A.iternodes(): 

2165 node.attr["fontcolor"] = "black" 

2166 node.attr["style"] = "rounded" 

2167 A.edge_attr.update({"color": "#650021", "arrowsize": 0.5}) 

2168 

2169 def get_subgraph_nodes(subgraph: GrFNSubgraph): 

2170 return subgraph.nodes + [ 

2171 node 

2172 for child_graph in self.subgraphs.successors(subgraph) 

2173 for node in get_subgraph_nodes(child_graph) 

2174 ] 

2175 

2176 def populate_subgraph(subgraph: GrFNSubgraph, parent: AGraph): 

2177 all_sub_nodes = get_subgraph_nodes(subgraph) 

2178 container_subgraph = parent.add_subgraph( 

2179 all_sub_nodes, 

2180 name=f"cluster_{str(subgraph)}", 

2181 label=subgraph.basename, 

2182 style="bold, rounded", 

2183 rankdir="TB", 

2184 color=subgraph.border_color, 

2185 ) 

2186 

2187 for new_subgraph in self.subgraphs.successors(subgraph): 

2188 populate_subgraph(new_subgraph, container_subgraph) 

2189 

2190 root_subgraph = [n for n, d in self.subgraphs.in_degree() if d == 0][0] 

2191 populate_subgraph(root_subgraph, A) 

2192 return A 

2193 

2194 def to_igraph_gml(self, filepath: str) -> NoReturn: 

2195 filename = os.path.join( 

2196 filepath, 

2197 f"{self.namespace}__{self.scope}__{self.name}--igraph.gml", 

2198 ) 

2199 

2200 V = [str(v) for v in super().nodes] 

2201 E = [(str(e1), str(e2)) for e1, e2 in super().edges] 

2202 iG = nx.DiGraph() 

2203 iG.add_nodes_from(V) 

2204 iG.add_edges_from(E) 

2205 nx.write_gml(iG, filename) 

2206 

2207 def to_json(self) -> str: 

2208 """Outputs the contents of this GrFN to a JSON object string. 

2209 

2210 :return: Description of returned object. 

2211 :rtype: type 

2212 :raises ExceptionName: Why the exception is raised. 

2213 """ 

2214 data = { 

2215 "uid": self.uid, 

2216 "identifier": "::".join( 

2217 ["@container", self.namespace, self.scope, self.name] 

2218 ), 

2219 "timestamp": self.timestamp, 

2220 "variables": [var.to_dict() for var in self.nodes], 

2221 "edges": [(src.uid, dst.uid) for src, dst in self.edges], 

2222 "subgraphs": [sgraphs.to_dict() for sgraphs in self.subgraphs], 

2223 } 

2224 return json.dumps(data) 

2225 

2226 def to_json_file(self, json_path) -> None: 

2227 with open(json_path, "w") as outfile: 

2228 outfile.write(self.to_json())