Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into feature/lancedb-cloud-support
Browse files Browse the repository at this point in the history
  • Loading branch information
SeeknnDestroy authored Dec 31, 2023
2 parents 61b81ec + 2ec1b21 commit 9494135
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 122 deletions.
2 changes: 1 addition & 1 deletion autollm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
and vector databases, along with various utility functions.
"""

__version__ = '0.1.3'
__version__ = '0.1.5'
__author__ = 'safevideo'
__license__ = 'AGPL-3.0'

Expand Down
98 changes: 98 additions & 0 deletions autollm/auto/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import asyncio
from typing import Any, List

from litellm import embedding as lite_embedding
from llama_index.bridge.pydantic import Field
from llama_index.embeddings.base import BaseEmbedding, Embedding


class AutoEmbedding(BaseEmbedding):
"""
Custom embedding class for flexible and efficient text embedding.
This class interfaces with the LiteLLM library to use its embedding functionality, making it compatible
with a wide range of LLM models.
"""

# Define the model attribute using Pydantic's Field
model: str = Field(default="text-embedding-ada-002", description="The name of the embedding model.")

def __init__(self, model: str, **kwargs: Any) -> None:
"""
Initialize the AutoEmbedding with a specific model.
Args:
model (str): ID of the embedding model to use.
**kwargs (Any): Additional keyword arguments.
"""
super().__init__(**kwargs)
self.model = model # Set the model ID for embedding

def _get_query_embedding(self, query: str) -> Embedding:
"""
Synchronously get the embedding for a query string.
Args:
query (str): The query text to embed.
Returns:
Embedding: The embedding vector.
"""
response = lite_embedding(model=self.model, input=[query])
return self._parse_embedding_response(response)

async def _aget_query_embedding(self, query: str) -> Embedding:
"""
Asynchronously get the embedding for a query string.
Args:
query (str): The query text to embed.
Returns:
Embedding: The embedding vector.
"""
response = await asyncio.to_thread(lite_embedding, model=self.model, input=[query])
return self._parse_embedding_response(response)

def _get_text_embedding(self, text: str) -> Embedding:
"""
Synchronously get the embedding for a text string.
Args:
text (str): The text to embed.
Returns:
Embedding: The embedding vector.
"""
return self._get_query_embedding(text)

async def _aget_text_embedding(self, text: str) -> Embedding:
"""
Asynchronously get the embedding for a text string.
Args:
text (str): The text to embed.
Returns:
Embedding: The embedding vector.
"""
return await self._aget_query_embedding(text)

def _parse_embedding_response(self, response):
"""
Parse the embedding response from LiteLLM and extract the embedding data.
Args:
response: The response object from LiteLLM's embedding function.
Returns:
List[float]: The extracted embedding list.
"""
try:
if 'data' in response and len(response['data']) > 0 and 'embedding' in response['data'][0]:
return response['data'][0]['embedding']
else:
raise ValueError("Invalid response structure from embedding function.")
except (TypeError, KeyError, IndexError) as e:
# Handle any parsing errors
raise ValueError(f"Error parsing embedding response: {e}")
4 changes: 2 additions & 2 deletions autollm/auto/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from llama_index.llms import LiteLLM
from llama_index.llms.base import LLM
from llama_index.llms.base import BaseLLM


class AutoLiteLLM:
Expand All @@ -14,7 +14,7 @@ def from_defaults(
model: str = "gpt-3.5-turbo",
max_tokens: Optional[int] = 256,
temperature: float = 0.1,
api_base: Optional[str] = None) -> LLM:
api_base: Optional[str] = None) -> BaseLLM:
"""
Create any LLM by model name. Check https://docs.litellm.ai/docs/providers for a list of
supported models.
Expand Down
123 changes: 33 additions & 90 deletions autollm/auto/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from llama_index import Document, ServiceContext, VectorStoreIndex
from llama_index.embeddings.utils import EmbedType
from llama_index.indices.query.base import BaseQueryEngine
from llama_index.prompts.base import PromptTemplate
from llama_index.prompts.base import BasePromptTemplate, PromptTemplate
from llama_index.prompts.prompt_type import PromptType
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.schema import BaseNode

from autollm.auto.embedding import AutoEmbedding
from autollm.auto.llm import AutoLiteLLM
from autollm.auto.service_context import AutoServiceContext
from autollm.auto.vector_store_index import AutoVectorStoreIndex
Expand All @@ -24,11 +25,11 @@ def create_query_engine(
llm_api_base: Optional[str] = None,
# service_context_params
system_prompt: str = None,
query_wrapper_prompt: str = None,
query_wrapper_prompt: Union[str, BasePromptTemplate] = None,
enable_cost_calculator: bool = True,
embed_model: Union[str, EmbedType] = "default", # ["default", "local"]
embed_model: Optional[str] = "text-embedding-ada-002",
chunk_size: Optional[int] = 512,
chunk_overlap: Optional[int] = 200,
chunk_overlap: Optional[int] = 100,
context_window: Optional[int] = None,
enable_title_extractor: bool = False,
enable_summary_extractor: bool = False,
Expand All @@ -44,11 +45,8 @@ def create_query_engine(
vector_store_type: str = "LanceDBVectorStore",
lancedb_uri: str = "./.lancedb",
lancedb_table_name: str = "vectors",
# Deprecated parameters
llm_params: dict = None,
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
exist_ok: bool = False,
overwrite_existing: bool = False,
**vector_store_kwargs) -> BaseQueryEngine:
"""
Create a query engine from parameters.
Expand All @@ -61,7 +59,7 @@ def create_query_engine(
llm_temperature (float): The temperature to use for the LLM.
llm_api_base (str): The API base to use for the LLM.
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (str): The query wrapper prompt to use for the query engine.
query_wrapper_prompt (Union[str, BasePromptTemplate]): The query wrapper prompt to use for the query engine.
enable_cost_calculator (bool): Flag to enable cost calculator logging.
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI,
"local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large)
Expand All @@ -83,32 +81,14 @@ def create_query_engine(
Returns:
A llama_index.BaseQueryEngine instance.
"""
# Check for deprecated parameters
if llm_params is not None:
raise ValueError(
"llm_params is deprecated. Instead of llm_params={'llm_model': 'model_name', ...}, "
"use llm_model='model_name', llm_api_base='api_base', llm_max_tokens=1028, llm_temperature=0.1 directly as arguments."
)
if vector_store_params is not None:
raise ValueError(
"vector_store_params is deprecated. Instead of vector_store_params={'vector_store_type': 'type', ...}, "
"use vector_store_type='type', lancedb_uri='uri', lancedb_table_name='table', enable_metadata_extraction=True directly as arguments."
)
if service_context_params is not None:
raise ValueError(
"service_context_params is deprecated. Use the explicit parameters like system_prompt='prompt', "
"query_wrapper_prompt='wrapper', enable_cost_calculator=True, embed_model='model', chunk_size=512, "
"chunk_overlap=..., context_window=... directly as arguments.")
if query_engine_params is not None:
raise ValueError(
"query_engine_params is deprecated. Instead of query_engine_params={'similarity_top_k': 5, ...}, "
"use similarity_top_k=5 directly as an argument.")

llm = AutoLiteLLM.from_defaults(
model=llm_model, api_base=llm_api_base, max_tokens=llm_max_tokens, temperature=llm_temperature)

embedding = AutoEmbedding(model=embed_model)

service_context = AutoServiceContext.from_defaults(
llm=llm,
embed_model=embed_model,
embed_model=embedding,
system_prompt=system_prompt,
query_wrapper_prompt=query_wrapper_prompt,
enable_cost_calculator=enable_cost_calculator,
Expand All @@ -128,15 +108,22 @@ def create_query_engine(
documents=documents,
nodes=nodes,
service_context=service_context,
exist_ok=exist_ok,
overwrite_existing=overwrite_existing,
**vector_store_kwargs)
if refine_prompt is not None:
refine_prompt_template = PromptTemplate(refine_prompt, prompt_type=PromptType.REFINE)
else:
refine_prompt_template = None

# Convert query_wrapper_prompt to PromptTemplate if it is a string
if isinstance(query_wrapper_prompt, str):
query_wrapper_prompt = PromptTemplate(template=query_wrapper_prompt)
response_synthesizer = get_response_synthesizer(
service_context=service_context,
response_mode=response_mode,
text_qa_template=query_wrapper_prompt,
refine_template=refine_prompt_template,
response_mode=response_mode,
structured_answer_filtering=structured_answer_filtering)

return vector_store_index.as_query_engine(
Expand Down Expand Up @@ -168,7 +155,7 @@ class AutoQueryEngine:
system_prompt=None,
query_wrapper_prompt=None,
enable_cost_calculator=True,
embed_model="default", # ["default", "local"]
embed_model="text-embedding-ada-002",
chunk_size=512,
chunk_overlap=None,
context_window=None,
Expand All @@ -178,7 +165,6 @@ class AutoQueryEngine:
vector_store_type="LanceDBVectorStore",
lancedb_uri="./.lancedb",
lancedb_table_name="vectors",
enable_metadata_extraction=False,
**vector_store_kwargs)
)
```
Expand Down Expand Up @@ -207,15 +193,15 @@ def from_defaults(
documents: Optional[Sequence[Document]] = None,
nodes: Optional[Sequence[BaseNode]] = None,
# llm_params
llm_model: str = "gpt-3.5-turbo",
llm_model: Optional[str] = "gpt-3.5-turbo",
llm_api_base: Optional[str] = None,
llm_max_tokens: Optional[int] = None,
llm_temperature: float = 0.1,
llm_temperature: Optional[float] = 0.1,
# service_context_params
system_prompt: str = None,
query_wrapper_prompt: str = None,
query_wrapper_prompt: Union[str, BasePromptTemplate] = None,
enable_cost_calculator: bool = True,
embed_model: Union[str, EmbedType] = "default", # ["default", "local"]
embed_model: Optional[str] = "text-embedding-ada-002",
chunk_size: Optional[int] = 512,
chunk_overlap: Optional[int] = 200,
context_window: Optional[int] = None,
Expand All @@ -228,12 +214,8 @@ def from_defaults(
vector_store_type: str = "LanceDBVectorStore",
lancedb_uri: str = "./.lancedb",
lancedb_table_name: str = "vectors",
enable_metadata_extraction: bool = False,
# Deprecated parameters
llm_params: dict = None,
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
exist_ok: bool = False,
overwrite_existing: bool = False,
**vector_store_kwargs) -> BaseQueryEngine:
"""
Create an AutoQueryEngine from default parameters.
Expand All @@ -246,10 +228,9 @@ def from_defaults(
llm_temperature (float): The temperature to use for the LLM.
llm_api_base (str): The API base to use for the LLM.
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (str): The query wrapper prompt to use for the query engine.
query_wrapper_prompt (Union[str, BasePromptTemplate]): The query wrapper prompt to use for the query engine.
enable_cost_calculator (bool): Flag to enable cost calculator logging.
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI,
"local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large)
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings.
chunk_size (int): The token chunk size for each chunk.
chunk_overlap (int): The token overlap between each chunk.
context_window (int): The maximum context size that will get sent to the LLM.
Expand All @@ -264,6 +245,8 @@ def from_defaults(
vector_store_type (str): The vector store type to use for the query engine.
lancedb_uri (str): The URI to use for the LanceDB vector store.
lancedb_table_name (str): The table name to use for the LanceDB vector store.
exist_ok (bool): Flag to allow overwriting an existing vector store.
overwrite_existing (bool): Flag to allow overwriting an existing vector store.
Returns:
A llama_index.BaseQueryEngine instance.
Expand Down Expand Up @@ -294,50 +277,10 @@ def from_defaults(
vector_store_type=vector_store_type,
lancedb_uri=lancedb_uri,
lancedb_table_name=lancedb_table_name,
enable_metadata_extraction=enable_metadata_extraction,
# Deprecated parameters
llm_params=llm_params,
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
exist_ok=exist_ok,
overwrite_existing=overwrite_existing,
**vector_store_kwargs)

@staticmethod
def from_parameters(
documents: Sequence[Document] = None,
system_prompt: str = None,
query_wrapper_prompt: str = None,
enable_cost_calculator: bool = True,
embed_model: Union[str, EmbedType] = "default", # ["default", "local"]
llm_params: dict = None,
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None) -> BaseQueryEngine:
"""
DEPRECATED. Use AutoQueryEngine.from_defaults instead.
Create an AutoQueryEngine from parameters.
Parameters:
documents (Sequence[Document]): Sequence of llama_index.Document instances.
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (str): The query wrapper prompt to use for the query engine.
enable_cost_calculator (bool): Flag to enable cost calculator logging.
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI,
"local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large)
llm_params (dict): Parameters for the LLM.
vector_store_params (dict): Parameters for the vector store.
service_context_params (dict): Parameters for the service context.
query_engine_params (dict): Parameters for the query engine.
Returns:
A llama_index.BaseQueryEngine instance.
"""

# TODO: Remove this method in the next release
raise ValueError(
"AutoQueryEngine.from_parameters is deprecated. Use AutoQueryEngine.from_defaults instead.")

@staticmethod
def from_config(
config_file_path: str,
Expand Down
4 changes: 2 additions & 2 deletions autollm/auto/service_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def from_defaults(
query_wrapper_prompt: Union[str, BasePromptTemplate] = None,
enable_cost_calculator: bool = False,
chunk_size: Optional[int] = 512,
chunk_overlap: Optional[int] = 200,
chunk_overlap: Optional[int] = 100,
context_window: Optional[int] = None,
enable_title_extractor: bool = False,
enable_summary_extractor: bool = False,
Expand Down Expand Up @@ -65,7 +65,7 @@ def from_defaults(
"""
if not system_prompt and not query_wrapper_prompt:
system_prompt, query_wrapper_prompt = set_default_prompt_template()
# Convert system_prompt to ChatPromptTemplate if it is a string
# Convert query_wrapper_prompt to PromptTemplate if it is a string
if isinstance(query_wrapper_prompt, str):
query_wrapper_prompt = PromptTemplate(template=query_wrapper_prompt)

Expand Down
Loading

0 comments on commit 9494135

Please sign in to comment.