From 89a57e9c6f5e9ba161ae2b030ac18fd2c859f99e Mon Sep 17 00:00:00 2001 From: junjiejiangjjj Date: Tue, 21 May 2024 11:10:55 +0800 Subject: [PATCH] Add schema limit (#132) Signed-off-by: junjie.jiang --- src/common.h | 2 + src/create_collection_task.cpp | 31 ++++++++++--- src/unittest/run_examples.py | 3 +- tests/test_delete.py | 2 + tests/test_schema.py | 80 ++++++++++++++++++++++++++++++++++ 5 files changed, 109 insertions(+), 9 deletions(-) create mode 100644 tests/test_schema.py diff --git a/src/common.h b/src/common.h index ee044e8..523e5e8 100644 --- a/src/common.h +++ b/src/common.h @@ -58,6 +58,8 @@ const std::string kMetaFieldName("$meta"); const std::string kPlaceholderTag("$0"); const int64_t kTopkLimit = 16384; +const int64_t kSchemaFieldLimit = 64; +const int64_t kMaxLengthLimit = 65535; // scalar index type const std::string kDefaultStringIndexType("Trie"); diff --git a/src/create_collection_task.cpp b/src/create_collection_task.cpp index 8297501..d57cf96 100644 --- a/src/create_collection_task.cpp +++ b/src/create_collection_task.cpp @@ -55,7 +55,7 @@ CreateCollectionTask::GetVarcharFieldMaxLength( if (kv_pair.key() == kMaxLengthKey) { try { auto length = std::stoll(kv_pair.value()); - if (length <= 0) { + if (length <= 0 || length > kMaxLengthLimit) { return Status::ParameterInvalid( "the maximum length specified for a VarChar should be " "in (0, 65535])"); @@ -73,8 +73,17 @@ CreateCollectionTask::GetVarcharFieldMaxLength( for (const auto& kv_pair : field.index_params()) { if (kv_pair.key() == kMaxLengthKey) { try { - *max_len = std::stoll(kv_pair.value()); - return Status::Ok(); + auto length = std::stoll(kv_pair.value()); + if (length <= 0 || length > kMaxLengthLimit) { + return Status::ParameterInvalid( + "the maximum length specified for a VarChar should be " + "in (0, 65535])"); + + return Status::Ok(); + } else { + *max_len = static_cast(length); + return Status::Ok(); + } } catch (std::exception& e) { return Status::ParameterInvalid("Invalid max length {}", kv_pair.value()); @@ -142,7 +151,8 @@ CreateCollectionTask::CheckDefaultValue( case DCase::kFloatData: if (f.data_type() != DType::Float) { LOG_ERROR( - "{} field's default value is Float type, mismatches " + "{} field's default value is Float type, " + "mismatches " "field type", f.name()); return false; @@ -151,7 +161,8 @@ CreateCollectionTask::CheckDefaultValue( case DCase::kDoubleData: if (f.data_type() != DType::Double) { LOG_ERROR( - "{} field's default value is Double type, mismatches " + "{} field's default value is Double type, " + "mismatches " "field type", f.name()); return false; @@ -229,6 +240,11 @@ CreateCollectionTask::AppendSysFields( Status CreateCollectionTask::ValidateSchema( const ::milvus::proto::schema::CollectionSchema& schema) { + if (schema.fields_size() > kSchemaFieldLimit) + return Status::ParameterInvalid( + "maximum field's number should be limited to {}", + kSchemaFieldLimit); + std::set field_names; std::string pk_name; for (const auto& field_schema : schema.fields()) { @@ -239,13 +255,14 @@ CreateCollectionTask::ValidateSchema( if (field_schema.is_primary_key()) { if (!pk_name.empty()) { return Status::ParameterInvalid( - "there are more than one primary key, field_name = {}, {}", + "there are more than one primary key, field_name = {}, " + "{}", pk_name, field_schema.name()); } else { pk_name = field_schema.name(); } - } + } if (field_schema.is_dynamic()) { return Status::ParameterInvalid( "cannot explicitly set a field as a dynamic field"); diff --git a/src/unittest/run_examples.py b/src/unittest/run_examples.py index da497ba..a3ff2ff 100644 --- a/src/unittest/run_examples.py +++ b/src/unittest/run_examples.py @@ -16,8 +16,7 @@ def run_all(py_path): - - for f in examples_dir.glob('*.py'): + for f in py_path.glob('*.py'): if str(f).endswith('bfloat16_example.py') or str(f).endswith('dynamic_field.py'): continue print(str(f)) diff --git a/tests/test_delete.py b/tests/test_delete.py index 254353e..4eb5a20 100644 --- a/tests/test_delete.py +++ b/tests/test_delete.py @@ -40,6 +40,7 @@ def test_delete_by_ids(self): result = milvus_client.delete(collection_name, ids=['-xf%^@#$%^&***)(*/.', '中文id']) result = milvus_client.search(collection_name, [[0.0, 1.0]], limit=3) self.assertEqual([item['id']for item in result[0]], ['Título', 'Cien años de soledad']) + milvus_client.release_collection(collection_name) del milvus_client local_client = MilvusClient('./local_test.db') @@ -77,6 +78,7 @@ def test_delete_by_filter(self): result = milvus_client.delete(collection_name, filter='(a==100) && (b==300)') result = milvus_client.search(collection_name, [[0.0, 1.0]], limit=3) self.assertEqual([item['id']for item in result[0]], ['中文id', 'Título', 'Cien años de soledad']) + milvus_client.release_collection(collection_name) del milvus_client local_client = MilvusClient('./local_test.db') diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..72e36c7 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,80 @@ +# Copyright (C) 2019-2024 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +import unittest +from pymilvus import MilvusClient, MilvusException, DataType + + +class TestDefaultSearch(unittest.TestCase): + def test_schema_field_limits(self): + collection_name = "hello_milvus" + milvus_client = MilvusClient("./local_test.db") + has_collection = milvus_client.has_collection(collection_name) + if has_collection: + milvus_client.drop_collection(collection_name) + schema = milvus_client.create_schema(enable_dynamic_field=True) + schema.add_field("id", DataType.INT64, is_primary=True) + schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=2) + for i in range(62): + schema.add_field('a' + str(i), DataType.INT64) + index_params = milvus_client.prepare_index_params() + index_params.add_index(field_name = "embeddings", metric_type="L2") + milvus_client.create_collection(collection_name, schema=schema, index_params=index_params) + + def test_schema_field_out_limits(self): + collection_name = "hello_milvus" + milvus_client = MilvusClient("./local_test.db") + has_collection = milvus_client.has_collection(collection_name) + if has_collection: + milvus_client.drop_collection(collection_name) + schema = milvus_client.create_schema(enable_dynamic_field=True) + schema.add_field("id", DataType.INT64, is_primary=True) + schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=2) + for i in range(63): + schema.add_field('a' + str(i), DataType.INT64) + index_params = milvus_client.prepare_index_params() + index_params.add_index(field_name = "embeddings", metric_type="L2") + with self.assertRaises(MilvusException): + milvus_client.create_collection(collection_name, schema=schema, index_params=index_params) + + def test_varchar_field_maxlen(self): + collection_name = "hello_milvus" + milvus_client = MilvusClient("./local_test.db") + has_collection = milvus_client.has_collection(collection_name) + if has_collection: + milvus_client.drop_collection(collection_name) + schema = milvus_client.create_schema(enable_dynamic_field=True) + schema.add_field("id", DataType.INT64, is_primary=True) + schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=2) + schema.add_field("string", DataType.VARCHAR, max_length=65535) + index_params = milvus_client.prepare_index_params() + index_params.add_index(field_name = "embeddings", metric_type="L2") + milvus_client.create_collection(collection_name, schema=schema, index_params=index_params) + + def test_varchar_field_out_maxlen(self): + collection_name = "hello_milvus" + milvus_client = MilvusClient("./local_test.db") + has_collection = milvus_client.has_collection(collection_name) + if has_collection: + milvus_client.drop_collection(collection_name) + schema = milvus_client.create_schema(enable_dynamic_field=True) + schema.add_field("id", DataType.INT64, is_primary=True) + schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=2) + schema.add_field("string", DataType.VARCHAR, max_length=65536) + index_params = milvus_client.prepare_index_params() + index_params.add_index(field_name = "embeddings", metric_type="L2") + with self.assertRaises(MilvusException): + milvus_client.create_collection(collection_name, schema=schema, index_params=index_params) + + +if __name__ == '__main__': + unittest.main()