Coverage for skema/gromet/execution_engine/tests/test_execution.py: 100%
15 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1import pytest
2import torch
3from pathlib import Path
4from tempfile import TemporaryDirectory, TemporaryFile
6from skema.rest.proxies import SKEMA_GRAPH_DB_PROTO, SKEMA_GRAPH_DB_HOST, SKEMA_GRAPH_DB_PORT
7from skema.gromet.execution_engine.execution_engine import ExecutionEngine
10@pytest.mark.ci_only
11def test_parameter_extraction():
12 """Unit test for testing basic parameter extraction with execution engine"""
13 input = """
14x = 2
15y = x+1
16z = x+y
17"""
18 expected_output = {"x": torch.tensor(2), "y": torch.tensor(3), "z": torch.tensor(5)}
20 with TemporaryDirectory() as temp:
21 source_path = Path(temp) / "test_parameter_extraction.py"
22 source_path.write_text(input)
24 output = ExecutionEngine(
25 protocol=SKEMA_GRAPH_DB_PROTO, host=SKEMA_GRAPH_DB_HOST, port=SKEMA_GRAPH_DB_PORT, source_path=str(source_path)
26 ).parameter_extraction()
28 # torch.tensor overrides the equality '==' operator, so the following is a valid check
29 assert output == expected_output