Skip to content

Commit

Permalink
feat: simplify and align routers call methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Jan 3, 2025
1 parent 84d9a7d commit ec4c216
Show file tree
Hide file tree
Showing 3 changed files with 516 additions and 489 deletions.
113 changes: 28 additions & 85 deletions semantic_router/routers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,19 @@ def __call__(
vector = self._encode(text=[text])
# convert to numpy array if not already
vector = xq_reshape(vector)
# calculate semantics
route, top_class_scores = self._retrieve_top_route(vector, route_filter)
# get scores and routes
scores, routes = self.index.query(
vector=vector[0], top_k=self.top_k, route_filter=route_filter
)
query_results = [
{"route": d, "score": s.item()} for d, s in zip(routes, scores)
]
# decide most relevant routes
top_class, top_class_scores = self._semantic_classify(
query_results=query_results
)
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
passed = self._check_threshold(top_class_scores, route)
if passed and route is not None and not simulate_static:
if route.function_schemas and text is None:
Expand Down Expand Up @@ -473,10 +484,19 @@ async def acall(
vector = await self._async_encode(text=[text])
# convert to numpy array if not already
vector = xq_reshape(vector)
# calculate semantics
route, top_class_scores = await self._async_retrieve_top_route(
vector, route_filter
# get scores and routes
scores, routes = await self.index.aquery(
vector=vector[0], top_k=self.top_k, route_filter=route_filter
)
query_results = [
{"route": d, "score": s.item()} for d, s in zip(routes, scores)
]
# decide most relevant routes
top_class, top_class_scores = await self._async_semantic_classify(
query_results=query_results
)
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
passed = self._check_threshold(top_class_scores, route)
if passed and route is not None and not simulate_static:
if route.function_schemas and text is None:
Expand All @@ -503,66 +523,6 @@ async def acall(
# if no route passes threshold, return empty route choice
return RouteChoice()

# TODO: add multiple routes return to __call__ and acall
@deprecated("This method is deprecated. Use `__call__` instead.")
def retrieve_multiple_routes(
self,
text: Optional[str] = None,
vector: Optional[List[float] | np.ndarray] = None,
) -> List[RouteChoice]:
if vector is None:
if text is None:
raise ValueError("Either text or vector must be provided")
vector = self._encode(text=[text])
# convert to numpy array if not already
vector = xq_reshape(vector)
# get relevant utterances
results = self._retrieve(xq=vector)
# decide most relevant routes
categories_with_scores = self._semantic_classify_multiple_routes(results)
return [
RouteChoice(name=category, similarity_score=score)
for category, score in categories_with_scores
]

# route_choices = []
# TODO JB: do we need this check? Maybe we should be returning directly
# for category, score in categories_with_scores:
# route = self.check_for_matching_routes(category)
# if route:
# route_choice = RouteChoice(name=route.name, similarity_score=score)
# route_choices.append(route_choice)

# return route_choices

def _retrieve_top_route(
self, vector: np.ndarray, route_filter: Optional[List[str]] = None
) -> Tuple[Optional[Route], List[float]]:
"""
Retrieve the top matching route based on the given vector.
Returns a tuple of the route (if any) and the scores of the top class.
"""
# get relevant results (scores and routes)
results = self._retrieve(xq=vector, top_k=self.top_k, route_filter=route_filter)
# decide most relevant routes
top_class, top_class_scores = self._semantic_classify(results)
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
return route, top_class_scores

async def _async_retrieve_top_route(
self, vector: np.ndarray, route_filter: Optional[List[str]] = None
) -> Tuple[Optional[Route], List[float]]:
# get relevant results (scores and routes)
results = await self._async_retrieve(
xq=vector, top_k=self.top_k, route_filter=route_filter
)
# decide most relevant routes
top_class, top_class_scores = await self._async_semantic_classify(results)
# TODO do we need this check?
route = self.check_for_matching_routes(top_class)
return route, top_class_scores

def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]:
"""Runs a sync of the local routes with the remote index.
Expand Down Expand Up @@ -1116,26 +1076,6 @@ async def _async_encode(self, text: list[str]) -> Any:
# TODO: should encode "content" rather than text
raise NotImplementedError("This method should be implemented by subclasses.")

def _retrieve(
self, xq: Any, top_k: int = 5, route_filter: Optional[List[str]] = None
) -> List[Dict]:
"""Given a query vector, retrieve the top_k most similar records."""
# get scores and routes
scores, routes = self.index.query(
vector=xq[0], top_k=top_k, route_filter=route_filter
)
return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]

async def _async_retrieve(
self, xq: Any, top_k: int = 5, route_filter: Optional[List[str]] = None
) -> List[Dict]:
"""Given a query vector, retrieve the top_k most similar records."""
# get scores and routes
scores, routes = await self.index.aquery(
vector=xq[0], top_k=top_k, route_filter=route_filter
)
return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]

def _set_aggregation_method(self, aggregation: str = "sum"):
# TODO is this really needed?
if aggregation == "sum":
Expand All @@ -1149,6 +1089,7 @@ def _set_aggregation_method(self, aggregation: str = "sum"):
f"Unsupported aggregation method chosen: {aggregation}. Choose either 'SUM', 'MEAN', or 'MAX'."
)

# TODO JB allow return of multiple routes
def _semantic_classify(self, query_results: List[Dict]) -> Tuple[str, List[float]]:
"""Classify the query results into a single class based on the highest total score.
If no classification is found, return an empty string and an empty list.
Expand Down Expand Up @@ -1216,6 +1157,7 @@ def get(self, name: str) -> Optional[Route]:
logger.error(f"Route `{name}` not found")
return None

@deprecated("This method is deprecated. Use `semantic_classify` instead.")
def _semantic_classify_multiple_routes(
self, query_results: List[Dict]
) -> List[Tuple[str, float]]:
Expand Down Expand Up @@ -1243,6 +1185,7 @@ def group_scores_by_class(
self, query_results: List[Dict]
) -> Dict[str, List[float]]:
scores_by_class: Dict[str, List[float]] = {}
logger.warning(f"JBTEMP: {query_results=}")
for result in query_results:
score = result["score"]
route = result["route"]
Expand Down
13 changes: 8 additions & 5 deletions semantic_router/routers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def add(self, routes: List[Route] | Route):
if current_remote_hash.value == "":
# if remote hash is empty, the index is to be initialized
current_remote_hash = current_local_hash
logger.warning(f"JBTEMP: {routes}")
if isinstance(routes, Route):
routes = [routes]
# create embeddings for all routes
Expand Down Expand Up @@ -220,16 +221,18 @@ def __call__(
raise ValueError("Sparse vector is required for HybridLocalIndex.")
# TODO: add alpha as a parameter
scores, route_names = self.index.query(
vector=vector,
vector=vector[0],
top_k=self.top_k,
route_filter=route_filter,
sparse_vector=sparse_vector,
)
query_results = [
{"route": d, "score": s.item()} for d, s in zip(route_names, scores)
]
# TODO JB we should probably make _semantic_classify consume arrays rather than
# needing to convert to list here
top_class, top_class_scores = self._semantic_classify(
[
{"score": score, "route": route}
for score, route in zip(scores, route_names)
]
query_results=query_results
)
passed = self._pass_threshold(top_class_scores, self.score_threshold)
if passed:
Expand Down
Loading

0 comments on commit ec4c216

Please sign in to comment.