Skip to content

Commit f7f5585

Browse files
committed
feat: Add avatar-live and reference image support
- Add avatar-live realtime model with set_image() method - Add lucy-restyle-v2v video model with reference_image support - Add AvatarOptions for configuring avatar image on connect - Add VideoRestyleInput with mutual exclusivity validation - Add comprehensive tests for both features
1 parent 0eaaec7 commit f7f5585

File tree

12 files changed

+581
-19
lines changed

12 files changed

+581
-19
lines changed

decart/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
QueueResultError,
1313
TokenCreateError,
1414
)
15-
from .models import models, ModelDefinition
15+
from .models import models, ModelDefinition, VideoRestyleInput
1616
from .types import FileInput, ModelState, Prompt
1717
from .queue import (
1818
QueueClient,
@@ -31,6 +31,7 @@
3131
RealtimeClient,
3232
RealtimeConnectOptions,
3333
ConnectionState,
34+
AvatarOptions,
3435
)
3536

3637
REALTIME_AVAILABLE = True
@@ -39,6 +40,7 @@
3940
RealtimeClient = None # type: ignore
4041
RealtimeConnectOptions = None # type: ignore
4142
ConnectionState = None # type: ignore
43+
AvatarOptions = None # type: ignore
4244

4345
__version__ = "0.0.1"
4446

@@ -56,6 +58,7 @@
5658
"QueueResultError",
5759
"models",
5860
"ModelDefinition",
61+
"VideoRestyleInput",
5962
"FileInput",
6063
"ModelState",
6164
"Prompt",
@@ -75,5 +78,6 @@
7578
"RealtimeClient",
7679
"RealtimeConnectOptions",
7780
"ConnectionState",
81+
"AvatarOptions",
7882
]
7983
)

decart/models.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from typing import Literal, Optional, List, Generic, TypeVar
2-
from pydantic import BaseModel, Field, ConfigDict
2+
from pydantic import BaseModel, Field, ConfigDict, model_validator
33
from .errors import ModelNotFoundError
44
from .types import FileInput, MotionTrajectoryInput
55

66

7-
RealTimeModels = Literal["mirage", "mirage_v2", "lucy_v2v_720p_rt"]
7+
RealTimeModels = Literal["mirage", "mirage_v2", "lucy_v2v_720p_rt", "avatar-live"]
88
VideoModels = Literal[
99
"lucy-dev-i2v",
1010
"lucy-fast-v2v",
@@ -13,6 +13,7 @@
1313
"lucy-pro-v2v",
1414
"lucy-pro-flf2v",
1515
"lucy-motion",
16+
"lucy-restyle-v2v",
1617
]
1718
ImageModels = Literal["lucy-pro-t2i", "lucy-pro-i2i"]
1819
Model = Literal[RealTimeModels, VideoModels, ImageModels]
@@ -95,6 +96,38 @@ class ImageToMotionVideoInput(DecartBaseModel):
9596
resolution: Optional[str] = None
9697

9798

99+
class VideoRestyleInput(DecartBaseModel):
100+
"""Input for lucy-restyle-v2v model.
101+
102+
Must provide either `prompt` OR `reference_image`, but not both.
103+
`enhance_prompt` is only valid when using `prompt`, not `reference_image`.
104+
"""
105+
106+
prompt: Optional[str] = Field(default=None, min_length=1, max_length=1000)
107+
reference_image: Optional[FileInput] = None
108+
data: FileInput
109+
seed: Optional[int] = None
110+
resolution: Optional[str] = None
111+
enhance_prompt: Optional[bool] = None
112+
113+
@model_validator(mode="after")
114+
def validate_prompt_or_reference_image(self) -> "VideoRestyleInput":
115+
has_prompt = self.prompt is not None
116+
has_reference_image = self.reference_image is not None
117+
118+
if has_prompt == has_reference_image:
119+
raise ValueError(
120+
"Must provide either 'prompt' or 'reference_image', but not both"
121+
)
122+
123+
if has_reference_image and self.enhance_prompt is not None:
124+
raise ValueError(
125+
"'enhance_prompt' is only valid when using 'prompt', not 'reference_image'"
126+
)
127+
128+
return self
129+
130+
98131
class TextToImageInput(BaseModel):
99132
prompt: str = Field(
100133
...,
@@ -144,6 +177,14 @@ class ImageToImageInput(DecartBaseModel):
144177
height=704,
145178
input_schema=BaseModel,
146179
),
180+
"avatar-live": ModelDefinition(
181+
name="avatar-live",
182+
url_path="/v1/avatar-live/stream",
183+
fps=25,
184+
width=1280,
185+
height=720,
186+
input_schema=BaseModel,
187+
),
147188
},
148189
"video": {
149190
"lucy-dev-i2v": ModelDefinition(
@@ -202,6 +243,14 @@ class ImageToImageInput(DecartBaseModel):
202243
height=704,
203244
input_schema=ImageToMotionVideoInput,
204245
),
246+
"lucy-restyle-v2v": ModelDefinition(
247+
name="lucy-restyle-v2v",
248+
url_path="/v1/generate/lucy-restyle-v2v",
249+
fps=25,
250+
width=1280,
251+
height=704,
252+
input_schema=VideoRestyleInput,
253+
),
205254
},
206255
"image": {
207256
"lucy-pro-t2i": ModelDefinition(
@@ -247,6 +296,7 @@ def video(model: VideoModels) -> VideoModelDefinition:
247296
- "lucy-dev-i2v" - Image-to-video (Dev quality)
248297
- "lucy-fast-v2v" - Video-to-video (Fast quality)
249298
- "lucy-motion" - Image-to-motion-video
299+
- "lucy-restyle-v2v" - Video-to-video with prompt or reference image
250300
"""
251301
try:
252302
return _MODELS["video"][model] # type: ignore[return-value]

decart/queue/request.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ async def submit_job(
2424

2525
for key, value in inputs.items():
2626
if value is not None:
27-
if key in ("data", "start", "end"):
27+
if key in ("data", "start", "end", "reference_image"):
2828
content, content_type = await file_input_to_bytes(value, session)
2929
form_data.add_field(key, content, content_type=content_type)
3030
else:

decart/realtime/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .client import RealtimeClient
2-
from .types import RealtimeConnectOptions, ConnectionState
2+
from .types import RealtimeConnectOptions, ConnectionState, AvatarOptions
33

44
__all__ = [
55
"RealtimeClient",
66
"RealtimeConnectOptions",
77
"ConnectionState",
8+
"AvatarOptions",
89
]

decart/realtime/client.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,33 @@
11
from typing import Callable, Optional
22
import asyncio
3+
import base64
34
import logging
45
import uuid
6+
import aiohttp
57
from aiortc import MediaStreamTrack
68

79
from .webrtc_manager import WebRTCManager, WebRTCConfiguration
8-
from .messages import PromptMessage
10+
from .messages import PromptMessage, SetAvatarImageMessage
911
from .types import ConnectionState, RealtimeConnectOptions
12+
from ..types import FileInput
1013
from ..errors import DecartSDKError, InvalidInputError, WebRTCError
14+
from ..process.request import file_input_to_bytes
1115

1216
logger = logging.getLogger(__name__)
1317

1418

1519
class RealtimeClient:
16-
def __init__(self, manager: WebRTCManager, session_id: str):
20+
def __init__(
21+
self,
22+
manager: WebRTCManager,
23+
session_id: str,
24+
http_session: Optional[aiohttp.ClientSession] = None,
25+
is_avatar_live: bool = False,
26+
):
1727
self._manager = manager
1828
self.session_id = session_id
29+
self._http_session = http_session
30+
self._is_avatar_live = is_avatar_live
1931
self._connection_callbacks: list[Callable[[ConnectionState], None]] = []
2032
self._error_callbacks: list[Callable[[DecartSDKError], None]] = []
2133

@@ -24,14 +36,16 @@ async def connect(
2436
cls,
2537
base_url: str,
2638
api_key: str,
27-
local_track: MediaStreamTrack,
39+
local_track: Optional[MediaStreamTrack],
2840
options: RealtimeConnectOptions,
2941
integration: Optional[str] = None,
3042
) -> "RealtimeClient":
3143
session_id = str(uuid.uuid4())
3244
ws_url = f"{base_url}{options.model.url_path}"
3345
ws_url += f"?api_key={api_key}&model={options.model.name}"
3446

47+
is_avatar_live = options.model.name == "avatar-live"
48+
3549
config = WebRTCConfiguration(
3650
webrtc_url=ws_url,
3751
api_key=api_key,
@@ -43,16 +57,33 @@ async def connect(
4357
initial_state=options.initial_state,
4458
customize_offer=options.customize_offer,
4559
integration=integration,
60+
is_avatar_live=is_avatar_live,
4661
)
4762

63+
# Create HTTP session for file conversions
64+
http_session = aiohttp.ClientSession()
65+
4866
manager = WebRTCManager(config)
49-
client = cls(manager=manager, session_id=session_id)
67+
client = cls(
68+
manager=manager,
69+
session_id=session_id,
70+
http_session=http_session,
71+
is_avatar_live=is_avatar_live,
72+
)
5073

5174
config.on_connection_state_change = client._emit_connection_change
5275
config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error))
5376

5477
try:
55-
await manager.connect(local_track)
78+
# For avatar-live, convert and send avatar image before WebRTC connection
79+
avatar_image_base64: Optional[str] = None
80+
if is_avatar_live and options.avatar:
81+
image_bytes, _ = await file_input_to_bytes(
82+
options.avatar.avatar_image, http_session
83+
)
84+
avatar_image_base64 = base64.b64encode(image_bytes).decode("utf-8")
85+
86+
await manager.connect(local_track, avatar_image_base64=avatar_image_base64)
5687

5788
if options.initial_state:
5889
if options.initial_state.prompt:
@@ -61,6 +92,7 @@ async def connect(
6192
enrich=options.initial_state.prompt.enrich,
6293
)
6394
except Exception as e:
95+
await http_session.close()
6496
raise WebRTCError(str(e), cause=e)
6597

6698
return client
@@ -100,6 +132,47 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
100132
finally:
101133
self._manager.unregister_prompt_wait(prompt)
102134

135+
async def set_image(self, image: FileInput) -> None:
136+
"""Set or update the avatar image.
137+
138+
Only available for avatar-live model.
139+
140+
Args:
141+
image: The image to set. Can be bytes, Path, URL string, or file-like object.
142+
143+
Raises:
144+
InvalidInputError: If not using avatar-live model or image is invalid.
145+
DecartSDKError: If the server fails to acknowledge the image.
146+
"""
147+
if not self._is_avatar_live:
148+
raise InvalidInputError("set_image() is only available for avatar-live model")
149+
150+
if not self._http_session:
151+
raise InvalidInputError("HTTP session not available")
152+
153+
# Convert image to base64
154+
image_bytes, _ = await file_input_to_bytes(image, self._http_session)
155+
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
156+
157+
event, result = self._manager.register_image_set_wait()
158+
159+
try:
160+
await self._manager.send_message(
161+
SetAvatarImageMessage(type="set_image", image_data=image_base64)
162+
)
163+
164+
try:
165+
await asyncio.wait_for(event.wait(), timeout=15.0)
166+
except asyncio.TimeoutError:
167+
raise DecartSDKError("Image set acknowledgment timed out")
168+
169+
if not result["success"]:
170+
raise DecartSDKError(
171+
result.get("status") or "Failed to set avatar image"
172+
)
173+
finally:
174+
self._manager.unregister_image_set_wait()
175+
103176
def is_connected(self) -> bool:
104177
return self._manager.is_connected()
105178

@@ -108,6 +181,8 @@ def get_connection_state(self) -> ConnectionState:
108181

109182
async def disconnect(self) -> None:
110183
await self._manager.cleanup()
184+
if self._http_session and not self._http_session.closed:
185+
await self._http_session.close()
111186

112187
def on(self, event: str, callback: Callable) -> None:
113188
if event == "connection_change":

decart/realtime/messages.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,22 @@ class PromptAckMessage(BaseModel):
5151
error: Optional[str] = None
5252

5353

54+
class ImageSetMessage(BaseModel):
55+
"""Acknowledgment for avatar image set from server."""
56+
57+
type: Literal["image_set"]
58+
status: str
59+
60+
5461
# Discriminated union for incoming messages
5562
IncomingMessage = Annotated[
56-
Union[AnswerMessage, IceCandidateMessage, SessionIdMessage, PromptAckMessage],
63+
Union[
64+
AnswerMessage,
65+
IceCandidateMessage,
66+
SessionIdMessage,
67+
PromptAckMessage,
68+
ImageSetMessage,
69+
],
5770
Field(discriminator="type"),
5871
]
5972

@@ -79,8 +92,17 @@ class PromptMessage(BaseModel):
7992
enhance_prompt: bool = True
8093

8194

95+
class SetAvatarImageMessage(BaseModel):
96+
"""Set avatar image message."""
97+
98+
type: Literal["set_image"]
99+
image_data: str # Base64-encoded image
100+
101+
82102
# Outgoing message union (no discriminator needed - we know what we're sending)
83-
OutgoingMessage = Union[OfferMessage, IceCandidateMessage, PromptMessage]
103+
OutgoingMessage = Union[
104+
OfferMessage, IceCandidateMessage, PromptMessage, SetAvatarImageMessage
105+
]
84106

85107

86108
def parse_incoming_message(data: dict) -> IncomingMessage:

decart/realtime/types.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Literal, Callable, Optional
22
from dataclasses import dataclass
33
from ..models import ModelDefinition
4-
from ..types import ModelState
4+
from ..types import ModelState, FileInput
55

66
try:
77
from aiortc import MediaStreamTrack
@@ -12,9 +12,18 @@
1212
ConnectionState = Literal["connecting", "connected", "disconnected"]
1313

1414

15+
@dataclass
16+
class AvatarOptions:
17+
"""Options for avatar-live model."""
18+
19+
avatar_image: FileInput
20+
"""The avatar image to use. Can be bytes, Path, URL string, or file-like object."""
21+
22+
1523
@dataclass
1624
class RealtimeConnectOptions:
1725
model: ModelDefinition
1826
on_remote_stream: Callable[[MediaStreamTrack], None]
1927
initial_state: Optional[ModelState] = None
2028
customize_offer: Optional[Callable] = None
29+
avatar: Optional[AvatarOptions] = None

0 commit comments

Comments
 (0)