Coverage for skema/program_analysis/CAST/fortran/variable_context.py: 76%

95 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1from typing import List, Dict, Set 

2from skema.program_analysis.CAST2FN.model.cast import ( 

3 Var, 

4 Name, 

5 FunctionDef 

6) 

7 

8class VariableContext(object): 

9 def __init__(self): 

10 self.context = [{}] # Stack of context dictionaries 

11 self.context_return_values = [set()] # Stack of context return values 

12 

13 # All symbols will use a seperate naming convention to prevent two scopes using the same symbol name 

14 # The name will be a dot notation list of scopes i.e. scope1.scope2.symbol 

15 self.all_symbols_scopes = [] 

16 self.all_symbols = {} 

17 

18 # The prefix is used to handle adding Record types to the variable context. 

19 # This gives each symbol a unqique name. For example "a" would become "type_name.a" 

20 # For nested type definitions (derived type in a module), multiple prefixes can be added. 

21 self.prefix = [] 

22 

23 # Flag neccessary to declare if a function is internal or external 

24 self.internal = False 

25 

26 self.variable_id = 0 

27 self.iterator_id = 0 

28 self.stop_condition_id = 0 

29 self.function_name_id = 0 

30 

31 self.class_functions = {"_class": {"function": FunctionDef()}} 

32 

33 def push_context(self): 

34 """Create a new variable context and add it to the stack""" 

35 

36 # TODO: Could this add unwanted variables to the context or overwrite existing variables 

37 # If the internal flag is set, then all new scopes will use the top-level context 

38 if self.internal: 

39 return None 

40 

41 self.context.append({}) 

42 self.context_return_values.append(set()) 

43 

44 def pop_context(self): 

45 """Pop the current variable context off of the stack and remove any references to those symbols.""" 

46 

47 # If the internal flag is set, then all new scopes will use the top-level context 

48 if self.internal: 

49 return None 

50 

51 context = self.context.pop() 

52 

53 # Remove symbols from all_symbols variable 

54 for symbol in context: 

55 if isinstance(self.all_symbols[symbol], List): 

56 self.all_symbols[symbol].pop() 

57 else: 

58 self.all_symbols.pop(symbol) 

59 

60 self.context_return_values.pop() 

61 

62 def add_variable(self, symbol: str, type: str, source_refs: List) -> Name: 

63 """Add a variable to the current variable context""" 

64 # Generate the full symbol name using the prefix 

65 full_symbol_name = ".".join(self.prefix + [symbol]) 

66 

67 cast_name = Name(source_refs=source_refs) 

68 cast_name.name = symbol 

69 cast_name.id = self.variable_id 

70 

71 # Update variable id 

72 self.variable_id += 1 

73 

74 # Add the node to the variable context 

75 self.context[-1][full_symbol_name] = { 

76 "node": cast_name, 

77 "type": type, 

78 } 

79 

80 # Add reference to all_symbols 

81 if full_symbol_name in self.all_symbols: 

82 if isinstance(self.all_symbols[full_symbol_name], List): 

83 self.all_symbols[full_symbol_name].append(self.context[-1][full_symbol_name]) 

84 else: 

85 self.all_symbols[full_symbol_name] = [self.all_symbols[full_symbol_name], self.context[-1][full_symbol_name]] 

86 else: 

87 self.all_symbols[full_symbol_name] = self.context[-1][full_symbol_name] 

88 return cast_name 

89 

90 def is_variable(self, symbol: str) -> bool: 

91 """Check if a symbol exists in any context""" 

92 return symbol in self.all_symbols 

93 

94 def get_node(self, symbol: str) -> Dict: 

95 if isinstance(self.all_symbols[symbol], List): 

96 return self.all_symbols[symbol][-1]["node"] 

97 

98 return self.all_symbols[symbol]["node"] 

99 

100 def get_type(self, symbol: str) -> str: 

101 if isinstance(self.all_symbols[symbol], List): 

102 return self.all_symbols[symbol][-1]["type"] 

103 

104 return self.all_symbols[symbol]["type"] 

105 

106 def update_type(self, symbol: str, type: str): 

107 """Update the type associated with a given symbol""" 

108 # Generate the full symbol name using the prefix 

109 full_symbol_name = ".".join(self.prefix + [symbol]) 

110 if isinstance(self.all_symbols[full_symbol_name], List): 

111 self.all_symbols[full_symbol_name][-1]["type"] = type 

112 else: 

113 self.all_symbols[full_symbol_name]["type"] = type 

114 

115 def add_return_value(self, symbol): 

116 self.context_return_values[-1].add(symbol) 

117 

118 def remove_return_value(self, symbol): 

119 self.context_return_values[-1].discard(symbol) 

120 

121 def generate_iterator(self): 

122 symbol = f"generated_iter_{self.iterator_id}" 

123 self.iterator_id += 1 

124 

125 return self.add_variable(symbol, "iterator", None) 

126 

127 def generate_stop_condition(self): 

128 symbol = f"sc_{self.stop_condition_id}" 

129 self.stop_condition_id += 1 

130 

131 return self.add_variable(symbol, "boolean", None) 

132 

133 def generate_func(self, name): 

134 symbol = f"{name}_{self.function_name_id}" 

135 self.function_name_id += 1 

136 

137 return self.add_variable(symbol, "unknown", None) 

138 

139 def enter_record_definition(self, name: str): 

140 """Enter a record definition. Updates the prefix to the name of the record""" 

141 self.prefix.append(name) 

142 

143 def exit_record_definition(self): 

144 """Exit a record definition. Resets the prefix to the empty string""" 

145 self.prefix.pop() 

146 

147 def set_internal(self): 

148 '''Set the internal flag, meaning, all ''' 

149 self.internal = True 

150 

151 def unset_internal(self): 

152 self.internal = False 

153 

154 def register_module_function(self, function: str): 

155 # Fortran variables are case INSENSITIVE so we should lower it first 

156 function = function.lower() 

157 function_def = FunctionDef( 

158 name=Name( 

159 name="", 

160 id=-1, 

161 source_refs=[] 

162 ), 

163 func_args=[], 

164 body=[], 

165 source_refs=[] 

166 ) 

167 self.class_functions[function] = function_def 

168 

169 return function_def 

170 

171 

172 def is_class_function(self, function: str): 

173 function = function.lower() 

174 return function in self.class_functions 

175 

176 def copy_class_function(self, function: str, function_def: FunctionDef ): 

177 function = function.lower() 

178 self.class_functions[function].name = function_def.name 

179 self.class_functions[function].func_args = function_def.func_args 

180 self.class_functions[function].body = function_def.body 

181 self.class_functions[function].source_refs = function_def.source_refs 

182