Coverage for skema/program_analysis/CAST2FN/visitors/cast_function_call_visitor.py: 0%
79 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import typing
2from functools import singledispatchmethod
4from skema.program_analysis.CAST2FN.visitors.cast_visitor import CASTVisitor
6from skema.program_analysis.CAST2FN.model.cast import (
7 AstNode,
8 Assignment,
9 Attribute,
10 Call,
11 FunctionDef,
12 Loop,
13 ModelBreak,
14 ModelContinue,
15 ModelIf,
16 ModelReturn,
17 Module,
18 Name,
19 RecordDef,
20 Var,
21)
24def flatten(l):
25 for el in l:
26 if isinstance(el, typing.Iterable) and not isinstance(
27 el, (str, bytes)
28 ):
29 yield from flatten(el)
30 else:
31 yield el
34def get_function_visit_order(cast):
35 visitor = CASTFunctionCallVisitor()
36 calls = visitor.visit(cast)
38 roots = list()
39 for k in calls.keys():
40 found = False
41 for v in calls.values():
42 if k in v:
43 found = True
44 break
45 if not found:
46 roots.append(k)
48 order = list()
50 def get_order(name, calls_list):
51 if name not in calls_list:
52 return
53 for call in calls_list[name]:
54 get_order(call, calls_list)
55 order.append(name)
57 for root in roots:
58 get_order(root, calls)
60 return order
63class CASTTypeError(TypeError):
64 """Used to create errors in the CASTToAGraphVisitor, in particular
65 when the visitor encounters some value that it wasn't expecting.
67 Args:
68 Exception: An exception that occurred during execution.
69 """
71 pass
74class CASTFunctionCallVisitor(CASTVisitor):
75 @singledispatchmethod
76 def visit(self, node: AstNode):
77 """Generic visitor for unimplemented/unexpected nodes"""
78 raise CASTTypeError(f"Unrecognized node type: {type(node)}")
80 @visit.register
81 def _(self, node: list):
82 return self.visit_list(node)
84 @visit.register
85 def _(self, node: Assignment):
86 return self.visit(node.left) + self.visit(node.right)
88 @visit.register
89 def _(self, node: Attribute):
90 return self.visit(node.value) + self.visit(node.attr)
92 # @visit.register
93 # def _(self, node: BinaryOp):
94 # return self.visit(node.left) + self.visit(node.right)
96 """
97 @visit.register
98 def _(self, node: Boolean):
99 return []
100 """
102 @visit.register
103 def _(self, node: Call):
104 return [node.func.name] + self.visit(node.arguments)
106 @visit.register
107 def _(self, node: RecordDef):
108 # Fields should not have function calles
109 return self.visit(node.funcs)
111 """
112 @visit.register
113 def _(self, node: Dict):
114 return self.visit(node.keys) + self.visit(node.values)
115 """
117 """
118 @visit.register
119 def _(self, node: Expr):
120 return self.visit(node.expr)
121 """
123 @visit.register
124 def _(self, node: FunctionDef):
125 return (node.name, set(flatten(self.visit(node.body))))
127 """
128 @visit.register
129 def _(self, node: List):
130 return self.visit(node.values)
131 """
133 @visit.register
134 def _(self, node: Loop):
135 return self.visit(node.expr) + self.visit(node.body)
137 @visit.register
138 def _(self, node: ModelBreak):
139 return []
141 @visit.register
142 def _(self, node: ModelContinue):
143 return []
145 @visit.register
146 def _(self, node: ModelIf):
147 return (
148 self.visit(node.expr)
149 + self.visit(node.body)
150 + self.visit(node.orelse)
151 )
153 @visit.register
154 def _(self, node: ModelReturn):
155 return self.visit(node.value)
157 @visit.register
158 def _(self, node: Module):
159 return {
160 item[0]: item[1]
161 for item in self.visit(node.body)
162 if len(item) == 2
163 }
165 @visit.register
166 def _(self, node: Name):
167 return []
169 # @visit.register
170 # def _(self, node: UnaryOp):
171 # return self.visit(node.value)
173 @visit.register
174 def _(self, node: Var):
175 return []