Coverage for skema/model_assembly/sandbox.py: 51%

41 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1from typing import List, Callable 

2from numbers import Number, Real 

3import re 

4 

5import numpy as np 

6 

7# Import math functions that may be used in lambda functions. Import here so 

8# they can be used in the eval() call of lambda strings 

9from math import cos, exp, sqrt 

10 

11 

12UNSAFE_BUILTINS = re.compile( 

13 r""" 

14 (?=,)?(?=\s)? # Match if proceeded by a comma or whitespace 

15 ( 

16 __import__\([\"'][A-Za-z_]+[\"']\) | # Match the import runtime var 

17 __loader__\. | # Match the loader runtime var 

18 globals\(\) | # Match the global() builtin 

19 locals\(\) # Match the local() builtin 

20 ) 

21 """, 

22 re.VERBOSE, 

23) 

24UNSAFE_IMPORT = r"\bimport [A-Za-z_]+\b" 

25 

26 

27class UnsafeOperationError(EnvironmentError): 

28 pass 

29 

30 

31class BadLambdaError(ValueError): 

32 pass 

33 

34 

35class BadDerivedTypeError(ValueError): 

36 pass 

37 

38 

39def load_lambda_function(func_str: str) -> Callable: 

40 # Checking to ensure the string has no executable import statements 

41 if re.search(UNSAFE_BUILTINS, func_str) is not None: 

42 raise UnsafeOperationError(f"found in lambda:\n{func_str}") 

43 

44 # Checking for expected lambda expression header 

45 if not func_str.startswith("lambda"): 

46 raise RuntimeError(f"Lambda expression does not start with 'lambda'\n") 

47 

48 try: 

49 func_ref = eval(func_str) 

50 

51 # Checking to see if eval() produced a callable object 

52 if not isinstance(func_ref, Callable): 

53 raise BadLambdaError(f"Callable not found for lambda:\n{func_str}") 

54 

55 return func_ref 

56 except Exception as e: 

57 print(f"eval() failed for lambda: {func_str}") 

58 raise e 

59 

60 

61def load_derived_type(type_str: str) -> None: 

62 # Checking to ensure the string has no executable import statements 

63 bad_match = re.search(rf"({UNSAFE_BUILTINS})|({UNSAFE_IMPORT})", type_str) 

64 if bad_match is not None: 

65 raise UnsafeOperationError(f"found in derived-type:\n{type_str}") 

66 

67 # Check for a dataclass structure match and extract the class name 

68 type_name_match = re.match( 

69 r"(?<=@dataclass\nclass )[A-Za-z_]+(?=:)", type_str 

70 ) 

71 # Checking to see if the string starts with a dataclass 

72 if type_name_match is None: 

73 raise RuntimeError(f"Unexpected form for derived type:\n{type_str}") 

74 

75 try: 

76 exec(type_str) 

77 

78 # Checking to see if exec() produced the derived type class as a member 

79 # of the Python globals() object 

80 type_name = type_name_match.group() 

81 if type_name not in globals(): 

82 raise BadDerivedTypeError( 

83 f"{type_name} not found for derived-type: {type_str}" 

84 ) 

85 except Exception as e: 

86 print(f"exec() failed for derived-type: {type_str}") 

87 raise e