diff --git a/examples/evaluate_benchmark_with_custom_api.py b/examples/evaluate_benchmark_with_custom_provider.py similarity index 77% rename from examples/evaluate_benchmark_with_custom_api.py rename to examples/evaluate_benchmark_with_custom_provider.py index 379c6f605..371f97f51 100644 --- a/examples/evaluate_benchmark_with_custom_api.py +++ b/examples/evaluate_benchmark_with_custom_provider.py @@ -1,5 +1,5 @@ from unitxt import evaluate, load_dataset -from unitxt.inference import MultiAPIInferenceEngine +from unitxt.inference import CrossProviderModel from unitxt.text_utils import print_dict data = load_dataset( @@ -8,8 +8,8 @@ disable_cache=False, ) -model = MultiAPIInferenceEngine( - model="llama-3-8b-instruct", temperature=0.0, top_p=1.0, api="watsonx" +model = CrossProviderModel( + model="llama-3-8b-instruct", temperature=0.0, top_p=1.0, provider="watsonx" ) predictions = model.infer(data) diff --git a/prepare/engines/multi_api/llama3.py b/prepare/engines/multi_api/llama3.py index 8ebaa4adf..8b3ee4494 100644 --- a/prepare/engines/multi_api/llama3.py +++ b/prepare/engines/multi_api/llama3.py @@ -1,16 +1,8 @@ from unitxt.catalog import add_to_catalog -from unitxt.inference import MultiAPIInferenceEngine +from unitxt.inference import CrossProviderModel -engine = MultiAPIInferenceEngine( +engine = CrossProviderModel( model="llama-3-8b-instruct", - api_model_map={ - "watsonx": { - "llama-3-8b-instruct": "watsonx/meta-llama/llama-3-8b-instruct", - }, - "together-ai": { - "llama-3-8b-instruct": "together_ai/togethercomputer/llama-3-8b-instruct" - }, - }, ) add_to_catalog(engine, "engines.model.llama_3_8b_instruct", overwrite=True) diff --git a/src/unitxt/catalog/engines/model/llama_3_8b_instruct.json b/src/unitxt/catalog/engines/model/llama_3_8b_instruct.json index a6c2be46c..ab9eee536 100644 --- a/src/unitxt/catalog/engines/model/llama_3_8b_instruct.json +++ b/src/unitxt/catalog/engines/model/llama_3_8b_instruct.json @@ -1,12 +1,4 @@ { - "__type__": "multi_api_inference_engine", - "model": "llama-3-8b-instruct", - "api_model_map": { - "watsonx": { - "llama-3-8b-instruct": "watsonx/meta-llama/llama-3-8b-instruct" - }, - "together-ai": { - "llama-3-8b-instruct": "together_ai/togethercomputer/llama-3-8b-instruct" - } - } + "__type__": "cross_provider_model", + "model": "llama-3-8b-instruct" } diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 5acb50d01..15a308c6f 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -1669,8 +1669,8 @@ def _infer( _supported_apis = Literal["watsonx", "together-ai", "open-ai", "aws", "ollama"] -class MultiAPIInferenceEngine(InferenceEngine, StandardAPIParamsMixin): - """Inference engine capable of dynamically switching between multiple APIs. +class CrossProviderModel(InferenceEngine, StandardAPIParamsMixin): + """Inference engine capable of dynamically switching between multiple providers APIs. This class extends the InferenceEngine and OpenAiInferenceEngineParamsMixin to enable seamless integration with various API providers. The supported APIs are @@ -1687,9 +1687,9 @@ class MultiAPIInferenceEngine(InferenceEngine, StandardAPIParamsMixin): across different API backends. """ - api: Optional[_supported_apis] = None + provider: Optional[_supported_apis] = None - api_model_map: Dict[_supported_apis, Dict[str, str]] = { + provider_model_map: Dict[_supported_apis, Dict[str, str]] = { "watsonx": { "llama-3-8b-instruct": "watsonx/meta-llama/llama-3-8b-instruct", "llama-3-70b-instruct": "watsonx/meta-llama/llama-3-70b-instruct", @@ -1708,7 +1708,7 @@ class MultiAPIInferenceEngine(InferenceEngine, StandardAPIParamsMixin): }, } - _api_to_base_class = { + _provider_to_base_class = { "watsonx": LiteLLMInferenceEngine, "open-ai": LiteLLMInferenceEngine, "together-ai": LiteLLMInferenceEngine, @@ -1716,14 +1716,14 @@ class MultiAPIInferenceEngine(InferenceEngine, StandardAPIParamsMixin): "ollama": OllamaInferenceEngine, } - def get_api_name(self): - return self.api if self.api is not None else settings.default_inference_api + def get_provider_name(self): + return self.provider if self.provider is not None else settings.default_provider def prepare_engine(self): - api = self.get_api_name() - cls = self.__class__._api_to_base_class[api] + provider = self.get_provider_name() + cls = self.__class__._provider_to_base_class[provider] args = self.to_dict([StandardAPIParamsMixin]) - args["model"] = self.api_model_map[api][self.model] + args["model"] = self.provider_model_map[provider][self.model] self.engine = cls(**args) def _infer( @@ -1734,8 +1734,8 @@ def _infer( return self.engine._infer(dataset, return_meta_data) def get_engine_id(self): - api = self.get_api_name() - return get_model_and_label_id(self.api_model_map[api][self.model], api) + api = self.get_provider_name() + return get_model_and_label_id(self.provider_model_map[api][self.model], api) class HFOptionSelectingInferenceEngine(InferenceEngine): diff --git a/src/unitxt/settings_utils.py b/src/unitxt/settings_utils.py index 9a03cf81e..a95cacfe3 100644 --- a/src/unitxt/settings_utils.py +++ b/src/unitxt/settings_utils.py @@ -152,7 +152,7 @@ def __getattr__(self, key): settings.disable_hf_datasets_cache = (bool, True) settings.loader_cache_size = (int, 1) settings.task_data_as_text = (bool, True) - settings.default_inference_api = "watsonx" + settings.default_provider = "watsonx" settings.default_format = None if Constants.is_uninitilized():