Coverage for skema/img2mml/eqn2mml.py: 83%

52 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1# -*- coding: utf-8 -*- 

2""" 

3Convert the LaTeX equation to the corresponding presentation MathML using the MathJAX service. 

4Please run the following command to initialize the MathJAX service: 

5node data_generation/mathjax_server.js 

6""" 

7 

8from typing import Text 

9from typing_extensions import Annotated 

10from fastapi import APIRouter, FastAPI, status, Response, Request, Query, UploadFile 

11from skema.rest.proxies import SKEMA_MATHJAX_ADDRESS 

12from skema.img2mml.api import ( 

13 get_mathml_from_bytes, 

14 get_mathml_from_latex, 

15) 

16from skema.img2mml import schema 

17import base64 

18import requests 

19from skema.img2mml.api import Image2MathML 

20from pathlib import Path 

21import os 

22 

23cwd = Path(__file__).parents[0] 

24config_path = os.getenv( 

25 "SKEMA_IMG2MML_CONFIG_PATH", str(cwd / "configs" / "xfmer_mml_config.json") 

26) 

27vocab_path = os.getenv( 

28 "SKEMA_IMG2MML_VOCAB_PATH", 

29 str(cwd / "trained_models" / "arxiv_im2mml_with_fonts_with_boldface_vocab.txt"), 

30) 

31model_path = os.getenv( 

32 "SKEMA_IMG2MML_MODEL_PATH", 

33 str(cwd / "trained_models" / "cnn_xfmer_arxiv_im2mml_with_fonts_boldface_best.pt"), 

34) 

35 

36image2mathml_db = Image2MathML( 

37 config_path=config_path, vocab_path=vocab_path, model_path=model_path 

38) 

39 

40router = APIRouter() 

41 

42 

43def b64_image_to_mml(img_b64: str) -> str: 

44 """Helper method to convert image (encoded as base64) to MML""" 

45 img_bytes = base64.b64decode(img_b64) 

46 # convert bytes of png image to tensor:q 

47 return get_mathml_from_bytes(img_bytes, image2mathml_db) 

48 

49 

50EquationQueryParameter = Annotated[ 

51 Text, 

52 Query( 

53 examples={ 

54 "lotka eq1": { 

55 "summary": "Lotka-Volterra (eq1)", 

56 "description": "Lotka-Volterra (eq1)", 

57 "value": "\\frac{\\delta x}{\\delta t} = {\\alpha x} - {\\beta x y}", 

58 }, 

59 "lotka eq2": { 

60 "summary": "Lotka-Volterra (eq2)", 

61 "description": "Lotka-Volterra (eq2)", 

62 "value": "\\frac{\\delta y}{\\delta t} = {\\alpha x y} - {\\gamma y}", 

63 }, 

64 "simple": { 

65 "summary": "A familiar equation", 

66 "description": "A simple equation (mass-energy equivalence)", 

67 "value": "E = mc^{2}", 

68 }, 

69 "complex": { 

70 "summary": "A more feature-rich equation (Bayes' rule)", 

71 "description": "A equation drawing on latex features", 

72 "value": "\\frac{P(\\textrm{a } | \\textrm{ b}) \\times P(\\textrm{b})}{P(\\textrm{a})}", 

73 }, 

74 }, 

75 ), 

76] 

77 

78 

79def process_latex_equation(eqn: Text) -> Response: 

80 """Helper function used by both GET and POST LaTeX equation processing endpoints""" 

81 res = get_mathml_from_latex(eqn) 

82 return Response(content=res, media_type="application/xml") 

83 

84 

85@router.get( 

86 "/img2mml/healthcheck", 

87 summary="Check health of eqn2mml service", 

88 response_model=int, 

89 status_code=status.HTTP_200_OK, 

90) 

91def img2mml_healthcheck() -> int: 

92 return status.HTTP_200_OK 

93 

94 

95@router.get( 

96 "/latex2mml/healthcheck", 

97 summary="Check health of mathjax service", 

98 response_model=int, 

99 status_code=status.HTTP_200_OK, 

100) 

101def latex2mml_healthcheck() -> int: 

102 try: 

103 return int(requests.get(f"{SKEMA_MATHJAX_ADDRESS}/healthcheck").status_code) 

104 except: 

105 return status.HTTP_500_INTERNAL_SERVER_ERROR 

106 

107 

108@router.post("/image/mml", summary="Get MathML representation of an equation image") 

109async def post_image_to_mathml(data: UploadFile) -> Response: 

110 """ 

111 Endpoint for generating MathML from an input image. 

112 

113 ### Python example 

114 ``` 

115 import requests 

116 

117 files = { 

118 "data": open("bayes-rule-white-bg.png", "rb"), 

119 } 

120 r = requests.post("http://0.0.0.0:8000/image/mml", files=files) 

121 print(r.text) 

122 """ 

123 # Read image data 

124 image_bytes = await data.read() 

125 

126 # pass image bytes to get_mathml_from_bytes function 

127 res = get_mathml_from_bytes(image_bytes, image2mathml_db) 

128 

129 return Response(content=res, media_type="application/xml") 

130 

131 

132@router.post( 

133 "/image/base64/mml", summary="Get MathML representation of an equation image" 

134) 

135async def post_b64image_to_mathml(request: Request) -> Response: 

136 """ 

137 Endpoint for generating MathML from an input image. 

138 

139 ### Python example 

140 ``` 

141 from pathlib import Path 

142 import base64 

143 import requests 

144 

145 url = "http://0.0.0.0:8000/image/base64/mml" 

146 with Path("bayes-rule-white-bg.png").open("rb") as infile: 

147 img_bytes = infile.read() 

148 img_b64 = base64.b64encode(img_bytes).decode("utf-8") 

149 r = requests.post(url, data=img_b64) 

150 print(r.text) 

151 """ 

152 img_b64 = await request.body() 

153 res = b64_image_to_mml(img_b64) 

154 return Response(content=res, media_type="application/xml") 

155 

156 

157@router.get("/latex/mml", summary="Get MathML representation of a LaTeX equation") 

158async def get_tex_to_mathml(tex_src: EquationQueryParameter) -> Response: 

159 """ 

160 GET endpoint for generating MathML from an input LaTeX equation. 

161 

162 ### Python example 

163 ``` 

164 import requests 

165 

166 r = requests.get("http://0.0.0.0:8000/latex/mml", params={"tex_src":"E = mc^{c}"}) 

167 print(r.text) 

168 """ 

169 return process_latex_equation(tex_src) 

170 

171 

172@router.post("/latex/mml", summary="Get MathML representation of a LaTeX equation") 

173async def post_tex_to_mathml(eqn: schema.LatexEquation) -> Response: 

174 """ 

175 Endpoint for generating MathML from an input LaTeX equation. 

176 

177 ### Python example 

178 ``` 

179 import requests 

180 

181 r = requests.post("http://0.0.0.0:8000/latex/mml", json={"tex_src":"E = mc^{2}"}) 

182 print(r.text) 

183 """ 

184 # convert latex string to presentation mathml 

185 return process_latex_equation(eqn.tex_src) 

186 

187 

188app = FastAPI() 

189app.include_router(router)