Skip to content

Commit

Permalink
add milvus logic to ingestor as psuedo task (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 authored Jan 16, 2025
1 parent 3ec4015 commit 5edceb5
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 16 deletions.
17 changes: 12 additions & 5 deletions client/src/nv_ingest_client/client/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from nv_ingest_client.primitives.tasks import SplitTask
from nv_ingest_client.primitives.tasks import StoreEmbedTask
from nv_ingest_client.primitives.tasks import StoreTask
from nv_ingest_client.primitives.tasks import VdbUploadTask
from nv_ingest_client.util.util import filter_function_kwargs
from nv_ingest_client.util.milvus import MilvusOperator

DEFAULT_JOB_QUEUE_ID = "morpheus_task_queue"

Expand Down Expand Up @@ -74,6 +74,7 @@ def __init__(
self._documents = documents or []
self._client = client
self._job_queue_id = job_queue_id
self._vdb_bulk_upload = None

if self._client is None:
client_kwargs = filter_function_kwargs(NvIngestClient, **kwargs)
Expand Down Expand Up @@ -223,7 +224,10 @@ def ingest(self, **kwargs: Any) -> List[Dict[str, Any]]:

fetch_kwargs = filter_function_kwargs(self._client.fetch_job_result, **kwargs)
result = self._client.fetch_job_result(self._job_ids, **fetch_kwargs)

if self._vdb_bulk_upload:
self._vdb_bulk_upload.run(result)
# only upload as part of jobs user specified this action
self._vdb_bulk_upload = None
return result

def ingest_async(self, **kwargs: Any) -> Future:
Expand Down Expand Up @@ -271,6 +275,11 @@ def _done_callback(future):
for future in future_to_job_id:
future.add_done_callback(_done_callback)

if self._vdb_bulk_upload:
self._vdb_bulk_upload.run(combined_future)
# only upload as part of jobs user specified this action
self._vdb_bulk_upload = None

return combined_future

@ensure_job_specs
Expand Down Expand Up @@ -454,7 +463,6 @@ def store_embed(self, **kwargs: Any) -> "Ingestor":

return self

@ensure_job_specs
def vdb_upload(self, **kwargs: Any) -> "Ingestor":
"""
Adds a VdbUploadTask to the batch job specification.
Expand All @@ -469,8 +477,7 @@ def vdb_upload(self, **kwargs: Any) -> "Ingestor":
Ingestor
Returns self for chaining.
"""
vdb_upload_task = VdbUploadTask(**kwargs)
self._job_specs.add_task(vdb_upload_task)
self._vdb_bulk_upload = MilvusOperator(**kwargs)

return self

Expand Down
93 changes: 89 additions & 4 deletions client/src/nv_ingest_client/util/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,87 @@
from typing import List
import time
from urllib.parse import urlparse
from typing import Union, Dict


def _dict_to_params(collections_dict: dict, write_params: dict):
params_tuple_list = []
for coll_name, data_type in collections_dict.items():
cp_write_params = write_params.copy()
enabled_dtypes = {
"enable_text": False,
"enable_charts": False,
"enable_tables": False,
}
if not isinstance(data_type, list):
data_type = [data_type]
for d_type in data_type:
enabled_dtypes[f"enable_{d_type}"] = True
cp_write_params.update(enabled_dtypes)
params_tuple_list.append((coll_name, cp_write_params))
return params_tuple_list


class MilvusOperator:
def __init__(
self,
collection_name: Union[str, Dict] = "nv_ingest_collection",
milvus_uri: str = "http://localhost:19530",
sparse: bool = True,
recreate: bool = True,
gpu_index: bool = True,
gpu_search: bool = False,
dense_dim: int = 1024,
minio_endpoint: str = "localhost:9000",
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
bm25_save_path: str = "bm25_model.json",
compute_bm25_stats: bool = True,
access_key: str = "minioadmin",
secret_key: str = "minioadmin",
bucket_name: str = "a-bucket",
**kwargs,
):
self.milvus_kwargs = locals()
self.milvus_kwargs.pop("self")
self.collection_name = self.milvus_kwargs.pop("collection_name")
self.milvus_kwargs.pop("kwargs", None)

def get_connection_params(self):
conn_dict = {
"milvus_uri": self.milvus_kwargs["milvus_uri"],
"sparse": self.milvus_kwargs["sparse"],
"recreate": self.milvus_kwargs["recreate"],
"gpu_index": self.milvus_kwargs["gpu_index"],
"gpu_search": self.milvus_kwargs["gpu_search"],
"dense_dim": self.milvus_kwargs["dense_dim"],
}
return (self.collection_name, conn_dict)

def get_write_params(self):
write_params = self.milvus_kwargs.copy()
del write_params["recreate"]
del write_params["gpu_index"]
del write_params["gpu_search"]
del write_params["dense_dim"]

return (self.collection_name, write_params)

def run(self, records):
collection_name, create_params = self.get_connection_params()
_, write_params = self.get_write_params()
if isinstance(collection_name, str):
create_nvingest_collection(collection_name, **create_params)
write_to_nvingest_collection(records, collection_name, **write_params)
elif isinstance(collection_name, dict):
split_params_list = _dict_to_params(collection_name, write_params)
for sub_params in split_params_list:
coll_name, sub_write_params = sub_params
create_nvingest_collection(coll_name, **create_params)
write_to_nvingest_collection(records, coll_name, **sub_write_params)
else:
raise ValueError(f"Unsupported type for collection_name detected: {type(collection_name)}")


def create_nvingest_schema(dense_dim: int = 1024, sparse: bool = False) -> CollectionSchema:
Expand Down Expand Up @@ -414,11 +495,12 @@ def write_to_nvingest_collection(
collection_name: str,
milvus_uri: str = "http://localhost:19530",
minio_endpoint: str = "localhost:9000",
sparse: bool = False,
sparse: bool = True,
enable_text: bool = True,
enable_charts: bool = True,
enable_tables: bool = True,
bm25_save_path: str = "bm25_model.json",
compute_bm25_stats: bool = True,
access_key: str = "minioadmin",
secret_key: str = "minioadmin",
bucket_name: str = "a-bucket",
Expand Down Expand Up @@ -465,11 +547,14 @@ def write_to_nvingest_collection(
else:
stream = True
bm25_ef = None
if sparse:
if sparse and compute_bm25_stats:
bm25_ef = create_bm25_model(
records, enable_text=enable_text, enable_charts=enable_charts, enable_tables=enable_tables
)
bm25_ef.save(bm25_save_path)
elif sparse and not compute_bm25_stats:
bm25_ef = BM25EmbeddingFunction(build_default_analyzer(language="en"))
bm25_ef.load(bm25_save_path)
client = MilvusClient(milvus_uri)
schema = Collection(collection_name).schema
if stream:
Expand Down Expand Up @@ -619,7 +704,7 @@ def hybrid_retrieval(
"data": dense_embeddings,
"anns_field": dense_field,
"param": s_param_1,
"limit": top_k,
"limit": top_k * 2,
}

dense_req = AnnSearchRequest(**search_param_1)
Expand All @@ -628,7 +713,7 @@ def hybrid_retrieval(
"data": sparse_embeddings,
"anns_field": sparse_field,
"param": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}},
"limit": top_k,
"limit": top_k * 2,
}
sparse_req = AnnSearchRequest(**search_param_2)

Expand Down
12 changes: 5 additions & 7 deletions tests/nv_ingest_client/client/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from nv_ingest_client.primitives.tasks import StoreEmbedTask
from nv_ingest_client.primitives.tasks import StoreTask
from nv_ingest_client.primitives.tasks import TableExtractionTask
from nv_ingest_client.primitives.tasks import VdbUploadTask
from nv_ingest_client.util.milvus import MilvusOperator

MODULE_UNDER_TEST = "nv_ingest_client.client.interface"

Expand Down Expand Up @@ -193,15 +193,13 @@ def test_store_task_some_args_extra_param(ingestor):
def test_vdb_upload_task_no_args(ingestor):
ingestor.vdb_upload()

assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[0], VdbUploadTask)
assert isinstance(ingestor._vdb_bulk_upload, MilvusOperator)


def test_vdb_upload_task_some_args(ingestor):
ingestor.vdb_upload(filter_errors=True)

task = ingestor._job_specs.job_specs["pdf"][0]._tasks[0]
assert isinstance(task, VdbUploadTask)
assert task._filter_errors is True
assert isinstance(ingestor._vdb_bulk_upload, MilvusOperator)


def test_caption_task_no_args(ingestor):
Expand All @@ -228,8 +226,8 @@ def test_chain(ingestor):
assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[5], FilterTask)
assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[6], SplitTask)
assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[7], StoreTask)
assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[8], VdbUploadTask)
assert len(ingestor._job_specs.job_specs["pdf"][0]._tasks) == 9
assert isinstance(ingestor._vdb_bulk_upload, MilvusOperator)
assert len(ingestor._job_specs.job_specs["pdf"][0]._tasks) == 8


def test_ingest(ingestor, mock_client):
Expand Down
67 changes: 67 additions & 0 deletions tests/nv_ingest_client/util/test_milvus_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
from nv_ingest_client.util.milvus import MilvusOperator, _dict_to_params


@pytest.fixture
def milvus_test_dict():
mil_op = MilvusOperator()
kwargs = mil_op.milvus_kwargs
kwargs["collection_name"] = mil_op.collection_name
return kwargs


def test_extra_kwargs(milvus_test_dict):
mil_op = MilvusOperator(filter_errors=True)
milvus_test_dict.pop("collection_name")
assert mil_op.milvus_kwargs == milvus_test_dict


@pytest.mark.parametrize("collection_name", [None, "name"])
def test_op_collection_name(collection_name):
if collection_name:
mo = MilvusOperator(collection_name=collection_name)
else:
# default
collection_name = "nv_ingest_collection"
mo = MilvusOperator()
cr_collection_name, conn_params = mo.get_connection_params()
wr_collection_name, write_params = mo.get_write_params()
assert cr_collection_name == wr_collection_name == collection_name


def test_op_connection_params(milvus_test_dict):
mo = MilvusOperator()
cr_collection_name, conn_params = mo.get_connection_params()
assert cr_collection_name == milvus_test_dict["collection_name"]
for k, v in conn_params.items():
assert milvus_test_dict[k] == v


def test_op_write_params(milvus_test_dict):
mo = MilvusOperator()
collection_name, wr_params = mo.get_write_params()
assert collection_name == milvus_test_dict["collection_name"]
for k, v in wr_params.items():
assert milvus_test_dict[k] == v


@pytest.mark.parametrize(
"collection_name, expected_results",
[
({"text": ["text", "charts", "tables"]}, {"enable_text": True, "enable_charts": True, "enable_tables": True}),
({"text": ["text", "tables"]}, {"enable_text": True, "enable_charts": False, "enable_tables": True}),
({"text": ["text", "charts"]}, {"enable_text": True, "enable_charts": True, "enable_tables": False}),
({"text": ["text"]}, {"enable_text": True, "enable_charts": False, "enable_tables": False}),
],
)
def test_op_dict_to_params(collection_name, expected_results):
mo = MilvusOperator()
_, wr_params = mo.get_write_params()
response = _dict_to_params(collection_name, wr_params)
if isinstance(collection_name, str):
collection_name = {collection_name: None}
for res in response:
coll_name, write_params = res
for k, v in expected_results.items():
assert write_params[k] == v
coll_name in collection_name.keys()

0 comments on commit 5edceb5

Please sign in to comment.