Skip to content

[InferenceClient] Add third-party providers support #2757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jan 23, 2025

Conversation

hanouticelina
Copy link
Contributor

@hanouticelina hanouticelina commented Jan 17, 2025

Following huggingface.js#1077 and moon-landing#12072, this PR adds 3rd party inference providers support into huggingface_hub.InferenceClient.

This v0 adds third-party inference provider support in a modular way. Each provider code lives in its own self-contained file under src/huggingface/inference/_providers/ to make it easier for us to add or update a provider. Similarly, in a future PR, we probably should isolate the Inference API specific code and keep InferenceClient as generic as possible.

Note: For fal.ai, we currently call the blocking API endpoint which has a 60s timeout limit, The same applies to Replicate. This limits the models we can use with these providers. In a future PR, we could add continuous polling support to use non-blocking API endpoints, enabling support for longer-running models.

TODO:

  • Add proxy calls to 3rd party providers. (in a following PR)
  • Add (VCR) tests.
  • Update Inference documentation. (in a following PR)

Copy link
Contributor

@SBrandeis SBrandeis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's goooo

payload["json"].update(
{
"model": model,
"response_format": "base64",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀 👀

Copy link
Member

@julien-c julien-c left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat 🔥

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice seen this taking shape! 🔥 I've started to review the PR and have 2 main comments:

  1. I think InferenceAPI should be considered as a provider to factorize things as much as possible (and avoid the "if provider is not None: ..."
  2. Using classes and inheritance might be avoided (since we don't use any inheritance benefit). BaseProvider should be more like an interface rather than a class.

You'll find my comments below. Hope I did not go too far into overthinking 🙈 Prefer to think this thoroughly before publishing :)

Comment on lines 5 to 10
@dataclass
class BaseProvider:
"""Base class defining the interface for inference providers."""

BASE_URL: str = field(init=False)
MODEL_IDS_MAPPING: Dict[str, str] = field(default_factory=dict, init=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that in the current structure, BaseProvider been a class/dataclass is clunky and doesn't add much value compared to a base module. The fact that it's a class is not really used (could be a singleton since it's always instantiated with PROVIDERS[name]()). And the fact that it's a dataclass either since __repr__ will likely be unusable (MODEL_IDS_MAPPING) is too large) and other dataclass benefits are not used (no comparisons, etc.). On the contrary it brings extra complexity in the code like MODEL_IDS_MAPPING: Dict[str, Dict[str, str]] = field(default_factory=lambda: {...}

2 solutions I see here:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Either use a Protocol:
# __init__.py
from typing import ...
from . import replicate, together, sambanova, fal_ai

class Provider(Protocol):
    """Protocol defining the interface for inference providers."""

    BASE_URL: str
    MODEL_IDS_MAPPING: Dict[str, Dict[str, str]]

    def build_url(self, task: Optional[str] = None, model: Optional[str] = None) -> str: ...
    def map_model(self, task: Optional[str] = None, model: Optional[str] = None) -> str: ...
    def prepare_headers(self, headers: Dict, task: Optional[str] = None, model: Optional[str] = None) -> Dict: ...
    def prepare_payload(self, input: str, parameters: Dict[str, Any], task: Optional[str] = None, model: Optional[str] = None) -> Dict[str, Any]: ...
    def get_response(self, response: Union[bytes, Dict], task: Optional[str] = None) -> Any: ...

PROVIDERS: Dict[str, Provider] = {
    "fal-ai": fal_ai,
    "together": togerther,
    "sambanova": sambanova,
    "replicate": replicate,
}

...
# replicate.py
BASE_URL = "https://api.replicate.com"
MODEL_IDS_MAPPING: Dict[str, str] = {
    "text-to-image": {
        "black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
        "ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step:5599ed30703defd1d160a25a63321b4dec97101d98b4674bcc56e41f62f35637",
    },
}

# no need for "self"
def build_url(task: Optional[str] = None, model: Optional[str] = None) -> str:
    if model is not None and ":" in model:
        return f"{self.BASE_URL}/v1/predictions"
    return f"{self.BASE_URL}/v1/models/{model}/predictions"

...

Type annotations are still happy and on a maintenance side, that's less indentations, no unused self attribute, no need for __base__.py, no field(default_factory=dict ..., etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. either use classes but instead of having them on Provider without parameters/attributes, we could have them at a task/model level. So it would be more of a ProviderTaskHelper (or something like this). My reasoning is that all methods build_url, map_model, prepare_payload, etc. heavily depends on task/model so we're always passing it to each method.
def get_provider_helper(provider: str, task: str, model: Optional[str] = None) -> ProviderHelper:
    """Get provider instance by name."""
    if provider not supported:
        raise ValueError(...)
    if task not supported by provider:
        raise ValueError(...)
    if model not supported by provider:
        raise ValueError(...)
    return ...

I feel that with complexity growing (e.g. more providers, more tasks), relying on if task == "...": in the code will start to be more and more complex to maintain. Having 1 class per provider per task will make it more readable and self-container (IMO).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also have a mix of 1. and 2. (not 100% sure though) like this:

class TaskProviderHelper(Protocol):
    def build_url(self, model: Optional[str] = None) -> str: ...
    def map_model(self, model: Optional[str] = None) -> str: ...
    def prepare_headers(self, headers: Dict) -> Dict: ...
    def prepare_payload(self, input: str, parameters: Dict[str, Any]) -> Dict[str, Any]: ...
    def get_response(self, response: Union[bytes, Dict]) -> Any: ...

and a folder structure like this:

from .replicate import text_to_image as replicate_text_to_image
from .together import conversational as together_conversational
from .together import text_to_image as together_text_to_image
(...)

PROVIDERS = {
    "replicate": {
        "text_to_image": replicate_text_to_image,
    },
    "together": {
        "conversational": together_conversational,
        "text_to_image": together_text_to_image,
    }

def get_provider_helper(provider: str, task: str) -> TaskProviderHelper:
    return PROVIDERS[provider][task] # with more checks ofc

Side effect: get_response could be correctly type annotated with the expected output for the given task.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, agree. to be honest, I was overthinking too much about this part, I will revert back to using a Protocol instead

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great! I've left a few comments, mostly related to how headers are handled (important to have them thread-safe) and file structure

Comment on lines +281 to +283
The task to perform on the inference. if you are passing a provider, `task` is required.
Verify which tasks are supported by the provider.For `hf-inference`, all available tasks
can be found [here](https://huggingface.co/tasks).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The task to perform on the inference. if you are passing a provider, `task` is required.
Verify which tasks are supported by the provider.For `hf-inference`, all available tasks
can be found [here](https://huggingface.co/tasks).
The task to perform on the inference. if you are passing a provider, `task` is required.
Verify which tasks are supported by the provider.
Available tasks can be found [here](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-tasks).

(TODO in subsequent PR: extend https://huggingface.co/docs/huggingface_hub/guides/inference#supported-tasks to document tasks per providers)

Wauplin and others added 5 commits January 22, 2025 16:42
Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except for #2757 (comment) for which I feel a bit strongly and a few tests, I think we are close to be able to merge this PR.

As discussed offline, we'll have to take care about a few things:

  • replace prepare_headers / prepare_payload / build_url by a unique prepare_request
  • revert providers from module-based to class-based (same as "hf-inference")
  • add documentation (examples with providers + maintain a provider <> tasks table)
  • ASR parameters (+ likely T2I / TTS as well?)
  • implement proxy-ed calls (+ make sure we never leak HF token to another provider)
  • revamp VCR tests => server-side caching instead

All of this can be done in subsequent PRs. This PR is already big enough like this 😄

Thanks again for coordinating all this @hanouticelina ! It takes InferenceClient to a whole new dimension 🚀

@hanouticelina hanouticelina marked this pull request as ready for review January 22, 2025 18:00
@hanouticelina hanouticelina requested a review from Wauplin January 22, 2025 19:54
Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

@hanouticelina hanouticelina merged commit 826f654 into main Jan 23, 2025
16 of 17 checks passed
@hanouticelina hanouticelina deleted the inference-providers-compatibility branch January 23, 2025 10:34
@Wauplin
Copy link
Contributor

Wauplin commented Jan 23, 2025

🎉

@julien-c
Copy link
Member

🤯 🤯

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants