Coverage for skema/program_analysis/CAST2FN/ann_cast/cast_to_annotated_cast.py: 95%
117 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1from functools import singledispatchmethod
2import typing
4from skema.program_analysis.CAST2FN.cast import CAST
6from skema.program_analysis.CAST2FN.model.cast import (
7 Assignment,
8 AstNode,
9 Attribute,
10 Call,
11 FunctionDef,
12 Goto,
13 Label,
14 CASTLiteralValue,
15 Loop,
16 ModelBreak,
17 ModelContinue,
18 ModelIf,
19 ModelReturn,
20 Module,
21 Name,
22 RecordDef,
23 ScalarType,
24 Var,
25)
27from skema.program_analysis.CAST2FN.ann_cast.annotated_cast import *
28from skema.program_analysis.CAST2FN.model.cast.structure_type import StructureType
31class CASTTypeError(TypeError):
32 """Used to create errors in the visitor, in particular
33 when the visitor encounters some value that it wasn't expecting.
35 Args:
36 Exception: An exception that occurred during execution.
37 """
40class CastToAnnotatedCastVisitor:
41 """
42 class CastToAnnotatedCastVisitor - A visitor that traverses CAST nodes
43 and generates an annotated cast version of the CAST.
45 The AnnCastNodes have additional attributes (fields) that are used
46 in a later pass to maintain scoping information for GrFN containers.
47 """
49 def __init__(self, cast: CAST):
50 self.cast = cast
52 def visit_node_list(self, node_list: typing.List[AstNode]):
53 return [self.visit(node) for node in node_list]
55 def generate_annotated_cast(self, grfn_2_2: bool = False):
56 nodes = self.cast.nodes
58 annotated_cast = []
59 for node in nodes:
60 annotated_cast.append(self.visit(node))
62 return PipelineState(annotated_cast, grfn_2_2)
64 def visit(self, node: AstNode) -> AnnCastNode:
65 # print current node being visited.
66 # this can be useful for debugging
67 # class_name = node.__class__.__name__
68 # print(f"\nProcessing node type {class_name}")
69 return self._visit(node)
71 @singledispatchmethod
72 def _visit(self, node: AstNode):
73 raise NameError(f"Unrecognized node type: {type(node)}")
75 @_visit.register
76 def visit_assignment(self, node: Assignment):
77 left = self.visit(node.left)
78 right = self.visit(node.right)
79 return AnnCastAssignment(left, right, node.source_refs)
81 @_visit.register
82 def visit_attribute(self, node: Attribute):
83 value = self.visit(node.value)
84 attr = self.visit(node.attr)
85 return AnnCastAttribute(value, attr, node.source_refs)
87 @_visit.register
88 def visit_operator(self, node: Operator):
89 operands = self.visit_node_list(node.operands)
90 return AnnCastOperator(node.source_language, node.interpreter, node.version, node.op, operands, node.source_refs)
92 @_visit.register
93 def visit_call(self, node: Call):
94 func = self.visit(node.func)
95 arguments = self.visit_node_list(node.arguments)
97 return AnnCastCall(func, arguments, node.source_refs)
99 @_visit.register
100 def visit_record_def(self, node: RecordDef):
101 bases = self.visit_node_list(node.bases)
102 funcs = self.visit_node_list(node.funcs)
103 fields = self.visit_node_list(node.fields)
104 return AnnCastRecordDef(
105 node.name, bases, funcs, fields, node.source_refs
106 )
108 @_visit.register
109 def visit_function_def(self, node: FunctionDef):
110 name = node.name
111 args = self.visit_node_list(node.func_args)
112 body = self.visit_node_list(node.body)
113 return AnnCastFunctionDef(name, args, body, node.source_refs)
115 @_visit.register
116 def visit_goto(self, node: Goto):
117 expr = self.visit(node.expr) if node.expr != None else None
118 label = node.label
119 return AnnCastGoto(expr, label, node.source_refs)
121 @_visit.register
122 def visit_label(self, node: Label):
123 label = node.label
124 return AnnCastLabel(label, node.source_refs)
126 @_visit.register
127 def visit_literal_value(self, node: CASTLiteralValue):
128 if node.value_type == "List[Any]":
129 node.value.size = self.visit(
130 node.value.size
131 ) # Turns the cast var into annCast
132 node.value.initial_value = self.visit(
133 node.value.initial_value
134 ) # Turns the literalValue into annCast
135 return AnnCastLiteralValue(
136 node.value_type,
137 node.value,
138 node.source_code_data_type,
139 node.source_refs,
140 )
141 elif node.value_type == StructureType.TUPLE:
142 values = self.visit_node_list(node.value)
143 return AnnCastLiteralValue(
144 node.value_type,
145 values,
146 node.source_code_data_type,
147 node.source_refs,
148 )
149 return AnnCastLiteralValue(
150 node.value_type,
151 node.value,
152 node.source_code_data_type,
153 node.source_refs,
154 )
156 @_visit.register
157 def visit_loop(self, node: Loop):
158 if node.pre != None:
159 pre = self.visit_node_list(node.pre)
160 else:
161 pre = []
162 if node.post != None and len(node.post) > 0:
163 post = self.visit_node_list(node.post)
164 else:
165 post = []
166 expr = self.visit(node.expr)
167 body = self.visit_node_list(node.body)
168 return AnnCastLoop(pre, expr, body, post, node.source_refs)
170 @_visit.register
171 def visit_model_break(self, node: ModelBreak):
172 return AnnCastModelBreak(node.source_refs)
174 @_visit.register
175 def visit_model_continue(self, node: ModelContinue):
176 return AnnCastModelContinue(node)
178 @_visit.register
179 def visit_model_import(self, node: ModelImport):
180 return AnnCastModelImport(node)
182 @_visit.register
183 def visit_model_if(self, node: ModelIf):
184 expr = self.visit(node.expr)
185 body = self.visit_node_list(node.body)
186 orelse = self.visit_node_list(node.orelse)
187 return AnnCastModelIf(expr, body, orelse, node.source_refs)
189 @_visit.register
190 def visit_model_return(self, node: ModelReturn):
191 value = self.visit(node.value)
192 return AnnCastModelReturn(value, node.source_refs)
194 @_visit.register
195 def visit_module(self, node: Module):
196 body = self.visit_node_list(node.body)
197 return AnnCastModule(node.name, body, node.source_refs)
199 @_visit.register
200 def visit_name(self, node: Name):
201 return AnnCastName(node.name, node.id, node.source_refs)
203 @_visit.register
204 def visit_var(self, node: Var):
205 val = self.visit(node.val)
206 if node.default_value != None:
207 default_value = self.visit(node.default_value)
208 else:
209 default_value = None
210 return AnnCastVar(val, node.type, default_value, node.source_refs)