Skip to content

Commit 3989fc2

Browse files
authored
chore: add experimental MutimodalEmbeddingGenerator class (#1374)
* chore: add experimental MutimodalEmbeddingGenerator class * fix
1 parent 923da03 commit 3989fc2

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

bigframes/ml/llm.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
_TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT,
5858
)
5959

60+
_MULTIMODAL_EMBEDDING_001_ENDPOINT = "multimodalembedding@001"
61+
6062
_GEMINI_PRO_ENDPOINT = "gemini-pro"
6163
_GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514"
6264
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514"
@@ -762,6 +764,152 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerat
762764
return new_model.session.read_gbq_model(model_name)
763765

764766

767+
@log_adapter.class_logger
768+
class MultimodalEmbeddingGenerator(base.RetriableRemotePredictor):
769+
"""Multimodal embedding generator LLM model.
770+
771+
.. note::
772+
BigFrames Blob is still under experiments. It may not work and subject to change in the future.
773+
774+
Args:
775+
model_name (str, Default to "multimodalembedding@001"):
776+
The model for multimodal embedding. Can set to "multimodalembedding@001". Multimodal-embedding models returns model embeddings for text, image and video inputs.
777+
Default to "multimodalembedding@001".
778+
session (bigframes.Session or None):
779+
BQ session to create the model. If None, use the global default session.
780+
connection_name (str or None):
781+
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
782+
If None, use default connection in session context.
783+
"""
784+
785+
def __init__(
786+
self,
787+
*,
788+
model_name: Literal["multimodalembedding@001"] = "multimodalembedding@001",
789+
session: Optional[bigframes.Session] = None,
790+
connection_name: Optional[str] = None,
791+
):
792+
if not bigframes.options.experiments.blob:
793+
raise NotImplementedError()
794+
self.model_name = model_name
795+
self.session = session or global_session.get_global_session()
796+
self.connection_name = connection_name
797+
798+
self._bqml_model_factory = globals.bqml_model_factory()
799+
self._bqml_model: core.BqmlModel = self._create_bqml_model()
800+
801+
def _create_bqml_model(self):
802+
# Parse and create connection if needed.
803+
self.connection_name = self.session._create_bq_connection(
804+
connection=self.connection_name, iam_role="aiplatform.user"
805+
)
806+
807+
if self.model_name != _MULTIMODAL_EMBEDDING_001_ENDPOINT:
808+
msg = _MODEL_NOT_SUPPORTED_WARNING.format(
809+
model_name=self.model_name,
810+
known_models=_MULTIMODAL_EMBEDDING_001_ENDPOINT,
811+
)
812+
warnings.warn(msg)
813+
814+
options = {
815+
"endpoint": self.model_name,
816+
}
817+
return self._bqml_model_factory.create_remote_model(
818+
session=self.session, connection_name=self.connection_name, options=options
819+
)
820+
821+
@classmethod
822+
def _from_bq(
823+
cls, session: bigframes.Session, bq_model: bigquery.Model
824+
) -> MultimodalEmbeddingGenerator:
825+
assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
826+
assert "remoteModelInfo" in bq_model._properties
827+
assert "endpoint" in bq_model._properties["remoteModelInfo"]
828+
assert "connection" in bq_model._properties["remoteModelInfo"]
829+
830+
# Parse the remote model endpoint
831+
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
832+
model_connection = bq_model._properties["remoteModelInfo"]["connection"]
833+
model_endpoint = bqml_endpoint.split("/")[-1]
834+
835+
model = cls(
836+
session=session,
837+
model_name=model_endpoint, # type: ignore
838+
connection_name=model_connection,
839+
)
840+
841+
model._bqml_model = core.BqmlModel(session, bq_model)
842+
return model
843+
844+
@property
845+
def _predict_func(
846+
self,
847+
) -> Callable[
848+
[bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame
849+
]:
850+
return self._bqml_model.generate_embedding
851+
852+
@property
853+
def _status_col(self) -> str:
854+
return _ML_GENERATE_EMBEDDING_STATUS
855+
856+
def predict(
857+
self, X: utils.ArrayType, *, max_retries: int = 0
858+
) -> bigframes.dataframe.DataFrame:
859+
"""Predict the result from input DataFrame.
860+
861+
Args:
862+
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
863+
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction.
864+
The content column must be of string type or BigFrames Blob of image or video.
865+
866+
max_retries (int, default 0):
867+
Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
868+
Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
869+
870+
Returns:
871+
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
872+
"""
873+
if max_retries < 0:
874+
raise ValueError(
875+
f"max_retries must be larger than or equal to 0, but is {max_retries}."
876+
)
877+
878+
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
879+
880+
if len(X.columns) == 1:
881+
# BQML identified the column by name
882+
col_label = cast(blocks.Label, X.columns[0])
883+
X = X.rename(columns={col_label: "content"})
884+
885+
# TODO(garrettwu): remove transform to ObjRefRuntime when BQML supports ObjRef as input
886+
if X["content"].dtype == dtypes.OBJ_REF_DTYPE:
887+
X["content"] = X["content"].blob._get_runtime("R", with_metadata=True)
888+
889+
options = {
890+
"flatten_json_output": True,
891+
}
892+
893+
return self._predict_and_retry(X, options=options, max_retries=max_retries)
894+
895+
def to_gbq(
896+
self, model_name: str, replace: bool = False
897+
) -> MultimodalEmbeddingGenerator:
898+
"""Save the model to BigQuery.
899+
900+
Args:
901+
model_name (str):
902+
The name of the model.
903+
replace (bool, default False):
904+
Determine whether to replace if the model already exists. Default to False.
905+
906+
Returns:
907+
MultimodalEmbeddingGenerator: Saved model."""
908+
909+
new_model = self._bqml_model.copy(model_name, replace)
910+
return new_model.session.read_gbq_model(model_name)
911+
912+
765913
@log_adapter.class_logger
766914
class GeminiTextGenerator(base.RetriableRemotePredictor):
767915
"""Gemini text generator LLM model.

bigframes/ml/loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
llm._TEXT_EMBEDDING_005_ENDPOINT: llm.TextEmbeddingGenerator,
7676
llm._TEXT_EMBEDDING_004_ENDPOINT: llm.TextEmbeddingGenerator,
7777
llm._TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT: llm.TextEmbeddingGenerator,
78+
llm._MULTIMODAL_EMBEDDING_001_ENDPOINT: llm.MultimodalEmbeddingGenerator,
7879
}
7980
)
8081

@@ -98,6 +99,7 @@ def from_bq(
9899
llm.PaLM2TextEmbeddingGenerator,
99100
llm.Claude3TextGenerator,
100101
llm.TextEmbeddingGenerator,
102+
llm.MultimodalEmbeddingGenerator,
101103
pipeline.Pipeline,
102104
compose.ColumnTransformer,
103105
preprocessing.PreprocessingType,

bigframes/ml/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def parse_model_endpoint(model_endpoint: str) -> tuple[str, Optional[str]]:
100100
model_name = model_endpoint
101101
version = None
102102

103+
if model_endpoint.startswith("multimodalembedding"):
104+
return model_name, version
105+
103106
at_idx = model_endpoint.find("@")
104107
if at_idx != -1:
105108
version = model_endpoint[at_idx + 1 :]

0 commit comments

Comments
 (0)