diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 7b66191a11..72c49e124b 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -57,6 +57,8 @@ _TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT, ) +_MULTIMODAL_EMBEDDING_001_ENDPOINT = "multimodalembedding@001" + _GEMINI_PRO_ENDPOINT = "gemini-pro" _GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514" _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514" @@ -762,6 +764,152 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerat return new_model.session.read_gbq_model(model_name) +@log_adapter.class_logger +class MultimodalEmbeddingGenerator(base.RetriableRemotePredictor): + """Multimodal embedding generator LLM model. + + .. note:: + BigFrames Blob is still under experiments. It may not work and subject to change in the future. + + Args: + model_name (str, Default to "multimodalembedding@001"): + The model for multimodal embedding. Can set to "multimodalembedding@001". Multimodal-embedding models returns model embeddings for text, image and video inputs. + Default to "multimodalembedding@001". + session (bigframes.Session or None): + BQ session to create the model. If None, use the global default session. + connection_name (str or None): + Connection to connect with remote service. str of the format ... + If None, use default connection in session context. + """ + + def __init__( + self, + *, + model_name: Literal["multimodalembedding@001"] = "multimodalembedding@001", + session: Optional[bigframes.Session] = None, + connection_name: Optional[str] = None, + ): + if not bigframes.options.experiments.blob: + raise NotImplementedError() + self.model_name = model_name + self.session = session or global_session.get_global_session() + self.connection_name = connection_name + + self._bqml_model_factory = globals.bqml_model_factory() + self._bqml_model: core.BqmlModel = self._create_bqml_model() + + def _create_bqml_model(self): + # Parse and create connection if needed. + self.connection_name = self.session._create_bq_connection( + connection=self.connection_name, iam_role="aiplatform.user" + ) + + if self.model_name != _MULTIMODAL_EMBEDDING_001_ENDPOINT: + msg = _MODEL_NOT_SUPPORTED_WARNING.format( + model_name=self.model_name, + known_models=_MULTIMODAL_EMBEDDING_001_ENDPOINT, + ) + warnings.warn(msg) + + options = { + "endpoint": self.model_name, + } + return self._bqml_model_factory.create_remote_model( + session=self.session, connection_name=self.connection_name, options=options + ) + + @classmethod + def _from_bq( + cls, session: bigframes.Session, bq_model: bigquery.Model + ) -> MultimodalEmbeddingGenerator: + assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" + assert "remoteModelInfo" in bq_model._properties + assert "endpoint" in bq_model._properties["remoteModelInfo"] + assert "connection" in bq_model._properties["remoteModelInfo"] + + # Parse the remote model endpoint + bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] + model_connection = bq_model._properties["remoteModelInfo"]["connection"] + model_endpoint = bqml_endpoint.split("/")[-1] + + model = cls( + session=session, + model_name=model_endpoint, # type: ignore + connection_name=model_connection, + ) + + model._bqml_model = core.BqmlModel(session, bq_model) + return model + + @property + def _predict_func( + self, + ) -> Callable[ + [bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame + ]: + return self._bqml_model.generate_embedding + + @property + def _status_col(self) -> str: + return _ML_GENERATE_EMBEDDING_STATUS + + def predict( + self, X: utils.ArrayType, *, max_retries: int = 0 + ) -> bigframes.dataframe.DataFrame: + """Predict the result from input DataFrame. + + Args: + X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series): + Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction. + The content column must be of string type or BigFrames Blob of image or video. + + max_retries (int, default 0): + Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry. + Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result. + + Returns: + bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. + """ + if max_retries < 0: + raise ValueError( + f"max_retries must be larger than or equal to 0, but is {max_retries}." + ) + + (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) + + if len(X.columns) == 1: + # BQML identified the column by name + col_label = cast(blocks.Label, X.columns[0]) + X = X.rename(columns={col_label: "content"}) + + # TODO(garrettwu): remove transform to ObjRefRuntime when BQML supports ObjRef as input + if X["content"].dtype == dtypes.OBJ_REF_DTYPE: + X["content"] = X["content"].blob._get_runtime("R", with_metadata=True) + + options = { + "flatten_json_output": True, + } + + return self._predict_and_retry(X, options=options, max_retries=max_retries) + + def to_gbq( + self, model_name: str, replace: bool = False + ) -> MultimodalEmbeddingGenerator: + """Save the model to BigQuery. + + Args: + model_name (str): + The name of the model. + replace (bool, default False): + Determine whether to replace if the model already exists. Default to False. + + Returns: + MultimodalEmbeddingGenerator: Saved model.""" + + new_model = self._bqml_model.copy(model_name, replace) + return new_model.session.read_gbq_model(model_name) + + @log_adapter.class_logger class GeminiTextGenerator(base.RetriableRemotePredictor): """Gemini text generator LLM model. diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index 5d52927ded..eef72584bc 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -75,6 +75,7 @@ llm._TEXT_EMBEDDING_005_ENDPOINT: llm.TextEmbeddingGenerator, llm._TEXT_EMBEDDING_004_ENDPOINT: llm.TextEmbeddingGenerator, llm._TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT: llm.TextEmbeddingGenerator, + llm._MULTIMODAL_EMBEDDING_001_ENDPOINT: llm.MultimodalEmbeddingGenerator, } ) @@ -98,6 +99,7 @@ def from_bq( llm.PaLM2TextEmbeddingGenerator, llm.Claude3TextGenerator, llm.TextEmbeddingGenerator, + llm.MultimodalEmbeddingGenerator, pipeline.Pipeline, compose.ColumnTransformer, preprocessing.PreprocessingType, diff --git a/bigframes/ml/utils.py b/bigframes/ml/utils.py index e1620485d5..e034fd00f7 100644 --- a/bigframes/ml/utils.py +++ b/bigframes/ml/utils.py @@ -100,6 +100,9 @@ def parse_model_endpoint(model_endpoint: str) -> tuple[str, Optional[str]]: model_name = model_endpoint version = None + if model_endpoint.startswith("multimodalembedding"): + return model_name, version + at_idx = model_endpoint.find("@") if at_idx != -1: version = model_endpoint[at_idx + 1 :]