Coverage for skema/img2mml/api.py: 78%

118 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import os 

2import requests 

3from pathlib import Path 

4import urllib.request 

5from skema.rest.proxies import SKEMA_MATHJAX_ADDRESS 

6from skema.img2mml.translate import convert_to_torch_tensor, render_mml 

7from skema.img2mml.models.image2mml_xfmer import Image2MathML_Xfmer 

8import torch 

9from typing import Tuple, List, Any, Dict 

10from logging import info 

11from skema.img2mml.translate import define_model 

12import json 

13from PIL import Image 

14from io import BytesIO 

15 

16from huggingface_hub import hf_hub_download 

17 

18def retrieve_model(model_path=None) -> str: 

19 """ 

20 Retrieve the img2mml model from the specified path or download it if not found. 

21 

22 Args: 

23 model_path (str, optional): Path to the img2mml model file. Defaults to None. 

24 

25 Returns: 

26 str: Path to the loaded model file. 

27 """ 

28 cwd = Path(__file__).parents[0] 

29 REPO_NAME = "lum-ai/img2mml" 

30 MODEL_NAME = "cnn_xfmer_arxiv_im2mml_with_fonts_boldface_best.pt" 

31 # If the model path is none or doesn't exist, the default model will be downloaded from server. 

32 if model_path is None or not os.path.exists(model_path): 

33 model_path = cwd / "trained_models" / MODEL_NAME 

34 

35 # Check if the model file already exists 

36 if not os.path.exists(model_path): 

37 # If the file doesn't exist, download it from the specified URL 

38 print(f"Downloading the model checkpoint from HuggingFace...") 

39 hf_hub_download(repo_id=REPO_NAME, filename=MODEL_NAME, local_dir=model_path.parent, local_dir_use_symlinks=False) 

40 

41 return str(model_path) 

42 

43 

44def check_gpu_availability() -> torch.device: 

45 """ 

46 Check if GPU is available and return the appropriate device. 

47 

48 Returns: 

49 torch.device: The device (GPU or CPU) to be used for computation. 

50 """ 

51 if not torch.cuda.is_available(): 

52 print("CUDA is not available, falling back to using the CPU.") 

53 device = torch.device("cpu") 

54 else: 

55 device = torch.device("cuda") 

56 

57 return device 

58 

59 

60def load_model( 

61 model_path: str, 

62 config: dict, 

63 vocab: List[str], 

64 device: torch.device = torch.device("cpu"), 

65) -> Image2MathML_Xfmer: 

66 """ 

67 Load the model's state dictionary from a file. 

68 

69 Args: 

70 model_path: The path to the model state dictionary file. 

71 config: The configuration setting. 

72 vocab: The vocabulary dictionary of the img2mml model. 

73 device: The device (GPU or CPU) to be used for computation. 

74 

75 Returns: 

76 The model with loaded state dictionary. 

77 

78 Raises: 

79 FileNotFoundError: If the model state dictionary file does not exist. 

80 RuntimeError: If there is an error during loading the state dictionary. 

81 

82 Note: 

83 If `clean_state_dict` is True, the function removes the "module." prefix from the state_dict keys 

84 if present. 

85 

86 If CUDA is not available, the function falls back to using the CPU for loading the state dictionary. 

87 """ 

88 

89 model: Image2MathML_Xfmer = define_model(config, vocab, device).to(device) 

90 cwd = Path(__file__).parents[0] 

91 if model_path is None: 

92 model_path = ( 

93 cwd / "trained_models" / "arxiv_im2mml_with_fonts_with_boldface_best.pt" 

94 ) 

95 try: 

96 if not torch.cuda.is_available(): 

97 info("CUDA is not available, falling back to using the CPU.") 

98 

99 new_model = dict() 

100 for key, value in torch.load(model_path, map_location=device).items(): 

101 new_model[key[7:]] = value 

102 model.load_state_dict(new_model, strict=False) 

103 

104 except FileNotFoundError: 

105 raise FileNotFoundError(f"Model state dictionary file not found: {model_path}") 

106 except Exception as e: 

107 raise RuntimeError( 

108 f"Error loading state dictionary from file: {model_path}\n{e}" 

109 ) 

110 

111 return model 

112 

113 

114def load_vocab(vocab_path: str = None) -> Tuple[List[str], dict, dict]: 

115 """ 

116 Load vocabulary from a list and create dictionaries for both forward and backward mapping. 

117 

118 Args: 

119 vocab (Optional[str, Path]): The vocabulary path. 

120 

121 Returns: 

122 Tuple[List[str], dict, dict]: A tuple containing two dictionaries: 

123 - vocab (List[str]): A complete dictionary. 

124 - vocab_itos (dict): A dictionary mapping index to token. 

125 - vocab_stoi (dict): A dictionary mapping token to index. 

126 """ 

127 cwd = Path(__file__).parents[0] 

128 if vocab_path is None: 

129 vocab_path = ( 

130 cwd / "trained_models" / "arxiv_im2mml_with_fonts_with_boldface_vocab.txt" 

131 ) 

132 

133 # read vocab.txt 

134 with open(vocab_path) as f: 

135 vocab = f.readlines() 

136 

137 vocab_itos = dict() 

138 vocab_stoi = dict() 

139 

140 for v in vocab: 

141 k, v = v.split() 

142 vocab_itos[v.strip()] = k.strip() 

143 vocab_stoi[k.strip()] = v.strip() 

144 

145 return vocab, vocab_itos, vocab_stoi 

146 

147 

148class Image2MathML: 

149 def __init__(self, config_path: str, vocab_path: str, model_path: str) -> None: 

150 self.config = self.load_config(config_path) 

151 self.vocab, self.vocab_itos, self.vocab_stoi = self.load_vocab(vocab_path) 

152 self.device = self.check_gpu_availability() 

153 self.model = self.load_model(model_path) 

154 

155 def load_config(self, config_path: str) -> Dict[str, Any]: 

156 with open(config_path, "r") as cfg: 

157 config = json.load(cfg) 

158 return config 

159 

160 def load_vocab(self, vocab_path: str) -> Tuple[Any, Dict[str, Any], Dict[str, Any]]: 

161 # Load the image2mathml vocabulary 

162 vocab, vocab_itos, vocab_stoi = load_vocab(vocab_path=vocab_path) 

163 return vocab, vocab_itos, vocab_stoi 

164 

165 def check_gpu_availability(self) -> torch.device: 

166 # Check GPU availability 

167 if torch.cuda.is_available(): 

168 device = torch.device("cuda") 

169 else: 

170 device = torch.device("cpu") 

171 return device 

172 

173 def load_model(self, model_path: str) -> Image2MathML_Xfmer: 

174 # Load the image2mathml model 

175 MODEL_PATH = retrieve_model(model_path=model_path) 

176 img2mml_model: Image2MathML_Xfmer = load_model( 

177 model_path=MODEL_PATH, 

178 config=self.config, 

179 vocab=self.vocab, 

180 device=self.device, 

181 ) 

182 return img2mml_model 

183 

184 

185def replace_transparent_background(image_bytes: bytes) -> bytes: 

186 """ 

187 Replace transparent background with white if the image has transparency. 

188 

189 Args: 

190 image_bytes (bytes): Bytes of the input image. 

191 

192 Returns: 

193 bytes: Bytes of the processed image with replaced background. 

194 """ 

195 # Open the image using PIL 

196 image = Image.open(BytesIO(image_bytes)) 

197 

198 # Check if the image has an alpha (transparency) channel 

199 if image.mode in ("RGBA", "LA") and image.getchannel("A"): 

200 # Create a new image with white background 

201 new_image = Image.new("RGB", image.size, (255, 255, 255)) 

202 new_image.paste( 

203 image, mask=image.split()[3] 

204 ) # Paste the original image on the new image with alpha mask 

205 # Save the new image to bytes 

206 output_bytes = BytesIO() 

207 new_image.save(output_bytes, format="PNG") 

208 return output_bytes.getvalue() 

209 else: 

210 # If the image does not have transparency, return the original image data 

211 return image_bytes 

212 

213 

214def get_mathml_from_bytes( 

215 data: bytes, 

216 image2mathml_db: Image2MathML, 

217) -> str: 

218 """ 

219 Convert an image in bytes format to MathML representation using the provided model. 

220 

221 Args: 

222 data (bytes): The image data in bytes format. 

223 model (Image2MathML_Xfmer): The pre-trained image-to-MathML model. 

224 config (Dict): Configuration dictionary for rendering MathML. 

225 vocab_itos (Dict): Dictionary mapping index to token for vocabulary. 

226 vocab_stoi (Dict): Dictionary mapping token to index for vocabulary. 

227 device (torch.device): CPU or GPU. 

228 

229 Returns: 

230 str: The MathML representation of the input image. 

231 """ 

232 # replace transparent background with white if the image has transparency 

233 data = replace_transparent_background(data) 

234 # convert png image to tensor 

235 imagetensor = convert_to_torch_tensor(data, image2mathml_db.config) 

236 

237 # change the shape of tensor from (C_in, H, W) 

238 # to (1, C_in, H, w) [batch =1] 

239 imagetensor = imagetensor.unsqueeze(0) 

240 

241 return render_mml( 

242 image2mathml_db.model, 

243 image2mathml_db.vocab_itos, 

244 image2mathml_db.vocab_stoi, 

245 imagetensor, 

246 image2mathml_db.device, 

247 ) 

248 

249 

250def get_mathml_from_file(filepath) -> str: 

251 """Read an equation image file and convert it to MathML""" 

252 

253 with open(filepath, "rb") as f: 

254 data = f.read() 

255 

256 return get_mathml_from_bytes(data) 

257 

258 

259def get_mathml_from_latex(eqn: str) -> str: 

260 """Read a LaTeX equation string and convert it to presentation MathML""" 

261 

262 # Define the webservice address from the MathJAX service 

263 webservice = SKEMA_MATHJAX_ADDRESS 

264 print(f"Connecting to {webservice}") 

265 

266 # Translate and save each LaTeX string using the NodeJS service for MathJax 

267 res = requests.post( 

268 f"{webservice}/tex2mml", 

269 headers={"Content-type": "application/json"}, 

270 json={"tex_src": eqn}, 

271 ) 

272 if res.status_code == 200: 

273 return res.text 

274 else: 

275 try: 

276 res.raise_for_status() 

277 except requests.HTTPError as e: 

278 return f"HTTP error occurred: {e}" 

279 except requests.ConnectionError as e: 

280 return f"Connection error occurred: {e}" 

281 except requests.Timeout as e: 

282 return f"Timeout error occurred: {e}" 

283 except requests.RequestException as e: 

284 return f"An error occurred: {e}" 

285 finally: 

286 return "Conversion Failed." 

287 

288