Skip to content

Commit

Permalink
feat: upgrade tantivy to 0.22 (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjrh authored May 3, 2024
1 parent 9fafdf2 commit 983364b
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 136 deletions.
265 changes: 191 additions & 74 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ crate-type = ["cdylib"]
pyo3-build-config = "0.20.0"

[dependencies]
base64 = "0.21"
base64 = "0.22"
chrono = "0.4.23"
tantivy = "0.21.0"
tantivy = "0.22.0"
itertools = "0.12.0"
futures = "0.3.26"
pythonize = "0.20.0"
Expand Down
89 changes: 46 additions & 43 deletions src/document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ use pyo3::{

use chrono::{offset::TimeZone, NaiveDateTime, Utc};

use tantivy::{self as tv, schema::Value};
use tantivy::{self as tv, schema::document::OwnedValue as Value};

use crate::{facet::Facet, schema::Schema, to_pyerr};
use serde::{
ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer,
};
use serde_json::Value as JsonValue;
use std::{
collections::{BTreeMap, HashMap},
collections::BTreeMap,
fmt,
net::{IpAddr, Ipv6Addr},
str::FromStr,
Expand Down Expand Up @@ -54,7 +53,7 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
}
if let Ok(dict) = any.downcast::<PyDict>() {
if let Ok(json) = pythonize::depythonize(dict) {
return Ok(Value::JsonObject(json));
return Ok(Value::Object(json));
}
}
Err(to_pyerr(format!("Value unsupported {any:?}")))
Expand Down Expand Up @@ -119,11 +118,11 @@ pub(crate) fn extract_value_for_type(
tv::schema::Type::Json => {
if let Ok(json_str) = any.extract::<&str>() {
return serde_json::from_str(json_str)
.map(Value::JsonObject)
.map(Value::Object)
.map_err(to_pyerr_for_type("Json", field_name, any));
}

Value::JsonObject(
Value::Object(
any.downcast::<PyDict>()
.map(|dict| pythonize::depythonize(dict))
.map_err(to_pyerr_for_type("Json", field_name, any))?
Expand Down Expand Up @@ -192,32 +191,20 @@ fn extract_value_single_or_list_for_type(
}
}

fn value_to_object(val: &JsonValue, py: Python<'_>) -> PyObject {
match val {
JsonValue::Null => py.None(),
JsonValue::Bool(b) => b.to_object(py),
JsonValue::Number(n) => match n {
n if n.is_i64() => n.as_i64().to_object(py),
n if n.is_u64() => n.as_u64().to_object(py),
n if n.is_f64() => n.as_f64().to_object(py),
_ => panic!("number too large"),
},
JsonValue::String(s) => s.to_object(py),
JsonValue::Array(v) => {
let inner: Vec<_> =
v.iter().map(|x| value_to_object(x, py)).collect();
inner.to_object(py)
}
JsonValue::Object(m) => {
let inner: HashMap<_, _> =
m.iter().map(|(k, v)| (k, value_to_object(v, py))).collect();
inner.to_object(py)
}
fn object_to_py(
py: Python,
obj: &BTreeMap<String, Value>,
) -> PyResult<PyObject> {
let dict = PyDict::new(py);
for (k, v) in obj.iter() {
dict.set_item(k, value_to_py(py, v)?)?;
}
Ok(dict.into())
}

fn value_to_py(py: Python, value: &Value) -> PyResult<PyObject> {
Ok(match value {
Value::Null => py.None(),
Value::Str(text) => text.into_py(py),
Value::U64(num) => (*num).into_py(py),
Value::I64(num) => (*num).into_py(py),
Expand All @@ -243,20 +230,19 @@ fn value_to_py(py: Python, value: &Value) -> PyResult<PyObject> {
.into_py(py)
}
Value::Facet(f) => Facet { inner: f.clone() }.into_py(py),
Value::JsonObject(json_object) => {
let inner: HashMap<_, _> = json_object
.iter()
.map(|(k, v)| (k, value_to_object(v, py)))
.collect();
inner.to_object(py)
Value::Array(arr) => {
// TODO implement me
unimplemented!();
}
Value::Object(obj) => object_to_py(py, obj)?,
Value::Bool(b) => b.into_py(py),
Value::IpAddr(i) => (*i).to_string().into_py(py),
})
}

fn value_to_string(value: &Value) -> String {
match value {
Value::Null => format!("{:?}", value),
Value::Str(text) => text.clone(),
Value::U64(num) => format!("{num}"),
Value::I64(num) => format!("{num}"),
Expand All @@ -268,7 +254,11 @@ fn value_to_string(value: &Value) -> String {
// TODO implement me
unimplemented!();
}
Value::JsonObject(json_object) => {
Value::Array(arr) => {
let inner: Vec<_> = arr.iter().map(value_to_string).collect();
format!("{inner:?}")
}
Value::Object(json_object) => {
serde_json::to_string(&json_object).unwrap()
}
Value::Bool(b) => format!("{b}"),
Expand Down Expand Up @@ -308,6 +298,8 @@ where
/// necessary for serialization.
#[derive(Deserialize, Serialize)]
enum SerdeValue {
/// Null
Null,
/// The str type is used for any text information.
Str(String),
/// Pre-tokenized str type,
Expand All @@ -330,15 +322,18 @@ enum SerdeValue {
Facet(tv::schema::Facet),
/// Arbitrarily sized byte array
Bytes(Vec<u8>),
/// Json object value.
JsonObject(serde_json::Map<String, serde_json::Value>),
/// Array
Array(Vec<Value>),
/// Object value.
Object(BTreeMap<String, Value>),
/// IpV6 Address. Internally there is no IpV4, it needs to be converted to `Ipv6Addr`.
IpAddr(Ipv6Addr),
}

impl From<SerdeValue> for Value {
fn from(value: SerdeValue) -> Self {
match value {
SerdeValue::Null => Self::Null,
SerdeValue::Str(v) => Self::Str(v),
SerdeValue::PreTokStr(v) => Self::PreTokStr(v),
SerdeValue::U64(v) => Self::U64(v),
Expand All @@ -347,7 +342,8 @@ impl From<SerdeValue> for Value {
SerdeValue::Date(v) => Self::Date(v),
SerdeValue::Facet(v) => Self::Facet(v),
SerdeValue::Bytes(v) => Self::Bytes(v),
SerdeValue::JsonObject(v) => Self::JsonObject(v),
SerdeValue::Array(v) => Self::Array(v),
SerdeValue::Object(v) => Self::Object(v),
SerdeValue::Bool(v) => Self::Bool(v),
SerdeValue::IpAddr(v) => Self::IpAddr(v),
}
Expand All @@ -357,6 +353,7 @@ impl From<SerdeValue> for Value {
impl From<Value> for SerdeValue {
fn from(value: Value) -> Self {
match value {
Value::Null => Self::Null,
Value::Str(v) => Self::Str(v),
Value::PreTokStr(v) => Self::PreTokStr(v),
Value::U64(v) => Self::U64(v),
Expand All @@ -365,7 +362,8 @@ impl From<Value> for SerdeValue {
Value::Date(v) => Self::Date(v),
Value::Facet(v) => Self::Facet(v),
Value::Bytes(v) => Self::Bytes(v),
Value::JsonObject(v) => Self::JsonObject(v),
Value::Array(v) => Self::Array(v),
Value::Object(v) => Self::Object(v),
Value::Bool(v) => Self::Bool(v),
Value::IpAddr(v) => Self::IpAddr(v),
}
Expand All @@ -376,6 +374,8 @@ impl From<Value> for SerdeValue {
/// cloning.
#[derive(Serialize)]
enum BorrowedSerdeValue<'a> {
/// Null
Null,
/// The str type is used for any text information.
Str(&'a str),
/// Pre-tokenized str type,
Expand All @@ -395,15 +395,18 @@ enum BorrowedSerdeValue<'a> {
Facet(&'a tv::schema::Facet),
/// Arbitrarily sized byte array
Bytes(&'a [u8]),
/// Array
Array(&'a Vec<Value>),
/// Json object value.
JsonObject(&'a serde_json::Map<String, serde_json::Value>),
Object(&'a BTreeMap<String, Value>),
/// IpV6 Address. Internally there is no IpV4, it needs to be converted to `Ipv6Addr`.
IpAddr(&'a Ipv6Addr),
}

impl<'a> From<&'a Value> for BorrowedSerdeValue<'a> {
fn from(value: &'a Value) -> Self {
match value {
Value::Null => Self::Null,
Value::Str(v) => Self::Str(v),
Value::PreTokStr(v) => Self::PreTokStr(v),
Value::U64(v) => Self::U64(v),
Expand All @@ -412,7 +415,8 @@ impl<'a> From<&'a Value> for BorrowedSerdeValue<'a> {
Value::Date(v) => Self::Date(v),
Value::Facet(v) => Self::Facet(v),
Value::Bytes(v) => Self::Bytes(v),
Value::JsonObject(v) => Self::JsonObject(v),
Value::Array(v) => Self::Array(v),
Value::Object(v) => Self::Object(v),
Value::Bool(v) => Self::Bool(v),
Value::IpAddr(v) => Self::IpAddr(v),
}
Expand Down Expand Up @@ -559,8 +563,7 @@ impl Document {
py_dict: &PyDict,
schema: Option<&Schema>,
) -> PyResult<Document> {
let mut field_values: BTreeMap<String, Vec<tv::schema::Value>> =
BTreeMap::new();
let mut field_values: BTreeMap<String, Vec<Value>> = BTreeMap::new();
Document::extract_py_values_from_dict(
py_dict,
schema,
Expand Down Expand Up @@ -809,7 +812,7 @@ impl Document {
fn extract_py_values_from_dict(
py_dict: &PyDict,
schema: Option<&Schema>,
out_field_values: &mut BTreeMap<String, Vec<tv::schema::Value>>,
out_field_values: &mut BTreeMap<String, Vec<Value>>,
) -> PyResult<()> {
// TODO: Reserve when https://github.com/rust-lang/rust/issues/72631 is stable.
// out_field_values.reserve(py_dict.len());
Expand Down
29 changes: 22 additions & 7 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ use crate::{
use tantivy as tv;
use tantivy::{
directory::MmapDirectory,
schema::{NamedFieldDocument, Term, Value},
schema::{
document::TantivyDocument, NamedFieldDocument, OwnedValue as Value,
Term,
},
tokenizer::{
Language, LowerCaser, RemoveLongFilter, SimpleTokenizer, Stemmer,
TextAnalyzer,
Expand Down Expand Up @@ -73,7 +76,8 @@ impl IndexWriter {
/// since the creation of the index.
pub fn add_document(&mut self, doc: &Document) -> PyResult<u64> {
let named_doc = NamedFieldDocument(doc.field_values.clone());
let doc = self.schema.convert_named_doc(named_doc).map_err(to_pyerr)?;
let doc = TantivyDocument::convert_named_doc(&self.schema, named_doc)
.map_err(to_pyerr)?;
self.inner()?.add_document(doc).map_err(to_pyerr)
}

Expand All @@ -86,7 +90,8 @@ impl IndexWriter {
/// The `opstamp` represents the number of documents that have been added
/// since the creation of the index.
pub fn add_json(&mut self, json: &str) -> PyResult<u64> {
let doc = self.schema.parse_document(json).map_err(to_pyerr)?;
let doc = TantivyDocument::parse_json(&self.schema, json)
.map_err(to_pyerr)?;
let opstamp = self.inner()?.add_document(doc);
opstamp.map_err(to_pyerr)
}
Expand Down Expand Up @@ -154,6 +159,11 @@ impl IndexWriter {
let field = get_field(&self.schema, field_name)?;
let value = extract_value(field_value)?;
let term = match value {
Value::Null => {
return Err(exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is null type not deletable."
)))
},
Value::Str(text) => Term::from_field_text(field, &text),
Value::U64(num) => Term::from_field_u64(field, num),
Value::I64(num) => Term::from_field_i64(field, num),
Expand All @@ -170,7 +180,12 @@ impl IndexWriter {
"Field `{field_name}` is pretokenized. This is not authorized for delete."
)))
}
Value::JsonObject(_) => {
Value::Array(_) => {
return Err(exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is array type not deletable."
)))
}
Value::Object(_) => {
return Err(exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is json object type not deletable."
)))
Expand Down Expand Up @@ -297,9 +312,9 @@ impl Index {
) -> Result<(), PyErr> {
let reload_policy = reload_policy.to_lowercase();
let reload_policy = match reload_policy.as_ref() {
"commit" => tv::ReloadPolicy::OnCommit,
"on-commit" => tv::ReloadPolicy::OnCommit,
"oncommit" => tv::ReloadPolicy::OnCommit,
"commit" => tv::ReloadPolicy::OnCommitWithDelay,
"on-commit" => tv::ReloadPolicy::OnCommitWithDelay,
"oncommit" => tv::ReloadPolicy::OnCommitWithDelay,
"manual" => tv::ReloadPolicy::Manual,
_ => return Err(exceptions::PyValueError::new_err(
"Invalid reload policy, valid choices are: 'manual' and 'OnCommit'"
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ::tantivy as tv;
use ::tantivy::schema::{Term, Value};
use ::tantivy::schema::{OwnedValue as Value, Term};
use pyo3::{exceptions, prelude::*, wrap_pymodule};

mod document;
Expand Down
2 changes: 1 addition & 1 deletion src/parser_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ impl ExpectedBase64Error {

/// If `true`, the length of the base64 string was invalid.
fn caused_by_invalid_length(&self) -> bool {
matches!(self.decode_error, base64::DecodeError::InvalidLength)
matches!(self.decode_error, base64::DecodeError::InvalidLength(_))
}

/// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
Expand Down
12 changes: 9 additions & 3 deletions src/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ use pyo3::{basic::CompareOp, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
use tantivy as tv;
use tantivy::collector::{Count, MultiCollector, TopDocs};
use tantivy::TantivyDocument;
// Bring the trait into scope. This is required for the `to_named_doc` method.
// However, tantivy-py declares its own `Document` class, so we need to avoid
// introduce the `Document` trait into the namespace.
use tantivy::Document as _;

/// Tantivy's Searcher class
///
Expand Down Expand Up @@ -248,9 +253,10 @@ impl Searcher {
///
/// Returns the Document, raises ValueError if the document can't be found.
fn doc(&self, doc_address: &DocAddress) -> PyResult<Document> {
let doc = self.inner.doc(doc_address.into()).map_err(to_pyerr)?;
let named_doc = self.inner.schema().to_named_doc(&doc);
Ok(Document {
let doc: TantivyDocument =
self.inner.doc(doc_address.into()).map_err(to_pyerr)?;
let named_doc = doc.to_named_doc(self.inner.schema());
Ok(crate::document::Document {
field_values: named_doc.0,
})
}
Expand Down
4 changes: 3 additions & 1 deletion src/snippet.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::to_pyerr;
use pyo3::prelude::*;
use tantivy as tv;
// Bring the trait into scope to use methods like `as_str()` on `OwnedValue`.
use tantivy::schema::Value;

/// Tantivy Snippet
///
Expand Down Expand Up @@ -71,7 +73,7 @@ impl SnippetGenerator {
pub fn snippet_from_doc(&self, doc: &crate::Document) -> crate::Snippet {
let text: String = doc
.iter_values_for_field(&self.field_name)
.flat_map(tv::schema::Value::as_text)
.flat_map(|ov| ov.as_str())
.collect::<Vec<&str>>()
.join(" ");

Expand Down
Loading

0 comments on commit 983364b

Please sign in to comment.