Coverage for skema/program_analysis/CAST/fortran/variable_context.py: 76%
95 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1from typing import List, Dict, Set
2from skema.program_analysis.CAST2FN.model.cast import (
3 Var,
4 Name,
5 FunctionDef
6)
8class VariableContext(object):
9 def __init__(self):
10 self.context = [{}] # Stack of context dictionaries
11 self.context_return_values = [set()] # Stack of context return values
13 # All symbols will use a seperate naming convention to prevent two scopes using the same symbol name
14 # The name will be a dot notation list of scopes i.e. scope1.scope2.symbol
15 self.all_symbols_scopes = []
16 self.all_symbols = {}
18 # The prefix is used to handle adding Record types to the variable context.
19 # This gives each symbol a unqique name. For example "a" would become "type_name.a"
20 # For nested type definitions (derived type in a module), multiple prefixes can be added.
21 self.prefix = []
23 # Flag neccessary to declare if a function is internal or external
24 self.internal = False
26 self.variable_id = 0
27 self.iterator_id = 0
28 self.stop_condition_id = 0
29 self.function_name_id = 0
31 self.class_functions = {"_class": {"function": FunctionDef()}}
33 def push_context(self):
34 """Create a new variable context and add it to the stack"""
36 # TODO: Could this add unwanted variables to the context or overwrite existing variables
37 # If the internal flag is set, then all new scopes will use the top-level context
38 if self.internal:
39 return None
41 self.context.append({})
42 self.context_return_values.append(set())
44 def pop_context(self):
45 """Pop the current variable context off of the stack and remove any references to those symbols."""
47 # If the internal flag is set, then all new scopes will use the top-level context
48 if self.internal:
49 return None
51 context = self.context.pop()
53 # Remove symbols from all_symbols variable
54 for symbol in context:
55 if isinstance(self.all_symbols[symbol], List):
56 self.all_symbols[symbol].pop()
57 else:
58 self.all_symbols.pop(symbol)
60 self.context_return_values.pop()
62 def add_variable(self, symbol: str, type: str, source_refs: List) -> Name:
63 """Add a variable to the current variable context"""
64 # Generate the full symbol name using the prefix
65 full_symbol_name = ".".join(self.prefix + [symbol])
67 cast_name = Name(source_refs=source_refs)
68 cast_name.name = symbol
69 cast_name.id = self.variable_id
71 # Update variable id
72 self.variable_id += 1
74 # Add the node to the variable context
75 self.context[-1][full_symbol_name] = {
76 "node": cast_name,
77 "type": type,
78 }
80 # Add reference to all_symbols
81 if full_symbol_name in self.all_symbols:
82 if isinstance(self.all_symbols[full_symbol_name], List):
83 self.all_symbols[full_symbol_name].append(self.context[-1][full_symbol_name])
84 else:
85 self.all_symbols[full_symbol_name] = [self.all_symbols[full_symbol_name], self.context[-1][full_symbol_name]]
86 else:
87 self.all_symbols[full_symbol_name] = self.context[-1][full_symbol_name]
88 return cast_name
90 def is_variable(self, symbol: str) -> bool:
91 """Check if a symbol exists in any context"""
92 return symbol in self.all_symbols
94 def get_node(self, symbol: str) -> Dict:
95 if isinstance(self.all_symbols[symbol], List):
96 return self.all_symbols[symbol][-1]["node"]
98 return self.all_symbols[symbol]["node"]
100 def get_type(self, symbol: str) -> str:
101 if isinstance(self.all_symbols[symbol], List):
102 return self.all_symbols[symbol][-1]["type"]
104 return self.all_symbols[symbol]["type"]
106 def update_type(self, symbol: str, type: str):
107 """Update the type associated with a given symbol"""
108 # Generate the full symbol name using the prefix
109 full_symbol_name = ".".join(self.prefix + [symbol])
110 if isinstance(self.all_symbols[full_symbol_name], List):
111 self.all_symbols[full_symbol_name][-1]["type"] = type
112 else:
113 self.all_symbols[full_symbol_name]["type"] = type
115 def add_return_value(self, symbol):
116 self.context_return_values[-1].add(symbol)
118 def remove_return_value(self, symbol):
119 self.context_return_values[-1].discard(symbol)
121 def generate_iterator(self):
122 symbol = f"generated_iter_{self.iterator_id}"
123 self.iterator_id += 1
125 return self.add_variable(symbol, "iterator", None)
127 def generate_stop_condition(self):
128 symbol = f"sc_{self.stop_condition_id}"
129 self.stop_condition_id += 1
131 return self.add_variable(symbol, "boolean", None)
133 def generate_func(self, name):
134 symbol = f"{name}_{self.function_name_id}"
135 self.function_name_id += 1
137 return self.add_variable(symbol, "unknown", None)
139 def enter_record_definition(self, name: str):
140 """Enter a record definition. Updates the prefix to the name of the record"""
141 self.prefix.append(name)
143 def exit_record_definition(self):
144 """Exit a record definition. Resets the prefix to the empty string"""
145 self.prefix.pop()
147 def set_internal(self):
148 '''Set the internal flag, meaning, all '''
149 self.internal = True
151 def unset_internal(self):
152 self.internal = False
154 def register_module_function(self, function: str):
155 # Fortran variables are case INSENSITIVE so we should lower it first
156 function = function.lower()
157 function_def = FunctionDef(
158 name=Name(
159 name="",
160 id=-1,
161 source_refs=[]
162 ),
163 func_args=[],
164 body=[],
165 source_refs=[]
166 )
167 self.class_functions[function] = function_def
169 return function_def
172 def is_class_function(self, function: str):
173 function = function.lower()
174 return function in self.class_functions
176 def copy_class_function(self, function: str, function_def: FunctionDef ):
177 function = function.lower()
178 self.class_functions[function].name = function_def.name
179 self.class_functions[function].func_args = function_def.func_args
180 self.class_functions[function].body = function_def.body
181 self.class_functions[function].source_refs = function_def.source_refs