Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions runpod/api/ctl_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,28 @@ def create_endpoint(
workers_min: int = 0,
workers_max: int = 3,
flashboot=False,
allowed_cuda_versions: str = "12.1,12.2,12.3,12.4,12.5",
allowed_cuda_versions: list = None,
gpu_count: int = 1,
):
"""
Create an endpoint

:param allowed_cuda_versions: Comma-separated string of allowed CUDA versions (e.g., "12.4,12.5").
:param name: the name of the endpoint
:param template_id: the id of the template to use for the endpoint
:param gpu_ids: the ids of the GPUs to use for the endpoint
:param network_volume_id: the id of the network volume to use for the endpoint
:param locations: the locations to use for the endpoint
:param idle_timeout: the idle timeout for the endpoint
:param scaler_type: the scaler type for the endpoint
:param scaler_value: the scaler value for the endpoint
:param workers_min: the minimum number of workers for the endpoint
:param workers_max: the maximum number of workers for the endpoint
:param allowed_cuda_versions: Comma-separated list of allowed CUDA versions (e.g., ["12.4", "12.5"]).
:param gpu_count: the number of GPUs to use for the endpoint

:example:

>>> endpoint_id = runpod.create_endpoint("test", "template_id")
"""
raw_response = run_graphql_query(
endpoint_mutations.generate_endpoint_mutation(
Expand Down
9 changes: 5 additions & 4 deletions runpod/api/mutations/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def generate_endpoint_mutation(
workers_min: int = 0,
workers_max: int = 3,
flashboot=False,
allowed_cuda_versions: str = "12.1,12.2,12.3,12.4,12.5",
allowed_cuda_versions: list = None,
gpu_count: int = None,
):
"""Generate a string for a GraphQL mutation to create a new endpoint."""
Expand Down Expand Up @@ -46,9 +46,10 @@ def generate_endpoint_mutation(
input_fields.append(f"workersMin: {workers_min}")
input_fields.append(f"workersMax: {workers_max}")

if allowed_cuda_versions is not None:
input_fields.append(f'allowedCudaVersions: "{allowed_cuda_versions}"')

if allowed_cuda_versions:
cuda_versions = ", ".join(f'"{v}"' for v in allowed_cuda_versions)
input_fields.append(f"allowedCudaVersions: [{cuda_versions}]")

if gpu_count is not None:
input_fields.append(f"gpuCount: {gpu_count}")

Expand Down
Loading