Skip to content

Commit cf8c5f5

Browse files
authored
Enable the embedding task for all content types including image captions (#336)
1 parent 9ef8e0b commit cf8c5f5

File tree

8 files changed

+156
-131
lines changed

8 files changed

+156
-131
lines changed

client/src/nv_ingest_client/nv_ingest_cli.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@
120120
--task 'extract:{"document_type":"pdf", "extract_method":"unstructured_io"}'
121121
--task 'extract:{"document_type":"docx", "extract_text":true, "extract_images":true}'
122122
--task 'store:{"content_type":"image", "store_method":"minio", "endpoint":"minio:9000"}'
123-
--task 'embed:{"text":true, "tables":true}'
123+
--task 'embed'
124124
--task 'vdb_upload'
125125
--task 'caption:{}'
126126
@@ -143,8 +143,6 @@
143143
- embed: Computes embeddings on multimodal extractions.
144144
Options:
145145
- filter_errors (bool): Flag to filter embedding errors. Optional.
146-
- tables (bool): Flag to create embeddings for table extractions. Optional.
147-
- text (bool): Flag to create embeddings for text extractions. Optional.
148146
\b
149147
- extract: Extracts content from documents, customizable per document type.
150148
Can be specified multiple times for different 'document_type' values.

client/src/nv_ingest_client/primitives/tasks/embed.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,30 @@
99
import logging
1010
from typing import Dict
1111

12-
from pydantic import BaseModel
12+
from pydantic import BaseModel, root_validator
1313

1414
from .task_base import Task
1515

1616
logger = logging.getLogger(__name__)
1717

1818

1919
class EmbedTaskSchema(BaseModel):
20-
text: bool = True
21-
tables: bool = True
2220
filter_errors: bool = False
2321

22+
@root_validator(pre=True)
23+
def handle_deprecated_fields(cls, values):
24+
if "text" in values:
25+
logger.warning(
26+
"'text' parameter is deprecated and will be ignored. Future versions will remove this argument."
27+
)
28+
values.pop("text")
29+
if "tables" in values:
30+
logger.warning(
31+
"'tables' parameter is deprecated and will be ignored. Future versions will remove this argument."
32+
)
33+
values.pop("tables")
34+
return values
35+
2436
class Config:
2537
extra = "forbid"
2638

@@ -30,13 +42,22 @@ class EmbedTask(Task):
3042
Object for document embedding task
3143
"""
3244

33-
def __init__(self, text: bool = True, tables: bool = True, filter_errors: bool = False) -> None:
45+
def __init__(self, text: bool = None, tables: bool = None, filter_errors: bool = False) -> None:
3446
"""
3547
Setup Embed Task Config
3648
"""
3749
super().__init__()
38-
self._text = text
39-
self._tables = tables
50+
51+
if text is not None:
52+
logger.warning(
53+
"'text' parameter is deprecated and will be ignored. Future versions will remove this argument."
54+
)
55+
56+
if tables is not None:
57+
logger.warning(
58+
"'tables' parameter is deprecated and will be ignored. Future versions will remove this argument."
59+
)
60+
4061
self._filter_errors = filter_errors
4162

4263
def __str__(self) -> str:
@@ -45,8 +66,6 @@ def __str__(self) -> str:
4566
"""
4667
info = ""
4768
info += "Embed Task:\n"
48-
info += f" text: {self._text}\n"
49-
info += f" tables: {self._tables}\n"
5069
info += f" filter_errors: {self._filter_errors}\n"
5170
return info
5271

@@ -56,8 +75,6 @@ def to_dict(self) -> Dict:
5675
"""
5776

5877
task_properties = {
59-
"text": self._text,
60-
"tables": self._tables,
6178
"filter_errors": False,
6279
}
6380

src/nv_ingest/modules/transforms/embed_extractions.py

Lines changed: 113 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -281,22 +281,54 @@ def _add_embeddings(row, embeddings, info_msgs):
281281
return row
282282

283283

284-
def _get_text_content(row):
284+
def _get_pandas_text_content(row):
285285
"""
286286
A pandas UDF used to select extracted text content to be used to create embeddings.
287287
"""
288288

289289
return row["content"]
290290

291291

292-
def _get_table_content(row):
292+
def _get_pandas_table_content(row):
293293
"""
294294
A pandas UDF used to select extracted table/chart content to be used to create embeddings.
295295
"""
296296

297297
return row["table_metadata"]["table_content"]
298298

299299

300+
def _get_pandas_image_content(row):
301+
"""
302+
A pandas UDF used to select extracted image captions to be used to create embeddings.
303+
"""
304+
305+
return row["image_metadata"]["caption"]
306+
307+
308+
def _get_cudf_text_content(df: cudf.DataFrame):
309+
"""
310+
A cuDF UDF used to select extracted text content to be used to create embeddings.
311+
"""
312+
313+
return df.struct.field("content")
314+
315+
316+
def _get_cudf_table_content(df: cudf.DataFrame):
317+
"""
318+
A cuDF UDF used to select extracted table/chart content to be used to create embeddings.
319+
"""
320+
321+
return df.struct.field("table_metadata").struct.field("table_content")
322+
323+
324+
def _get_cudf_image_content(df: cudf.DataFrame):
325+
"""
326+
A cuDF UDF used to select extracted image captions to be used to create embeddings.
327+
"""
328+
329+
return df.struct.field("image_metadata").struct.field("caption")
330+
331+
300332
def _batch_generator(iterable: Iterable, batch_size=10):
301333
"""
302334
A generator to yield batches of size `batch_size` from an interable.
@@ -349,7 +381,6 @@ def _generate_batches(prompts: List[str], batch_size: int = 100):
349381

350382
def _generate_embeddings(
351383
ctrl_msg: ControlMessage,
352-
content_type: ContentTypeEnum,
353384
event_loop: asyncio.SelectorEventLoop,
354385
batch_size: int,
355386
api_key: str,
@@ -361,8 +392,10 @@ def _generate_embeddings(
361392
filter_errors: bool,
362393
):
363394
"""
364-
A function to generate embeddings for the supplied `ContentTypeEnum`. The `ContentTypeEnum` will
365-
drive filtering criteria used to select rows of data to enrich with embeddings.
395+
A function to generate text embeddings for supported content types (TEXT, STRUCTURED, IMAGE).
396+
397+
This function dynamically selects the appropriate metadata field based on content type and
398+
calculates embeddings using the NIM embedding service. AUDIO and VIDEO types are stubbed and skipped.
366399
367400
Parameters
368401
----------
@@ -403,53 +436,71 @@ def _generate_embeddings(
403436
content_mask : cudf.Series
404437
A boolean mask representing rows filtered to calculate embeddings.
405438
"""
439+
cudf_content_extractor = {
440+
ContentTypeEnum.TEXT: _get_cudf_text_content,
441+
ContentTypeEnum.STRUCTURED: _get_cudf_table_content,
442+
ContentTypeEnum.IMAGE: _get_cudf_image_content,
443+
ContentTypeEnum.AUDIO: lambda _: None, # Not supported yet.
444+
ContentTypeEnum.VIDEO: lambda _: None, # Not supported yet.
445+
}
446+
pandas_content_extractor = {
447+
ContentTypeEnum.TEXT: _get_pandas_text_content,
448+
ContentTypeEnum.STRUCTURED: _get_pandas_table_content,
449+
ContentTypeEnum.IMAGE: _get_pandas_image_content,
450+
ContentTypeEnum.AUDIO: lambda _: None, # Not supported yet.
451+
ContentTypeEnum.VIDEO: lambda _: None, # Not supported yet.
452+
}
453+
454+
logger.debug("Generating text embeddings for supported content types: TEXT, STRUCTURED, IMAGE.")
455+
456+
embedding_dataframes = []
457+
content_masks = []
406458

407459
with ctrl_msg.payload().mutable_dataframe() as mdf:
408460
if mdf.empty:
409-
return None, None
410-
411-
# generate table text mask
412-
if content_type == ContentTypeEnum.TEXT:
413-
content_mask = (mdf["document_type"] == content_type.value) & (
414-
mdf["metadata"].struct.field("content") != ""
415-
).fillna(False)
416-
content_getter = _get_text_content
417-
elif content_type == ContentTypeEnum.STRUCTURED:
418-
table_mask = mdf["document_type"] == content_type.value
419-
if not table_mask.any():
420-
return None, None
421-
content_mask = table_mask & (
422-
mdf["metadata"].struct.field("table_metadata").struct.field("table_content") != ""
423-
).fillna(False)
424-
content_getter = _get_table_content
425-
426-
# exit if matches found
427-
if not content_mask.any():
428-
return None, None
429-
430-
df_text = mdf.loc[content_mask].to_pandas().reset_index(drop=True)
431-
# get text list
432-
filtered_text = df_text["metadata"].apply(content_getter)
433-
# calculate embeddings
434-
filtered_text_batches = _generate_batches(filtered_text.tolist(), batch_size)
435-
text_embeddings = _async_runner(
436-
filtered_text_batches,
437-
api_key,
438-
embedding_nim_endpoint,
439-
embedding_model,
440-
encoding_format,
441-
input_type,
442-
truncate,
443-
event_loop,
444-
filter_errors,
445-
)
446-
# update embeddings in metadata
447-
df_text[["metadata", "document_type", "_contains_embeddings"]] = df_text.apply(
448-
_add_embeddings, **text_embeddings, axis=1
449-
)[["metadata", "document_type", "_contains_embeddings"]]
450-
df_text["_content"] = filtered_text
461+
return ctrl_msg
462+
463+
for content_type, content_getter in pandas_content_extractor.items():
464+
if not content_getter:
465+
logger.debug(f"Skipping unsupported content type: {content_type}")
466+
continue
467+
468+
content_mask = mdf["document_type"] == content_type.value
469+
if not content_mask.any():
470+
continue
471+
472+
cudf_content_getter = cudf_content_extractor[content_type]
473+
content_mask = (content_mask & (cudf_content_getter(mdf["metadata"]) != "")).fillna(False)
474+
if not content_mask.any():
475+
continue
476+
477+
df_content = mdf.loc[content_mask].to_pandas().reset_index(drop=True)
478+
filtered_content = df_content["metadata"].apply(content_getter)
479+
# calculate embeddings
480+
filtered_content_batches = _generate_batches(filtered_content.tolist(), batch_size)
481+
content_embeddings = _async_runner(
482+
filtered_content_batches,
483+
api_key,
484+
embedding_nim_endpoint,
485+
embedding_model,
486+
encoding_format,
487+
input_type,
488+
truncate,
489+
event_loop,
490+
filter_errors,
491+
)
492+
# update embeddings in metadata
493+
df_content[["metadata", "document_type", "_contains_embeddings"]] = df_content.apply(
494+
_add_embeddings, **content_embeddings, axis=1
495+
)[["metadata", "document_type", "_contains_embeddings"]]
496+
df_content["_content"] = filtered_content
497+
498+
embedding_dataframes.append(df_content)
499+
content_masks.append(content_mask)
500+
501+
message = _concatenate_extractions(ctrl_msg, embedding_dataframes, content_masks)
451502

452-
return df_text, content_mask
503+
return message
453504

454505

455506
def _concatenate_extractions(ctrl_msg: ControlMessage, dataframes: List[pd.DataFrame], masks: List[cudf.Series]):
@@ -493,8 +544,8 @@ def _concatenate_extractions(ctrl_msg: ControlMessage, dataframes: List[pd.DataF
493544
@register_module(MODULE_NAME, MODULE_NAMESPACE)
494545
def _embed_extractions(builder: mrc.Builder):
495546
"""
496-
A pipeline module that receives incoming messages in ControlMessage format and calculates embeddings for
497-
supported document types.
547+
A pipeline module that receives incoming messages in ControlMessage format
548+
and calculates text embeddings for all supported content types.
498549
499550
Parameters
500551
----------
@@ -519,56 +570,20 @@ def embed_extractions_fn(message: ControlMessage):
519570
try:
520571
task_props = message.remove_task("embed")
521572
model_dump = task_props.model_dump()
522-
embed_text = model_dump.get("text")
523-
embed_tables = model_dump.get("tables")
524573
filter_errors = model_dump.get("filter_errors", False)
525574

526-
logger.debug(f"Generating embeddings: text={embed_text}, tables={embed_tables}")
527-
embedding_dataframes = []
528-
content_masks = []
529-
530-
if embed_text:
531-
df_text, content_mask = _generate_embeddings(
532-
message,
533-
ContentTypeEnum.TEXT,
534-
event_loop,
535-
validated_config.batch_size,
536-
validated_config.api_key,
537-
validated_config.embedding_nim_endpoint,
538-
validated_config.embedding_model,
539-
validated_config.encoding_format,
540-
validated_config.input_type,
541-
validated_config.truncate,
542-
filter_errors,
543-
)
544-
if df_text is not None:
545-
embedding_dataframes.append(df_text)
546-
content_masks.append(content_mask)
547-
548-
if embed_tables:
549-
df_tables, table_mask = _generate_embeddings(
550-
message,
551-
ContentTypeEnum.STRUCTURED,
552-
event_loop,
553-
validated_config.batch_size,
554-
validated_config.api_key,
555-
validated_config.embedding_nim_endpoint,
556-
validated_config.embedding_model,
557-
validated_config.encoding_format,
558-
validated_config.input_type,
559-
validated_config.truncate,
560-
filter_errors,
561-
)
562-
if df_tables is not None:
563-
embedding_dataframes.append(df_tables)
564-
content_masks.append(table_mask)
565-
566-
if len(content_masks) == 0:
567-
return message
568-
569-
message = _concatenate_extractions(message, embedding_dataframes, content_masks)
570-
571-
return message
575+
return _generate_embeddings(
576+
message,
577+
event_loop,
578+
validated_config.batch_size,
579+
validated_config.api_key,
580+
validated_config.embedding_nim_endpoint,
581+
validated_config.embedding_model,
582+
validated_config.encoding_format,
583+
validated_config.input_type,
584+
validated_config.truncate,
585+
filter_errors,
586+
)
572587

573588
except Exception as e:
574589
traceback.print_exc()

src/nv_ingest/schemas/ingest_job_schema.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,6 @@ class IngestTaskDedupSchema(BaseModelNoExt):
130130

131131

132132
class IngestTaskEmbedSchema(BaseModelNoExt):
133-
text: bool = True
134-
tables: bool = True
135133
filter_errors: bool = False
136134

137135

src/nv_ingest/schemas/metadata_schema.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ class AccessLevelEnum(int, Enum):
3636

3737

3838
class ContentTypeEnum(str, Enum):
39-
TEXT = "text"
39+
AUDIO = "audio"
40+
EMBEDDING = "embedding"
4041
IMAGE = "image"
42+
INFO_MSG = "info_message"
4143
STRUCTURED = "structured"
44+
TEXT = "text"
4245
UNSTRUCTURED = "unstructured"
43-
INFO_MSG = "info_message"
44-
EMBEDDING = "embedding"
46+
VIDEO = "video"
4547

4648

4749
class StdContentDescEnum(str, Enum):

0 commit comments

Comments
 (0)