Coverage for skema/model_assembly/model_dynamics.py: 0%
150 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1from copy import deepcopy
3from networkx.algorithms.simple_paths import all_simple_paths
4from skema.model_assembly.networks import (
5 GrFNLoopSubgraph,
6 GrFNSubgraph,
7 GroundedFunctionNetwork,
8 HyperEdge,
9 LambdaNode,
10)
11from skema.model_assembly.structures import LambdaType
14def remove_node_and_hyper_edges(grfn: GroundedFunctionNetwork, node):
15 for grfn_node in grfn.nodes:
16 if node.uid == grfn_node.uid:
17 grfn.remove_node(grfn_node)
18 break
20 def remove_from_subgraphs(g):
21 if node in g.nodes:
22 g.nodes = [n for n in g.nodes if n != node]
23 for sub in grfn.subgraphs.successors(g):
24 remove_from_subgraphs(sub)
26 remove_from_subgraphs(grfn.root_subgraph)
28 edges_to_remove = set()
29 for hyper_edge in grfn.hyper_edges:
30 if node in hyper_edge.inputs:
31 hyper_edge.inputs.remove(node)
32 if len(hyper_edge.inputs) == 0:
33 edges_to_remove.add(hyper_edge)
34 elif node in hyper_edge.outputs:
35 hyper_edge.outputs.remove(node)
36 if len(hyper_edge.outputs) == 0:
37 edges_to_remove.add(hyper_edge)
38 elif node == hyper_edge.lambda_fn:
39 edges_to_remove.add(hyper_edge)
41 grfn.hyper_edges = [
42 h for h in grfn.hyper_edges if h not in edges_to_remove
43 ]
44 for edge in edges_to_remove:
45 grfn.lambdas = [l for l in grfn.lambdas if edge.lambda_fn.uid != l.uid]
48def get_input_interface_node(
49 grfn: GroundedFunctionNetwork, subgraph: GrFNSubgraph
50):
51 return [
52 node
53 for node in subgraph.nodes
54 if isinstance(node, LambdaNode)
55 and node.func_type == LambdaType.INTERFACE
56 and all(
57 [
58 node_succ in subgraph.nodes
59 for node_succ in grfn.successors(node)
60 ]
61 )
62 ][0]
65def get_output_interface_node(
66 grfn: GroundedFunctionNetwork, subgraph: GrFNSubgraph
67):
68 return [
69 node
70 for node in subgraph.nodes
71 if isinstance(node, LambdaNode)
72 and node.func_type == LambdaType.INTERFACE
73 and all(
74 [
75 node_succ in subgraph.nodes
76 for node_succ in grfn.predecessors(node)
77 ]
78 )
79 ][0]
82def get_decision_nodes(subgraph: GrFNSubgraph):
83 return [
84 node
85 for node in subgraph.nodes
86 if isinstance(node, LambdaNode)
87 and node.func_type == LambdaType.DECISION
88 ]
91def extract_dynamics_from_loop(
92 grfn: GroundedFunctionNetwork, loop: GrFNLoopSubgraph
93):
94 # Create a copy of the current grfn to trim nodes out of to create the
95 # model dynamics grfn
96 dynamics_grfn = deepcopy(grfn)
97 dynamics_grfn_subgraphs_graph = dynamics_grfn.subgraphs
98 loop_copy = [s for s in dynamics_grfn.subgraphs if s == loop][0]
99 to_remove = set()
101 # Delete all other loop subgraphs besides the loop we are operating on from
102 # the root subgraph. TODO test if this works
103 loop_subgraphs_to_remove = []
104 for subgraph in dynamics_grfn_subgraphs_graph.successors(
105 dynamics_grfn.root_subgraph
106 ):
107 if subgraph != loop and isinstance(subgraph, GrFNLoopSubgraph):
108 loop_subgraphs_to_remove.append(subgraph)
109 for subgraph in loop_subgraphs_to_remove:
110 dynamics_grfn_subgraphs_graph.remove_node(subgraph)
112 # Generatate the input/output var pairs for the loop interface
113 loop_successors = dynamics_grfn_subgraphs_graph.successors(loop)
114 loop_interface = get_input_interface_node(dynamics_grfn, loop)
115 loop_output_interface = get_output_interface_node(dynamics_grfn, loop)
116 loop_decisions = get_decision_nodes(loop)
118 loop_interface_hyper_edge = [
119 h for h in grfn.hyper_edges if h.lambda_fn == loop_interface
120 ][0]
121 loop_interfaces_input_output_var_pairs = []
122 for (input, output) in zip(
123 loop_interface_hyper_edge.inputs, loop_interface_hyper_edge.outputs
124 ):
125 loop_interfaces_input_output_var_pairs.append((input, output))
127 # For each variable going through the loop interface in main, if it then
128 # goes through the loop decision node, create an edge from the original
129 # output var of the interface to where the decision variable goes.
130 for input_var, output_var in loop_interfaces_input_output_var_pairs:
131 output_succs = list(dynamics_grfn.successors(output_var))
132 for output_var_succ in output_succs:
133 if output_var_succ.func_type == LambdaType.DECISION:
134 var_after_decision = [
135 v
136 for v in dynamics_grfn.successors(output_var_succ)
137 if v.identifier.var_name == output_var.identifier.var_name
138 ][0]
139 for new_output_var_succ in dynamics_grfn.successors(
140 var_after_decision
141 ):
142 dynamics_grfn.add_edge(output_var, new_output_var_succ)
143 remove_node_and_hyper_edges(dynamics_grfn, var_after_decision)
145 output_decision_node = [
146 node
147 for node in loop_decisions
148 if len(list(grfn.predecessors(node)))
149 == ((len(list(grfn.successors(node))) * 2) + 1)
150 ][0]
151 for succ in dynamics_grfn.successors(loop_output_interface):
152 for decision_pred in dynamics_grfn.predecessors(output_decision_node):
153 if decision_pred.identifier.var_name == succ.identifier.var_name:
154 for src, _ in list(dynamics_grfn.in_edges(decision_pred)):
155 # dynamics_grfn.remove_edge(src, decision_pred)
156 # dynamics_grfn.add_edge(src, succ)
157 # to_remove.add(decision_pred)
158 var_output_edge_matches = [
159 h
160 for h in dynamics_grfn.hyper_edges
161 if decision_pred in h.outputs
162 ]
163 if len(var_output_edge_matches) > 0:
164 var_output_edge = var_output_edge_matches[0]
165 var_output_idx = [
166 idx
167 for idx, v in enumerate(var_output_edge.outputs)
168 if decision_pred.uid == v.uid
169 if decision_pred.uid == v.uid
170 ][0]
171 var_output_edge.outputs[var_output_idx] = succ
172 dynamics_grfn.add_edge(src, succ)
173 remove_node_and_hyper_edges(dynamics_grfn, decision_pred)
174 break
176 to_remove.update(dynamics_grfn.successors(output_decision_node))
178 # Now that we have created new edges ignoring the decision node, remove
179 # the decision nodes
180 for loop_decision in loop_decisions:
181 remove_node_and_hyper_edges(dynamics_grfn, loop_decision)
183 # For each subgraph within the loop, add an edge from the grfn root
184 # subgraph to it, move nodes from loop to root subgraph, and track
185 # these variables
186 loop_nodes_to_preserve = set()
187 for loop_succ in loop_successors:
188 loop_succ.parent = dynamics_grfn.root_subgraph.uid
189 dynamics_grfn_subgraphs_graph.add_edge(
190 dynamics_grfn.root_subgraph, loop_succ
191 )
192 loop_succ_interface = get_input_interface_node(
193 dynamics_grfn, loop_succ
194 )
195 loop_succ_interface_pred = set(
196 dynamics_grfn.predecessors(loop_succ_interface)
197 )
198 # Find potential paths to this loop successors interface
199 paths_to_interface = all_simple_paths(
200 dynamics_grfn, loop_interface, loop_succ_interface
201 )
202 interface_hyper_edge_inputs = list()
203 # for each path found to this interface
204 for path in paths_to_interface:
205 # for each node on the path, if it is from the loop subgraph,
206 # add it into the root subgraph
207 for node in path:
208 if node != loop_interface and node in loop.nodes:
209 dynamics_grfn.root_subgraph.nodes.append(node)
210 loop_nodes_to_preserve.add(node)
211 if (
212 node in loop_succ_interface_pred
213 and node not in interface_hyper_edge_inputs
214 ):
215 interface_hyper_edge_inputs.append(node)
217 existing_hyper_edges = [
218 h
219 for h in dynamics_grfn.hyper_edges
220 if isinstance(h.lambda_fn, LambdaNode)
221 and h.lambda_fn.uid == loop_succ_interface.uid
222 ]
223 if len(existing_hyper_edges) > 0:
224 dynamics_grfn.hyper_edges = [
225 h
226 for h in dynamics_grfn.hyper_edges
227 if h != existing_hyper_edges[0]
228 ]
229 dynamics_grfn.hyper_edges.append(
230 HyperEdge(
231 interface_hyper_edge_inputs,
232 loop_succ_interface,
233 existing_hyper_edges[0].outputs,
234 )
235 )
237 # Preserve the output vars of the loop successors we are keeping
238 loop_succ_output_interface = loop_succ.get_output_interface_node(
239 dynamics_grfn.hyper_edges
240 )
241 for v in loop_succ_output_interface.outputs:
242 dynamics_grfn.root_subgraph.nodes.append(v)
243 loop_nodes_to_preserve.add(v)
245 loop_nodes_to_preserve.add(loop_succ_interface)
246 loop_nodes_to_preserve.add(loop_succ_output_interface)
248 # Create an edge from the variable going through the loop interface in main
249 # to wherever the output variable of the loop interface is going to.
250 # Remove the output var node in the loop from the graph.
251 for input_var, output_var in loop_interfaces_input_output_var_pairs:
252 for output_var_succ in dynamics_grfn.successors(output_var):
253 dynamics_grfn.add_edge(input_var, output_var_succ)
254 for edge in dynamics_grfn.hyper_edges:
255 for idx, input in enumerate(edge.inputs):
256 if output_var == input:
257 edge.inputs[idx] = input_var
258 break
260 remove_node_and_hyper_edges(dynamics_grfn, output_var)
262 # Remove variables going out of the loop sugraph as the results in main
263 loop_edges = [
264 e for e in dynamics_grfn.hyper_edges if e.lambda_fn in loop.nodes
265 ]
266 loop_output_interface_edge = loop.get_output_interface_node(loop_edges)
267 for loop_output_var in loop_output_interface_edge.outputs:
268 if loop_output_var not in loop_nodes_to_preserve:
269 remove_node_and_hyper_edges(dynamics_grfn, loop_output_var)
271 # Remove all loop nodes that we dont want to preserve from the grfn
272 for node in loop.nodes:
273 if node not in loop_nodes_to_preserve:
274 remove_node_and_hyper_edges(dynamics_grfn, node)
276 # Remove the model driver loop from the dynamics grfn
277 dynamics_grfn_subgraphs_graph.remove_node(loop_copy)
279 def remove_empty_path(node):
280 if node in dynamics_grfn.nodes:
281 node_succs = list(dynamics_grfn.successors(node))
282 if len(node_succs) == 0:
283 predecessors = dynamics_grfn.predecessors(node)
284 remove_node_and_hyper_edges(dynamics_grfn, node)
285 # TODO this works for now, but the node might not always
286 # be in the root subgraph.
287 # dynamics_grfn.root_subgraph.nodes.remove(node)
288 for p in predecessors:
289 remove_empty_path(p)
291 # Remove hanging variables (and there potential singular path) going into
292 # the loop interface that are not used anymore. (This applies to variables
293 # like a loop iterator "i" or variables only used in the condition.)
294 for n in loop_interface_hyper_edge.inputs:
295 if n not in loop_nodes_to_preserve:
296 remove_empty_path(n)
298 for l_node in [
299 n for n in dynamics_grfn.nodes if isinstance(n, LambdaNode)
300 ]:
301 # There is only one output with a literal node
302 output_var = list(dynamics_grfn.successors(l_node))[0]
303 if (
304 # TODO maybe fix?
305 l_node.func_type == LambdaType.LITERAL
306 and len(list(dynamics_grfn.successors(output_var))) == 0
307 ):
308 to_remove.add(output_var)
309 to_remove.add(l_node)
311 for n in to_remove:
312 remove_node_and_hyper_edges(dynamics_grfn, n)
314 return dynamics_grfn
317def extract_model_dynamics(grfn: GroundedFunctionNetwork):
318 resulting_model_dynamics_grfns = []
319 root_subgraph = grfn.root_subgraph
320 root_successors = grfn.subgraphs.successors(root_subgraph)
321 for succ in root_successors:
322 if isinstance(succ, GrFNLoopSubgraph):
323 extracted_dynamics = extract_dynamics_from_loop(grfn, succ)
324 resulting_model_dynamics_grfns.append(extracted_dynamics)
326 return resulting_model_dynamics_grfns