Coverage for skema/program_analysis/CAST2FN/cast.py: 29%

127 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import ast 

2import json 

3import difflib 

4import typing 

5import networkx as nx 

6 

7from skema.program_analysis.CAST2FN.model.cast import ( 

8 AstNode, 

9 Assignment, 

10 Attribute, 

11 Call, 

12 CASTLiteralValue, 

13 FunctionDef, 

14 Goto, 

15 Label, 

16 Loop, 

17 ModelBreak, 

18 ModelContinue, 

19 ModelIf, 

20 ModelImport, 

21 ModelReturn, 

22 Module, 

23 Name, 

24 Operator, 

25 RecordDef, 

26 ScalarType, 

27 StructureType, 

28 SourceRef, 

29 VarType, 

30 Var, 

31 ValueConstructor, 

32) 

33# from skema.program_analysis.CAST2FN.visitors import ( 

34# CASTToAIRVisitor, 

35#) 

36from skema.model_assembly.air import AutoMATES_IR 

37from skema.model_assembly.networks import GroundedFunctionNetwork 

38from skema.model_assembly.structures import ( 

39 GenericContainer, 

40 GenericStmt, 

41 GenericIdentifier, 

42 GenericDefinition, 

43 TypeDefinition, 

44 VariableDefinition, 

45) 

46 

47CAST_NODES_TYPES_LIST = [ 

48 AstNode, 

49 Assignment, 

50 Attribute, 

51 Call, 

52 CASTLiteralValue, 

53 FunctionDef, 

54 Goto, 

55 Label, 

56 Loop, 

57 ModelBreak, 

58 ModelContinue, 

59 ModelIf, 

60 ModelImport, 

61 ModelReturn, 

62 Module, 

63 Name, 

64 Operator, 

65 RecordDef, 

66 ScalarType, 

67 StructureType, 

68 SourceRef, 

69 VarType, 

70 Var, 

71 ValueConstructor, 

72] 

73 

74 

75def compare_name_nodes(name1: Name, name2: Name) -> bool: 

76 """ 

77 Checks if two Name nodes are equal, by only looking at 

78 their `name` fields. The `__eq__` method on `Name` nodes generated 

79 by Swagger also checks the `id` attribute, which we do not expect 

80 to be consistent across CAST generations, since it is a UUID 

81 """ 

82 if not (isinstance(name1, Name) and isinstance(name2, Name)): 

83 return False 

84 return name1.name == name2.name 

85 

86 

87class CASTJsonException(Exception): 

88 """ 

89 Class used to represent exceptions encountered when encoding/decoding CAST json 

90 """ 

91 

92 pass 

93 

94 

95class CAST(object): 

96 """ 

97 Represents the Common Abstract Syntax Tree (CAST) that will be used to generically represent 

98 any languages AST. 

99 """ 

100 

101 nodes: typing.List[AstNode] 

102 cast_source_language: str 

103 

104 def __init__(self, nodes: typing.List[AstNode], cast_source_language: str): 

105 self.nodes = nodes 

106 self.cast_source_language = cast_source_language 

107 

108 def __eq__(self, other): 

109 """ 

110 For equality, the two CAST objects must have the same node data. 

111 When checking each node, we allow a custom node comparison function. 

112 Currently, the only case where we do a custom comparison is for Name nodes. 

113 For all other nodes we use their Swagger generated `__eq__` method 

114 """ 

115 if len(self.nodes) != len(other.nodes): 

116 return False 

117 

118 for i, node in enumerate(self.nodes): 

119 other_node = other.nodes[i] 

120 if isinstance(node, Name): 

121 comparator = compare_name_nodes 

122 else: 

123 comparator = lambda n1, n2: n1 == n2 

124 

125 if not comparator(node, other_node): 

126 # printing diff to help locating difference 

127 print(f"CAST __eq__ failed:") 

128 self_lines = str(node).splitlines() 

129 other_lines = str(other_node).splitlines() 

130 for i, diff in enumerate( 

131 difflib.ndiff(self_lines, other_lines) 

132 ): 

133 if diff[0] == " ": 

134 continue 

135 print(f"Line {i}: {diff}") 

136 return False 

137 

138 return True 

139 

140 def to_AGraph(self): 

141 G = nx.DiGraph() 

142 for node in self.nodes: 

143 print("node", node) 

144 print("type", type(node)) 

145 for ast_node in ast.walk(node.body): 

146 for child_node in ast_node.children: 

147 G.add_edge(ast_node, child_node) 

148 

149 import skema.utils.misc.test_pygraphviz 

150 

151 test_pygraphviz( 

152 "The to_AGraph method requires the pygraphviz package to be installed!" 

153 ) 

154 

155 A = nx.nx_agraph.to_agraph(G) 

156 A.graph_attr.update( 

157 {"dpi": 227, "fontsize": 20, "fontname": "Menlo", "rankdir": "TB"} 

158 ) 

159 A.node_attr.update({"fontname": "Menlo"}) 

160 return A 

161 

162 def to_air_dict(self): 

163 c2a_visitor = CASTToAIRVisitor(self.nodes, self.cast_source_language) 

164 air = c2a_visitor.to_air() 

165 

166 main_container = [ 

167 c["name"] 

168 for c in air["containers"] 

169 if c["name"].endswith("::main") 

170 ] 

171 

172 called_containers = [ 

173 s["function"]["name"] 

174 for c in air["containers"] 

175 for s in c["body"] 

176 if s["function"]["type"] == "container" 

177 ] 

178 root_containers = [ 

179 c["name"] 

180 for c in air["containers"] 

181 if c["name"] not in called_containers 

182 ] 

183 

184 container_id_to_start_from = None 

185 if len(main_container) > 0: 

186 container_id_to_start_from = main_container[0] 

187 elif len(root_containers) > 0: 

188 container_id_to_start_from = root_containers[0] 

189 else: 

190 # TODO 

191 raise Exception( 

192 "Error: Unable to find root container to build GrFN." 

193 ) 

194 

195 air["entrypoint"] = container_id_to_start_from 

196 

197 return air 

198 

199 def to_AIR(self): 

200 air = self.to_air_dict() 

201 

202 C, V, T, D = dict(), dict(), dict(), dict() 

203 

204 # Create variable definitions 

205 for var_data in air["variables"]: 

206 new_var = VariableDefinition.from_data(var_data) 

207 V[new_var.identifier] = new_var 

208 

209 # Create type definitions 

210 for type_data in air["types"]: 

211 new_type = TypeDefinition.from_dict(type_data) 

212 T[new_type.identifier] = new_type 

213 

214 # Create container definitions 

215 for con_data in air["containers"]: 

216 new_container = GenericContainer.from_dict(con_data) 

217 for in_var in new_container.arguments: 

218 if in_var not in V: 

219 V[in_var] = VariableDefinition.from_identifier(in_var) 

220 C[new_container.identifier] = new_container 

221 

222 return AutoMATES_IR( 

223 GenericIdentifier.from_str(air["entrypoint"]), C, V, T, [], [], [] 

224 ) 

225 

226 def to_GrFN(self): 

227 air = self.to_AIR() 

228 grfn = GroundedFunctionNetwork.from_AIR(air) 

229 return grfn 

230 

231 def write_cast_object(self, cast_value): 

232 if isinstance(cast_value, list): 

233 return [self.write_cast_object(val) for val in cast_value] 

234 elif not isinstance(cast_value, AstNode) and not isinstance( 

235 cast_value, SourceRef 

236 ): 

237 return cast_value 

238 

239 return dict( 

240 { 

241 attr: self.write_cast_object(getattr(cast_value, attr)) 

242 for attr in cast_value.attribute_map.keys() 

243 }, 

244 **{"node_type": type(cast_value).__name__}, 

245 ) 

246 

247 def to_json_object(self): 

248 """ 

249 Returns a json object of the CAST 

250 """ 

251 return {"nodes": [self.write_cast_object(n) for n in self.nodes]} 

252 

253 def to_json_str(self): 

254 """ 

255 Returns a json string of the CAST 

256 """ 

257 return json.dumps( 

258 self.to_json_object(), 

259 sort_keys=True, 

260 indent=4, 

261 ) 

262 

263 @classmethod 

264 def parse_cast_json(cls, data): 

265 if isinstance(data, list): 

266 # If we see a list parse each one of its elements 

267 return [cls.parse_cast_json(item) for item in data] 

268 elif data is None: 

269 return None 

270 elif isinstance(data, (float, int, str, bool)): 

271 # If we see a primitave type, simply return its value 

272 return data 

273 elif all( 

274 k in data for k in ("row_start", "row_end", "col_start", "col_end") 

275 ): 

276 return SourceRef( 

277 row_start=data["row_start"], 

278 row_end=data["row_end"], 

279 col_start=data["col_start"], 

280 col_end=data["col_end"], 

281 source_file_name=data["source_file_name"], 

282 ) 

283 

284 if "node_type" in data: 

285 # Create the object specified by "node_type" object with the values 

286 # from its children nodes 

287 for node_type in CAST_NODES_TYPES_LIST: 

288 

289 if node_type.__name__ == data["node_type"]: 

290 node_results = { 

291 k: cls.parse_cast_json(v) 

292 for k, v in data.items() 

293 if k != "node_type" 

294 } 

295 return node_type(**node_results) 

296 

297 raise CASTJsonException( 

298 f"Unable to decode json CAST field with field names: {set(data.keys())}" 

299 ) 

300 

301 @classmethod 

302 def from_json_data(cls, json_data, cast_source_language="unknown"): 

303 """ 

304 Parses json CAST data object and returns the created CAST object 

305 

306 Args: 

307 data: JSON object with a "nodes" field containing a 

308 list of the top level nodes 

309 

310 Returns: 

311 CAST: The parsed CAST object. 

312 """ 

313 nodes = cls.parse_cast_json(json_data["nodes"]) 

314 return cls(nodes, cast_source_language) 

315 

316 @classmethod 

317 def from_json_file(cls, json_filepath): 

318 """ 

319 Loads json CAST data from a file and returns the created CAST object 

320 

321 Args: 

322 json_filepath: string of a full filepath to a JSON file 

323 representing a CAST with a `nodes` field 

324 

325 Returns: 

326 CAST: The parsed CAST object. 

327 """ 

328 return cls.from_json_data(json.load(open(json_filepath, "r"))) 

329 

330 @classmethod 

331 def from_json_str(cls, json_str): 

332 """ 

333 Parses json CAST string and returns the created CAST object 

334 

335 Args: 

336 json_str: JSON string representing a CAST with a "nodes" field 

337 containing a list of the top level nodes 

338 

339 Raises: 

340 CASTJsonException: If we encounter an unknown CAST node 

341 

342 Returns: 

343 CAST: The parsed CAST object. 

344 """ 

345 return cls.from_json_data(json.loads(json_str)) 

346 

347 @classmethod 

348 def from_python_ast(cls): 

349 pass