Skip to content

Commit

Permalink
Changed typehints from dict to Dict.
Browse files Browse the repository at this point in the history
The former is only supported from Python 3.10 onwards and we'd like to support Python 3.9 onwards.
  • Loading branch information
Siraj-Aizlewood committed May 3, 2024
1 parent 6440ceb commit 0212ba4
Show file tree
Hide file tree
Showing 16 changed files with 45 additions and 46 deletions.
2 changes: 1 addition & 1 deletion docs/00-introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.11.4"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions docs/examples/unstructured-element-splitter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@
"source": [
"from unstructured.documents.elements import Element\n",
"from colorama import Fore, Style\n",
"from typing import List\n",
"from typing import List, Dict\n",
"\n",
"\n",
"def group_elements_by_title(elements: List[Element]) -> dict:\n",
"def group_elements_by_title(elements: List[Element]) -> Dict:\n",
" grouped_elements = {}\n",
" current_title = \"Untitled\" # Default title for initial text without a title\n",
"\n",
Expand Down Expand Up @@ -143,10 +143,10 @@
"outputs": [],
"source": [
"from semantic_router.splitters import RollingWindowSplitter\n",
"\n",
"from typing import Dict\n",
"\n",
"def create_title_chunks(\n",
" grouped_elements: dict, splitter: RollingWindowSplitter\n",
" grouped_elements: Dict, splitter: RollingWindowSplitter\n",
") -> list:\n",
" title_with_chunks = []\n",
" for title, elements in grouped_elements.items():\n",
Expand Down
4 changes: 2 additions & 2 deletions replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ def replace_type_hints(file_path):
# Decode the file data with error handling
file_data = file_data.decode("utf-8", errors="ignore")

# Regular expression pattern to find 'dict[Type1, Type2] | None' and replace with 'Optional[dict[Type1, Type2]]'
# Regular expression pattern to find 'Dict[Type1, Type2] | None' and replace with 'Optional[Dict[Type1, Type2]]'.
file_data = re.sub(
r"dict\[(\w+), (\w+)\]\s*\|\s*None", r"Optional[dict[\1, \2]]", file_data
r"Dict\[(\w+), (\w+)\]\s*\|\s*None", r"Optional[Dict[\1, \2]]", file_data
)

with open(file_path, "w") as file:
Expand Down
8 changes: 4 additions & 4 deletions semantic_router/encoders/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

import numpy as np
from pydantic.v1 import PrivateAttr

from typing import Dict
from semantic_router.encoders import BaseEncoder


class CLIPEncoder(BaseEncoder):
name: str = "openai/clip-vit-base-patch16"
type: str = "huggingface"
score_threshold: float = 0.2
tokenizer_kwargs: dict = {}
processor_kwargs: dict = {}
model_kwargs: dict = {}
tokenizer_kwargs: Dict = {}
processor_kwargs: Dict = {}
model_kwargs: Dict = {}
device: Optional[str] = None
_tokenizer: Any = PrivateAttr()
_processor: Any = PrivateAttr()
Expand Down
6 changes: 3 additions & 3 deletions semantic_router/encoders/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import requests
import time
import os
from typing import Any, List, Optional
from typing import Any, List, Optional, Dict

from pydantic.v1 import PrivateAttr

Expand All @@ -35,8 +35,8 @@ class HuggingFaceEncoder(BaseEncoder):
name: str = "sentence-transformers/all-MiniLM-L6-v2"
type: str = "huggingface"
score_threshold: float = 0.5
tokenizer_kwargs: dict = {}
model_kwargs: dict = {}
tokenizer_kwargs: Dict = {}
model_kwargs: Dict = {}
device: Optional[str] = None
_tokenizer: Any = PrivateAttr()
_model: Any = PrivateAttr()
Expand Down
2 changes: 1 addition & 1 deletion semantic_router/encoders/tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def fit(self, routes: List[Route]):
self.word_index = self._build_word_index(docs)
self.idf = self._compute_idf(docs)

def _build_word_index(self, docs: List[str]) -> dict:
def _build_word_index(self, docs: List[str]) -> Dict:
words = set()
for doc in docs:
for word in doc.split():
Expand Down
6 changes: 3 additions & 3 deletions semantic_router/encoders/vit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any, List, Optional, Dict

from pydantic.v1 import PrivateAttr

Expand All @@ -9,8 +9,8 @@ class VitEncoder(BaseEncoder):
name: str = "google/vit-base-patch16-224"
type: str = "huggingface"
score_threshold: float = 0.5
processor_kwargs: dict = {}
model_kwargs: dict = {}
processor_kwargs: Dict = {}
model_kwargs: Dict = {}
device: Optional[str] = None
_processor: Any = PrivateAttr()
_model: Any = PrivateAttr()
Expand Down
4 changes: 2 additions & 2 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union, Dict

import numpy as np
from pydantic.v1 import BaseModel
Expand Down Expand Up @@ -35,7 +35,7 @@ def delete(self, route_name: str):
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def describe(self) -> dict:
def describe(self) -> Dict:
"""
Returns a dictionary with index details such as type, dimensions, and total
vector count.
Expand Down
4 changes: 2 additions & 2 deletions semantic_router/index/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict

import numpy as np

Expand Down Expand Up @@ -49,7 +49,7 @@ def get_routes(self) -> List[Tuple]:
raise ValueError("No routes have been added to the index.")
return list(zip(self.routes, self.utterances))

def describe(self) -> dict:
def describe(self) -> Dict:
return {
"type": self.type,
"dimensions": self.index.shape[1] if self.index is not None else 0,
Expand Down
4 changes: 2 additions & 2 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _init_index(self, force_create: bool = False) -> Union[Any, None]:
self.host = self.client.describe_index(self.index_name)["host"]
return index

def _batch_upsert(self, batch: List[dict]):
def _batch_upsert(self, batch: List[Dict]):
"""Helper method for upserting a single batch of records."""
if self.index is not None:
self.index.upsert(vectors=batch, namespace=self.namespace)
Expand Down Expand Up @@ -241,7 +241,7 @@ def delete(self, route_name: str):
def delete_all(self):
self.index.delete(delete_all=True, namespace=self.namespace)

def describe(self) -> dict:
def describe(self) -> Dict:
if self.index is not None:
stats = self.index.describe_index_stats()
return {
Expand Down
2 changes: 1 addition & 1 deletion semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def delete(self, route_name: str):
),
)

def describe(self) -> dict:
def describe(self) -> Dict:
collection_info = self.client.get_collection(self.index_name)

return {
Expand Down
8 changes: 4 additions & 4 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def _encode(self, text: str) -> Any:

def _retrieve(
self, xq: Any, top_k: int = 5, route_filter: Optional[List[str]] = None
) -> List[dict]:
) -> List[Dict]:
"""Given a query vector, retrieve the top_k most similar records."""
# get scores and routes
scores, routes = self.index.query(
Expand All @@ -448,7 +448,7 @@ def _set_aggregation_method(self, aggregation: str = "sum"):
f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'."
)

def _semantic_classify(self, query_results: List[dict]) -> Tuple[str, List[float]]:
def _semantic_classify(self, query_results: List[Dict]) -> Tuple[str, List[float]]:
scores_by_class = self.group_scores_by_class(query_results)

# Calculate total score for each class
Expand All @@ -473,7 +473,7 @@ def get(self, name: str) -> Optional[Route]:
return None

def _semantic_classify_multiple_routes(
self, query_results: List[dict]
self, query_results: List[Dict]
) -> List[Tuple[str, float]]:
scores_by_class = self.group_scores_by_class(query_results)

Expand All @@ -496,7 +496,7 @@ def _semantic_classify_multiple_routes(
return classes_above_threshold

def group_scores_by_class(
self, query_results: List[dict]
self, query_results: List[Dict]
) -> Dict[str, List[float]]:
scores_by_class: Dict[str, List[float]] = {}
for result in query_results:
Expand Down
10 changes: 5 additions & 5 deletions semantic_router/llms/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, List, Optional
from typing import Any, List, Optional, Dict

from pydantic.v1 import BaseModel

Expand All @@ -20,7 +20,7 @@ def __call__(self, messages: List[Message]) -> Optional[str]:
raise NotImplementedError("Subclasses must implement this method")

def _is_valid_inputs(
self, inputs: List[dict[str, Any]], function_schemas: List[dict[str, Any]]
self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]]
) -> bool:
"""Determine if the functions chosen by the LLM exist within the function_schemas,
and if the input arguments are valid for those functions."""
Expand Down Expand Up @@ -49,7 +49,7 @@ def _is_valid_inputs(
logger.error(f"Input validation error: {str(e)}")
return False

def _validate_single_function_inputs(self, inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool:
def _validate_single_function_inputs(self, inputs: Dict[str, Any], function_schema: Dict[str, Any]) -> bool:
"""Validate the extracted inputs against the function schema"""
try:
# Extract parameter names and types from the signature string
Expand Down Expand Up @@ -78,8 +78,8 @@ def _extract_parameter_info(self, signature: str) -> tuple[List[str], List[str]]
return param_names, param_types

def extract_function_inputs(
self, query: str, function_schemas: List[dict[str, Any]]
) -> dict:
self, query: str, function_schemas: List[Dict[str, Any]]
) -> Dict:
logger.info("Extracting function input...")

prompt = f"""
Expand Down
6 changes: 3 additions & 3 deletions semantic_router/llms/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Optional, List
from typing import Any, Optional, List, Dict

from pydantic.v1 import PrivateAttr

Expand Down Expand Up @@ -79,8 +79,8 @@ def _grammar(self):
self.grammar = None

def extract_function_inputs(
self, query: str, function_schema: dict[str, Any]
) -> dict:
self, query: str, function_schema: Dict[str, Any]
) -> Dict:
with self._grammar():
return super().extract_function_inputs(
query=query, function_schema=function_schema
Expand Down
11 changes: 5 additions & 6 deletions semantic_router/llms/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Optional, Any
from typing import List, Optional, Any, Callable, Dict

import openai
from openai._types import NotGiven
Expand All @@ -11,7 +11,6 @@
import json
from semantic_router.utils.function_call import get_schema, convert_python_type_to_json_type
import inspect
from typing import Callable, Dict
import re

class OpenAILLM(BaseLLM):
Expand Down Expand Up @@ -41,7 +40,7 @@ def __init__(
self.temperature = temperature
self.max_tokens = max_tokens

def _extract_tool_calls_info(self, tool_calls: List[dict[str, Any]]) -> List[dict[str, Any]]:
def _extract_tool_calls_info(self, tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
tool_calls_info = []
for tool_call in tool_calls:
if tool_call.function.arguments is None:
Expand All @@ -57,7 +56,7 @@ def _extract_tool_calls_info(self, tool_calls: List[dict[str, Any]]) -> List[dic
def __call__(
self,
messages: List[Message],
function_schemas: Optional[List[dict[str, Any]]] = None,
function_schemas: Optional[List[Dict[str, Any]]] = None,
) -> str:
if self.client is None:
raise ValueError("OpenAI client is not initialized.")
Expand Down Expand Up @@ -99,8 +98,8 @@ def __call__(
raise Exception(f"LLM error: {e}") from e

def extract_function_inputs(
self, query: str, function_schemas: List[dict[str, Any]]
) -> dict:
self, query: str, function_schemas: List[Dict[str, Any]]
) -> Dict:
messages = []
system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request."
messages.append(Message(role="system", content=system_prompt))
Expand Down
6 changes: 3 additions & 3 deletions semantic_router/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Optional, Union, Any
from typing import List, Optional, Union, Any, Dict
from pydantic.v1 import BaseModel


Expand All @@ -24,7 +24,7 @@ class EncoderInfo(BaseModel):

class RouteChoice(BaseModel):
name: Optional[str] = None
function_call: Optional[List[dict]] = None
function_call: Optional[List[Dict]] = None
similarity_score: Optional[float] = None


Expand Down Expand Up @@ -55,7 +55,7 @@ class DocumentSplit(BaseModel):
is_triggered: bool = False
triggered_score: Optional[float] = None
token_count: Optional[int] = None
metadata: Optional[dict] = None
metadata: Optional[Dict] = None

@property
def content(self) -> str:
Expand Down

0 comments on commit 0212ba4

Please sign in to comment.