-
Notifications
You must be signed in to change notification settings - Fork 744
[InferenceClient] Add third-party providers support #2757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's goooo
payload["json"].update( | ||
{ | ||
"model": model, | ||
"response_format": "base64", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👀 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neat 🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice seen this taking shape! 🔥 I've started to review the PR and have 2 main comments:
- I think InferenceAPI should be considered as a provider to factorize things as much as possible (and avoid the "if provider is not None: ..."
- Using classes and inheritance might be avoided (since we don't use any inheritance benefit). BaseProvider should be more like an interface rather than a class.
You'll find my comments below. Hope I did not go too far into overthinking 🙈 Prefer to think this thoroughly before publishing :)
@dataclass | ||
class BaseProvider: | ||
"""Base class defining the interface for inference providers.""" | ||
|
||
BASE_URL: str = field(init=False) | ||
MODEL_IDS_MAPPING: Dict[str, str] = field(default_factory=dict, init=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that in the current structure, BaseProvider
been a class/dataclass is clunky and doesn't add much value compared to a base module. The fact that it's a class is not really used (could be a singleton since it's always instantiated with PROVIDERS[name]()
). And the fact that it's a dataclass either since __repr__
will likely be unusable (MODEL_IDS_MAPPING) is too large) and other dataclass benefits are not used (no comparisons, etc.). On the contrary it brings extra complexity in the code like MODEL_IDS_MAPPING: Dict[str, Dict[str, str]] = field(default_factory=lambda: {...}
2 solutions I see here:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Either use a Protocol:
# __init__.py
from typing import ...
from . import replicate, together, sambanova, fal_ai
class Provider(Protocol):
"""Protocol defining the interface for inference providers."""
BASE_URL: str
MODEL_IDS_MAPPING: Dict[str, Dict[str, str]]
def build_url(self, task: Optional[str] = None, model: Optional[str] = None) -> str: ...
def map_model(self, task: Optional[str] = None, model: Optional[str] = None) -> str: ...
def prepare_headers(self, headers: Dict, task: Optional[str] = None, model: Optional[str] = None) -> Dict: ...
def prepare_payload(self, input: str, parameters: Dict[str, Any], task: Optional[str] = None, model: Optional[str] = None) -> Dict[str, Any]: ...
def get_response(self, response: Union[bytes, Dict], task: Optional[str] = None) -> Any: ...
PROVIDERS: Dict[str, Provider] = {
"fal-ai": fal_ai,
"together": togerther,
"sambanova": sambanova,
"replicate": replicate,
}
...
# replicate.py
BASE_URL = "https://api.replicate.com"
MODEL_IDS_MAPPING: Dict[str, str] = {
"text-to-image": {
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
},
}
# no need for "self"
def build_url(task: Optional[str] = None, model: Optional[str] = None) -> str:
if model is not None and ":" in model:
return f"{self.BASE_URL}/v1/predictions"
return f"{self.BASE_URL}/v1/models/{model}/predictions"
...
Type annotations are still happy and on a maintenance side, that's less indentations, no unused self
attribute, no need for __base__.py
, no field(default_factory=dict ...
, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- either use classes but instead of having them on Provider without parameters/attributes, we could have them at a task/model level. So it would be more of a
ProviderTaskHelper
(or something like this). My reasoning is that all methodsbuild_url
,map_model
,prepare_payload
, etc. heavily depends on task/model so we're always passing it to each method.
def get_provider_helper(provider: str, task: str, model: Optional[str] = None) -> ProviderHelper:
"""Get provider instance by name."""
if provider not supported:
raise ValueError(...)
if task not supported by provider:
raise ValueError(...)
if model not supported by provider:
raise ValueError(...)
return ...
I feel that with complexity growing (e.g. more providers, more tasks), relying on if task == "...":
in the code will start to be more and more complex to maintain. Having 1 class per provider per task will make it more readable and self-container (IMO).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also have a mix of 1. and 2. (not 100% sure though) like this:
class TaskProviderHelper(Protocol):
def build_url(self, model: Optional[str] = None) -> str: ...
def map_model(self, model: Optional[str] = None) -> str: ...
def prepare_headers(self, headers: Dict) -> Dict: ...
def prepare_payload(self, input: str, parameters: Dict[str, Any]) -> Dict[str, Any]: ...
def get_response(self, response: Union[bytes, Dict]) -> Any: ...
and a folder structure like this:
from .replicate import text_to_image as replicate_text_to_image
from .together import conversational as together_conversational
from .together import text_to_image as together_text_to_image
(...)
PROVIDERS = {
"replicate": {
"text_to_image": replicate_text_to_image,
},
"together": {
"conversational": together_conversational,
"text_to_image": together_text_to_image,
}
def get_provider_helper(provider: str, task: str) -> TaskProviderHelper:
return PROVIDERS[provider][task] # with more checks ofc
Side effect: get_response
could be correctly type annotated with the expected output for the given task.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, agree. to be honest, I was overthinking too much about this part, I will revert back to using a Protocol
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great! I've left a few comments, mostly related to how headers are handled (important to have them thread-safe) and file structure
src/huggingface_hub/inference/_providers/fal_ai/text_to_image.py
Outdated
Show resolved
Hide resolved
src/huggingface_hub/inference/_providers/fal_ai/text_to_image.py
Outdated
Show resolved
Hide resolved
src/huggingface_hub/inference/_providers/hf_inference/_common.py
Outdated
Show resolved
Hide resolved
src/huggingface_hub/inference/_providers/hf_inference/text_to_image.py
Outdated
Show resolved
Hide resolved
The task to perform on the inference. if you are passing a provider, `task` is required. | ||
Verify which tasks are supported by the provider.For `hf-inference`, all available tasks | ||
can be found [here](https://huggingface.co/tasks). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The task to perform on the inference. if you are passing a provider, `task` is required. | |
Verify which tasks are supported by the provider.For `hf-inference`, all available tasks | |
can be found [here](https://huggingface.co/tasks). | |
The task to perform on the inference. if you are passing a provider, `task` is required. | |
Verify which tasks are supported by the provider. | |
Available tasks can be found [here](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-tasks). |
(TODO in subsequent PR: extend https://huggingface.co/docs/huggingface_hub/guides/inference#supported-tasks to document tasks per providers)
Co-authored-by: Lucain <lucain@huggingface.co>
…gface/huggingface_hub into inference-providers-compatibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Except for #2757 (comment) for which I feel a bit strongly and a few tests, I think we are close to be able to merge this PR.
As discussed offline, we'll have to take care about a few things:
- replace
prepare_headers
/prepare_payload
/build_url
by a uniqueprepare_request
- revert providers from module-based to class-based (same as
"hf-inference"
) - add documentation (examples with providers + maintain a provider <> tasks table)
- ASR parameters (+ likely T2I / TTS as well?)
- implement proxy-ed calls (+ make sure we never leak HF token to another provider)
- revamp VCR tests => server-side caching instead
All of this can be done in subsequent PRs. This PR is already big enough like this 😄
Thanks again for coordinating all this @hanouticelina ! It takes InferenceClient
to a whole new dimension 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
🎉 |
🤯 🤯 |
Following huggingface.js#1077 and moon-landing#12072, this PR adds 3rd party inference providers support into
huggingface_hub.InferenceClient
.This v0 adds third-party inference provider support in a modular way. Each provider code lives in its own self-contained file under
src/huggingface/inference/_providers/
to make it easier for us to add or update a provider. Similarly, in a future PR, we probably should isolate the Inference API specific code and keepInferenceClient
as generic as possible.Note: For fal.ai, we currently call the blocking API endpoint which has a 60s timeout limit, The same applies to Replicate. This limits the models we can use with these providers. In a future PR, we could add continuous polling support to use non-blocking API endpoints, enabling support for longer-running models.
TODO: