From c5c759c4c540a1405aaebe00e28c81f3fc8a160a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Benjamin=20Clavi=C3=A9?= Date: Fri, 23 Feb 2024 12:12:56 +0100 Subject: [PATCH] Patch/llama index and device (#155) * fix: llamaindex imports * fix: rare device mismatch * chore: use only ruff linting * ruff * ruff * isort --- ...hout_annotations_with_instructor_and_RAGatouille.ipynb | 3 ++- pyproject.toml | 5 +---- ragatouille/data/preprocessors.py | 8 ++++++-- ragatouille/models/colbert.py | 8 ++++---- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/03-finetuning_without_annotations_with_instructor_and_RAGatouille.ipynb b/examples/03-finetuning_without_annotations_with_instructor_and_RAGatouille.ipynb index a270857..9091807 100644 --- a/examples/03-finetuning_without_annotations_with_instructor_and_RAGatouille.ipynb +++ b/examples/03-finetuning_without_annotations_with_instructor_and_RAGatouille.ipynb @@ -106,6 +106,7 @@ "outputs": [], "source": [ "import instructor\n", + "# If you're using llamaindex 0.10 or above, these need to be imported from llama_index.core instead!\n", "from llama_index import Document\n", "from llama_index.text_splitter import SentenceSplitter\n", "from openai import OpenAI\n", @@ -459,7 +460,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index b9a1415..58c3e47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,4 @@ unfixable = [ "T201", "T203", ] -ignore-init-module-imports = true - -[tool.ruff.lint.isort] -section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] +ignore-init-module-imports = true \ No newline at end of file diff --git a/ragatouille/data/preprocessors.py b/ragatouille/data/preprocessors.py index 0984aaf..13f872d 100644 --- a/ragatouille/data/preprocessors.py +++ b/ragatouille/data/preprocessors.py @@ -1,5 +1,9 @@ -from llama_index import Document -from llama_index.text_splitter import SentenceSplitter +try: + from llama_index import Document + from llama_index.text_splitter import SentenceSplitter +except ImportError: + from llama_index.core import Document + from llama_index.core.text_splitter import SentenceSplitter def llama_index_sentence_splitter( diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index 8024531..534a952 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -191,8 +191,8 @@ def add_to_index( bsize=bsize, ) else: - if self.config.index_bsize != bsize: # Update bsize if it's different - self.config.index_bsize = bsize + if self.config.index_bsize != bsize: # Update bsize if it's different + self.config.index_bsize = bsize updater = IndexUpdater( config=self.config, searcher=searcher, checkpoint=self.checkpoint @@ -757,7 +757,7 @@ def encode( - encodings.shape[1], encodings.shape[2], ) - ), + ).to(device=encodings.device), ], dim=1, ) @@ -771,7 +771,7 @@ def encode( - doc_masks.shape[1], ), -float("inf"), - ), + ).to(device=encodings.device), ], dim=1, )