Coverage for skema/program_analysis/tests/test_conditional.py: 100%

83 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1# import json NOTE: json and Path aren't used right now, 

2# from pathlib import Path but will be used in the future 

3from skema.program_analysis.multi_file_ingester import process_file_system 

4from skema.gromet.fn import GrometFNModuleCollection 

5from skema.gromet.fn import FunctionType 

6import ast 

7 

8from skema.program_analysis.CAST.pythonAST import py_ast_to_cast 

9from skema.program_analysis.CAST2FN.model.cast import SourceRef 

10from skema.program_analysis.CAST2FN import cast 

11from skema.program_analysis.CAST2FN.cast import CAST 

12from skema.program_analysis.run_ann_cast_pipeline import ann_cast_pipeline 

13 

14# NOTE: these examples are very trivial for the realm of recursion 

15# more complex ones will follow later as needed 

16 

17def cond1(): 

18 return """ 

19x = 2 

20 

21if x < 5: 

22 x = x + 1 

23else: 

24 x = x - 3 

25 """ 

26 

27def cond2(): 

28 return """ 

29x = 2 

30y = 4 

31 

32if x < 5: 

33 x = x + y 

34else: 

35 x = x - 3  

36 """ 

37 

38def generate_gromet(test_file_string): 

39 # use ast.Parse to get Python AST 

40 contents = ast.parse(test_file_string) 

41 

42 # use Python to CAST 

43 line_count = len(test_file_string.split("\n")) 

44 convert = py_ast_to_cast.PyASTToCAST("temp") 

45 C = convert.visit(contents, {}, {}) 

46 C.source_refs = [SourceRef("temp", None, None, 1, line_count)] 

47 out_cast = cast.CAST([C], "python") 

48 

49 # use AnnCastPipeline to create GroMEt 

50 gromet = ann_cast_pipeline(out_cast, gromet=True, to_file=False, from_obj=True) 

51 

52 return gromet 

53 

54def test_cond1(): 

55 cond_gromet = generate_gromet(cond1()) 

56 base_fn = cond_gromet.fn 

57 predicate_fn = cond_gromet.fn_array[1] 

58 

59 assert predicate_fn.b[0].function_type == FunctionType.PREDICATE 

60 

61 # Base FN with cond 

62 

63 assert len(base_fn.pof) == 1 

64 assert len(base_fn.pic) == 1 

65 assert len(base_fn.poc) == 1 

66 assert len(base_fn.wfc) == 1 

67 assert len(base_fn.bc) == 1 

68 

69 # Check Ports 

70 assert base_fn.pic[0].name == "x" and base_fn.pic[0].box == 1 

71 assert base_fn.poc[0].name == "x" and base_fn.poc[0].box == 1 

72 

73 # Check bc box 

74 assert base_fn.bc[0].condition == 2 and base_fn.bc[0].body_if == 3 

75 assert base_fn.bc[0].body_else == 5 

76 

77 # Check predicate 

78 assert len(predicate_fn.opi) == 1 

79 assert len(predicate_fn.opo) == 2 

80 assert len(predicate_fn.wopio) == 1 

81 assert len(predicate_fn.wfopi) == 1 

82 assert len(predicate_fn.wff) == 1 

83 assert len(predicate_fn.wfopo) == 1 

84 

85 # Check bf count 

86 assert len(predicate_fn.bf) == 2 

87 

88 # Check port boxes 

89 assert predicate_fn.pif[0].box == 2 and predicate_fn.pif[1].box == 2 

90 assert predicate_fn.pof[0].box == 1 and predicate_fn.pof[1].box == 2 

91 

92 # Check wires 

93 assert predicate_fn.wopio[0].src == 1 and predicate_fn.wopio[0].tgt == 1 

94 

95 assert predicate_fn.wfopi[0].src == 1 and predicate_fn.wfopi[0].tgt == 1 

96 assert predicate_fn.wff[0].src == 2 and predicate_fn.wff[0].tgt == 1 

97 assert predicate_fn.wfopo[0].src == 2 and predicate_fn.wfopo[0].tgt == 2 

98 

99 # Check bf 

100 assert predicate_fn.bf[1].name == "ast.Lt" 

101 

102 

103def test_cond2(): 

104 cond_gromet = generate_gromet(cond2()) 

105 base_fn = cond_gromet.fn 

106 predicate_fn = cond_gromet.fn_array[2] 

107 

108 assert predicate_fn.b[0].function_type == FunctionType.PREDICATE 

109 

110 # Base FN with cond 

111 

112 assert len(base_fn.pof) == 2 

113 assert len(base_fn.pic) == 2 

114 assert len(base_fn.poc) == 2 

115 assert len(base_fn.wfc) == 2 

116 assert len(base_fn.bc) == 1 

117 

118 # Check Ports 

119 assert base_fn.pic[0].name == "x" and base_fn.pic[0].box == 1 

120 assert base_fn.pic[1].name == "y" and base_fn.pic[1].box == 1 

121 assert base_fn.poc[0].name == "x" and base_fn.poc[0].box == 1 

122 assert base_fn.poc[1].name == "y" and base_fn.poc[1].box == 1 

123 

124 # Check bc box 

125 assert base_fn.bc[0].condition == 3 and base_fn.bc[0].body_if == 4 

126 assert base_fn.bc[0].body_else == 6 

127 

128 # Check predicate 

129 assert len(predicate_fn.opi) == 2 

130 assert len(predicate_fn.opo) == 3 

131 assert len(predicate_fn.wopio) == 2 

132 assert len(predicate_fn.wfopi) == 1 

133 assert len(predicate_fn.wff) == 1 

134 assert len(predicate_fn.wfopo) == 1 

135 

136 # Check bf count 

137 assert len(predicate_fn.bf) == 2 

138 

139 # Check port boxes 

140 assert predicate_fn.pif[0].box == 2 and predicate_fn.pif[1].box == 2 

141 assert predicate_fn.pof[0].box == 1 and predicate_fn.pof[1].box == 2 

142 

143 # Check wires 

144 assert predicate_fn.wopio[0].src == 1 and predicate_fn.wopio[0].tgt == 1 

145 

146 assert predicate_fn.wfopi[0].src == 1 and predicate_fn.wfopi[0].tgt == 1 

147 assert predicate_fn.wff[0].src == 2 and predicate_fn.wff[0].tgt == 1 

148 assert predicate_fn.wfopo[0].src == 3 and predicate_fn.wfopo[0].tgt == 2 

149 

150 # Check bf 

151 assert predicate_fn.bf[1].name == "ast.Lt" 

152 

153 

154def test_conditional(): 

155 test_cond1() 

156 test_cond2()