diff --git a/semantic_router/layer.py b/semantic_router/layer.py index bed1ef28..158c12be 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -435,8 +435,40 @@ def add(self, route: Route): def list_route_names(self) -> List[str]: return [route.name for route in self.routes] - def update(self, route_name: str, utterances: List[str]): - raise NotImplementedError("This method has not yet been implemented.") + def update( + self, + name: str, + threshold: Optional[float] = None, + utterances: Optional[List[str]] = None, + ): + """Updates the route specified in name. Allows the update of + threshold and/or utterances. If no values are provided via the + threshold or utterances parameters, those fields are not updated. + If neither field is provided raises a ValueError. + + The name must exist within the local RouteLayer, if not a + KeyError will be raised. + """ + + if threshold is None and utterances is None: + raise ValueError( + "At least one of 'threshold' or 'utterances' must be provided." + ) + if utterances: + raise NotImplementedError( + "The update method cannot be used for updating utterances yet." + ) + + route = self.get(name) + if route: + if threshold: + old_threshold = route.score_threshold + route.score_threshold = threshold + logger.info( + f"Updated threshold for route '{route.name}' from {old_threshold} to {threshold}" + ) + else: + raise ValueError(f"Route '{name}' not found. Nothing updated.") def delete(self, route_name: str): """Deletes a route given a specific route name. diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 12773995..3bfcd485 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -811,6 +811,45 @@ def test_refresh_routes_not_implemented(self, openai_encoder, routes, index_cls) ): route_layer._refresh_routes() + def test_update_threshold(self, openai_encoder, routes, index_cls): + index = init_index(index_cls) + route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index) + route_name = "Route 1" + new_threshold = 0.8 + route_layer.update(name=route_name, threshold=new_threshold) + updated_route = route_layer.get(route_name) + assert ( + updated_route.score_threshold == new_threshold + ), f"Expected threshold to be updated to {new_threshold}, but got {updated_route.score_threshold}" + + def test_update_non_existent_route(self, openai_encoder, routes, index_cls): + index = init_index(index_cls) + route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index) + non_existent_route = "Non-existent Route" + with pytest.raises( + ValueError, + match=f"Route '{non_existent_route}' not found. Nothing updated.", + ): + route_layer.update(name=non_existent_route, threshold=0.7) + + def test_update_without_parameters(self, openai_encoder, routes, index_cls): + index = init_index(index_cls) + route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index) + with pytest.raises( + ValueError, + match="At least one of 'threshold' or 'utterances' must be provided.", + ): + route_layer.update(name="Route 1") + + def test_update_utterances_not_implemented(self, openai_encoder, routes, index_cls): + index = init_index(index_cls) + route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index) + with pytest.raises( + NotImplementedError, + match="The update method cannot be used for updating utterances yet.", + ): + route_layer.update(name="Route 1", utterances=["New utterance"]) + class TestLayerFit: def test_eval(self, openai_encoder, routes, test_data):