Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions comfy_api_nodes/apis/wavespeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pydantic import BaseModel, Field


class SeedVR2ImageRequest(BaseModel):
image: str = Field(...)
target_resolution: str = Field(...)
output_format: str = Field("png")
enable_sync_mode: bool = Field(False)


class FlashVSRRequest(BaseModel):
target_resolution: str = Field(...)
video: str = Field(...)
duration: float = Field(...)


class TaskCreatedDataResponse(BaseModel):
id: str = Field(...)


class TaskCreatedResponse(BaseModel):
code: int = Field(...)
message: str = Field(...)
data: TaskCreatedDataResponse | None = Field(None)


class TaskResultDataResponse(BaseModel):
status: str = Field(...)
outputs: list[str] = Field([])


class TaskResultResponse(BaseModel):
code: int = Field(...)
message: str = Field(...)
data: TaskResultDataResponse | None = Field(None)
178 changes: 178 additions & 0 deletions comfy_api_nodes/nodes_wavespeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from typing_extensions import override

from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.wavespeed import (
FlashVSRRequest,
TaskCreatedResponse,
TaskResultResponse,
SeedVR2ImageRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
poll_op,
sync_op,
upload_video_to_comfyapi,
validate_container_format_is_mp4,
validate_video_duration,
upload_images_to_comfyapi,
get_number_of_images,
download_url_to_image_tensor,
)


class WavespeedFlashVSRNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WavespeedFlashVSRNode",
display_name="FlashVSR Video Upscale",
category="api node/video/WaveSpeed",
description="Fast, high-quality video upscaler that "
"boosts resolution and restores clarity for low-resolution or blurry footage.",
inputs=[
IO.Video.Input("video"),
IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]),
expr="""
(
$price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032};
{
"type":"usd",
"usd": $lookup($price_for_1sec, widgets.target_resolution),
"format":{"suffix": "/second", "approximate": true}
}
)
""",
),
)

@classmethod
async def execute(
cls,
video: Input.Video,
target_resolution: str,
) -> IO.NodeOutput:
validate_container_format_is_mp4(video)
validate_video_duration(video, min_duration=5, max_duration=60 * 10)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"),
response_model=TaskCreatedResponse,
data=FlashVSRRequest(
target_resolution=target_resolution.lower(),
video=await upload_video_to_comfyapi(cls, video),
duration=video.get_duration(),
),
)
if initial_res.code != 200:
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
response_model=TaskResultResponse,
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
poll_interval=10.0,
max_poll_attempts=480,
)
if final_response.code != 200:
raise ValueError(
f"Task processing failed with code={final_response.code} and message={final_response.message}"
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0]))


class WavespeedImageUpscaleNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WavespeedImageUpscaleNode",
display_name="WaveSpeed Image Upscale",
category="api node/image/WaveSpeed",
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
inputs=[
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),
IO.Image.Input("image"),
IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$prices := {"seedvr2": 0.01, "ultimate": 0.06};
{"type":"usd", "usd": $lookup($prices, widgets.model)}
)
""",
),
)

@classmethod
async def execute(
cls,
model: str,
image: Input.Image,
target_resolution: str,
) -> IO.NodeOutput:
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
if model == "SeedVR2":
model_path = "seedvr2/image"
else:
model_path = "ultimate-image-upscaler"
initial_res = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"),
response_model=TaskCreatedResponse,
data=SeedVR2ImageRequest(
target_resolution=target_resolution.lower(),
image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0],
),
)
if initial_res.code != 200:
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
response_model=TaskResultResponse,
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
poll_interval=10.0,
max_poll_attempts=480,
)
if final_response.code != 200:
raise ValueError(
f"Task processing failed with code={final_response.code} and message={final_response.message}"
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0]))


class WavespeedExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
WavespeedFlashVSRNode,
WavespeedImageUpscaleNode,
]


async def comfy_entrypoint() -> WavespeedExtension:
return WavespeedExtension()