|
57 | 57 | _TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT,
|
58 | 58 | )
|
59 | 59 |
|
| 60 | +_MULTIMODAL_EMBEDDING_001_ENDPOINT = "multimodalembedding@001" |
| 61 | + |
60 | 62 | _GEMINI_PRO_ENDPOINT = "gemini-pro"
|
61 | 63 | _GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514"
|
62 | 64 | _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514"
|
@@ -762,6 +764,152 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerat
|
762 | 764 | return new_model.session.read_gbq_model(model_name)
|
763 | 765 |
|
764 | 766 |
|
| 767 | +@log_adapter.class_logger |
| 768 | +class MultimodalEmbeddingGenerator(base.RetriableRemotePredictor): |
| 769 | + """Multimodal embedding generator LLM model. |
| 770 | +
|
| 771 | + .. note:: |
| 772 | + BigFrames Blob is still under experiments. It may not work and subject to change in the future. |
| 773 | +
|
| 774 | + Args: |
| 775 | + model_name (str, Default to "multimodalembedding@001"): |
| 776 | + The model for multimodal embedding. Can set to "multimodalembedding@001". Multimodal-embedding models returns model embeddings for text, image and video inputs. |
| 777 | + Default to "multimodalembedding@001". |
| 778 | + session (bigframes.Session or None): |
| 779 | + BQ session to create the model. If None, use the global default session. |
| 780 | + connection_name (str or None): |
| 781 | + Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>. |
| 782 | + If None, use default connection in session context. |
| 783 | + """ |
| 784 | + |
| 785 | + def __init__( |
| 786 | + self, |
| 787 | + *, |
| 788 | + model_name: Literal["multimodalembedding@001"] = "multimodalembedding@001", |
| 789 | + session: Optional[bigframes.Session] = None, |
| 790 | + connection_name: Optional[str] = None, |
| 791 | + ): |
| 792 | + if not bigframes.options.experiments.blob: |
| 793 | + raise NotImplementedError() |
| 794 | + self.model_name = model_name |
| 795 | + self.session = session or global_session.get_global_session() |
| 796 | + self.connection_name = connection_name |
| 797 | + |
| 798 | + self._bqml_model_factory = globals.bqml_model_factory() |
| 799 | + self._bqml_model: core.BqmlModel = self._create_bqml_model() |
| 800 | + |
| 801 | + def _create_bqml_model(self): |
| 802 | + # Parse and create connection if needed. |
| 803 | + self.connection_name = self.session._create_bq_connection( |
| 804 | + connection=self.connection_name, iam_role="aiplatform.user" |
| 805 | + ) |
| 806 | + |
| 807 | + if self.model_name != _MULTIMODAL_EMBEDDING_001_ENDPOINT: |
| 808 | + msg = _MODEL_NOT_SUPPORTED_WARNING.format( |
| 809 | + model_name=self.model_name, |
| 810 | + known_models=_MULTIMODAL_EMBEDDING_001_ENDPOINT, |
| 811 | + ) |
| 812 | + warnings.warn(msg) |
| 813 | + |
| 814 | + options = { |
| 815 | + "endpoint": self.model_name, |
| 816 | + } |
| 817 | + return self._bqml_model_factory.create_remote_model( |
| 818 | + session=self.session, connection_name=self.connection_name, options=options |
| 819 | + ) |
| 820 | + |
| 821 | + @classmethod |
| 822 | + def _from_bq( |
| 823 | + cls, session: bigframes.Session, bq_model: bigquery.Model |
| 824 | + ) -> MultimodalEmbeddingGenerator: |
| 825 | + assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" |
| 826 | + assert "remoteModelInfo" in bq_model._properties |
| 827 | + assert "endpoint" in bq_model._properties["remoteModelInfo"] |
| 828 | + assert "connection" in bq_model._properties["remoteModelInfo"] |
| 829 | + |
| 830 | + # Parse the remote model endpoint |
| 831 | + bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] |
| 832 | + model_connection = bq_model._properties["remoteModelInfo"]["connection"] |
| 833 | + model_endpoint = bqml_endpoint.split("/")[-1] |
| 834 | + |
| 835 | + model = cls( |
| 836 | + session=session, |
| 837 | + model_name=model_endpoint, # type: ignore |
| 838 | + connection_name=model_connection, |
| 839 | + ) |
| 840 | + |
| 841 | + model._bqml_model = core.BqmlModel(session, bq_model) |
| 842 | + return model |
| 843 | + |
| 844 | + @property |
| 845 | + def _predict_func( |
| 846 | + self, |
| 847 | + ) -> Callable[ |
| 848 | + [bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame |
| 849 | + ]: |
| 850 | + return self._bqml_model.generate_embedding |
| 851 | + |
| 852 | + @property |
| 853 | + def _status_col(self) -> str: |
| 854 | + return _ML_GENERATE_EMBEDDING_STATUS |
| 855 | + |
| 856 | + def predict( |
| 857 | + self, X: utils.ArrayType, *, max_retries: int = 0 |
| 858 | + ) -> bigframes.dataframe.DataFrame: |
| 859 | + """Predict the result from input DataFrame. |
| 860 | +
|
| 861 | + Args: |
| 862 | + X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series): |
| 863 | + Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction. |
| 864 | + The content column must be of string type or BigFrames Blob of image or video. |
| 865 | +
|
| 866 | + max_retries (int, default 0): |
| 867 | + Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry. |
| 868 | + Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result. |
| 869 | +
|
| 870 | + Returns: |
| 871 | + bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. |
| 872 | + """ |
| 873 | + if max_retries < 0: |
| 874 | + raise ValueError( |
| 875 | + f"max_retries must be larger than or equal to 0, but is {max_retries}." |
| 876 | + ) |
| 877 | + |
| 878 | + (X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) |
| 879 | + |
| 880 | + if len(X.columns) == 1: |
| 881 | + # BQML identified the column by name |
| 882 | + col_label = cast(blocks.Label, X.columns[0]) |
| 883 | + X = X.rename(columns={col_label: "content"}) |
| 884 | + |
| 885 | + # TODO(garrettwu): remove transform to ObjRefRuntime when BQML supports ObjRef as input |
| 886 | + if X["content"].dtype == dtypes.OBJ_REF_DTYPE: |
| 887 | + X["content"] = X["content"].blob._get_runtime("R", with_metadata=True) |
| 888 | + |
| 889 | + options = { |
| 890 | + "flatten_json_output": True, |
| 891 | + } |
| 892 | + |
| 893 | + return self._predict_and_retry(X, options=options, max_retries=max_retries) |
| 894 | + |
| 895 | + def to_gbq( |
| 896 | + self, model_name: str, replace: bool = False |
| 897 | + ) -> MultimodalEmbeddingGenerator: |
| 898 | + """Save the model to BigQuery. |
| 899 | +
|
| 900 | + Args: |
| 901 | + model_name (str): |
| 902 | + The name of the model. |
| 903 | + replace (bool, default False): |
| 904 | + Determine whether to replace if the model already exists. Default to False. |
| 905 | +
|
| 906 | + Returns: |
| 907 | + MultimodalEmbeddingGenerator: Saved model.""" |
| 908 | + |
| 909 | + new_model = self._bqml_model.copy(model_name, replace) |
| 910 | + return new_model.session.read_gbq_model(model_name) |
| 911 | + |
| 912 | + |
765 | 913 | @log_adapter.class_logger
|
766 | 914 | class GeminiTextGenerator(base.RetriableRemotePredictor):
|
767 | 915 | """Gemini text generator LLM model.
|
|
0 commit comments