Coverage for skema/img2mml/translate.py: 88%

189 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1# -*- coding: utf-8 -*- 

2 

3import random 

4import numpy as np 

5import torch 

6from torchvision import transforms 

7from PIL import Image 

8from skema.img2mml.models.encoders.cnn_encoder import CNN_Encoder 

9from skema.img2mml.models.encoders.xfmer_encoder import Transformer_Encoder 

10from skema.img2mml.models.decoders.xfmer_decoder import Transformer_Decoder 

11import io 

12from typing import List, Union 

13import logging 

14import re 

15from skema.img2mml.models.image2mml_xfmer import Image2MathML_Xfmer 

16from xml.etree import ElementTree as ET 

17 

18# Set logging level to INFO 

19logging.basicConfig(level=logging.INFO) 

20 

21 

22def remove_eqn_number(image: Image.Image, threshold: float = 0.1) -> Image.Image: 

23 """ 

24 Remove equation number from an image of an equation. 

25 

26 Args: 

27 image (Image.Image): The input image. 

28 threshold (float, optional): The threshold to determine the size of the equation number. 

29 A smaller threshold will consider larger areas as equation numbers. 

30 Defaults to 0.1. 

31 

32 Returns: 

33 Image.Image: The modified image with the equation number removed. 

34 """ 

35 image_arr = np.asarray(image, dtype=np.uint8) 

36 

37 # Invert the image by subtracting it from the maximum pixel value 

38 inverted = np.max(image_arr) - image_arr 

39 

40 # Get the width and height of the image 

41 height, width = inverted.shape[:2] 

42 

43 # Start scanning from the right side 

44 column_sum = np.sum(inverted, axis=0) 

45 rightmost_column = width - 1 

46 leftmost_column = rightmost_column 

47 while leftmost_column >= 0: 

48 if column_sum[leftmost_column] != 0: 

49 if rightmost_column - leftmost_column > threshold * width: 

50 image_arr = image_arr[:, 0:leftmost_column] 

51 return Image.fromarray(image_arr) 

52 

53 leftmost_column -= 1 

54 rightmost_column = leftmost_column 

55 else: 

56 leftmost_column -= 1 

57 

58 return Image.fromarray(image_arr) 

59 

60 

61def calculate_scale_factor( 

62 image: Image.Image, target_width: int, target_height: int 

63) -> float: 

64 """ 

65 Calculate the scale factor to normalize the input image to the target width and height while preserving the 

66 original aspect ratio. If the original aspect ratio is larger than the target aspect ratio, the scale factor 

67 will be calculated based on width. Otherwise, it will be calculated based on height. 

68 

69 Args: 

70 image (PIL.Image.Image): The input image to be normalized. 

71 target_width (int): The target width for normalization. 

72 target_height (int): The target height for normalization. 

73 

74 Returns: 

75 float: The scale factor to normalize the image. 

76 """ 

77 original_width, original_height = image.size 

78 original_aspect_ratio = original_width / original_height 

79 target_aspect_ratio = target_width / target_height 

80 

81 if original_aspect_ratio > target_aspect_ratio: 

82 # Calculate scale factor based on width 

83 scale_factor = target_width / original_width 

84 else: 

85 # Calculate scale factor based on height 

86 scale_factor = target_height / original_height 

87 

88 return scale_factor 

89 

90 

91def preprocess_img(image: Image.Image, config: dict) -> Image.Image: 

92 """preprocessing image - cropping, resizing, and padding""" 

93 # remove equation number if having 

94 image = remove_eqn_number(image) 

95 

96 # converting to np array 

97 image_arr = np.asarray(image, dtype=np.uint8) 

98 # find where the data lies 

99 indices = np.where(image_arr != 255) 

100 # get the boundaries 

101 x_min = np.min(indices[1]) 

102 x_max = np.max(indices[1]) 

103 y_min = np.min(indices[0]) 

104 y_max = np.max(indices[0]) 

105 

106 # cropping tha image 

107 image = image.crop((x_min, y_min, x_max, y_max)) 

108 

109 # calculate the target width and height 

110 target_width = config["preprocessed_image_width"] - 2 * config["padding"] 

111 target_height = config["preprocessed_image_height"] - 2 * config["padding"] 

112 # calculate the scale factor 

113 resize_factor = calculate_scale_factor(image, target_width, target_height) 

114 

115 # resizing the image 

116 image = image.resize( 

117 ( 

118 int(image.size[0] * resize_factor), 

119 int(image.size[1] * resize_factor), 

120 ), 

121 Image.LANCZOS, 

122 ) 

123 

124 # padding 

125 pad = config["padding"] 

126 width = config["preprocessed_image_width"] 

127 height = config["preprocessed_image_height"] 

128 new_image = Image.new("RGB", (width, height), (255, 255, 255)) 

129 new_image.paste(image, (pad, pad)) 

130 

131 return new_image 

132 

133 

134def convert_to_torch_tensor(image: bytes, config: dict) -> torch.Tensor: 

135 """Convert image to torch tensor.""" 

136 image = Image.open(io.BytesIO(image)).convert("L") 

137 image = preprocess_img(image, config) 

138 

139 # convert to tensor 

140 image = transforms.ToTensor()(image) 

141 

142 return image 

143 

144 

145def set_random_seed(seed: int) -> None: 

146 """Set up seed.""" 

147 

148 random.seed(seed) 

149 np.random.seed(seed) 

150 torch.manual_seed(seed) 

151 torch.cuda.manual_seed(seed) 

152 

153 

154def define_model( 

155 config: dict, vocab: List[str], device: torch.device, model_type="xfmer" 

156) -> Image2MathML_Xfmer: 

157 """ 

158 Defining the model 

159 initializing encoder, decoder, and model 

160 """ 

161 

162 print("Defining model...") 

163 

164 model_type = config["model_type"] 

165 input_channels = config["input_channels"] 

166 output_dim = len(vocab) 

167 emb_dim = config["embedding_dim"] 

168 dec_hid_dim = config["decoder_hid_dim"] 

169 dropout = config["dropout"] 

170 max_len = config["max_len"] 

171 

172 print(f"building {model_type} model...") 

173 

174 dim_feedfwd = config["dim_feedforward_for_xfmer"] 

175 n_heads = config["n_xfmer_heads"] 

176 n_xfmer_encoder_layers = config["n_xfmer_encoder_layers"] 

177 n_xfmer_decoder_layers = config["n_xfmer_decoder_layers"] 

178 len_dim = 2500 

179 

180 enc = { 

181 "CNN": CNN_Encoder(input_channels, dec_hid_dim, dropout, device), 

182 "XFMER": Transformer_Encoder( 

183 emb_dim, 

184 dec_hid_dim, 

185 n_heads, 

186 dropout, 

187 device, 

188 max_len, 

189 n_xfmer_encoder_layers, 

190 dim_feedfwd, 

191 len_dim, 

192 ), 

193 } 

194 dec = Transformer_Decoder( 

195 emb_dim, 

196 n_heads, 

197 dec_hid_dim, 

198 output_dim, 

199 dropout, 

200 max_len, 

201 n_xfmer_decoder_layers, 

202 dim_feedfwd, 

203 device, 

204 ) 

205 model = Image2MathML_Xfmer(enc, dec, vocab, device) 

206 

207 return model 

208 

209 

210def merge_mn_elements(input_mathml: str) -> str: 

211 """ 

212 Merge consecutive <mn> elements in the <mrow> structure of MathML. 

213 

214 Args: 

215 input_mathml (str): Input MathML string. 

216 

217 Returns: 

218 str: MathML string with consecutive <mn> elements merged. 

219 """ 

220 input_mathml = input_mathml.replace("&#x", "###") # keep the unicode representation 

221 root = ET.fromstring(input_mathml) 

222 

223 def merge_mn_in_mrow(mrow: ET.Element) -> List[ET.Element]: 

224 """ 

225 Merge consecutive <mn> elements within a given <mrow>. 

226 

227 Args: 

228 mrow (ET.Element): <mrow> element. 

229 

230 Returns: 

231 List[ET.Element]: List of merged elements. 

232 """ 

233 merged_elements = [] 

234 current_number = None 

235 

236 for child in mrow: 

237 if child.tag == 'mn': 

238 if current_number is None: 

239 current_number = int(child.text) 

240 else: 

241 current_number = current_number * 10 + int(child.text) 

242 else: 

243 if current_number is not None: 

244 merged_elements.append(ET.Element('mn')) 

245 merged_elements[-1].text = str(current_number) 

246 current_number = None 

247 merged_elements.append(child) 

248 

249 if current_number is not None: 

250 merged_elements.append(ET.Element('mn')) 

251 merged_elements[-1].text = str(current_number) 

252 

253 return merged_elements 

254 

255 def process_element(element: ET.Element) -> None: 

256 """ 

257 Recursively process each element in the MathML structure. 

258 

259 Args: 

260 element (ET.Element): Current element to process. 

261 """ 

262 if element.tag == 'mrow': 

263 element[:] = merge_mn_in_mrow(element) 

264 for child in element: 

265 process_element(child) 

266 else: 

267 for child in element: 

268 process_element(child) 

269 

270 process_element(root) 

271 modified_xml_string = ET.tostring(root, encoding="utf-8", method="xml").decode('utf-8') 

272 modified_xml_string = modified_xml_string.replace("###", "&#x") 

273 

274 return modified_xml_string 

275 

276 

277def process_mtext(xml_string: str) -> str: 

278 """ 

279 Process the input MathML string by merging consecutive <mi> elements and replacing specific <mrow> structures. 

280 

281 Args: 

282 xml_string (str): The input MathML string. 

283 

284 Returns: 

285 str: The modified MathML string. 

286 """ 

287 xml_string = xml_string.replace("&#x", "###") # keep the unicode representation 

288 root = ET.fromstring(xml_string) 

289 

290 def merge_mi_elements(elements): 

291 """ 

292 Merge consecutive <mi> elements into an <mtext> element. 

293 

294 Args: 

295 elements (list): List of consecutive <mi> elements. 

296 

297 Returns: 

298 Element: The merged <mtext> element. 

299 """ 

300 merged_text = "".join([elem.text for elem in elements]) 

301 mtext = ET.Element("mtext") 

302 mtext.text = merged_text 

303 return mtext 

304 

305 # Replace specific <mrow> structures with <mtext> 

306 for mrow in root.findall(".//mrow"): 

307 mi_count = sum(1 for child in mrow if child.tag == "mi") 

308 non_mi_count = sum(1 for child in mrow if child.tag != "mi") 

309 

310 if mi_count >= 3 and non_mi_count == 0: 

311 mi_elements = [child for child in mrow if child.tag == "mi"] 

312 mtext = merge_mi_elements(mi_elements) 

313 mrow.clear() 

314 mrow.tag = "to_be_removed" 

315 mrow.append(mtext) 

316 

317 mi_count = 0 

318 consecutive_mi_elements = [] 

319 new_children = [] 

320 

321 # Merge consecutive <mi> elements 

322 for child in root: 

323 if child.tag == "mi": 

324 mi_count += 1 

325 consecutive_mi_elements.append(child) 

326 else: 

327 if mi_count >= 5: 

328 mtext = merge_mi_elements(consecutive_mi_elements) 

329 new_children.append(mtext) 

330 else: 

331 new_children.extend(consecutive_mi_elements) 

332 mi_count = 0 

333 consecutive_mi_elements = [] 

334 new_children.append(child) 

335 

336 if mi_count >= 5: 

337 mtext = merge_mi_elements(consecutive_mi_elements) 

338 new_children.append(mtext) 

339 else: 

340 new_children.extend(consecutive_mi_elements) 

341 

342 root.clear() 

343 root.extend(new_children) 

344 

345 modified_xml_string = ET.tostring(root, encoding="utf-8", method="xml").decode('utf-8') 

346 modified_xml_string = modified_xml_string.replace("<to_be_removed>", "") 

347 modified_xml_string = modified_xml_string.replace("</to_be_removed>", "") 

348 modified_xml_string = modified_xml_string.replace("###", "&#x") 

349 

350 return modified_xml_string 

351 

352 

353def add_semicolon_to_unicode(string: str) -> str: 

354 """ 

355 Checks if the string contains Unicode starting with '&#x' and adds a semicolon ';' after each occurrence if missing. 

356 

357 Args: 

358 string (str): The input string to check. 

359 

360 Returns: 

361 str: The modified string with semicolons added after each Unicode occurrence if necessary. 

362 """ 

363 # Define a regular expression pattern to match '&#x' followed by hexadecimal characters 

364 pattern = r"&#x[0-9A-Fa-f]+" 

365 

366 def add_semicolon(match): 

367 unicode_value = match.group(0) 

368 if not unicode_value.endswith(";"): 

369 unicode_value += ";" 

370 return unicode_value 

371 

372 # Find all matches in the string using the pattern and process each match individually 

373 modified_string = re.sub(pattern, add_semicolon, string) 

374 

375 return modified_string 

376 

377 

378def remove_spaces_between_tags(mathml_string: str) -> str: 

379 """ 

380 Remove spaces between ">" and "<" in a MathML string. 

381 

382 Args: 

383 mathml_string (str): The MathML string to process. 

384 

385 Returns: 

386 str: The modified MathML string with spaces removed between tags. 

387 """ 

388 pattern = r">(.*?)<" 

389 replaced_string = re.sub( 

390 pattern, lambda match: match.group(0).replace(" ", ""), mathml_string 

391 ) 

392 return replaced_string 

393 

394 

395def render_mml( 

396 model: Image2MathML_Xfmer, 

397 vocab_itos: dict, 

398 vocab_stoi: dict, 

399 img: torch.Tensor, 

400 device: torch.device, 

401) -> str: 

402 """ 

403 Perform sequence prediction for an input image to translate it into MathML contents. 

404 

405 Args: 

406 model (Image2MathML_Xfmer): The image-to-MathML model. 

407 vocab_itos (dict): The vocabulary lookup dictionary (index to symbol). 

408 vocab_stoi (dict): The vocabulary lookup dictionary (symbol to index). 

409 img (torch.Tensor): The input image as a tensor. 

410 device (torch.device): The device (GPU or CPU) to be used for computation. 

411 

412 Returns: 

413 str: The generated MathML string. 

414 """ 

415 

416 model.eval() 

417 with torch.no_grad(): 

418 img = img.to(device) 

419 

420 output = model( 

421 img, 

422 device, 

423 is_inference=True, 

424 SOS_token=int(vocab_stoi["<sos>"]), 

425 EOS_token=int(vocab_stoi["<eos>"]), 

426 PAD_token=int(vocab_stoi["<pad>"]), 

427 ) # O: (1, max_len, output_dim), preds: (1, max_len) 

428 

429 pred = list() 

430 for p in output: 

431 pred.append(vocab_itos[str(p)]) 

432 

433 pred_seq = " ".join(pred[1:-1]) 

434 

435 try: 

436 res = add_semicolon_to_unicode(remove_spaces_between_tags(pred_seq)) 

437 res = merge_mn_elements(res) 

438 res = process_mtext(res) 

439 return res 

440 except: 

441 return res