Skip to content

Commit

Permalink
Change api to provider
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <elronbandel@gmail.com>
  • Loading branch information
elronbandel committed Nov 18, 2024
1 parent bd8e176 commit b686f95
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unitxt import evaluate, load_dataset
from unitxt.inference import MultiAPIInferenceEngine
from unitxt.inference import CrossProviderModel
from unitxt.text_utils import print_dict

data = load_dataset(
Expand All @@ -8,8 +8,8 @@
disable_cache=False,
)

model = MultiAPIInferenceEngine(
model="llama-3-8b-instruct", temperature=0.0, top_p=1.0, api="watsonx"
model = CrossProviderModel(
model="llama-3-8b-instruct", temperature=0.0, top_p=1.0, provider="watsonx"
)

predictions = model.infer(data)
Expand Down
12 changes: 2 additions & 10 deletions prepare/engines/multi_api/llama3.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
from unitxt.catalog import add_to_catalog
from unitxt.inference import MultiAPIInferenceEngine
from unitxt.inference import CrossProviderModel

engine = MultiAPIInferenceEngine(
engine = CrossProviderModel(
model="llama-3-8b-instruct",
api_model_map={
"watsonx": {
"llama-3-8b-instruct": "watsonx/meta-llama/llama-3-8b-instruct",
},
"together-ai": {
"llama-3-8b-instruct": "together_ai/togethercomputer/llama-3-8b-instruct"
},
},
)

add_to_catalog(engine, "engines.model.llama_3_8b_instruct", overwrite=True)
12 changes: 2 additions & 10 deletions src/unitxt/catalog/engines/model/llama_3_8b_instruct.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
{
"__type__": "multi_api_inference_engine",
"model": "llama-3-8b-instruct",
"api_model_map": {
"watsonx": {
"llama-3-8b-instruct": "watsonx/meta-llama/llama-3-8b-instruct"
},
"together-ai": {
"llama-3-8b-instruct": "together_ai/togethercomputer/llama-3-8b-instruct"
}
}
"__type__": "cross_provider_model",
"model": "llama-3-8b-instruct"
}
24 changes: 12 additions & 12 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,8 +1669,8 @@ def _infer(
_supported_apis = Literal["watsonx", "together-ai", "open-ai", "aws", "ollama"]


class MultiAPIInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
"""Inference engine capable of dynamically switching between multiple APIs.
class CrossProviderModel(InferenceEngine, StandardAPIParamsMixin):
"""Inference engine capable of dynamically switching between multiple providers APIs.
This class extends the InferenceEngine and OpenAiInferenceEngineParamsMixin
to enable seamless integration with various API providers. The supported APIs are
Expand All @@ -1687,9 +1687,9 @@ class MultiAPIInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
across different API backends.
"""

api: Optional[_supported_apis] = None
provider: Optional[_supported_apis] = None

api_model_map: Dict[_supported_apis, Dict[str, str]] = {
provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
"watsonx": {
"llama-3-8b-instruct": "watsonx/meta-llama/llama-3-8b-instruct",
"llama-3-70b-instruct": "watsonx/meta-llama/llama-3-70b-instruct",
Expand All @@ -1708,22 +1708,22 @@ class MultiAPIInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
},
}

_api_to_base_class = {
_provider_to_base_class = {
"watsonx": LiteLLMInferenceEngine,
"open-ai": LiteLLMInferenceEngine,
"together-ai": LiteLLMInferenceEngine,
"aws": LiteLLMInferenceEngine,
"ollama": OllamaInferenceEngine,
}

def get_api_name(self):
return self.api if self.api is not None else settings.default_inference_api
def get_provider_name(self):
return self.provider if self.provider is not None else settings.default_provider

def prepare_engine(self):
api = self.get_api_name()
cls = self.__class__._api_to_base_class[api]
provider = self.get_provider_name()
cls = self.__class__._provider_to_base_class[provider]
args = self.to_dict([StandardAPIParamsMixin])
args["model"] = self.api_model_map[api][self.model]
args["model"] = self.provider_model_map[provider][self.model]
self.engine = cls(**args)

def _infer(
Expand All @@ -1734,8 +1734,8 @@ def _infer(
return self.engine._infer(dataset, return_meta_data)

def get_engine_id(self):
api = self.get_api_name()
return get_model_and_label_id(self.api_model_map[api][self.model], api)
api = self.get_provider_name()
return get_model_and_label_id(self.provider_model_map[api][self.model], api)


class HFOptionSelectingInferenceEngine(InferenceEngine):
Expand Down
2 changes: 1 addition & 1 deletion src/unitxt/settings_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __getattr__(self, key):
settings.disable_hf_datasets_cache = (bool, True)
settings.loader_cache_size = (int, 1)
settings.task_data_as_text = (bool, True)
settings.default_inference_api = "watsonx"
settings.default_provider = "watsonx"
settings.default_format = None

if Constants.is_uninitilized():
Expand Down

0 comments on commit b686f95

Please sign in to comment.