Skip to content

chore: add experimental MutimodalEmbeddingGenerator class #1374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
_TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT,
)

_MULTIMODAL_EMBEDDING_001_ENDPOINT = "multimodalembedding@001"

_GEMINI_PRO_ENDPOINT = "gemini-pro"
_GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514"
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514"
Expand Down Expand Up @@ -762,6 +764,152 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerat
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class MultimodalEmbeddingGenerator(base.RetriableRemotePredictor):
"""Multimodal embedding generator LLM model.

.. note::
BigFrames Blob is still under experiments. It may not work and subject to change in the future.

Args:
model_name (str, Default to "multimodalembedding@001"):
The model for multimodal embedding. Can set to "multimodalembedding@001". Multimodal-embedding models returns model embeddings for text, image and video inputs.
Default to "multimodalembedding@001".
session (bigframes.Session or None):
BQ session to create the model. If None, use the global default session.
connection_name (str or None):
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
If None, use default connection in session context.
"""

def __init__(
self,
*,
model_name: Literal["multimodalembedding@001"] = "multimodalembedding@001",
session: Optional[bigframes.Session] = None,
connection_name: Optional[str] = None,
):
if not bigframes.options.experiments.blob:
raise NotImplementedError()
self.model_name = model_name
self.session = session or global_session.get_global_session()
self.connection_name = connection_name

self._bqml_model_factory = globals.bqml_model_factory()
self._bqml_model: core.BqmlModel = self._create_bqml_model()

def _create_bqml_model(self):
# Parse and create connection if needed.
self.connection_name = self.session._create_bq_connection(
connection=self.connection_name, iam_role="aiplatform.user"
)

if self.model_name != _MULTIMODAL_EMBEDDING_001_ENDPOINT:
msg = _MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=_MULTIMODAL_EMBEDDING_001_ENDPOINT,
)
warnings.warn(msg)

options = {
"endpoint": self.model_name,
}
return self._bqml_model_factory.create_remote_model(
session=self.session, connection_name=self.connection_name, options=options
)

@classmethod
def _from_bq(
cls, session: bigframes.Session, bq_model: bigquery.Model
) -> MultimodalEmbeddingGenerator:
assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
assert "remoteModelInfo" in bq_model._properties
assert "endpoint" in bq_model._properties["remoteModelInfo"]
assert "connection" in bq_model._properties["remoteModelInfo"]

# Parse the remote model endpoint
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
model_connection = bq_model._properties["remoteModelInfo"]["connection"]
model_endpoint = bqml_endpoint.split("/")[-1]

model = cls(
session=session,
model_name=model_endpoint, # type: ignore
connection_name=model_connection,
)

model._bqml_model = core.BqmlModel(session, bq_model)
return model

@property
def _predict_func(
self,
) -> Callable[
[bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame
]:
return self._bqml_model.generate_embedding

@property
def _status_col(self) -> str:
return _ML_GENERATE_EMBEDDING_STATUS

def predict(
self, X: utils.ArrayType, *, max_retries: int = 0
) -> bigframes.dataframe.DataFrame:
"""Predict the result from input DataFrame.

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction.
The content column must be of string type or BigFrames Blob of image or video.

max_retries (int, default 0):
Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.

Returns:
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
"""
if max_retries < 0:
raise ValueError(
f"max_retries must be larger than or equal to 0, but is {max_retries}."
)

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

if len(X.columns) == 1:
# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "content"})

# TODO(garrettwu): remove transform to ObjRefRuntime when BQML supports ObjRef as input
if X["content"].dtype == dtypes.OBJ_REF_DTYPE:
X["content"] = X["content"].blob._get_runtime("R", with_metadata=True)

options = {
"flatten_json_output": True,
}

return self._predict_and_retry(X, options=options, max_retries=max_retries)

def to_gbq(
self, model_name: str, replace: bool = False
) -> MultimodalEmbeddingGenerator:
"""Save the model to BigQuery.

Args:
model_name (str):
The name of the model.
replace (bool, default False):
Determine whether to replace if the model already exists. Default to False.

Returns:
MultimodalEmbeddingGenerator: Saved model."""

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class GeminiTextGenerator(base.RetriableRemotePredictor):
"""Gemini text generator LLM model.
Expand Down
2 changes: 2 additions & 0 deletions bigframes/ml/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
llm._TEXT_EMBEDDING_005_ENDPOINT: llm.TextEmbeddingGenerator,
llm._TEXT_EMBEDDING_004_ENDPOINT: llm.TextEmbeddingGenerator,
llm._TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT: llm.TextEmbeddingGenerator,
llm._MULTIMODAL_EMBEDDING_001_ENDPOINT: llm.MultimodalEmbeddingGenerator,
}
)

Expand All @@ -98,6 +99,7 @@ def from_bq(
llm.PaLM2TextEmbeddingGenerator,
llm.Claude3TextGenerator,
llm.TextEmbeddingGenerator,
llm.MultimodalEmbeddingGenerator,
pipeline.Pipeline,
compose.ColumnTransformer,
preprocessing.PreprocessingType,
Expand Down
3 changes: 3 additions & 0 deletions bigframes/ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def parse_model_endpoint(model_endpoint: str) -> tuple[str, Optional[str]]:
model_name = model_endpoint
version = None

if model_endpoint.startswith("multimodalembedding"):
return model_name, version

at_idx = model_endpoint.find("@")
if at_idx != -1:
version = model_endpoint[at_idx + 1 :]
Expand Down