Coverage for skema/img2mml/api.py: 78%
118 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import os
2import requests
3from pathlib import Path
4import urllib.request
5from skema.rest.proxies import SKEMA_MATHJAX_ADDRESS
6from skema.img2mml.translate import convert_to_torch_tensor, render_mml
7from skema.img2mml.models.image2mml_xfmer import Image2MathML_Xfmer
8import torch
9from typing import Tuple, List, Any, Dict
10from logging import info
11from skema.img2mml.translate import define_model
12import json
13from PIL import Image
14from io import BytesIO
16from huggingface_hub import hf_hub_download
18def retrieve_model(model_path=None) -> str:
19 """
20 Retrieve the img2mml model from the specified path or download it if not found.
22 Args:
23 model_path (str, optional): Path to the img2mml model file. Defaults to None.
25 Returns:
26 str: Path to the loaded model file.
27 """
28 cwd = Path(__file__).parents[0]
29 REPO_NAME = "lum-ai/img2mml"
30 MODEL_NAME = "cnn_xfmer_arxiv_im2mml_with_fonts_boldface_best.pt"
31 # If the model path is none or doesn't exist, the default model will be downloaded from server.
32 if model_path is None or not os.path.exists(model_path):
33 model_path = cwd / "trained_models" / MODEL_NAME
35 # Check if the model file already exists
36 if not os.path.exists(model_path):
37 # If the file doesn't exist, download it from the specified URL
38 print(f"Downloading the model checkpoint from HuggingFace...")
39 hf_hub_download(repo_id=REPO_NAME, filename=MODEL_NAME, local_dir=model_path.parent, local_dir_use_symlinks=False)
41 return str(model_path)
44def check_gpu_availability() -> torch.device:
45 """
46 Check if GPU is available and return the appropriate device.
48 Returns:
49 torch.device: The device (GPU or CPU) to be used for computation.
50 """
51 if not torch.cuda.is_available():
52 print("CUDA is not available, falling back to using the CPU.")
53 device = torch.device("cpu")
54 else:
55 device = torch.device("cuda")
57 return device
60def load_model(
61 model_path: str,
62 config: dict,
63 vocab: List[str],
64 device: torch.device = torch.device("cpu"),
65) -> Image2MathML_Xfmer:
66 """
67 Load the model's state dictionary from a file.
69 Args:
70 model_path: The path to the model state dictionary file.
71 config: The configuration setting.
72 vocab: The vocabulary dictionary of the img2mml model.
73 device: The device (GPU or CPU) to be used for computation.
75 Returns:
76 The model with loaded state dictionary.
78 Raises:
79 FileNotFoundError: If the model state dictionary file does not exist.
80 RuntimeError: If there is an error during loading the state dictionary.
82 Note:
83 If `clean_state_dict` is True, the function removes the "module." prefix from the state_dict keys
84 if present.
86 If CUDA is not available, the function falls back to using the CPU for loading the state dictionary.
87 """
89 model: Image2MathML_Xfmer = define_model(config, vocab, device).to(device)
90 cwd = Path(__file__).parents[0]
91 if model_path is None:
92 model_path = (
93 cwd / "trained_models" / "arxiv_im2mml_with_fonts_with_boldface_best.pt"
94 )
95 try:
96 if not torch.cuda.is_available():
97 info("CUDA is not available, falling back to using the CPU.")
99 new_model = dict()
100 for key, value in torch.load(model_path, map_location=device).items():
101 new_model[key[7:]] = value
102 model.load_state_dict(new_model, strict=False)
104 except FileNotFoundError:
105 raise FileNotFoundError(f"Model state dictionary file not found: {model_path}")
106 except Exception as e:
107 raise RuntimeError(
108 f"Error loading state dictionary from file: {model_path}\n{e}"
109 )
111 return model
114def load_vocab(vocab_path: str = None) -> Tuple[List[str], dict, dict]:
115 """
116 Load vocabulary from a list and create dictionaries for both forward and backward mapping.
118 Args:
119 vocab (Optional[str, Path]): The vocabulary path.
121 Returns:
122 Tuple[List[str], dict, dict]: A tuple containing two dictionaries:
123 - vocab (List[str]): A complete dictionary.
124 - vocab_itos (dict): A dictionary mapping index to token.
125 - vocab_stoi (dict): A dictionary mapping token to index.
126 """
127 cwd = Path(__file__).parents[0]
128 if vocab_path is None:
129 vocab_path = (
130 cwd / "trained_models" / "arxiv_im2mml_with_fonts_with_boldface_vocab.txt"
131 )
133 # read vocab.txt
134 with open(vocab_path) as f:
135 vocab = f.readlines()
137 vocab_itos = dict()
138 vocab_stoi = dict()
140 for v in vocab:
141 k, v = v.split()
142 vocab_itos[v.strip()] = k.strip()
143 vocab_stoi[k.strip()] = v.strip()
145 return vocab, vocab_itos, vocab_stoi
148class Image2MathML:
149 def __init__(self, config_path: str, vocab_path: str, model_path: str) -> None:
150 self.config = self.load_config(config_path)
151 self.vocab, self.vocab_itos, self.vocab_stoi = self.load_vocab(vocab_path)
152 self.device = self.check_gpu_availability()
153 self.model = self.load_model(model_path)
155 def load_config(self, config_path: str) -> Dict[str, Any]:
156 with open(config_path, "r") as cfg:
157 config = json.load(cfg)
158 return config
160 def load_vocab(self, vocab_path: str) -> Tuple[Any, Dict[str, Any], Dict[str, Any]]:
161 # Load the image2mathml vocabulary
162 vocab, vocab_itos, vocab_stoi = load_vocab(vocab_path=vocab_path)
163 return vocab, vocab_itos, vocab_stoi
165 def check_gpu_availability(self) -> torch.device:
166 # Check GPU availability
167 if torch.cuda.is_available():
168 device = torch.device("cuda")
169 else:
170 device = torch.device("cpu")
171 return device
173 def load_model(self, model_path: str) -> Image2MathML_Xfmer:
174 # Load the image2mathml model
175 MODEL_PATH = retrieve_model(model_path=model_path)
176 img2mml_model: Image2MathML_Xfmer = load_model(
177 model_path=MODEL_PATH,
178 config=self.config,
179 vocab=self.vocab,
180 device=self.device,
181 )
182 return img2mml_model
185def replace_transparent_background(image_bytes: bytes) -> bytes:
186 """
187 Replace transparent background with white if the image has transparency.
189 Args:
190 image_bytes (bytes): Bytes of the input image.
192 Returns:
193 bytes: Bytes of the processed image with replaced background.
194 """
195 # Open the image using PIL
196 image = Image.open(BytesIO(image_bytes))
198 # Check if the image has an alpha (transparency) channel
199 if image.mode in ("RGBA", "LA") and image.getchannel("A"):
200 # Create a new image with white background
201 new_image = Image.new("RGB", image.size, (255, 255, 255))
202 new_image.paste(
203 image, mask=image.split()[3]
204 ) # Paste the original image on the new image with alpha mask
205 # Save the new image to bytes
206 output_bytes = BytesIO()
207 new_image.save(output_bytes, format="PNG")
208 return output_bytes.getvalue()
209 else:
210 # If the image does not have transparency, return the original image data
211 return image_bytes
214def get_mathml_from_bytes(
215 data: bytes,
216 image2mathml_db: Image2MathML,
217) -> str:
218 """
219 Convert an image in bytes format to MathML representation using the provided model.
221 Args:
222 data (bytes): The image data in bytes format.
223 model (Image2MathML_Xfmer): The pre-trained image-to-MathML model.
224 config (Dict): Configuration dictionary for rendering MathML.
225 vocab_itos (Dict): Dictionary mapping index to token for vocabulary.
226 vocab_stoi (Dict): Dictionary mapping token to index for vocabulary.
227 device (torch.device): CPU or GPU.
229 Returns:
230 str: The MathML representation of the input image.
231 """
232 # replace transparent background with white if the image has transparency
233 data = replace_transparent_background(data)
234 # convert png image to tensor
235 imagetensor = convert_to_torch_tensor(data, image2mathml_db.config)
237 # change the shape of tensor from (C_in, H, W)
238 # to (1, C_in, H, w) [batch =1]
239 imagetensor = imagetensor.unsqueeze(0)
241 return render_mml(
242 image2mathml_db.model,
243 image2mathml_db.vocab_itos,
244 image2mathml_db.vocab_stoi,
245 imagetensor,
246 image2mathml_db.device,
247 )
250def get_mathml_from_file(filepath) -> str:
251 """Read an equation image file and convert it to MathML"""
253 with open(filepath, "rb") as f:
254 data = f.read()
256 return get_mathml_from_bytes(data)
259def get_mathml_from_latex(eqn: str) -> str:
260 """Read a LaTeX equation string and convert it to presentation MathML"""
262 # Define the webservice address from the MathJAX service
263 webservice = SKEMA_MATHJAX_ADDRESS
264 print(f"Connecting to {webservice}")
266 # Translate and save each LaTeX string using the NodeJS service for MathJax
267 res = requests.post(
268 f"{webservice}/tex2mml",
269 headers={"Content-type": "application/json"},
270 json={"tex_src": eqn},
271 )
272 if res.status_code == 200:
273 return res.text
274 else:
275 try:
276 res.raise_for_status()
277 except requests.HTTPError as e:
278 return f"HTTP error occurred: {e}"
279 except requests.ConnectionError as e:
280 return f"Connection error occurred: {e}"
281 except requests.Timeout as e:
282 return f"Timeout error occurred: {e}"
283 except requests.RequestException as e:
284 return f"An error occurred: {e}"
285 finally:
286 return "Conversion Failed."