Coverage for skema/rest/tests/test_model_to_amr.py: 83%

163 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import json 

2import httpx 

3import pytest 

4import asyncio 

5import requests 

6from pathlib import Path 

7from zipfile import ZipFile 

8from io import BytesIO 

9 

10from fastapi import UploadFile 

11 

12from skema.rest.workflows import ( 

13 llm_assisted_codebase_to_pn_amr, 

14 code_snippets_to_pn_amr, 

15 lx_equations_to_amr, 

16 equation_to_amrs 

17) 

18from skema.rest import schema 

19from skema.rest.llm_proxy import Dynamics 

20from skema.rest.proxies import SKEMA_RS_ADDESS 

21from skema.skema_py import server as code2fn 

22from skema.data.program_analysis import MODEL_ZIP_ROOT_PATH 

23 

24 

25CHIME_SIR_PATH = MODEL_ZIP_ROOT_PATH.resolve() / "CHIME-SIR-model.zip" 

26SIDARTHE_PATH = MODEL_ZIP_ROOT_PATH.resolve() / "SIDARTHE.zip" 

27 

28 

29@pytest.mark.asyncio 

30async def test_any_amr_chime_sir(): 

31 """ 

32 Unit test for checking that Chime-SIR model produces any AMR. This test zip contains 4 versions of CHIME SIR. 

33 This will test if just the core dynamics works, the whole script, and also rewritten scripts work.  

34 """ 

35 

36 zip_bytes = BytesIO(CHIME_SIR_PATH.read_bytes()) 

37 

38 # NOTE: For CI we are unable to use the LLM assisted functions due to API keys 

39 # So, we will instead mock the output for those functions instead 

40 dyn1 = Dynamics(name="CHIME_SIR-old.py", description=None, block=["L21-L31"]) 

41 dyn2 = Dynamics(name="CHIME_SIR.py", description=None, block=["L101-L121"]) 

42 dyn3 = Dynamics(name="CHIME_SIR_core.py", description=None, block=["L1-L9"]) 

43 dyn4 = Dynamics(name="CHIME_SIR_while_loop.py", description=None, block=["L161-L201"]) 

44 llm_mock_output = [dyn1, dyn2, dyn3, dyn4] 

45 

46 line_begin = [] 

47 import_begin = [] 

48 line_end = [] 

49 import_end = [] 

50 files = [] 

51 blobs = [] 

52 amrs = [] 

53 

54 for linespan in llm_mock_output: 

55 blocks = len(linespan.block) 

56 lines = linespan.block[blocks-1].split("-") 

57 line_begin.append( 

58 max(int(lines[0][1:]) - 1, 0) 

59 ) # Normalizing the 1-index response from llm_proxy 

60 line_end.append(int(lines[1][1:])) 

61 if blocks == 2: 

62 lines = linespan.block[0].split("-") 

63 import_begin.append( 

64 max(int(lines[0][1:]) - 1, 0) 

65 ) # Normalizing the 1-index response from llm_proxy 

66 import_end.append(int(lines[1][1:])) 

67 

68 # So we are required to do the same when slicing the source code using its output. 

69 with ZipFile(zip_bytes, "r") as zip: 

70 for file in zip.namelist(): 

71 file_obj = Path(file) 

72 if file_obj.suffix in [".py"]: 

73 files.append(file) 

74 blobs.append(zip.open(file).read().decode("utf-8")) 

75 

76 # The source code is a string, so to slice using the line spans, we must first convert it to a list. 

77 # Then we can convert it back to a string using .join 

78 logging = [] 

79 for i in range(len(blobs)): 

80 if line_begin[i] == line_end[i]: 

81 print("failed linespan") 

82 else: 

83 if blocks == 2: 

84 temp = "".join(blobs[i].splitlines(keepends=True)[import_begin[i]:import_end[i]]) 

85 blobs[i] = temp + "\n" + "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) 

86 else: 

87 blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) 

88 try: 

89 async with httpx.AsyncClient() as client: 

90 code_snippet_response = await code_snippets_to_pn_amr( 

91 system=code2fn.System( 

92 files=[files[i]], 

93 blobs=[blobs[i]], 

94 ), 

95 client=client 

96 ) 

97 # code_snippet_response = json.loads(code_snippet_response.body) 

98 # print(f"code_snippet_response for test_any_amr_chime_sir: {code_snippet_response}") 

99 if "model" in code_snippet_response: 

100 code_snippet_response["header"]["name"] = "LLM-assisted code to amr model" 

101 code_snippet_response["header"]["description"] = f"This model came from code file: {files[i]}" 

102 code_snippet_response["header"]["linespan"] = f"{llm_mock_output[i]}" 

103 amrs.append(code_snippet_response) 

104 else: 

105 print("snippets failure") 

106 logging.append(f"{files[i]} failed to parse an AMR from the dynamics") 

107 except Exception as e: 

108 print("Hit except to snippets failure") 

109 print(f"Exception for test_any_amr_chime_sir:\t{e}") 

110 logging.append(f"{files[i]} failed to parse an AMR from the dynamics") 

111 # we will return the amr with most states, in assumption it is the most "correct" 

112 # by default it returns the first entry 

113 print(f"{amrs}") 

114 try: 

115 amr = amrs[0] 

116 for temp_amr in amrs: 

117 try: 

118 temp_len = len(temp_amr["model"]["states"]) 

119 amr_len = len(amr["model"]["states"]) 

120 if temp_len > amr_len: 

121 amr = temp_amr 

122 except: 

123 continue 

124 except Exception as e: 

125 print(f"Exception for test_any_amr_chime_sir:\t{e}") 

126 amr = logging 

127 print(f"final amr: {amr}\n") 

128 # For this test, we are just checking that AMR was generated without crashing. We are not checking for accuracy. 

129 assert "model" in amr, f"'model' should be in AMR response, but got {amr}" 

130 

131@pytest.mark.asyncio 

132async def test_any_amr_sidarthe(): 

133 """ 

134 Unit test for checking that Chime-SIR model produces any AMR. This test zip contains 4 versions of CHIME SIR. 

135 This will test if just the core dynamics works, the whole script, and also rewritten scripts work.  

136 """ 

137 zip_bytes = BytesIO(SIDARTHE_PATH.read_bytes()) 

138 

139 # NOTE: For CI we are unable to use the LLM assisted functions due to API keys 

140 # So, we will instead mock the output for those functions instead 

141 dyn1 = Dynamics(name="commented_Evaluation_Scenario_2.1.a.ii-Code_Version_A.py", description=None, block=["L1-L6","L7-L59"]) 

142 dyn2 = Dynamics(name="Evaluation_Scenario_2.1.a.ii-Code_Version_A.py", description=None, block=["L1-L6","L7-L18"]) 

143 llm_mock_output = [dyn1, dyn2] 

144 

145 line_begin = [] 

146 import_begin = [] 

147 line_end = [] 

148 import_end = [] 

149 files = [] 

150 blobs = [] 

151 amrs = [] 

152 

153 

154 for linespan in llm_mock_output: 

155 blocks = len(linespan.block) 

156 lines = linespan.block[blocks-1].split("-") 

157 line_begin.append( 

158 max(int(lines[0][1:]) - 1, 0) 

159 ) # Normalizing the 1-index response from llm_proxy 

160 line_end.append(int(lines[1][1:])) 

161 if blocks == 2: 

162 lines = linespan.block[0].split("-") 

163 import_begin.append( 

164 max(int(lines[0][1:]) - 1, 0) 

165 ) # Normalizing the 1-index response from llm_proxy 

166 import_end.append(int(lines[1][1:])) 

167 

168 # So we are required to do the same when slicing the source code using its output. 

169 with ZipFile(zip_bytes, "r") as zip: 

170 for file in zip.namelist(): 

171 file_obj = Path(file) 

172 if file_obj.suffix in [".py"]: 

173 files.append(file) 

174 blobs.append(zip.open(file).read().decode("utf-8")) 

175 

176 # The source code is a string, so to slice using the line spans, we must first convert it to a list. 

177 # Then we can convert it back to a string using .join 

178 logging = [] 

179 for i in range(len(blobs)): 

180 if line_begin[i] == line_end[i]: 

181 print("failed linespan") 

182 else: 

183 if blocks == 2: 

184 temp = "".join(blobs[i].splitlines(keepends=True)[import_begin[i]:import_end[i]]) 

185 blobs[i] = temp + "\n" + "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) 

186 else: 

187 blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) 

188 try: 

189 async with httpx.AsyncClient() as client: 

190 code_snippet_response = await code_snippets_to_pn_amr( 

191 system=code2fn.System( 

192 files=[files[i]], 

193 blobs=[blobs[i]], 

194 ), 

195 client=client 

196 ) 

197 if "model" in code_snippet_response: 

198 code_snippet_response["header"]["name"] = "LLM-assisted code to amr model" 

199 code_snippet_response["header"]["description"] = f"This model came from code file: {files[i]}" 

200 code_snippet_response["header"]["linespan"] = f"{llm_mock_output[i]}" 

201 amrs.append(code_snippet_response) 

202 else: 

203 print("snippets failure") 

204 logging.append(f"{files[i]} failed to parse an AMR from the dynamics") 

205 except Exception as e: 

206 print("Hit except to snippets failure") 

207 print(f"Exception for test_any_amr_sidarthe:\t{e}") 

208 logging.append(f"{files[i]} failed to parse an AMR from the dynamics") 

209 # we will return the amr with most states, in assumption it is the most "correct" 

210 # by default it returns the first entry 

211 print(f"{amrs}") 

212 try: 

213 amr = amrs[0] 

214 for temp_amr in amrs: 

215 try: 

216 temp_len = len(temp_amr["model"]["states"]) 

217 amr_len = len(amr["model"]["states"]) 

218 if temp_len > amr_len: 

219 amr = temp_amr 

220 except: 

221 continue 

222 except Exception as e: 

223 print(f"Exception for final amr of test_any_amr_sidarthe:\t{e}") 

224 amr = logging 

225 print(f"final amr: {amr}\n") 

226 # For this test, we are just checking that AMR was generated without crashing. We are not checking for accuracy. 

227 assert "model" in amr, f"'model' should be in AMR response, but got {amr}" 

228 

229@pytest.mark.asyncio 

230async def test_eq_to_regnet(): 

231 

232 payload = schema.EquationLatexToAMR( 

233 equations = [ 

234 "\\frac{\\partial x}{\\partial t} = {\\alpha x} - {\\beta x y}", 

235 "\\frac{\\partial y}{\\partial t} = {\\alpha x y} - {\\gamma y}" 

236 ], 

237 model = "regnet", 

238 ) 

239 async with httpx.AsyncClient() as client: 

240 regnet_amr_response = await lx_equations_to_amr(payload, client=client) 

241 

242 assert "model" in regnet_amr_response, f"'model' should be in AMR response, but got {regnet_amr_response}" 

243 

244@pytest.mark.asyncio 

245async def test_eq_to_gamr(): 

246 

247 payload = schema.EquationsToAMRs( 

248 equations = [ 

249 "\\frac{\\partial x}{\\partial t} = {\\alpha x} - {\\beta x y}", 

250 "\\frac{\\partial y}{\\partial t} = {\\alpha x y} - {\\gamma y}" 

251 ], 

252 model = "gamr", 

253 ) 

254 async with httpx.AsyncClient() as client: 

255 regnet_amr_response = await equation_to_amrs(payload, client=client) 

256 

257 assert "met" in regnet_amr_response, f"'met' should be in AMR response, but got {regnet_amr_response}" 

258