Coverage for skema/gromet/execution_engine/query_runner.py: 94%

34 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1import yaml 

2import traceback 

3from pathlib import Path 

4 

5from neo4j import GraphDatabase 

6 

7QUERIES_PATH = Path(__file__).parent / "queries.yaml" 

8 

9 

10class QueryRunner: 

11 def __init__(self, protocol: str, host: str, port: str): 

12 # First set up the queries map 

13 self.queries_path = QUERIES_PATH 

14 self.queries_map = yaml.safe_load(self.queries_path.read_text()) 

15 

16 # Set up memgrpah instance 

17 self.memgraph = GraphDatabase.driver( 

18 uri=f"{protocol}{host}:{port}", auth=("", "") 

19 ) 

20 self.memgraph.verify_connectivity() 

21 

22 def run_query( 

23 self, 

24 query: str, 

25 n_or_m: str = "m", 

26 filename: str = None, 

27 function: str = None, 

28 id: str = None, 

29 ): 

30 # Check if query is in query map. Currently we return None if its not found 

31 # TODO: Improve error handling 

32 if query not in self.queries_map: 

33 return None 

34 query = self.queries_map[query] 

35 

36 # There are times we will want to limit the scope we are running queries in. 

37 # This is done be adding clauses to the cypher queries. 

38 if filename: 

39 query = query.replace("$FILENAME", filename) 

40 

41 if id is not None: 

42 query = query.replace("$ID", str(id)) 

43 

44 # In most cases, we only want the node objects itself. So we will just return a list of nodes. 

45 records, summary, keys = self.memgraph.execute_query( 

46 query, database_="memgraph" 

47 ) 

48 return neo4j_to_memgprah(records, n_or_m) 

49 

50 

51def neo4j_to_memgprah(neo4j_output, n_or_m: str): 

52 """Converts neo4j output format to memgraph output format""" 

53 

54 class DummyNode: 

55 pass 

56 

57 results = [] 

58 for record in neo4j_output: 

59 node_ptr = dict(record)[n_or_m] 

60 

61 dummy_node = DummyNode() 

62 dummy_node._labels = list(node_ptr.labels) 

63 dummy_node._id = node_ptr.element_id 

64 

65 for key, value in node_ptr._properties.items(): 

66 setattr(dummy_node, key, value) 

67 

68 results.append(dummy_node) 

69 

70 return results