diff --git a/Cargo.lock b/Cargo.lock index 7374aefc6..112fc0800 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,7 +45,7 @@ checksum = "1cd7fce9ba8c3c042128ce72d8b2ddbf3a05747efb67ea0313c635e10bda47a2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -236,7 +236,7 @@ dependencies = [ "proc-macro2", "quote", "scratch", - "syn", + "syn 1.0.107", ] [[package]] @@ -253,7 +253,7 @@ checksum = "ebf883b7aacd7b2aeb2a7b338648ee19f57c140d4ee8e52c68979c6b2f7f2263" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -387,7 +387,7 @@ checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -800,9 +800,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" dependencies = [ "unicode-ident", ] @@ -854,7 +854,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -865,14 +865,24 @@ checksum = "e0b78ccbb160db1556cdb6fd96c50334c5d4ec44dc5e0a968d0a1208fa0efa8b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", +] + +[[package]] +name = "pythonize" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e35b716d430ace57e2d1b4afb51c9e5b7c46d2bce72926e07f9be6a98ced03e" +dependencies = [ + "pyo3", + "serde", ] [[package]] name = "quote" -version = "1.0.23" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "5fe8a65d69dd0808184ebb5f836ab526bb259db23c657efa38711b1072ee47f0" dependencies = [ "proc-macro2", ] @@ -1032,22 +1042,22 @@ checksum = "ddccb15bcce173023b3fedd9436f882a0739b8dfb45e4f6b6002bee5929f61b2" [[package]] name = "serde" -version = "1.0.152" +version = "1.0.174" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" +checksum = "3b88756493a5bd5e5395d53baa70b194b05764ab85b59e43e4b8f4e1192fa9b1" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.152" +version = "1.0.174" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" +checksum = "6e5c3a298c7f978e53536f95a63bdc4c4a64550582f31a0359a9afda6aede62e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.27", ] [[package]] @@ -1111,6 +1121,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b60f673f44a8255b9c8c657daf66a596d435f2da81a555b06dc644d080ba45e0" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "tantivy" version = "0.20.1" @@ -1120,6 +1141,8 @@ dependencies = [ "itertools", "pyo3", "pyo3-build-config", + "pythonize", + "serde", "serde_json", "tantivy 0.20.2", ] @@ -1313,7 +1336,7 @@ checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -1384,7 +1407,7 @@ checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -1505,7 +1528,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 1.0.107", "wasm-bindgen-shared", ] @@ -1527,7 +1550,7 @@ checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/Cargo.toml b/Cargo.toml index a46cc1885..dfab42791 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,8 @@ chrono = "0.4.23" tantivy = "0.20.1" itertools = "0.10.5" futures = "0.3.26" +pythonize = "0.19.0" +serde = "~1.0" serde_json = "1.0.91" [dependencies.pyo3] diff --git a/src/document.rs b/src/document.rs index 529567a4e..3c9b08ead 100644 --- a/src/document.rs +++ b/src/document.rs @@ -12,15 +12,20 @@ use pyo3::{ use chrono::{offset::TimeZone, NaiveDateTime, Utc}; -use tantivy as tv; +use tantivy::{self as tv, schema::Value}; use crate::{facet::Facet, schema::Schema, to_pyerr}; +use serde::{ + de::{MapAccess, Visitor}, + ser::SerializeMap, + Deserialize, Deserializer, Serialize, Serializer, +}; use serde_json::Value as JsonValue; use std::{ collections::{BTreeMap, HashMap}, fmt, + net::Ipv6Addr, }; -use tantivy::schema::Value; fn value_to_object(val: &JsonValue, py: Python<'_>) -> PyObject { match val { @@ -106,6 +111,98 @@ fn value_to_string(value: &Value) -> String { } } +/// Serializes a [`tv::DateTime`] object. +/// +/// Since tantivy stores it as a single `i64` nanosecond timestamp, it is serialized and +/// deserialized as one. +fn serialize_datetime( + dt: &tv::DateTime, + serializer: S, +) -> Result { + dt.into_timestamp_nanos().serialize(serializer) +} + +/// Deserializes a [`tv::DateTime`] object. +/// +/// Since tantivy stores it as a single `i64` nanosecond timestamp, it is serialized and +/// deserialized as one. +fn deserialize_datetime<'de, D>( + deserializer: D, +) -> Result +where + D: Deserializer<'de>, +{ + i64::deserialize(deserializer).map(tv::DateTime::from_timestamp_nanos) +} + +/// An equivalent type to [`tantivy::schema::Value`] but uses tagging in its serialization to +/// differentiate between different integer types. +#[derive(Deserialize, Serialize)] +enum SerdeValue { + /// The str type is used for any text information. + Str(String), + /// Pre-tokenized str type, + PreTokStr(tv::tokenizer::PreTokenizedString), + /// Unsigned 64-bits Integer `u64` + U64(u64), + /// Signed 64-bits Integer `i64` + I64(i64), + /// 64-bits Float `f64` + F64(f64), + /// Bool value + Bool(bool), + #[serde( + deserialize_with = "deserialize_datetime", + serialize_with = "serialize_datetime" + )] + /// Date/time with microseconds precision + Date(tv::DateTime), + /// Facet + Facet(tv::schema::Facet), + /// Arbitrarily sized byte array + Bytes(Vec), + /// Json object value. + JsonObject(serde_json::Map), + /// IpV6 Address. Internally there is no IpV4, it needs to be converted to `Ipv6Addr`. + IpAddr(Ipv6Addr), +} + +impl From for Value { + fn from(value: SerdeValue) -> Self { + match value { + SerdeValue::Str(v) => Self::Str(v), + SerdeValue::PreTokStr(v) => Self::PreTokStr(v), + SerdeValue::U64(v) => Self::U64(v), + SerdeValue::I64(v) => Self::I64(v), + SerdeValue::F64(v) => Self::F64(v), + SerdeValue::Date(v) => Self::Date(v), + SerdeValue::Facet(v) => Self::Facet(v), + SerdeValue::Bytes(v) => Self::Bytes(v), + SerdeValue::JsonObject(v) => Self::JsonObject(v), + SerdeValue::Bool(v) => Self::Bool(v), + SerdeValue::IpAddr(v) => Self::IpAddr(v), + } + } +} + +impl From for SerdeValue { + fn from(value: Value) -> Self { + match value { + Value::Str(v) => Self::Str(v), + Value::PreTokStr(v) => Self::PreTokStr(v), + Value::U64(v) => Self::U64(v), + Value::I64(v) => Self::I64(v), + Value::F64(v) => Self::F64(v), + Value::Date(v) => Self::Date(v), + Value::Facet(v) => Self::Facet(v), + Value::Bytes(v) => Self::Bytes(v), + Value::JsonObject(v) => Self::JsonObject(v), + Value::Bool(v) => Self::Bool(v), + Value::IpAddr(v) => Self::IpAddr(v), + } + } +} + /// Tantivy's Document is the object that can be indexed and then searched for. /// /// Documents are fundamentally a collection of unordered tuples @@ -148,10 +245,10 @@ fn value_to_string(value: &Value) -> String { /// {"unsigned": 1000, "signed": -5, "float": 0.4}, /// schema, /// ) -#[pyclass] +#[pyclass(module = "tantivy")] #[derive(Clone, Default, PartialEq)] pub(crate) struct Document { - pub(crate) field_values: BTreeMap>, + pub(crate) field_values: BTreeMap>, } impl fmt::Debug for Document { @@ -174,6 +271,67 @@ impl fmt::Debug for Document { } } +impl Serialize for Document { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = + serializer.serialize_map(Some(self.field_values.len()))?; + for (k, v) in &self.field_values { + let ser_v: Vec<_> = + v.iter().cloned().map(SerdeValue::from).collect(); + map.serialize_entry(&k, &ser_v)?; + } + map.end() + } +} + +impl<'de> Deserialize<'de> for Document { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct MapVisitor; + + impl<'de> Visitor<'de> for MapVisitor { + type Value = BTreeMap>; + + fn expecting( + &self, + formatter: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + formatter.write_str("a map") + } + + fn visit_map( + self, + mut access: M, + ) -> Result + where + M: MapAccess<'de>, + { + let mut map = BTreeMap::new(); + + while let Some((key, value)) = + access.next_entry::>()? + { + map.insert( + key, + value.into_iter().map(SerdeValue::into).collect(), + ); + } + + Ok(map) + } + } + + Ok(Self { + field_values: deserializer.deserialize_map(MapVisitor)?, + }) + } +} + fn add_value(doc: &mut Document, field_name: String, value: T) where Value: From, @@ -574,6 +732,26 @@ impl Document { _ => py.NotImplemented(), } } + + #[staticmethod] + fn _internal_from_pythonized(serialized: &PyAny) -> PyResult { + pythonize::depythonize(serialized).map_err(to_pyerr) + } + + fn __reduce__<'a>( + slf: PyRef<'a, Self>, + py: Python<'a>, + ) -> PyResult<&'a PyTuple> { + let serialized = pythonize::pythonize(py, &*slf).map_err(to_pyerr)?; + + Ok(PyTuple::new( + py, + [ + slf.into_py(py).getattr(py, "_internal_from_pythonized")?, + PyTuple::new(py, [serialized]).to_object(py), + ], + )) + } } impl Document { diff --git a/src/facet.rs b/src/facet.rs index a624e2495..2983fe238 100644 --- a/src/facet.rs +++ b/src/facet.rs @@ -1,4 +1,10 @@ -use pyo3::{basic::CompareOp, prelude::*, types::PyType}; +use crate::to_pyerr; +use pyo3::{ + basic::CompareOp, + prelude::*, + types::{PyTuple, PyType}, +}; +use serde::{Deserialize, Serialize}; use tantivy::schema; /// A Facet represent a point in a given hierarchy. @@ -10,14 +16,22 @@ use tantivy::schema; /// implicitely imply that a document belonging to a facet also belongs to the /// ancestor of its facet. In the example above, /electronics/tv_and_video/ /// and /electronics. -#[pyclass(frozen)] -#[derive(Clone, PartialEq)] +#[pyclass(frozen, module = "tantivy")] +#[derive(Clone, Deserialize, PartialEq, Serialize)] pub(crate) struct Facet { pub(crate) inner: schema::Facet, } #[pymethods] impl Facet { + /// Creates a `Facet` from its binary representation. + #[staticmethod] + fn from_encoded(encoded_bytes: Vec) -> PyResult { + let inner = + schema::Facet::from_encoded(encoded_bytes).map_err(to_pyerr)?; + Ok(Self { inner }) + } + /// Create a new instance of the "root facet" Equivalent to /. #[classmethod] fn root(_cls: &PyType) -> Facet { @@ -80,4 +94,18 @@ impl Facet { _ => py.NotImplemented(), } } + + fn __reduce__<'a>( + slf: PyRef<'a, Self>, + py: Python<'a>, + ) -> PyResult<&'a PyTuple> { + let encoded_bytes = slf.inner.encoded_str().as_bytes().to_vec(); + Ok(PyTuple::new( + py, + [ + slf.into_py(py).getattr(py, "from_encoded")?, + PyTuple::new(py, [encoded_bytes]).to_object(py), + ], + )) + } } diff --git a/src/lib.rs b/src/lib.rs index 7fe6c2af2..245cfee3b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ use facet::Facet; use index::Index; use schema::Schema; use schemabuilder::SchemaBuilder; -use searcher::{DocAddress, Searcher}; +use searcher::{DocAddress, SearchResult, Searcher}; /// Python bindings for the search engine library Tantivy. /// @@ -71,6 +71,7 @@ fn tantivy(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/schema.rs b/src/schema.rs index 61cf27392..ba0c74066 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,12 +1,14 @@ -use pyo3::{basic::CompareOp, prelude::*}; +use crate::to_pyerr; +use pyo3::{basic::CompareOp, prelude::*, types::PyTuple}; +use serde::{Deserialize, Serialize}; use tantivy as tv; /// Tantivy schema. /// /// The schema is very strict. To build the schema the `SchemaBuilder` class is /// provided. -#[pyclass(frozen)] -#[derive(PartialEq)] +#[pyclass(frozen, module = "tantivy")] +#[derive(Deserialize, PartialEq, Serialize)] pub(crate) struct Schema { pub(crate) inner: tv::schema::Schema, } @@ -25,4 +27,24 @@ impl Schema { _ => py.NotImplemented(), } } + + #[staticmethod] + fn _internal_from_pythonized(serialized: &PyAny) -> PyResult { + pythonize::depythonize(serialized).map_err(to_pyerr) + } + + fn __reduce__<'a>( + slf: PyRef<'a, Self>, + py: Python<'a>, + ) -> PyResult<&'a PyTuple> { + let serialized = pythonize::pythonize(py, &*slf).map_err(to_pyerr)?; + + Ok(PyTuple::new( + py, + [ + slf.into_py(py).getattr(py, "_internal_from_pythonized")?, + PyTuple::new(py, [serialized]).to_object(py), + ], + )) + } } diff --git a/src/searcher.rs b/src/searcher.rs index ae37fa500..d76d984c7 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -2,6 +2,7 @@ use crate::{document::Document, query::Query, to_pyerr}; use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*}; +use serde::{Deserialize, Serialize}; use tantivy as tv; use tantivy::collector::{Count, MultiCollector, TopDocs}; @@ -13,9 +14,11 @@ pub(crate) struct Searcher { pub(crate) inner: tv::Searcher, } -#[derive(Clone, PartialEq)] +#[derive(Clone, Deserialize, FromPyObject, PartialEq, Serialize)] enum Fruit { + #[pyo3(transparent)] Score(f32), + #[pyo3(transparent)] Order(u64), } @@ -37,8 +40,8 @@ impl ToPyObject for Fruit { } } -#[pyclass(frozen)] -#[derive(Clone, PartialEq)] +#[pyclass(frozen, module = "tantivy")] +#[derive(Clone, Default, Deserialize, PartialEq, Serialize)] /// Object holding a results successful search. pub(crate) struct SearchResult { hits: Vec<(Fruit, DocAddress)>, @@ -50,6 +53,19 @@ pub(crate) struct SearchResult { #[pymethods] impl SearchResult { + #[new] + fn new( + py: Python, + hits: Vec<(PyObject, DocAddress)>, + count: Option, + ) -> PyResult { + let hits = hits + .iter() + .map(|(f, d)| Ok((f.extract(py)?, d.clone()))) + .collect::>>()?; + Ok(Self { hits, count }) + } + fn __repr__(&self) -> PyResult { if let Some(count) = self.count { Ok(format!( @@ -74,6 +90,13 @@ impl SearchResult { } } + fn __getnewargs__( + &self, + py: Python, + ) -> PyResult<(Vec<(PyObject, DocAddress)>, Option)> { + Ok((self.hits(py)?, self.count)) + } + #[getter] /// The list of tuples that contains the scores and DocAddress of the /// search results. @@ -214,8 +237,8 @@ impl Searcher { /// It consists in an id identifying its segment, and its segment-local DocId. /// The id used for the segment is actually an ordinal in the list of segment /// hold by a Searcher. -#[pyclass(frozen)] -#[derive(Clone, Debug, PartialEq)] +#[pyclass(frozen, module = "tantivy")] +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] pub(crate) struct DocAddress { pub(crate) segment_ord: tv::SegmentOrdinal, pub(crate) doc: tv::DocId, @@ -223,6 +246,11 @@ pub(crate) struct DocAddress { #[pymethods] impl DocAddress { + #[new] + fn new(segment_ord: tv::SegmentOrdinal, doc: tv::DocId) -> Self { + DocAddress { segment_ord, doc } + } + /// The segment ordinal is an id identifying the segment hosting the /// document. It is only meaningful, in the context of a searcher. #[getter] @@ -248,6 +276,10 @@ impl DocAddress { _ => py.NotImplemented(), } } + + fn __getnewargs__(&self) -> PyResult<(tv::SegmentOrdinal, tv::DocId)> { + Ok((self.segment_ord, self.doc)) + } } impl From<&tv::DocAddress> for DocAddress { diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index c18aaae8b..2bc99f4bf 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -1,6 +1,7 @@ from io import BytesIO import copy import tantivy +import pickle import pytest from tantivy import Document, Index, SchemaBuilder @@ -473,6 +474,15 @@ def test_search_result_eq(self, ram_index, spanish_index): assert eng_result1 != esp_result assert eng_result2 != esp_result + def test_search_result_pickle(self, ram_index): + index = ram_index + query = index.parse_query("sea whale", ["title", "body"]) + + orig = index.searcher().search(query, 10) + pickled = pickle.loads(pickle.dumps(orig)) + + assert orig == pickled + class TestUpdateClass(object): def test_delete_update(self, ram_index): @@ -541,7 +551,10 @@ def test_create_readers(self): class TestSearcher(object): def test_searcher_repr(self, ram_index, ram_index_numeric_fields): assert repr(ram_index.searcher()) == "Searcher(num_docs=3, num_segments=1)" - assert repr(ram_index_numeric_fields.searcher()) == "Searcher(num_docs=2, num_segments=1)" + assert ( + repr(ram_index_numeric_fields.searcher()) + == "Searcher(num_docs=2, num_segments=1)" + ) class TestDocument(object): @@ -604,6 +617,12 @@ def test_document_copy(self): assert doc1 == doc3 assert doc2 == doc3 + def test_document_pickle(self): + orig = Document(id=1, title="hello world!") + pickled = pickle.loads(pickle.dumps(orig)) + + assert orig == pickled + class TestJsonField: def test_query_from_json_field(self): @@ -719,3 +738,34 @@ def test_facet_eq(): assert facet1 == facet2 assert facet1 != facet3 assert facet2 != facet3 + + +def test_schema_pickle(): + orig = ( + SchemaBuilder() + .add_integer_field("id", stored=True, indexed=True) + .add_unsigned_field("unsigned") + .add_float_field("rating", stored=True, indexed=True) + .add_text_field("body", stored=True) + .add_date_field("date") + .add_json_field("json") + .add_bytes_field("bytes") + .build() + ) + + pickled = pickle.loads(pickle.dumps(orig)) + + assert orig == pickled + + +def test_facet_pickle(): + orig = tantivy.Facet.from_string("/europe/france") + pickled = pickle.loads(pickle.dumps(orig)) + + assert orig == pickled + +def test_doc_address_pickle(): + orig = tantivy.DocAddress(42, 123) + pickled = pickle.loads(pickle.dumps(orig)) + + assert orig == pickled