Coverage for skema/img2mml/tests/test_model_loading.py: 82%
34 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1from pathlib import Path
2from skema.img2mml.api import get_mathml_from_bytes, retrieve_model, Image2MathML
3import os
6def test_model_retrieval():
7 """Tests model retrieval"""
8 model_path = retrieve_model()
9 # Delete the model file if it exists
10 if Path(model_path).exists():
11 os.remove(model_path)
12 # Retrieve the model again
13 model_path = retrieve_model()
14 assert Path(model_path).exists(), f"model was not found at {model_path}"
17def local_loading():
18 """Tests local loading files"""
19 cwd = Path(__file__).parents[0].parents[0]
20 config_path = cwd / "configs" / "xfmer_mml_config.json"
21 vocab_path = (
22 cwd / "trained_models" / "arxiv_im2mml_with_fonts_with_boldface_vocab.txt"
23 )
24 model_path = (
25 cwd / "trained_models" / "cnn_xfmer_arxiv_im2mml_with_fonts_boldface_best.pt"
26 )
28 image2mathml_db = Image2MathML(
29 config_path=config_path, vocab_path=vocab_path, model_path=model_path
30 )
31 assert image2mathml_db.model != None, "Fail to load the model checkpoint"
32 assert image2mathml_db.vocab != None, "Fail to load the vocabulary file"
33 assert image2mathml_db.config != None, "Fail to load the configuration file"
34 return image2mathml_db
37def test_local_loading_prediction():
38 """Tests model loading and prediction"""
39 # a) Local loading test
40 image2mathml_db = local_loading()
41 # b) Prediction test
42 cwd = Path(__file__).parents[0]
43 image_path = cwd / "data" / "261.png"
44 with Path(image_path).open("rb") as infile:
45 img_bytes = infile.read()
47 try:
48 mathml = get_mathml_from_bytes(img_bytes, image2mathml_db)
49 except FileNotFoundError:
50 raise FileNotFoundError(f"Model state dictionary file not found")
51 except RuntimeError:
52 raise RuntimeError(f"Error loading state dictionary from file")
53 except Exception as e:
54 raise Exception(f"Error converting the image: {e}")
55 assert mathml is not None, "model failed to generate mml from image"