Coverage for skema/gromet/execution_engine/types/sequence.py: 81%
83 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import numpy
2from typing import Union, List, Tuple, Any
3import itertools
5from skema.gromet.execution_engine.types.defined_types import Field, Sequence
7# TODO: Check the correctness for numpy arrays - How do n>1d arrays work in this case
10class Sequence_get(object):
11 source_language_name = {"CAST": "sequence_get"}
12 inputs = [
13 Field("sequence_input", "Sequence"),
14 Field("index", "DimensionalIndex"),
15 ]
16 outputs = [Field("sequence_output", "Sequence")]
17 shorthand = "sequence_get"
18 documentation = ""
21class Sequence_set(object):
22 source_language_name = {"CAST": "sequence_set"}
23 inputs = [
24 Field("sequence_input", "Sequence"),
25 Field("index", "DimensionalIndex"),
26 Field("element", "Any"),
27 ]
28 outputs = [Field("sequence_output", "Sequence")]
29 shorthand = "sequence_set"
30 documentation = ""
33class Sequence_concatenate(object):
34 source_language_name = {"CAST": "concatenate"}
35 inputs = [Field("sequence_inputs", "Sequence", True)]
36 outputs = [Field("sequence_output", "Sequence")]
37 shorthand = ""
38 documentation = ""
40 def exec(*sequence_inputs: Sequence) -> Sequence:
41 # TODO: How do we handle type checking, whose responsibility should it be?
42 assert type(
43 sequence_inputs[0] != range
44 ) # Range type doesn't support concatenation
45 assert all(
46 isinstance(sequence, type(sequence_inputs[0]))
47 for sequence in sequence_inputs
48 ) # Cannot concatenate sequences of different types
50 if isinstance(sequence_inputs[0], numpy.ndarray):
51 Sequence_concatenate.Array_concatenate(sequence_inputs)
52 else:
53 return type(sequence_inputs[0])(
54 itertools.chain.from_iterable(sequence_inputs)
55 )
57 def Array_concatenate(
58 array_inputs: Tuple[numpy.ndarray, ...]
59 ) -> numpy.ndarray:
60 return numpy.concatenate(array_inputs)
63class Sequence_replicate(object):
64 source_language_name = {"CAST": "replicate"}
65 inputs = [Field("sequence_input", "Sequence"), Field("count", "Integer")]
66 outputs = [Field("sequence_output", "Sequence")]
67 shorthand = ""
68 documentation = ""
70 def exec(sequence_input: Sequence, count: int) -> Sequence:
71 assert type(sequence_input != range)
72 if isinstance(sequence_input, numpy.ndarray):
73 return Sequence_replicate.Array_replicate(sequence_input, count)
74 else:
75 return sequence_input * count
77 def Array_replicate(
78 array_input: numpy.ndarray, count: int
79 ) -> numpy.ndarray:
80 return numpy.tile(array_input, count)
83class Sequence_length(object):
84 source_language_name = {"CAST": "length"}
85 inputs = [Field("sequence_input", "Sequence")]
86 outputs = [Field("length", "Integer")]
87 shorthand = ""
88 documentation = ""
90 def exec(sequence_input: Sequence) -> int:
91 return len(sequence_input)
94class Sequence_min(object):
95 source_language_name = {"CAST": "min"}
96 inputs = [Field("sequence_input", "Sequence")]
97 outputs = [Field("minimum", "Any")]
98 shorthand = ""
99 documentation = ""
101 def exec(sequence_input: Sequence) -> Any:
102 return min(list(sequence_input))
105class Sequence_max(object):
106 source_language_name = {"CAST": "max"}
107 inputs = [Field("sequence_input", "Sequence")]
108 outputs = [Field("maximum", "Any")]
109 shorthand = ""
110 documentation = ""
112 def exec(sequence_input: Sequence) -> Any:
113 return max(list(sequence_input))
116class Sequence_count(object):
117 source_language_name = {"CAST": "count"}
118 inputs = [Field("sequence_input", "Sequence"), Field("element", "Any")]
119 outputs = [Field("count", "Integer")]
120 shorthand = ""
121 documentation = ""
123 def exec(sequence_input: Sequence, element: Any) -> Any:
124 return list(sequence_input).count(element)
127class Sequence_index(object):
128 source_language_name = {"CAST": "index"}
129 inputs = [Field("list_input", "List"), Field("element", "Any")]
130 outputs = [Field("index", "Integer")]
131 shorthand = ""
132 documentation = ""
134 def exec(list_input: List, element: Any) -> Any:
135 return list(list_input).index(element)
137#class Sequence_pop(object):
138 # source_language_name = {"CAST": "pop"}
139 # inputs = [Field("list_input", "List"), Field("index", "Integer")]
140 # outputs = [Field("value", "Any"), Field("list_output", "List")]
141 #shorthand = ""
142 #documentation = ""
144# def exec(list_input: List, element: Any) -> Any:
145 # return list(list_input).pop(element)