Coverage for skema/rest/schema.py: 96%

73 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1# -*- coding: utf-8 -*- 

2""" 

3Response models for API 

4""" 

5from typing import List, Optional, Dict, Any 

6 

7from askem_extractions.data_model import AttributeCollection 

8from fastapi import UploadFile 

9from pydantic import BaseModel, Field 

10 

11# see https://github.com/pydantic/pydantic/issues/5821#issuecomment-1559196859 

12from typing_extensions import Literal 

13 

14from skema.img2mml import schema as eqn2mml_schema 

15 

16 

17class HealthStatus(BaseModel): 

18 morae: int = Field(description="HTTP status code for MORAE service", ge=100, le=599) 

19 mathjax: int = Field( 

20 description="HTTP status code for mathjax service (used by eqn2mml latex2mml endpoints)", 

21 ge=100, 

22 le=599, 

23 ) 

24 eqn2mml: int = Field( 

25 description="HTTP status code for eqn2mml service (img2mml endpoints)", 

26 ge=100, 

27 le=599, 

28 ) 

29 code2fn: int = Field( 

30 description="HTTP status code for code2fn service", ge=100, le=599 

31 ) 

32 integrated_text_reading: int = Field( 

33 description="HTTP status code for the integrated text reading service", 

34 ge=100, 

35 le=599, 

36 ) 

37 metal: int = Field( 

38 description="HTTP status code for the integrated text reading service", 

39 ge=100, 

40 le=599, 

41 ) 

42 

43 

44class EquationImagesToAMR(BaseModel): 

45 # FIXME: will this work or do we need base64? 

46 images: List[eqn2mml_schema.ImageBytes] 

47 model: Literal["regnet", "petrinet"] = Field( 

48 description="The model type", examples=["petrinet"] 

49 ) 

50 

51 

52class EquationLatexToAMR(BaseModel): 

53 equations: List[str] = Field( 

54 description="Equations in LaTeX", 

55 examples=[[ 

56 r"\frac{\partial x}{\partial t} = {\alpha x} - {\beta x y}", 

57 r"\frac{\partial y}{\partial t} = {\alpha x y} - {\gamma y}", 

58 ]], 

59 ) 

60 model: Literal["regnet", "petrinet"] = Field( 

61 description="The model type", examples=["regnet"] 

62 ) 

63 

64 

65class EquationToMET(BaseModel): 

66 equations: List[str] = Field( 

67 description="Equations in LaTeX or pMathML", 

68 examples=[[ 

69 r"\frac{\partial x}{\partial t} = {\alpha x} - {\beta x y}", 

70 r"\frac{\partial y}{\partial t} = {\alpha x y} - {\gamma y}", 

71 "<math><mfrac><mrow><mi>d</mi><mi>Susceptible</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>−</mo><mi>Infection</mi><mi>Infected</mi><mi>Susceptible</mi></math>", 

72 ]], 

73 ) 

74 

75class EquationsToAMRs(BaseModel): 

76 equations: List[str] = Field( 

77 description="Equations in LaTeX or pMathML", 

78 examples=[[ 

79 r"\frac{\partial x}{\partial t} = {\alpha x} - {\beta x y}", 

80 r"\frac{\partial y}{\partial t} = {\alpha x y} - {\gamma y}", 

81 "<math><mfrac><mrow><mi>d</mi><mi>Susceptible</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>−</mo><mi>Infection</mi><mi>Infected</mi><mi>Susceptible</mi></math>", 

82 ]], 

83 ) 

84 model: Literal["regnet", "petrinet", "met", "gamr", "decapode"] = Field( 

85 description="The model type", examples=["gamr"] 

86 ) 

87 

88class MmlToAMR(BaseModel): 

89 equations: List[str] = Field( 

90 description="Equations in pMML", 

91 examples=[[ 

92 "<math><mfrac><mrow><mi>d</mi><mi>Susceptible</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>−</mo><mi>Infection</mi><mi>Infected</mi><mi>Susceptible</mi></math>", 

93 "<math><mfrac><mrow><mi>d</mi><mi>Infected</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>−</mo><mi>Recovery</mi><mi>Infected</mi><mo>+</mo><mi>Infection</mi><mi>Infected</mi><mi>Susceptible</mi></math>", 

94 "<math><mfrac><mrow><mi>d</mi><mi>Recovered</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>Recovery</mi><mi>Infected</mi></math>", 

95 ]], 

96 ) 

97 model: Literal["regnet", "petrinet"] = Field( 

98 description="The model type", examples=["petrinet"] 

99 ) 

100 

101 

102class CodeSnippet(BaseModel): 

103 code: str = Field( 

104 title="code", 

105 description="snippet of code in referenced language", 

106 examples=["# this is a comment\ngreet = lambda: print('howdy!')"], 

107 ) 

108 language: Literal["Python", "Fortran", "CppOrC"] = Field( 

109 title="language", description="Programming language corresponding to `code`" 

110 ) 

111 

112 

113class MiraGroundingInputs(BaseModel): 

114 """Model of text reading request body""" 

115 

116 queries: List[str] = Field( 

117 description="List of input plain texts to be grounded to MIRA using embedding similarity", 

118 examples=[["susceptible population", "covid-19"]], 

119 ) 

120 

121 

122class MiraGroundingOutputItem(BaseModel): 

123 class MiraDKGConcept(BaseModel): 

124 id: str = Field(description="DKG element id", examples=["apollosv:00000233"]) 

125 name: str = Field( 

126 description="Canonical name of the concept", examples=["infected population"] 

127 ) 

128 description: Optional[str] = Field( 

129 description="Long winded description of the concept", 

130 examples=["A population of only infected members of one species."], 

131 default=None 

132 ) 

133 synonyms: List[str] = Field( 

134 description="Any alternative name to the cannonical one for the concept", 

135 examples=[[["Ill individuals", "The sick and ailing"]]], 

136 ) 

137 embedding: List[float] = Field( 

138 description="Word embedding of the underlying model for the concept" 

139 ) 

140 

141 def __hash__(self): 

142 return hash(tuple([self.id, tuple(self.synonyms), tuple(self.embedding)])) 

143 

144 score: float = Field( 

145 description="Cosine similarity of the embedding representation of the input with that of the DKG element", 

146 examples=[0.7896], 

147 ) 

148 groundingConcept: MiraDKGConcept = Field( 

149 description="DKG concept associated to the query", 

150 examples=[MiraDKGConcept( 

151 id="apollosv:00000233", 

152 name="infected population", 

153 description="A population of only infected members of one species.", 

154 synonyms=[], 

155 embedding=[ 

156 0.01590670458972454, 

157 0.03795482963323593, 

158 -0.08787763118743896, 

159 ], 

160 )], 

161 ) 

162 

163 

164class TextReadingInputDocuments(BaseModel): 

165 """Model of text reading request body""" 

166 

167 texts: List[str] = Field( 

168 title="texts", 

169 description="List of input plain texts to be annotated by the text reading pipelines", 

170 examples=[["x = 0", "y = 1", "I: Infected population"]], 

171 ) 

172 amrs: List[str] = Field( 

173 description="List of optional AMR files to align with the extractions", 

174 examples=[[]] 

175 ) 

176 

177 

178class TextReadingError(BaseModel): 

179 pipeline: str = Field( 

180 name="pipeline", 

181 description="TextReading pipeline that originated the error", 

182 examples=["SKEMA"], 

183 ) 

184 message: str = Field( 

185 name="message", 

186 description="Error message describing the problem. For debugging purposes", 

187 examples=["Out of memory error"], 

188 ) 

189 

190 def __hash__(self): 

191 return hash(f"{self.pipeline}-{self.message}") 

192 

193 

194class TextReadingDocumentResults(BaseModel): 

195 data: Optional[AttributeCollection] = Field( 

196 None, title="data", 

197 description="AttributeCollection instance with the results of text reading. None if there was an error", 

198 examples=[AttributeCollection(attributes=[])], # Too verbose to add a value here 

199 ) 

200 errors: Optional[List[TextReadingError]] = Field( 

201 None, name="errors", 

202 description="A list of errors reported by the text reading pipelines. None if all pipelines ran successfully", 

203 examples=[[TextReadingError(pipeline="MIT", message="Unauthorized API key")]], 

204 ) 

205 

206 def __hash__(self): 

207 return hash( 

208 tuple([self.data, "NONE" if self.errors is None else tuple(self.errors)]) 

209 ) 

210 

211 

212class TextReadingEvaluationResults(BaseModel): 

213 """ Evaluation results of the SKEMA TR extractions against manual annotations """ 

214 num_manual_annotations: int = Field( 

215 description="Number of manual annotations in the reference document" 

216 ), 

217 yield_: int = Field( 

218 name="yield", 

219 description="Total number of extractions detected in the current document", 

220 ), 

221 correct_extractions: int = Field( 

222 description="Number of extractions matched in the ground-truth annotations" 

223 ), 

224 recall: float = Field( 

225 description="How many of the GT annotations we found through extractions" 

226 ), 

227 precision: float = Field( 

228 description="How many of the extraction are correct according to the GT annotations" 

229 ), 

230 f1: float 

231 

232 

233class AMRLinkingEvaluationResults(BaseModel): 

234 """ Evaluation results of the AMR Linking procedure """ 

235 num_gt_elems_with_metadata: int 

236 precision: float 

237 recall: float 

238 f1: float 

239 

240 

241class TextReadingAnnotationsOutput(BaseModel): 

242 """Contains the TR document results for all the documents submitted for annotation""" 

243 

244 outputs: List[TextReadingDocumentResults] = Field( 

245 name="outputs", 

246 description="Contains the results of TR annotations for each input document. There is one entry per input and " 

247 "inputs and outputs are matched by the same index in the list", 

248 examples=[[ 

249 TextReadingDocumentResults( 

250 data=AttributeCollection(attributes=[]), errors=None 

251 ), 

252 TextReadingDocumentResults( 

253 data=AttributeCollection(attributes=[]), 

254 errors=[TextReadingError(pipeline="SKEMA", message="Dummy error")], 

255 ), 

256 ]], 

257 ) 

258 

259 generalized_errors: Optional[List[TextReadingError]] = Field( 

260 None, name="generalized_errors", 

261 description="Any pipeline-wide errors, not specific to a particular input", 

262 examples=[[TextReadingError(pipeline="MIT", message="API quota exceeded")]], 

263 ) 

264 

265 aligned_amrs: List[Dict[str, Any]] = Field( 

266 description="An aligned list of AMRs to the text extractions. This field will be populated only if it was" 

267 " provided as part of the input", 

268 default_factory=lambda: [] 

269 )