Coverage for skema/program_analysis/CAST2FN/visitors/cast_function_call_visitor.py: 0%

79 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import typing 

2from functools import singledispatchmethod 

3 

4from skema.program_analysis.CAST2FN.visitors.cast_visitor import CASTVisitor 

5 

6from skema.program_analysis.CAST2FN.model.cast import ( 

7 AstNode, 

8 Assignment, 

9 Attribute, 

10 Call, 

11 FunctionDef, 

12 Loop, 

13 ModelBreak, 

14 ModelContinue, 

15 ModelIf, 

16 ModelReturn, 

17 Module, 

18 Name, 

19 RecordDef, 

20 Var, 

21) 

22 

23 

24def flatten(l): 

25 for el in l: 

26 if isinstance(el, typing.Iterable) and not isinstance( 

27 el, (str, bytes) 

28 ): 

29 yield from flatten(el) 

30 else: 

31 yield el 

32 

33 

34def get_function_visit_order(cast): 

35 visitor = CASTFunctionCallVisitor() 

36 calls = visitor.visit(cast) 

37 

38 roots = list() 

39 for k in calls.keys(): 

40 found = False 

41 for v in calls.values(): 

42 if k in v: 

43 found = True 

44 break 

45 if not found: 

46 roots.append(k) 

47 

48 order = list() 

49 

50 def get_order(name, calls_list): 

51 if name not in calls_list: 

52 return 

53 for call in calls_list[name]: 

54 get_order(call, calls_list) 

55 order.append(name) 

56 

57 for root in roots: 

58 get_order(root, calls) 

59 

60 return order 

61 

62 

63class CASTTypeError(TypeError): 

64 """Used to create errors in the CASTToAGraphVisitor, in particular 

65 when the visitor encounters some value that it wasn't expecting. 

66 

67 Args: 

68 Exception: An exception that occurred during execution. 

69 """ 

70 

71 pass 

72 

73 

74class CASTFunctionCallVisitor(CASTVisitor): 

75 @singledispatchmethod 

76 def visit(self, node: AstNode): 

77 """Generic visitor for unimplemented/unexpected nodes""" 

78 raise CASTTypeError(f"Unrecognized node type: {type(node)}") 

79 

80 @visit.register 

81 def _(self, node: list): 

82 return self.visit_list(node) 

83 

84 @visit.register 

85 def _(self, node: Assignment): 

86 return self.visit(node.left) + self.visit(node.right) 

87 

88 @visit.register 

89 def _(self, node: Attribute): 

90 return self.visit(node.value) + self.visit(node.attr) 

91 

92 # @visit.register 

93 # def _(self, node: BinaryOp): 

94 # return self.visit(node.left) + self.visit(node.right) 

95 

96 """ 

97 @visit.register 

98 def _(self, node: Boolean): 

99 return [] 

100 """ 

101 

102 @visit.register 

103 def _(self, node: Call): 

104 return [node.func.name] + self.visit(node.arguments) 

105 

106 @visit.register 

107 def _(self, node: RecordDef): 

108 # Fields should not have function calles 

109 return self.visit(node.funcs) 

110 

111 """ 

112 @visit.register 

113 def _(self, node: Dict): 

114 return self.visit(node.keys) + self.visit(node.values) 

115 """ 

116 

117 """ 

118 @visit.register 

119 def _(self, node: Expr): 

120 return self.visit(node.expr) 

121 """ 

122 

123 @visit.register 

124 def _(self, node: FunctionDef): 

125 return (node.name, set(flatten(self.visit(node.body)))) 

126 

127 """ 

128 @visit.register 

129 def _(self, node: List): 

130 return self.visit(node.values) 

131 """ 

132 

133 @visit.register 

134 def _(self, node: Loop): 

135 return self.visit(node.expr) + self.visit(node.body) 

136 

137 @visit.register 

138 def _(self, node: ModelBreak): 

139 return [] 

140 

141 @visit.register 

142 def _(self, node: ModelContinue): 

143 return [] 

144 

145 @visit.register 

146 def _(self, node: ModelIf): 

147 return ( 

148 self.visit(node.expr) 

149 + self.visit(node.body) 

150 + self.visit(node.orelse) 

151 ) 

152 

153 @visit.register 

154 def _(self, node: ModelReturn): 

155 return self.visit(node.value) 

156 

157 @visit.register 

158 def _(self, node: Module): 

159 return { 

160 item[0]: item[1] 

161 for item in self.visit(node.body) 

162 if len(item) == 2 

163 } 

164 

165 @visit.register 

166 def _(self, node: Name): 

167 return [] 

168 

169 # @visit.register 

170 # def _(self, node: UnaryOp): 

171 # return self.visit(node.value) 

172 

173 @visit.register 

174 def _(self, node: Var): 

175 return []