Coverage for skema/img2mml/translate.py: 88%
189 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1# -*- coding: utf-8 -*-
3import random
4import numpy as np
5import torch
6from torchvision import transforms
7from PIL import Image
8from skema.img2mml.models.encoders.cnn_encoder import CNN_Encoder
9from skema.img2mml.models.encoders.xfmer_encoder import Transformer_Encoder
10from skema.img2mml.models.decoders.xfmer_decoder import Transformer_Decoder
11import io
12from typing import List, Union
13import logging
14import re
15from skema.img2mml.models.image2mml_xfmer import Image2MathML_Xfmer
16from xml.etree import ElementTree as ET
18# Set logging level to INFO
19logging.basicConfig(level=logging.INFO)
22def remove_eqn_number(image: Image.Image, threshold: float = 0.1) -> Image.Image:
23 """
24 Remove equation number from an image of an equation.
26 Args:
27 image (Image.Image): The input image.
28 threshold (float, optional): The threshold to determine the size of the equation number.
29 A smaller threshold will consider larger areas as equation numbers.
30 Defaults to 0.1.
32 Returns:
33 Image.Image: The modified image with the equation number removed.
34 """
35 image_arr = np.asarray(image, dtype=np.uint8)
37 # Invert the image by subtracting it from the maximum pixel value
38 inverted = np.max(image_arr) - image_arr
40 # Get the width and height of the image
41 height, width = inverted.shape[:2]
43 # Start scanning from the right side
44 column_sum = np.sum(inverted, axis=0)
45 rightmost_column = width - 1
46 leftmost_column = rightmost_column
47 while leftmost_column >= 0:
48 if column_sum[leftmost_column] != 0:
49 if rightmost_column - leftmost_column > threshold * width:
50 image_arr = image_arr[:, 0:leftmost_column]
51 return Image.fromarray(image_arr)
53 leftmost_column -= 1
54 rightmost_column = leftmost_column
55 else:
56 leftmost_column -= 1
58 return Image.fromarray(image_arr)
61def calculate_scale_factor(
62 image: Image.Image, target_width: int, target_height: int
63) -> float:
64 """
65 Calculate the scale factor to normalize the input image to the target width and height while preserving the
66 original aspect ratio. If the original aspect ratio is larger than the target aspect ratio, the scale factor
67 will be calculated based on width. Otherwise, it will be calculated based on height.
69 Args:
70 image (PIL.Image.Image): The input image to be normalized.
71 target_width (int): The target width for normalization.
72 target_height (int): The target height for normalization.
74 Returns:
75 float: The scale factor to normalize the image.
76 """
77 original_width, original_height = image.size
78 original_aspect_ratio = original_width / original_height
79 target_aspect_ratio = target_width / target_height
81 if original_aspect_ratio > target_aspect_ratio:
82 # Calculate scale factor based on width
83 scale_factor = target_width / original_width
84 else:
85 # Calculate scale factor based on height
86 scale_factor = target_height / original_height
88 return scale_factor
91def preprocess_img(image: Image.Image, config: dict) -> Image.Image:
92 """preprocessing image - cropping, resizing, and padding"""
93 # remove equation number if having
94 image = remove_eqn_number(image)
96 # converting to np array
97 image_arr = np.asarray(image, dtype=np.uint8)
98 # find where the data lies
99 indices = np.where(image_arr != 255)
100 # get the boundaries
101 x_min = np.min(indices[1])
102 x_max = np.max(indices[1])
103 y_min = np.min(indices[0])
104 y_max = np.max(indices[0])
106 # cropping tha image
107 image = image.crop((x_min, y_min, x_max, y_max))
109 # calculate the target width and height
110 target_width = config["preprocessed_image_width"] - 2 * config["padding"]
111 target_height = config["preprocessed_image_height"] - 2 * config["padding"]
112 # calculate the scale factor
113 resize_factor = calculate_scale_factor(image, target_width, target_height)
115 # resizing the image
116 image = image.resize(
117 (
118 int(image.size[0] * resize_factor),
119 int(image.size[1] * resize_factor),
120 ),
121 Image.LANCZOS,
122 )
124 # padding
125 pad = config["padding"]
126 width = config["preprocessed_image_width"]
127 height = config["preprocessed_image_height"]
128 new_image = Image.new("RGB", (width, height), (255, 255, 255))
129 new_image.paste(image, (pad, pad))
131 return new_image
134def convert_to_torch_tensor(image: bytes, config: dict) -> torch.Tensor:
135 """Convert image to torch tensor."""
136 image = Image.open(io.BytesIO(image)).convert("L")
137 image = preprocess_img(image, config)
139 # convert to tensor
140 image = transforms.ToTensor()(image)
142 return image
145def set_random_seed(seed: int) -> None:
146 """Set up seed."""
148 random.seed(seed)
149 np.random.seed(seed)
150 torch.manual_seed(seed)
151 torch.cuda.manual_seed(seed)
154def define_model(
155 config: dict, vocab: List[str], device: torch.device, model_type="xfmer"
156) -> Image2MathML_Xfmer:
157 """
158 Defining the model
159 initializing encoder, decoder, and model
160 """
162 print("Defining model...")
164 model_type = config["model_type"]
165 input_channels = config["input_channels"]
166 output_dim = len(vocab)
167 emb_dim = config["embedding_dim"]
168 dec_hid_dim = config["decoder_hid_dim"]
169 dropout = config["dropout"]
170 max_len = config["max_len"]
172 print(f"building {model_type} model...")
174 dim_feedfwd = config["dim_feedforward_for_xfmer"]
175 n_heads = config["n_xfmer_heads"]
176 n_xfmer_encoder_layers = config["n_xfmer_encoder_layers"]
177 n_xfmer_decoder_layers = config["n_xfmer_decoder_layers"]
178 len_dim = 2500
180 enc = {
181 "CNN": CNN_Encoder(input_channels, dec_hid_dim, dropout, device),
182 "XFMER": Transformer_Encoder(
183 emb_dim,
184 dec_hid_dim,
185 n_heads,
186 dropout,
187 device,
188 max_len,
189 n_xfmer_encoder_layers,
190 dim_feedfwd,
191 len_dim,
192 ),
193 }
194 dec = Transformer_Decoder(
195 emb_dim,
196 n_heads,
197 dec_hid_dim,
198 output_dim,
199 dropout,
200 max_len,
201 n_xfmer_decoder_layers,
202 dim_feedfwd,
203 device,
204 )
205 model = Image2MathML_Xfmer(enc, dec, vocab, device)
207 return model
210def merge_mn_elements(input_mathml: str) -> str:
211 """
212 Merge consecutive <mn> elements in the <mrow> structure of MathML.
214 Args:
215 input_mathml (str): Input MathML string.
217 Returns:
218 str: MathML string with consecutive <mn> elements merged.
219 """
220 input_mathml = input_mathml.replace("&#x", "###") # keep the unicode representation
221 root = ET.fromstring(input_mathml)
223 def merge_mn_in_mrow(mrow: ET.Element) -> List[ET.Element]:
224 """
225 Merge consecutive <mn> elements within a given <mrow>.
227 Args:
228 mrow (ET.Element): <mrow> element.
230 Returns:
231 List[ET.Element]: List of merged elements.
232 """
233 merged_elements = []
234 current_number = None
236 for child in mrow:
237 if child.tag == 'mn':
238 if current_number is None:
239 current_number = int(child.text)
240 else:
241 current_number = current_number * 10 + int(child.text)
242 else:
243 if current_number is not None:
244 merged_elements.append(ET.Element('mn'))
245 merged_elements[-1].text = str(current_number)
246 current_number = None
247 merged_elements.append(child)
249 if current_number is not None:
250 merged_elements.append(ET.Element('mn'))
251 merged_elements[-1].text = str(current_number)
253 return merged_elements
255 def process_element(element: ET.Element) -> None:
256 """
257 Recursively process each element in the MathML structure.
259 Args:
260 element (ET.Element): Current element to process.
261 """
262 if element.tag == 'mrow':
263 element[:] = merge_mn_in_mrow(element)
264 for child in element:
265 process_element(child)
266 else:
267 for child in element:
268 process_element(child)
270 process_element(root)
271 modified_xml_string = ET.tostring(root, encoding="utf-8", method="xml").decode('utf-8')
272 modified_xml_string = modified_xml_string.replace("###", "&#x")
274 return modified_xml_string
277def process_mtext(xml_string: str) -> str:
278 """
279 Process the input MathML string by merging consecutive <mi> elements and replacing specific <mrow> structures.
281 Args:
282 xml_string (str): The input MathML string.
284 Returns:
285 str: The modified MathML string.
286 """
287 xml_string = xml_string.replace("&#x", "###") # keep the unicode representation
288 root = ET.fromstring(xml_string)
290 def merge_mi_elements(elements):
291 """
292 Merge consecutive <mi> elements into an <mtext> element.
294 Args:
295 elements (list): List of consecutive <mi> elements.
297 Returns:
298 Element: The merged <mtext> element.
299 """
300 merged_text = "".join([elem.text for elem in elements])
301 mtext = ET.Element("mtext")
302 mtext.text = merged_text
303 return mtext
305 # Replace specific <mrow> structures with <mtext>
306 for mrow in root.findall(".//mrow"):
307 mi_count = sum(1 for child in mrow if child.tag == "mi")
308 non_mi_count = sum(1 for child in mrow if child.tag != "mi")
310 if mi_count >= 3 and non_mi_count == 0:
311 mi_elements = [child for child in mrow if child.tag == "mi"]
312 mtext = merge_mi_elements(mi_elements)
313 mrow.clear()
314 mrow.tag = "to_be_removed"
315 mrow.append(mtext)
317 mi_count = 0
318 consecutive_mi_elements = []
319 new_children = []
321 # Merge consecutive <mi> elements
322 for child in root:
323 if child.tag == "mi":
324 mi_count += 1
325 consecutive_mi_elements.append(child)
326 else:
327 if mi_count >= 5:
328 mtext = merge_mi_elements(consecutive_mi_elements)
329 new_children.append(mtext)
330 else:
331 new_children.extend(consecutive_mi_elements)
332 mi_count = 0
333 consecutive_mi_elements = []
334 new_children.append(child)
336 if mi_count >= 5:
337 mtext = merge_mi_elements(consecutive_mi_elements)
338 new_children.append(mtext)
339 else:
340 new_children.extend(consecutive_mi_elements)
342 root.clear()
343 root.extend(new_children)
345 modified_xml_string = ET.tostring(root, encoding="utf-8", method="xml").decode('utf-8')
346 modified_xml_string = modified_xml_string.replace("<to_be_removed>", "")
347 modified_xml_string = modified_xml_string.replace("</to_be_removed>", "")
348 modified_xml_string = modified_xml_string.replace("###", "&#x")
350 return modified_xml_string
353def add_semicolon_to_unicode(string: str) -> str:
354 """
355 Checks if the string contains Unicode starting with '&#x' and adds a semicolon ';' after each occurrence if missing.
357 Args:
358 string (str): The input string to check.
360 Returns:
361 str: The modified string with semicolons added after each Unicode occurrence if necessary.
362 """
363 # Define a regular expression pattern to match '&#x' followed by hexadecimal characters
364 pattern = r"&#x[0-9A-Fa-f]+"
366 def add_semicolon(match):
367 unicode_value = match.group(0)
368 if not unicode_value.endswith(";"):
369 unicode_value += ";"
370 return unicode_value
372 # Find all matches in the string using the pattern and process each match individually
373 modified_string = re.sub(pattern, add_semicolon, string)
375 return modified_string
378def remove_spaces_between_tags(mathml_string: str) -> str:
379 """
380 Remove spaces between ">" and "<" in a MathML string.
382 Args:
383 mathml_string (str): The MathML string to process.
385 Returns:
386 str: The modified MathML string with spaces removed between tags.
387 """
388 pattern = r">(.*?)<"
389 replaced_string = re.sub(
390 pattern, lambda match: match.group(0).replace(" ", ""), mathml_string
391 )
392 return replaced_string
395def render_mml(
396 model: Image2MathML_Xfmer,
397 vocab_itos: dict,
398 vocab_stoi: dict,
399 img: torch.Tensor,
400 device: torch.device,
401) -> str:
402 """
403 Perform sequence prediction for an input image to translate it into MathML contents.
405 Args:
406 model (Image2MathML_Xfmer): The image-to-MathML model.
407 vocab_itos (dict): The vocabulary lookup dictionary (index to symbol).
408 vocab_stoi (dict): The vocabulary lookup dictionary (symbol to index).
409 img (torch.Tensor): The input image as a tensor.
410 device (torch.device): The device (GPU or CPU) to be used for computation.
412 Returns:
413 str: The generated MathML string.
414 """
416 model.eval()
417 with torch.no_grad():
418 img = img.to(device)
420 output = model(
421 img,
422 device,
423 is_inference=True,
424 SOS_token=int(vocab_stoi["<sos>"]),
425 EOS_token=int(vocab_stoi["<eos>"]),
426 PAD_token=int(vocab_stoi["<pad>"]),
427 ) # O: (1, max_len, output_dim), preds: (1, max_len)
429 pred = list()
430 for p in output:
431 pred.append(vocab_itos[str(p)])
433 pred_seq = " ".join(pred[1:-1])
435 try:
436 res = add_semicolon_to_unicode(remove_spaces_between_tags(pred_seq))
437 res = merge_mn_elements(res)
438 res = process_mtext(res)
439 return res
440 except:
441 return res