Coverage for skema/gromet/execution_engine/execution_engine.py: 73%
163 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import yaml
2import argparse
3import asyncio
4import subprocess
5import json
6import asyncio
7import requests
8import traceback
9from ast import literal_eval
10from pathlib import Path
11from typing import Any, List, Dict
13import torch
15from skema.program_analysis.CAST.pythonAST.builtin_map import retrieve_operator
16from skema.program_analysis.single_file_ingester import process_file
17from skema.gromet.execution_engine.execute import execute_primitive
19from skema.rest.workflows import code_snippets_to_pn_amr
20from skema.skema_py.server import System
21from skema.gromet.execution_engine.query_runner import QueryRunner
22from skema.gromet.execution_engine.symbol_table import SymbolTable
23from skema.utils.fold import dictionary_to_gromet_json, del_nulls
24from skema.rest.utils import fn_preprocessor
25from skema.rest.morae_proxy import post_model
26from skema.skema_py.server import System, fn_given_filepaths
27from skema.rest.proxies import (
28 SKEMA_RS_ADDESS,
29 SKEMA_GRAPH_DB_PROTO,
30 SKEMA_GRAPH_DB_HOST,
31 SKEMA_GRAPH_DB_PORT,
32)
35class Execute(torch.autograd.Function):
36 @staticmethod
37 def forward(ctx, primitive: str, inputs: List[torch.Tensor]):
38 return execute_primitive(primitive, inputs)
40 @staticmethod
41 def backward(ctx, grad_output):
42 pass
45execute = Execute.apply
48class ExecutionEngine:
49 def __init__(self, protocol: str, host: str, port: str, source_path: str):
50 self.query_runner = QueryRunner(protocol, host, port)
51 self.symbol_table = SymbolTable()
52 self.source_path = source_path
54 # Filename is source path filename minus the extension
55 self.filename = Path(source_path).stem
57 # Upload source to Memgraph instance
58 self.model_id = None
59 self.upload_source_remote()
61 def enrich_amr(self, amr: Dict) -> Dict:
62 """Enrich the AMR for a source file with initial parameter values"""
64 parameters = amr["semantics"]["ode"]["parameters"]
65 # For each parameter, see if we have a matching value
66 for index, parameter in enumerate(parameters):
67 value = self.symbol_table.get_symbol(parameter["name"])
68 if value:
69 try:
70 parameters[index]["value"] = value["history"][0].item()
71 except:
72 continue
73 else:
74 print(
75 f"WARNING: Could not extract value for parameter {parameter['name']}"
76 )
78 return amr
80 def upload_source_remote(self):
81 """Ingest source file and upload Gromet to Memgraph"""
82 gromet_collection = asyncio.run(
83 fn_given_filepaths(
84 System(
85 files=[self.source_path], blobs=[Path(self.source_path).read_text()]
86 )
87 )
88 )
89 gromet_collection = fn_preprocessor(gromet_collection)[0]
91 # Upload to memgraph
92 self.model_id = requests.post(
93 f"{SKEMA_RS_ADDESS}/models", json=gromet_collection
94 ).json()
96 def execute(
97 self,
98 module: bool = False,
99 main: bool = False,
100 function: bool = False,
101 function_name: str = None,
102 ):
103 """Run the execution engine at specified scope"""
104 if module:
105 module_list = self.query_runner.run_query(
106 "module", n_or_m="n", id=self.model_id
107 )
108 self.visit(module_list[0])
110 # After execution, delete the model and close down the memgraph connection
111 response = requests.delete(f"{SKEMA_RS_ADDESS}/models/{self.model_id}")
112 self.query_runner.memgraph.close()
114 def parameter_extraction(self):
115 """Run the execution engine and extract initial values for each parameter"""
117 # Execute the source at the module level
118 self.execute(module=True)
120 # Extract the initial values from the symbol map
121 return self.symbol_table.get_initial_values()
123 def visit(self, node):
124 """Top-level visitor function"""
125 node_types = node._labels
126 try:
127 if "Module" in node_types:
128 self.visit_module(node)
129 if "Expression" in node_types:
130 self.visit_expression(node)
131 if "Function" in node_types:
132 self.visit_function(node)
133 if "Opo" in node_types:
134 return self.visit_opo(node)
135 if "Opi" in node_types:
136 return self.visit_opi(node)
137 if "Literal" in node_types:
138 return self.visit_literal(node)
139 if "Primitive" in node_types:
140 return self.visit_primitive(node)
141 except Exception as e:
142 print(f"Visitor for node {node} failed to execute.")
143 print(e)
144 print(traceback.format_exc())
146 def visit_module(self, node):
147 """Visitor for top-level module"""
148 node_id = str(node._id)
150 expressions = self.query_runner.run_query("ordered_expressions", id=node_id)
151 for expression in expressions:
152 self.visit(expression)
154 def visit_expression(self, node):
155 node_id = node._id
157 # Only the left hand side is directly connected to the expression. So, we access the right hand side from the left hand side node
158 # (Expression) -> (Opo) -> (Primitive | Literal | Opo)
159 left_hand_side = self.query_runner.run_query("assignment_left_hand", id=node_id)
160 right_hand_side = self.query_runner.run_query(
161 "assignment_right_hand", id=left_hand_side[0]._id
162 )
164 # The lefthand side represents the Opo of the variable we are assigning to
165 # TODO: What if we have multiple assignment x,y = 1,2
166 # TODO: Does an expression always correspond to an assingment?
167 symbol = self.visit(left_hand_side[0])
169 # The right hand side can be either a LiteralValue, an Expression, an Opi, or a Primitive
170 index = {"Primitive": 1, "Expression": 1, "Opi": 1, "Literal": 2}
171 right_hand_side = sorted(
172 right_hand_side, key=lambda node: index[list(node._labels)[0]]
173 )
174 value = self.visit(right_hand_side[0])
176 if ExecutionEngine.is_node_type(right_hand_side[0], "Opi"):
177 pass
179 if "Opi" in right_hand_side[0]._labels:
180 value = self.symbol_table.get_symbol(value)["current_value"]
182 if not self.symbol_table.get_symbol(symbol):
183 self.symbol_table.add_symbol(symbol, value, None)
184 else:
185 self.symbol_table.update_symbol(symbol, value, None)
187 def visit_function(self, node):
188 """Visitor for :Opi node type"""
189 # TODO: Add support for function calls/definitions
190 pass
192 def visit_opo(self, node):
193 "Visitor for :Opo node type"
194 return node.name
196 def visit_opi(self, node):
197 """Visitor for :Opi node type"""
198 node_id = node._id
200 # If un-named, we need to get the name from the attached Opo
201 if node.name == "un-named":
202 return self.visit(
203 self.query_runner.run_query("assignment_left_hand", id=node_id)[0]
204 )
206 return node.name
208 def _visit_opo_value(self, node):
209 """Visit the :Opo and return the value rather than the name"""
210 return self.symbol_table.get_symbol(node.name)
212 def _visit_opi_value(self, node):
213 """Visit the :Opi and return the value rather than the name"""
214 return self.symbol_table.get_symbol(node.name)
216 def visit_literal(self, node):
217 def create_dummy_node(value: Dict):
218 """Create a dummy gqlalchemy node so that we can pass a LiteralValue to a visitor."""
220 class DummyNode:
221 pass
223 node = DummyNode()
224 node._id = -1
225 node._labels = ["Literal"]
226 node.value = value
228 # TODO: Update LiteralValue representation for List types
229 node.value["value"] = str(node.value["value"])
231 return node
233 # TODO: Update LiteralValue to remove wrapping "" characters
234 value = node.value["value"].strip('"')
235 value_type = node.value["value_type"]
237 if value_type == "Integer":
238 return torch.tensor(int(value), dtype=torch.int)
239 elif value_type == "AbstractFloat":
240 return torch.tensor(float(value), dtype=torch.float64)
241 elif value_type == "Complex":
242 print(
243 "WARNING: Execution for type Complex not support and will be skipped."
244 )
245 elif value_type == "Boolean":
246 return torch.tensor(value == "True", dtype=torch.bool)
247 elif value_type == "List":
248 if isinstance(value, str):
249 return None
250 list = literal_eval(value)
251 return [self.visit(create_dummy_node(element)) for element in list]
252 elif value_type == "Map":
253 print("WARNING: Execution for type Map not support and will be skipped.")
254 elif value_type == "None":
255 return None
257 def visit_primitive(self, node):
258 """Visitor for :Primitive node type"""
259 node_id = node._id
261 """
262 input_nodes = [input for input in self.query_runner("primitive_operands", id=node_id)]
263 inputs = []
264 for input in input_nodes:
265 if ExecutionEngine.is_node_type(input, "Opi"):
266 value = self._visit_opi_value(input)
267 else:
268 value = self.visit(input)
269 inputs.append(value)
270 """
272 # Some inputs may be symbol names, so we need to access the current value from the symbol map
273 inputs = [
274 self.visit(input)
275 for input in self.query_runner.run_query("primitive_operands", id=node_id)
276 ]
277 inputs = [
278 self.symbol_table.get_symbol(input)["current_value"]
279 if isinstance(input, str)
280 else input
281 for input in inputs
282 ]
284 primative = retrieve_operator(node.name)
285 return execute(primative, inputs)
287 @staticmethod
288 def is_node_type(node, node_type: str):
289 """Helper function for checking if a node matches a given type"""
290 return node_type in node._labels
293if __name__ == "__main__":
294 parser = argparse.ArgumentParser(description="Parameter Extraction Script")
295 parser.add_argument("source_path", type=str, help="File path to source to execute")
297 args = parser.parse_args()
299 protocol, host, port = (
300 SKEMA_GRAPH_DB_PROTO,
301 SKEMA_GRAPH_DB_HOST,
302 SKEMA_GRAPH_DB_PORT,
303 )
304 engine = ExecutionEngine(protocol, host, port, args.source_path)
306 print(engine.parameter_extraction())
307 """ TODO: New arguments to add with function execution support
308 group = parser.add_mutually_exclusive_group(required=True)
309 group.add_argument("--main", action="store_true", help="Extract parameters from the main module")
310 group.add_argument("--function", type=str, metavar="function_name", help="Extract parameters from a specific function")
311 """