Coverage for skema/gromet/execution_engine/execution_engine.py: 73%

163 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import yaml 

2import argparse 

3import asyncio 

4import subprocess 

5import json 

6import asyncio 

7import requests 

8import traceback 

9from ast import literal_eval 

10from pathlib import Path 

11from typing import Any, List, Dict 

12 

13import torch 

14 

15from skema.program_analysis.CAST.pythonAST.builtin_map import retrieve_operator 

16from skema.program_analysis.single_file_ingester import process_file 

17from skema.gromet.execution_engine.execute import execute_primitive 

18 

19from skema.rest.workflows import code_snippets_to_pn_amr 

20from skema.skema_py.server import System 

21from skema.gromet.execution_engine.query_runner import QueryRunner 

22from skema.gromet.execution_engine.symbol_table import SymbolTable 

23from skema.utils.fold import dictionary_to_gromet_json, del_nulls 

24from skema.rest.utils import fn_preprocessor 

25from skema.rest.morae_proxy import post_model 

26from skema.skema_py.server import System, fn_given_filepaths 

27from skema.rest.proxies import ( 

28 SKEMA_RS_ADDESS, 

29 SKEMA_GRAPH_DB_PROTO, 

30 SKEMA_GRAPH_DB_HOST, 

31 SKEMA_GRAPH_DB_PORT, 

32) 

33 

34 

35class Execute(torch.autograd.Function): 

36 @staticmethod 

37 def forward(ctx, primitive: str, inputs: List[torch.Tensor]): 

38 return execute_primitive(primitive, inputs) 

39 

40 @staticmethod 

41 def backward(ctx, grad_output): 

42 pass 

43 

44 

45execute = Execute.apply 

46 

47 

48class ExecutionEngine: 

49 def __init__(self, protocol: str, host: str, port: str, source_path: str): 

50 self.query_runner = QueryRunner(protocol, host, port) 

51 self.symbol_table = SymbolTable() 

52 self.source_path = source_path 

53 

54 # Filename is source path filename minus the extension 

55 self.filename = Path(source_path).stem 

56 

57 # Upload source to Memgraph instance 

58 self.model_id = None 

59 self.upload_source_remote() 

60 

61 def enrich_amr(self, amr: Dict) -> Dict: 

62 """Enrich the AMR for a source file with initial parameter values""" 

63 

64 parameters = amr["semantics"]["ode"]["parameters"] 

65 # For each parameter, see if we have a matching value 

66 for index, parameter in enumerate(parameters): 

67 value = self.symbol_table.get_symbol(parameter["name"]) 

68 if value: 

69 try: 

70 parameters[index]["value"] = value["history"][0].item() 

71 except: 

72 continue 

73 else: 

74 print( 

75 f"WARNING: Could not extract value for parameter {parameter['name']}" 

76 ) 

77 

78 return amr 

79 

80 def upload_source_remote(self): 

81 """Ingest source file and upload Gromet to Memgraph""" 

82 gromet_collection = asyncio.run( 

83 fn_given_filepaths( 

84 System( 

85 files=[self.source_path], blobs=[Path(self.source_path).read_text()] 

86 ) 

87 ) 

88 ) 

89 gromet_collection = fn_preprocessor(gromet_collection)[0] 

90 

91 # Upload to memgraph 

92 self.model_id = requests.post( 

93 f"{SKEMA_RS_ADDESS}/models", json=gromet_collection 

94 ).json() 

95 

96 def execute( 

97 self, 

98 module: bool = False, 

99 main: bool = False, 

100 function: bool = False, 

101 function_name: str = None, 

102 ): 

103 """Run the execution engine at specified scope""" 

104 if module: 

105 module_list = self.query_runner.run_query( 

106 "module", n_or_m="n", id=self.model_id 

107 ) 

108 self.visit(module_list[0]) 

109 

110 # After execution, delete the model and close down the memgraph connection 

111 response = requests.delete(f"{SKEMA_RS_ADDESS}/models/{self.model_id}") 

112 self.query_runner.memgraph.close() 

113 

114 def parameter_extraction(self): 

115 """Run the execution engine and extract initial values for each parameter""" 

116 

117 # Execute the source at the module level 

118 self.execute(module=True) 

119 

120 # Extract the initial values from the symbol map 

121 return self.symbol_table.get_initial_values() 

122 

123 def visit(self, node): 

124 """Top-level visitor function""" 

125 node_types = node._labels 

126 try: 

127 if "Module" in node_types: 

128 self.visit_module(node) 

129 if "Expression" in node_types: 

130 self.visit_expression(node) 

131 if "Function" in node_types: 

132 self.visit_function(node) 

133 if "Opo" in node_types: 

134 return self.visit_opo(node) 

135 if "Opi" in node_types: 

136 return self.visit_opi(node) 

137 if "Literal" in node_types: 

138 return self.visit_literal(node) 

139 if "Primitive" in node_types: 

140 return self.visit_primitive(node) 

141 except Exception as e: 

142 print(f"Visitor for node {node} failed to execute.") 

143 print(e) 

144 print(traceback.format_exc()) 

145 

146 def visit_module(self, node): 

147 """Visitor for top-level module""" 

148 node_id = str(node._id) 

149 

150 expressions = self.query_runner.run_query("ordered_expressions", id=node_id) 

151 for expression in expressions: 

152 self.visit(expression) 

153 

154 def visit_expression(self, node): 

155 node_id = node._id 

156 

157 # Only the left hand side is directly connected to the expression. So, we access the right hand side from the left hand side node 

158 # (Expression) -> (Opo) -> (Primitive | Literal | Opo) 

159 left_hand_side = self.query_runner.run_query("assignment_left_hand", id=node_id) 

160 right_hand_side = self.query_runner.run_query( 

161 "assignment_right_hand", id=left_hand_side[0]._id 

162 ) 

163 

164 # The lefthand side represents the Opo of the variable we are assigning to 

165 # TODO: What if we have multiple assignment x,y = 1,2 

166 # TODO: Does an expression always correspond to an assingment? 

167 symbol = self.visit(left_hand_side[0]) 

168 

169 # The right hand side can be either a LiteralValue, an Expression, an Opi, or a Primitive 

170 index = {"Primitive": 1, "Expression": 1, "Opi": 1, "Literal": 2} 

171 right_hand_side = sorted( 

172 right_hand_side, key=lambda node: index[list(node._labels)[0]] 

173 ) 

174 value = self.visit(right_hand_side[0]) 

175 

176 if ExecutionEngine.is_node_type(right_hand_side[0], "Opi"): 

177 pass 

178 

179 if "Opi" in right_hand_side[0]._labels: 

180 value = self.symbol_table.get_symbol(value)["current_value"] 

181 

182 if not self.symbol_table.get_symbol(symbol): 

183 self.symbol_table.add_symbol(symbol, value, None) 

184 else: 

185 self.symbol_table.update_symbol(symbol, value, None) 

186 

187 def visit_function(self, node): 

188 """Visitor for :Opi node type""" 

189 # TODO: Add support for function calls/definitions 

190 pass 

191 

192 def visit_opo(self, node): 

193 "Visitor for :Opo node type" 

194 return node.name 

195 

196 def visit_opi(self, node): 

197 """Visitor for :Opi node type""" 

198 node_id = node._id 

199 

200 # If un-named, we need to get the name from the attached Opo 

201 if node.name == "un-named": 

202 return self.visit( 

203 self.query_runner.run_query("assignment_left_hand", id=node_id)[0] 

204 ) 

205 

206 return node.name 

207 

208 def _visit_opo_value(self, node): 

209 """Visit the :Opo and return the value rather than the name""" 

210 return self.symbol_table.get_symbol(node.name) 

211 

212 def _visit_opi_value(self, node): 

213 """Visit the :Opi and return the value rather than the name""" 

214 return self.symbol_table.get_symbol(node.name) 

215 

216 def visit_literal(self, node): 

217 def create_dummy_node(value: Dict): 

218 """Create a dummy gqlalchemy node so that we can pass a LiteralValue to a visitor.""" 

219 

220 class DummyNode: 

221 pass 

222 

223 node = DummyNode() 

224 node._id = -1 

225 node._labels = ["Literal"] 

226 node.value = value 

227 

228 # TODO: Update LiteralValue representation for List types 

229 node.value["value"] = str(node.value["value"]) 

230 

231 return node 

232 

233 # TODO: Update LiteralValue to remove wrapping "" characters 

234 value = node.value["value"].strip('"') 

235 value_type = node.value["value_type"] 

236 

237 if value_type == "Integer": 

238 return torch.tensor(int(value), dtype=torch.int) 

239 elif value_type == "AbstractFloat": 

240 return torch.tensor(float(value), dtype=torch.float64) 

241 elif value_type == "Complex": 

242 print( 

243 "WARNING: Execution for type Complex not support and will be skipped." 

244 ) 

245 elif value_type == "Boolean": 

246 return torch.tensor(value == "True", dtype=torch.bool) 

247 elif value_type == "List": 

248 if isinstance(value, str): 

249 return None 

250 list = literal_eval(value) 

251 return [self.visit(create_dummy_node(element)) for element in list] 

252 elif value_type == "Map": 

253 print("WARNING: Execution for type Map not support and will be skipped.") 

254 elif value_type == "None": 

255 return None 

256 

257 def visit_primitive(self, node): 

258 """Visitor for :Primitive node type""" 

259 node_id = node._id 

260 

261 """  

262 input_nodes = [input for input in self.query_runner("primitive_operands", id=node_id)] 

263 inputs = [] 

264 for input in input_nodes: 

265 if ExecutionEngine.is_node_type(input, "Opi"): 

266 value = self._visit_opi_value(input) 

267 else: 

268 value = self.visit(input) 

269 inputs.append(value) 

270 """ 

271 

272 # Some inputs may be symbol names, so we need to access the current value from the symbol map 

273 inputs = [ 

274 self.visit(input) 

275 for input in self.query_runner.run_query("primitive_operands", id=node_id) 

276 ] 

277 inputs = [ 

278 self.symbol_table.get_symbol(input)["current_value"] 

279 if isinstance(input, str) 

280 else input 

281 for input in inputs 

282 ] 

283 

284 primative = retrieve_operator(node.name) 

285 return execute(primative, inputs) 

286 

287 @staticmethod 

288 def is_node_type(node, node_type: str): 

289 """Helper function for checking if a node matches a given type""" 

290 return node_type in node._labels 

291 

292 

293if __name__ == "__main__": 

294 parser = argparse.ArgumentParser(description="Parameter Extraction Script") 

295 parser.add_argument("source_path", type=str, help="File path to source to execute") 

296 

297 args = parser.parse_args() 

298 

299 protocol, host, port = ( 

300 SKEMA_GRAPH_DB_PROTO, 

301 SKEMA_GRAPH_DB_HOST, 

302 SKEMA_GRAPH_DB_PORT, 

303 ) 

304 engine = ExecutionEngine(protocol, host, port, args.source_path) 

305 

306 print(engine.parameter_extraction()) 

307 """ TODO: New arguments to add with function execution support 

308 group = parser.add_mutually_exclusive_group(required=True) 

309 group.add_argument("--main", action="store_true", help="Extract parameters from the main module") 

310 group.add_argument("--function", type=str, metavar="function_name", help="Extract parameters from a specific function") 

311 """