Skip to content

Commit

Permalink
chore: add experimental MutimodalEmbeddingGenerator class (#1374)
Browse files Browse the repository at this point in the history
* chore: add experimental MutimodalEmbeddingGenerator class

* fix
  • Loading branch information
GarrettWu authored Feb 7, 2025
1 parent 923da03 commit 3989fc2
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 0 deletions.
148 changes: 148 additions & 0 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
_TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT,
)

_MULTIMODAL_EMBEDDING_001_ENDPOINT = "multimodalembedding@001"

_GEMINI_PRO_ENDPOINT = "gemini-pro"
_GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514"
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514"
Expand Down Expand Up @@ -762,6 +764,152 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerat
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class MultimodalEmbeddingGenerator(base.RetriableRemotePredictor):
"""Multimodal embedding generator LLM model.
.. note::
BigFrames Blob is still under experiments. It may not work and subject to change in the future.
Args:
model_name (str, Default to "multimodalembedding@001"):
The model for multimodal embedding. Can set to "multimodalembedding@001". Multimodal-embedding models returns model embeddings for text, image and video inputs.
Default to "multimodalembedding@001".
session (bigframes.Session or None):
BQ session to create the model. If None, use the global default session.
connection_name (str or None):
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
If None, use default connection in session context.
"""

def __init__(
self,
*,
model_name: Literal["multimodalembedding@001"] = "multimodalembedding@001",
session: Optional[bigframes.Session] = None,
connection_name: Optional[str] = None,
):
if not bigframes.options.experiments.blob:
raise NotImplementedError()
self.model_name = model_name
self.session = session or global_session.get_global_session()
self.connection_name = connection_name

self._bqml_model_factory = globals.bqml_model_factory()
self._bqml_model: core.BqmlModel = self._create_bqml_model()

def _create_bqml_model(self):
# Parse and create connection if needed.
self.connection_name = self.session._create_bq_connection(
connection=self.connection_name, iam_role="aiplatform.user"
)

if self.model_name != _MULTIMODAL_EMBEDDING_001_ENDPOINT:
msg = _MODEL_NOT_SUPPORTED_WARNING.format(
model_name=self.model_name,
known_models=_MULTIMODAL_EMBEDDING_001_ENDPOINT,
)
warnings.warn(msg)

options = {
"endpoint": self.model_name,
}
return self._bqml_model_factory.create_remote_model(
session=self.session, connection_name=self.connection_name, options=options
)

@classmethod
def _from_bq(
cls, session: bigframes.Session, bq_model: bigquery.Model
) -> MultimodalEmbeddingGenerator:
assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
assert "remoteModelInfo" in bq_model._properties
assert "endpoint" in bq_model._properties["remoteModelInfo"]
assert "connection" in bq_model._properties["remoteModelInfo"]

# Parse the remote model endpoint
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
model_connection = bq_model._properties["remoteModelInfo"]["connection"]
model_endpoint = bqml_endpoint.split("/")[-1]

model = cls(
session=session,
model_name=model_endpoint, # type: ignore
connection_name=model_connection,
)

model._bqml_model = core.BqmlModel(session, bq_model)
return model

@property
def _predict_func(
self,
) -> Callable[
[bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame
]:
return self._bqml_model.generate_embedding

@property
def _status_col(self) -> str:
return _ML_GENERATE_EMBEDDING_STATUS

def predict(
self, X: utils.ArrayType, *, max_retries: int = 0
) -> bigframes.dataframe.DataFrame:
"""Predict the result from input DataFrame.
Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction.
The content column must be of string type or BigFrames Blob of image or video.
max_retries (int, default 0):
Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
Returns:
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
"""
if max_retries < 0:
raise ValueError(
f"max_retries must be larger than or equal to 0, but is {max_retries}."
)

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

if len(X.columns) == 1:
# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "content"})

# TODO(garrettwu): remove transform to ObjRefRuntime when BQML supports ObjRef as input
if X["content"].dtype == dtypes.OBJ_REF_DTYPE:
X["content"] = X["content"].blob._get_runtime("R", with_metadata=True)

options = {
"flatten_json_output": True,
}

return self._predict_and_retry(X, options=options, max_retries=max_retries)

def to_gbq(
self, model_name: str, replace: bool = False
) -> MultimodalEmbeddingGenerator:
"""Save the model to BigQuery.
Args:
model_name (str):
The name of the model.
replace (bool, default False):
Determine whether to replace if the model already exists. Default to False.
Returns:
MultimodalEmbeddingGenerator: Saved model."""

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class GeminiTextGenerator(base.RetriableRemotePredictor):
"""Gemini text generator LLM model.
Expand Down
2 changes: 2 additions & 0 deletions bigframes/ml/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
llm._TEXT_EMBEDDING_005_ENDPOINT: llm.TextEmbeddingGenerator,
llm._TEXT_EMBEDDING_004_ENDPOINT: llm.TextEmbeddingGenerator,
llm._TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT: llm.TextEmbeddingGenerator,
llm._MULTIMODAL_EMBEDDING_001_ENDPOINT: llm.MultimodalEmbeddingGenerator,
}
)

Expand All @@ -98,6 +99,7 @@ def from_bq(
llm.PaLM2TextEmbeddingGenerator,
llm.Claude3TextGenerator,
llm.TextEmbeddingGenerator,
llm.MultimodalEmbeddingGenerator,
pipeline.Pipeline,
compose.ColumnTransformer,
preprocessing.PreprocessingType,
Expand Down
3 changes: 3 additions & 0 deletions bigframes/ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def parse_model_endpoint(model_endpoint: str) -> tuple[str, Optional[str]]:
model_name = model_endpoint
version = None

if model_endpoint.startswith("multimodalembedding"):
return model_name, version

at_idx = model_endpoint.find("@")
if at_idx != -1:
version = model_endpoint[at_idx + 1 :]
Expand Down

0 comments on commit 3989fc2

Please sign in to comment.