Skip to content

Commit 332d912

Browse files
committed
Add Allegro API server implementation with Prometheus metrics support
- Introduced `allegro_serve.py` to implement the Allegro API using LitServe. - Added Prometheus metrics for tracking request processing times. - Created `AllegroRequest` model for input validation and `AllegroAPI` class for handling inference requests. - Implemented error handling and logging for improved traceability. - Set up a temporary directory for video output and integrated S3 upload functionality. - Added an error log file for capturing runtime errors.
1 parent 0a887bf commit 332d912

File tree

8 files changed

+205
-1
lines changed

8 files changed

+205
-1
lines changed

api/allegro_serve.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import os
2+
import sys
3+
import time
4+
import tempfile
5+
from typing import Dict, Any, List, Union, Optional
6+
from pathlib import Path
7+
from pydantic import BaseModel, Field
8+
from litserve import LitAPI, LitServer, Logger
9+
from loguru import logger
10+
from prometheus_client import (
11+
CollectorRegistry,
12+
Histogram,
13+
make_asgi_app,
14+
multiprocess
15+
)
16+
17+
from configs.allegro_settings import AllegroSettings
18+
from scripts.allegro_diffusers import AllegroInference
19+
from scripts.mp4_to_s3_json import mp4_to_s3_json
20+
import torch
21+
22+
# Set up prometheus multiprocess mode
23+
os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc_dir"
24+
if not os.path.exists("/tmp/prometheus_multiproc_dir"):
25+
os.makedirs("/tmp/prometheus_multiproc_dir")
26+
27+
# Initialize prometheus registry
28+
registry = CollectorRegistry()
29+
multiprocess.MultiProcessCollector(registry)
30+
31+
class PrometheusLogger(Logger):
32+
"""Custom logger for Prometheus metrics.
33+
34+
Implements metric collection for request processing times
35+
using Prometheus Histograms.
36+
Metrics are stored in a multi-process compatible registry.
37+
38+
Attributes:
39+
function_duration (Histogram): Prometheus histogram for tracking processing times
40+
"""
41+
42+
def __init__(self):
43+
super().__init__()
44+
self.function_duration = Histogram(
45+
"allegro_request_processing_seconds",
46+
"Time spent processing Allegro request",
47+
["function_name"],
48+
registry=registry
49+
)
50+
51+
def process(self, key: str, value: float) -> None:
52+
"""Process and record a metric value.
53+
54+
Args:
55+
key (str): The name of the function or operation being measured
56+
value (float): The duration or metric value to record
57+
"""
58+
self.function_duration.labels(function_name=key).observe(value)
59+
60+
class AllegroRequest(BaseModel):
61+
"""Model representing a request for the Allegro model.
62+
63+
Validates input parameters for Allegro model inference.
64+
65+
Attributes:
66+
prompt (str): Text prompt for inference
67+
negative_prompt (Optional[str]): Text prompt for elements to avoid
68+
num_inference_steps (int): Number of denoising steps (1-100)
69+
guidance_scale (float): Controls adherence to prompt (1.0-20.0)
70+
height (int): Image height (256-720, multiple of 32)
71+
width (int): Image width (256-1280, multiple of 32)
72+
seed (Optional[int]): Random seed for reproducibility
73+
"""
74+
prompt: str = Field(..., description="Main text prompt for generation")
75+
negative_prompt: Optional[str] = Field(
76+
"worst quality, blurry, distorted",
77+
description="Text description of what to avoid"
78+
)
79+
num_inference_steps: int = Field(50, ge=1, le=100, description="Number of inference steps")
80+
guidance_scale: float = Field(7.5, ge=1.0, le=20.0, description="Guidance scale")
81+
height: int = Field(512, ge=256, le=720, multiple_of=32, description="Image height")
82+
width: int = Field(512, ge=256, le=1280, multiple_of=32, description="Image width")
83+
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
84+
85+
class AllegroAPI(LitAPI):
86+
"""API implementation for Allegro model inference using LitServe.
87+
88+
Attributes:
89+
settings (AllegroSettings): Configuration for Allegro model
90+
engine (AllegroInference): Inference engine for Allegro
91+
"""
92+
93+
def setup(self, device: str) -> None:
94+
"""Initialize the Allegro inference engine.
95+
96+
Args:
97+
device (str): Target device for inference ('cuda', 'cpu', etc.)
98+
"""
99+
try:
100+
logger.info(f"Initializing Allegro model on device: {device}")
101+
self.settings = AllegroSettings(device=device)
102+
self.engine = AllegroInference(self.settings)
103+
logger.info("Allegro setup completed successfully")
104+
except Exception as e:
105+
logger.error(f"Error during Allegro setup: {e}")
106+
raise
107+
108+
def decode_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
109+
"""Decode and validate the incoming request.
110+
111+
Args:
112+
request (dict): Input request dictionary
113+
114+
Returns:
115+
dict: Validated request
116+
"""
117+
try:
118+
return AllegroRequest(**request).dict()
119+
except Exception as e:
120+
logger.error(f"Request validation error: {e}")
121+
raise
122+
123+
def predict(self, inputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
124+
"""Perform inference using the Allegro model.
125+
126+
Args:
127+
inputs (list): List of validated requests
128+
129+
Returns:
130+
list: Results with URLs and metadata
131+
"""
132+
results = []
133+
for request in inputs:
134+
start_time = time.time()
135+
try:
136+
self.settings.update(request)
137+
with tempfile.TemporaryDirectory() as temp_dir:
138+
output_path = Path(temp_dir) / "output.mp4"
139+
self.settings.output_path = output_path
140+
self.engine.generate()
141+
142+
if not output_path.exists():
143+
raise FileNotFoundError(f"Output not found at {output_path}")
144+
145+
with open(output_path, 'rb') as video_file:
146+
s3_response = mp4_to_s3_json(video_file, output_path.name)
147+
148+
generation_time = time.time() - start_time
149+
150+
results.append({
151+
"status": "success",
152+
"video_url": s3_response["url"],
153+
"prompt": request["prompt"],
154+
"time_taken": generation_time
155+
})
156+
157+
except Exception as e:
158+
logger.error(f"Error during prediction: {e}")
159+
results.append({"status": "error", "error": str(e)})
160+
return results
161+
162+
def encode_response(self, output: List[Dict[str, Any]]) -> Dict[str, Any]:
163+
"""Encode the results into a response format.
164+
165+
Args:
166+
output (list): Results list
167+
168+
Returns:
169+
dict: Encoded response
170+
"""
171+
return {"results": output}
172+
173+
def main():
174+
prometheus_logger = PrometheusLogger()
175+
prometheus_logger.mount(
176+
path="/api/v1/metrics",
177+
app=make_asgi_app(registry=registry)
178+
)
179+
180+
logger.remove()
181+
logger.add(sys.stdout, format="<green>{time}</green> | <level>{message}</level>", level="INFO")
182+
logger.add("logs/error.log", format="<red>{time}</red> | <level>{message}</level>", level="ERROR")
183+
184+
try:
185+
api = AllegroAPI()
186+
server = LitServer(
187+
api,
188+
api_path='/api/v1/allegro',
189+
accelerator="auto",
190+
devices="auto",
191+
max_batch_size=4,
192+
loggers=prometheus_logger,
193+
)
194+
195+
logger.info("Starting Allegro API server on port 8000")
196+
server.run(port=8000)
197+
except Exception as e:
198+
logger.error(f"Failed to start server: {e}")
199+
sys.exit(1)
200+
201+
202+
203+
if __name__ == "__main__":
204+
main()

api/logs/error.log

Whitespace-only changes.
Binary file not shown.
-2 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
-2 Bytes
Binary file not shown.

scripts/allegro_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def generate_video(self, prompt: str, positive_prompt: str, negative_prompt: str
9595
logger.error(f"Error during video generation: {e}")
9696
raise
9797

98-
# Example usage (to be executed in a main script or testing environment)
98+
9999
if __name__ == "__main__":
100100
settings = AllegroSettings()
101101
inference = AllegroInference(settings)

0 commit comments

Comments
 (0)