Skip to content

Commit

Permalink
Add schema limit (#132)
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
  • Loading branch information
junjiejiangjjj authored May 21, 2024
1 parent 593c4f2 commit 89a57e9
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ const std::string kMetaFieldName("$meta");
const std::string kPlaceholderTag("$0");

const int64_t kTopkLimit = 16384;
const int64_t kSchemaFieldLimit = 64;
const int64_t kMaxLengthLimit = 65535;

// scalar index type
const std::string kDefaultStringIndexType("Trie");
Expand Down
31 changes: 24 additions & 7 deletions src/create_collection_task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ CreateCollectionTask::GetVarcharFieldMaxLength(
if (kv_pair.key() == kMaxLengthKey) {
try {
auto length = std::stoll(kv_pair.value());
if (length <= 0) {
if (length <= 0 || length > kMaxLengthLimit) {
return Status::ParameterInvalid(
"the maximum length specified for a VarChar should be "
"in (0, 65535])");
Expand All @@ -73,8 +73,17 @@ CreateCollectionTask::GetVarcharFieldMaxLength(
for (const auto& kv_pair : field.index_params()) {
if (kv_pair.key() == kMaxLengthKey) {
try {
*max_len = std::stoll(kv_pair.value());
return Status::Ok();
auto length = std::stoll(kv_pair.value());
if (length <= 0 || length > kMaxLengthLimit) {
return Status::ParameterInvalid(
"the maximum length specified for a VarChar should be "
"in (0, 65535])");

return Status::Ok();
} else {
*max_len = static_cast<uint64_t>(length);
return Status::Ok();
}
} catch (std::exception& e) {
return Status::ParameterInvalid("Invalid max length {}",
kv_pair.value());
Expand Down Expand Up @@ -142,7 +151,8 @@ CreateCollectionTask::CheckDefaultValue(
case DCase::kFloatData:
if (f.data_type() != DType::Float) {
LOG_ERROR(
"{} field's default value is Float type, mismatches "
"{} field's default value is Float type, "
"mismatches "
"field type",
f.name());
return false;
Expand All @@ -151,7 +161,8 @@ CreateCollectionTask::CheckDefaultValue(
case DCase::kDoubleData:
if (f.data_type() != DType::Double) {
LOG_ERROR(
"{} field's default value is Double type, mismatches "
"{} field's default value is Double type, "
"mismatches "
"field type",
f.name());
return false;
Expand Down Expand Up @@ -229,6 +240,11 @@ CreateCollectionTask::AppendSysFields(
Status
CreateCollectionTask::ValidateSchema(
const ::milvus::proto::schema::CollectionSchema& schema) {
if (schema.fields_size() > kSchemaFieldLimit)
return Status::ParameterInvalid(
"maximum field's number should be limited to {}",
kSchemaFieldLimit);

std::set<std::string> field_names;
std::string pk_name;
for (const auto& field_schema : schema.fields()) {
Expand All @@ -239,13 +255,14 @@ CreateCollectionTask::ValidateSchema(
if (field_schema.is_primary_key()) {
if (!pk_name.empty()) {
return Status::ParameterInvalid(
"there are more than one primary key, field_name = {}, {}",
"there are more than one primary key, field_name = {}, "
"{}",
pk_name,
field_schema.name());
} else {
pk_name = field_schema.name();
}
}
}
if (field_schema.is_dynamic()) {
return Status::ParameterInvalid(
"cannot explicitly set a field as a dynamic field");
Expand Down
3 changes: 1 addition & 2 deletions src/unittest/run_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@


def run_all(py_path):

for f in examples_dir.glob('*.py'):
for f in py_path.glob('*.py'):
if str(f).endswith('bfloat16_example.py') or str(f).endswith('dynamic_field.py'):
continue
print(str(f))
Expand Down
2 changes: 2 additions & 0 deletions tests/test_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_delete_by_ids(self):
result = milvus_client.delete(collection_name, ids=['-xf%^@#$%^&***)(*/.', '中文id'])
result = milvus_client.search(collection_name, [[0.0, 1.0]], limit=3)
self.assertEqual([item['id']for item in result[0]], ['Título', 'Cien años de soledad'])
milvus_client.release_collection(collection_name)
del milvus_client

local_client = MilvusClient('./local_test.db')
Expand Down Expand Up @@ -77,6 +78,7 @@ def test_delete_by_filter(self):
result = milvus_client.delete(collection_name, filter='(a==100) && (b==300)')
result = milvus_client.search(collection_name, [[0.0, 1.0]], limit=3)
self.assertEqual([item['id']for item in result[0]], ['中文id', 'Título', 'Cien años de soledad'])
milvus_client.release_collection(collection_name)
del milvus_client

local_client = MilvusClient('./local_test.db')
Expand Down
80 changes: 80 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (C) 2019-2024 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

import unittest
from pymilvus import MilvusClient, MilvusException, DataType


class TestDefaultSearch(unittest.TestCase):
def test_schema_field_limits(self):
collection_name = "hello_milvus"
milvus_client = MilvusClient("./local_test.db")
has_collection = milvus_client.has_collection(collection_name)
if has_collection:
milvus_client.drop_collection(collection_name)
schema = milvus_client.create_schema(enable_dynamic_field=True)
schema.add_field("id", DataType.INT64, is_primary=True)
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=2)
for i in range(62):
schema.add_field('a' + str(i), DataType.INT64)
index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name = "embeddings", metric_type="L2")
milvus_client.create_collection(collection_name, schema=schema, index_params=index_params)

def test_schema_field_out_limits(self):
collection_name = "hello_milvus"
milvus_client = MilvusClient("./local_test.db")
has_collection = milvus_client.has_collection(collection_name)
if has_collection:
milvus_client.drop_collection(collection_name)
schema = milvus_client.create_schema(enable_dynamic_field=True)
schema.add_field("id", DataType.INT64, is_primary=True)
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=2)
for i in range(63):
schema.add_field('a' + str(i), DataType.INT64)
index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name = "embeddings", metric_type="L2")
with self.assertRaises(MilvusException):
milvus_client.create_collection(collection_name, schema=schema, index_params=index_params)

def test_varchar_field_maxlen(self):
collection_name = "hello_milvus"
milvus_client = MilvusClient("./local_test.db")
has_collection = milvus_client.has_collection(collection_name)
if has_collection:
milvus_client.drop_collection(collection_name)
schema = milvus_client.create_schema(enable_dynamic_field=True)
schema.add_field("id", DataType.INT64, is_primary=True)
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=2)
schema.add_field("string", DataType.VARCHAR, max_length=65535)
index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name = "embeddings", metric_type="L2")
milvus_client.create_collection(collection_name, schema=schema, index_params=index_params)

def test_varchar_field_out_maxlen(self):
collection_name = "hello_milvus"
milvus_client = MilvusClient("./local_test.db")
has_collection = milvus_client.has_collection(collection_name)
if has_collection:
milvus_client.drop_collection(collection_name)
schema = milvus_client.create_schema(enable_dynamic_field=True)
schema.add_field("id", DataType.INT64, is_primary=True)
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=2)
schema.add_field("string", DataType.VARCHAR, max_length=65536)
index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name = "embeddings", metric_type="L2")
with self.assertRaises(MilvusException):
milvus_client.create_collection(collection_name, schema=schema, index_params=index_params)


if __name__ == '__main__':
unittest.main()

0 comments on commit 89a57e9

Please sign in to comment.