Coverage for skema/program_analysis/CAST2FN/ann_cast/cast_to_annotated_cast.py: 95%

117 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1from functools import singledispatchmethod 

2import typing 

3 

4from skema.program_analysis.CAST2FN.cast import CAST 

5 

6from skema.program_analysis.CAST2FN.model.cast import ( 

7 Assignment, 

8 AstNode, 

9 Attribute, 

10 Call, 

11 FunctionDef, 

12 Goto, 

13 Label, 

14 CASTLiteralValue, 

15 Loop, 

16 ModelBreak, 

17 ModelContinue, 

18 ModelIf, 

19 ModelReturn, 

20 Module, 

21 Name, 

22 RecordDef, 

23 ScalarType, 

24 Var, 

25) 

26 

27from skema.program_analysis.CAST2FN.ann_cast.annotated_cast import * 

28from skema.program_analysis.CAST2FN.model.cast.structure_type import StructureType 

29 

30 

31class CASTTypeError(TypeError): 

32 """Used to create errors in the visitor, in particular 

33 when the visitor encounters some value that it wasn't expecting. 

34 

35 Args: 

36 Exception: An exception that occurred during execution. 

37 """ 

38 

39 

40class CastToAnnotatedCastVisitor: 

41 """ 

42 class CastToAnnotatedCastVisitor - A visitor that traverses CAST nodes 

43 and generates an annotated cast version of the CAST. 

44 

45 The AnnCastNodes have additional attributes (fields) that are used 

46 in a later pass to maintain scoping information for GrFN containers. 

47 """ 

48 

49 def __init__(self, cast: CAST): 

50 self.cast = cast 

51 

52 def visit_node_list(self, node_list: typing.List[AstNode]): 

53 return [self.visit(node) for node in node_list] 

54 

55 def generate_annotated_cast(self, grfn_2_2: bool = False): 

56 nodes = self.cast.nodes 

57 

58 annotated_cast = [] 

59 for node in nodes: 

60 annotated_cast.append(self.visit(node)) 

61 

62 return PipelineState(annotated_cast, grfn_2_2) 

63 

64 def visit(self, node: AstNode) -> AnnCastNode: 

65 # print current node being visited. 

66 # this can be useful for debugging 

67 # class_name = node.__class__.__name__ 

68 # print(f"\nProcessing node type {class_name}") 

69 return self._visit(node) 

70 

71 @singledispatchmethod 

72 def _visit(self, node: AstNode): 

73 raise NameError(f"Unrecognized node type: {type(node)}") 

74 

75 @_visit.register 

76 def visit_assignment(self, node: Assignment): 

77 left = self.visit(node.left) 

78 right = self.visit(node.right) 

79 return AnnCastAssignment(left, right, node.source_refs) 

80 

81 @_visit.register 

82 def visit_attribute(self, node: Attribute): 

83 value = self.visit(node.value) 

84 attr = self.visit(node.attr) 

85 return AnnCastAttribute(value, attr, node.source_refs) 

86 

87 @_visit.register 

88 def visit_operator(self, node: Operator): 

89 operands = self.visit_node_list(node.operands) 

90 return AnnCastOperator(node.source_language, node.interpreter, node.version, node.op, operands, node.source_refs) 

91 

92 @_visit.register 

93 def visit_call(self, node: Call): 

94 func = self.visit(node.func) 

95 arguments = self.visit_node_list(node.arguments) 

96 

97 return AnnCastCall(func, arguments, node.source_refs) 

98 

99 @_visit.register 

100 def visit_record_def(self, node: RecordDef): 

101 bases = self.visit_node_list(node.bases) 

102 funcs = self.visit_node_list(node.funcs) 

103 fields = self.visit_node_list(node.fields) 

104 return AnnCastRecordDef( 

105 node.name, bases, funcs, fields, node.source_refs 

106 ) 

107 

108 @_visit.register 

109 def visit_function_def(self, node: FunctionDef): 

110 name = node.name 

111 args = self.visit_node_list(node.func_args) 

112 body = self.visit_node_list(node.body) 

113 return AnnCastFunctionDef(name, args, body, node.source_refs) 

114 

115 @_visit.register 

116 def visit_goto(self, node: Goto): 

117 expr = self.visit(node.expr) if node.expr != None else None 

118 label = node.label 

119 return AnnCastGoto(expr, label, node.source_refs) 

120 

121 @_visit.register 

122 def visit_label(self, node: Label): 

123 label = node.label 

124 return AnnCastLabel(label, node.source_refs) 

125 

126 @_visit.register 

127 def visit_literal_value(self, node: CASTLiteralValue): 

128 if node.value_type == "List[Any]": 

129 node.value.size = self.visit( 

130 node.value.size 

131 ) # Turns the cast var into annCast 

132 node.value.initial_value = self.visit( 

133 node.value.initial_value 

134 ) # Turns the literalValue into annCast 

135 return AnnCastLiteralValue( 

136 node.value_type, 

137 node.value, 

138 node.source_code_data_type, 

139 node.source_refs, 

140 ) 

141 elif node.value_type == StructureType.TUPLE: 

142 values = self.visit_node_list(node.value) 

143 return AnnCastLiteralValue( 

144 node.value_type, 

145 values, 

146 node.source_code_data_type, 

147 node.source_refs, 

148 ) 

149 return AnnCastLiteralValue( 

150 node.value_type, 

151 node.value, 

152 node.source_code_data_type, 

153 node.source_refs, 

154 ) 

155 

156 @_visit.register 

157 def visit_loop(self, node: Loop): 

158 if node.pre != None: 

159 pre = self.visit_node_list(node.pre) 

160 else: 

161 pre = [] 

162 if node.post != None and len(node.post) > 0: 

163 post = self.visit_node_list(node.post) 

164 else: 

165 post = [] 

166 expr = self.visit(node.expr) 

167 body = self.visit_node_list(node.body) 

168 return AnnCastLoop(pre, expr, body, post, node.source_refs) 

169 

170 @_visit.register 

171 def visit_model_break(self, node: ModelBreak): 

172 return AnnCastModelBreak(node.source_refs) 

173 

174 @_visit.register 

175 def visit_model_continue(self, node: ModelContinue): 

176 return AnnCastModelContinue(node) 

177 

178 @_visit.register 

179 def visit_model_import(self, node: ModelImport): 

180 return AnnCastModelImport(node) 

181 

182 @_visit.register 

183 def visit_model_if(self, node: ModelIf): 

184 expr = self.visit(node.expr) 

185 body = self.visit_node_list(node.body) 

186 orelse = self.visit_node_list(node.orelse) 

187 return AnnCastModelIf(expr, body, orelse, node.source_refs) 

188 

189 @_visit.register 

190 def visit_model_return(self, node: ModelReturn): 

191 value = self.visit(node.value) 

192 return AnnCastModelReturn(value, node.source_refs) 

193 

194 @_visit.register 

195 def visit_module(self, node: Module): 

196 body = self.visit_node_list(node.body) 

197 return AnnCastModule(node.name, body, node.source_refs) 

198 

199 @_visit.register 

200 def visit_name(self, node: Name): 

201 return AnnCastName(node.name, node.id, node.source_refs) 

202 

203 @_visit.register 

204 def visit_var(self, node: Var): 

205 val = self.visit(node.val) 

206 if node.default_value != None: 

207 default_value = self.visit(node.default_value) 

208 else: 

209 default_value = None 

210 return AnnCastVar(val, node.type, default_value, node.source_refs)