1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
use crate::{
    errors::{unexpected, Result},
    messages::{BoltRequest, BoltResponse},
    pool::ManagedConnection,
    stream::{DetachedRowStream, RowStream},
    types::{BoltList, BoltMap, BoltString, BoltType},
};

/// Abstracts a cypher query that is sent to neo4j server.
#[derive(Clone)]
pub struct Query {
    query: String,
    params: BoltMap,
}

impl Query {
    pub fn new(query: String) -> Self {
        Query {
            query,
            params: BoltMap::default(),
        }
    }

    pub fn param<T: Into<BoltType>>(mut self, key: &str, value: T) -> Self {
        self.params.put(key.into(), value.into());
        self
    }

    pub fn params<K, V>(mut self, input_params: impl IntoIterator<Item = (K, V)>) -> Self
    where
        K: Into<BoltString>,
        V: Into<BoltType>,
    {
        for (key, value) in input_params {
            self.params.put(key.into(), value.into());
        }

        self
    }

    pub fn has_param_key(&self, key: &str) -> bool {
        self.params.value.contains_key(key)
    }

    pub(crate) async fn run(self, db: &str, connection: &mut ManagedConnection) -> Result<()> {
        let run = BoltRequest::run(db, &self.query, self.params);
        match connection.send_recv(run).await? {
            BoltResponse::Success(_) => match connection.send_recv(BoltRequest::discard()).await? {
                BoltResponse::Success(_) => Ok(()),
                msg => Err(unexpected(msg, "DISCARD")),
            },
            msg => Err(unexpected(msg, "RUN")),
        }
    }

    pub(crate) async fn execute(
        self,
        db: &str,
        fetch_size: usize,
        mut connection: ManagedConnection,
    ) -> Result<DetachedRowStream> {
        let stream = self.execute_mut(db, fetch_size, &mut connection).await?;
        Ok(DetachedRowStream::new(stream, connection))
    }

    pub(crate) async fn execute_mut<'conn>(
        self,
        db: &str,
        fetch_size: usize,
        connection: &'conn mut ManagedConnection,
    ) -> Result<RowStream> {
        let run = BoltRequest::run(db, &self.query, self.params);
        match connection.send_recv(run).await {
            Ok(BoltResponse::Success(success)) => {
                let fields: BoltList = success.get("fields").unwrap_or_default();
                let qid: i64 = success.get("qid").unwrap_or(-1);
                Ok(RowStream::new(qid, fields, fetch_size))
            }
            msg => Err(unexpected(msg, "RUN")),
        }
    }
}

impl From<String> for Query {
    fn from(query: String) -> Self {
        Query::new(query)
    }
}

impl From<&str> for Query {
    fn from(query: &str) -> Self {
        Query::new(query.to_owned())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn add_params() {
        let q = Query::new("MATCH (n) WHERE n.name = $name AND n.age > $age RETURN n".to_owned());
        let q = q.params([
            ("name", BoltType::from("Frobniscante")),
            ("age", BoltType::from(42)),
        ]);

        assert_eq!(
            q.params.get::<String>("name").unwrap(),
            String::from("Frobniscante")
        );
        assert_eq!(q.params.get::<i64>("age").unwrap(), 42);

        assert!(q.has_param_key("name"));
        assert!(!q.has_param_key("country"));
    }
}