Coverage for skema/img2mml/tests/test_model_loading.py: 82%

34 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1from pathlib import Path 

2from skema.img2mml.api import get_mathml_from_bytes, retrieve_model, Image2MathML 

3import os 

4 

5 

6def test_model_retrieval(): 

7 """Tests model retrieval""" 

8 model_path = retrieve_model() 

9 # Delete the model file if it exists 

10 if Path(model_path).exists(): 

11 os.remove(model_path) 

12 # Retrieve the model again 

13 model_path = retrieve_model() 

14 assert Path(model_path).exists(), f"model was not found at {model_path}" 

15 

16 

17def local_loading(): 

18 """Tests local loading files""" 

19 cwd = Path(__file__).parents[0].parents[0] 

20 config_path = cwd / "configs" / "xfmer_mml_config.json" 

21 vocab_path = ( 

22 cwd / "trained_models" / "arxiv_im2mml_with_fonts_with_boldface_vocab.txt" 

23 ) 

24 model_path = ( 

25 cwd / "trained_models" / "cnn_xfmer_arxiv_im2mml_with_fonts_boldface_best.pt" 

26 ) 

27 

28 image2mathml_db = Image2MathML( 

29 config_path=config_path, vocab_path=vocab_path, model_path=model_path 

30 ) 

31 assert image2mathml_db.model != None, "Fail to load the model checkpoint" 

32 assert image2mathml_db.vocab != None, "Fail to load the vocabulary file" 

33 assert image2mathml_db.config != None, "Fail to load the configuration file" 

34 return image2mathml_db 

35 

36 

37def test_local_loading_prediction(): 

38 """Tests model loading and prediction""" 

39 # a) Local loading test 

40 image2mathml_db = local_loading() 

41 # b) Prediction test 

42 cwd = Path(__file__).parents[0] 

43 image_path = cwd / "data" / "261.png" 

44 with Path(image_path).open("rb") as infile: 

45 img_bytes = infile.read() 

46 

47 try: 

48 mathml = get_mathml_from_bytes(img_bytes, image2mathml_db) 

49 except FileNotFoundError: 

50 raise FileNotFoundError(f"Model state dictionary file not found") 

51 except RuntimeError: 

52 raise RuntimeError(f"Error loading state dictionary from file") 

53 except Exception as e: 

54 raise Exception(f"Error converting the image: {e}") 

55 assert mathml is not None, "model failed to generate mml from image"