From e82acb991691e1280b67799900fb8f0fa9746f39 Mon Sep 17 00:00:00 2001 From: wxywb Date: Fri, 23 Aug 2024 18:48:58 +0800 Subject: [PATCH] fix: Add user-friendly message for BGE-M3 users on Google Colab. (#33) Signed-off-by: wxywb --- milvus_model/hybrid/bge_m3.py | 13 +++++++++++-- milvus_model/utils/__init__.py | 4 ++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/milvus_model/hybrid/bge_m3.py b/milvus_model/hybrid/bge_m3.py index f5347bf..6d4f6d1 100644 --- a/milvus_model/hybrid/bge_m3.py +++ b/milvus_model/hybrid/bge_m3.py @@ -5,11 +5,20 @@ import numpy as np from milvus_model.base import BaseEmbeddingFunction -from milvus_model.utils import import_FlagEmbedding +from milvus_model.utils import import_FlagEmbedding, import_datasets from milvus_model.sparse.utils import stack_sparse_embeddings +import_datasets() import_FlagEmbedding() -from FlagEmbedding import BGEM3FlagModel + +try: + from FlagEmbedding import BGEM3FlagModel +except AttributeError as e: + import sys + if "google.colab" in sys.modules and "ListView" in str(e): + print("\033[91mIt looks like you're running on Google Colab. Please restart the session to resolve this issue.\033[0m") + print("\033[91mFor further details, visit: https://github.com/milvus-io/milvus-model/issues/32.\033[0m") + raise logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/milvus_model/utils/__init__.py b/milvus_model/utils/__init__.py index 6587f9d..39a5e23 100644 --- a/milvus_model/utils/__init__.py +++ b/milvus_model/utils/__init__.py @@ -17,6 +17,7 @@ "import_mistralai", "import_nomic", "import_instructor", + "import_datasets", ] import importlib.util @@ -79,6 +80,9 @@ def import_nomic(): def import_instructor(): _check_library("InstructorEmbedding", package="InstructorEmbedding") +def import_datasets(): + _check_library("datasets", package="datasets") + def _check_library(libname: str, prompt: bool = True, package: Optional[str] = None): is_avail = False if importlib.util.find_spec(libname):