Coverage for skema/gromet/execution_engine/types/sequence.py: 81%

83 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import numpy 

2from typing import Union, List, Tuple, Any 

3import itertools 

4 

5from skema.gromet.execution_engine.types.defined_types import Field, Sequence 

6 

7# TODO: Check the correctness for numpy arrays - How do n>1d arrays work in this case 

8 

9 

10class Sequence_get(object): 

11 source_language_name = {"CAST": "sequence_get"} 

12 inputs = [ 

13 Field("sequence_input", "Sequence"), 

14 Field("index", "DimensionalIndex"), 

15 ] 

16 outputs = [Field("sequence_output", "Sequence")] 

17 shorthand = "sequence_get" 

18 documentation = "" 

19 

20 

21class Sequence_set(object): 

22 source_language_name = {"CAST": "sequence_set"} 

23 inputs = [ 

24 Field("sequence_input", "Sequence"), 

25 Field("index", "DimensionalIndex"), 

26 Field("element", "Any"), 

27 ] 

28 outputs = [Field("sequence_output", "Sequence")] 

29 shorthand = "sequence_set" 

30 documentation = "" 

31 

32 

33class Sequence_concatenate(object): 

34 source_language_name = {"CAST": "concatenate"} 

35 inputs = [Field("sequence_inputs", "Sequence", True)] 

36 outputs = [Field("sequence_output", "Sequence")] 

37 shorthand = "" 

38 documentation = "" 

39 

40 def exec(*sequence_inputs: Sequence) -> Sequence: 

41 # TODO: How do we handle type checking, whose responsibility should it be? 

42 assert type( 

43 sequence_inputs[0] != range 

44 ) # Range type doesn't support concatenation 

45 assert all( 

46 isinstance(sequence, type(sequence_inputs[0])) 

47 for sequence in sequence_inputs 

48 ) # Cannot concatenate sequences of different types 

49 

50 if isinstance(sequence_inputs[0], numpy.ndarray): 

51 Sequence_concatenate.Array_concatenate(sequence_inputs) 

52 else: 

53 return type(sequence_inputs[0])( 

54 itertools.chain.from_iterable(sequence_inputs) 

55 ) 

56 

57 def Array_concatenate( 

58 array_inputs: Tuple[numpy.ndarray, ...] 

59 ) -> numpy.ndarray: 

60 return numpy.concatenate(array_inputs) 

61 

62 

63class Sequence_replicate(object): 

64 source_language_name = {"CAST": "replicate"} 

65 inputs = [Field("sequence_input", "Sequence"), Field("count", "Integer")] 

66 outputs = [Field("sequence_output", "Sequence")] 

67 shorthand = "" 

68 documentation = "" 

69 

70 def exec(sequence_input: Sequence, count: int) -> Sequence: 

71 assert type(sequence_input != range) 

72 if isinstance(sequence_input, numpy.ndarray): 

73 return Sequence_replicate.Array_replicate(sequence_input, count) 

74 else: 

75 return sequence_input * count 

76 

77 def Array_replicate( 

78 array_input: numpy.ndarray, count: int 

79 ) -> numpy.ndarray: 

80 return numpy.tile(array_input, count) 

81 

82 

83class Sequence_length(object): 

84 source_language_name = {"CAST": "length"} 

85 inputs = [Field("sequence_input", "Sequence")] 

86 outputs = [Field("length", "Integer")] 

87 shorthand = "" 

88 documentation = "" 

89 

90 def exec(sequence_input: Sequence) -> int: 

91 return len(sequence_input) 

92 

93 

94class Sequence_min(object): 

95 source_language_name = {"CAST": "min"} 

96 inputs = [Field("sequence_input", "Sequence")] 

97 outputs = [Field("minimum", "Any")] 

98 shorthand = "" 

99 documentation = "" 

100 

101 def exec(sequence_input: Sequence) -> Any: 

102 return min(list(sequence_input)) 

103 

104 

105class Sequence_max(object): 

106 source_language_name = {"CAST": "max"} 

107 inputs = [Field("sequence_input", "Sequence")] 

108 outputs = [Field("maximum", "Any")] 

109 shorthand = "" 

110 documentation = "" 

111 

112 def exec(sequence_input: Sequence) -> Any: 

113 return max(list(sequence_input)) 

114 

115 

116class Sequence_count(object): 

117 source_language_name = {"CAST": "count"} 

118 inputs = [Field("sequence_input", "Sequence"), Field("element", "Any")] 

119 outputs = [Field("count", "Integer")] 

120 shorthand = "" 

121 documentation = "" 

122 

123 def exec(sequence_input: Sequence, element: Any) -> Any: 

124 return list(sequence_input).count(element) 

125 

126 

127class Sequence_index(object): 

128 source_language_name = {"CAST": "index"} 

129 inputs = [Field("list_input", "List"), Field("element", "Any")] 

130 outputs = [Field("index", "Integer")] 

131 shorthand = "" 

132 documentation = "" 

133 

134 def exec(list_input: List, element: Any) -> Any: 

135 return list(list_input).index(element) 

136 

137#class Sequence_pop(object): 

138 # source_language_name = {"CAST": "pop"} 

139 # inputs = [Field("list_input", "List"), Field("index", "Integer")] 

140 # outputs = [Field("value", "Any"), Field("list_output", "List")] 

141 #shorthand = "" 

142 #documentation = "" 

143 

144# def exec(list_input: List, element: Any) -> Any: 

145 # return list(list_input).pop(element)