Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from ...version import __version__ as runpod_version
from .rp_handler import is_generator
from .rp_job import run_job, run_job_generator
from .rp_ping import Heartbeat
from .worker_state import Job, JobsProgress
from .worker_state import Job, get_jobs_progress
from .rp_ping import get_heartbeat

RUNPOD_ENDPOINT_ID = os.environ.get("RUNPOD_ENDPOINT_ID", None)

Expand Down Expand Up @@ -96,8 +96,6 @@


# ------------------------------ Initializations ----------------------------- #
job_list = JobsProgress()
heartbeat = Heartbeat()


# ------------------------------- Input Objects ------------------------------ #
Expand Down Expand Up @@ -185,7 +183,7 @@ def __init__(self, config: Dict[str, Any]):
3. Sets the handler for processing jobs.
"""
# Start the heartbeat thread.
heartbeat.start_ping()
get_heartbeat().start_ping()

self.config = config

Expand Down Expand Up @@ -286,12 +284,12 @@ async def _realtime(self, job: Job):
Performs model inference on the input data using the provided handler.
If handler is not provided, returns an error message.
"""
job_list.add(job.id)
get_jobs_progress().add(job.id)

# Process the job using the provided handler, passing in the job input.
job_results = await run_job(self.config["handler"], job.__dict__)

job_list.remove(job.id)
get_jobs_progress().remove(job.id)

# Return the results of the job processing.
return jsonable_encoder(job_results)
Expand All @@ -304,7 +302,7 @@ async def _realtime(self, job: Job):
async def _sim_run(self, job_request: DefaultRequest) -> JobOutput:
"""Development endpoint to simulate run behavior."""
assigned_job_id = f"test-{uuid.uuid4()}"
job_list.add({
get_jobs_progress().add({
"id": assigned_job_id,
"input": job_request.input,
"webhook": job_request.webhook
Expand Down Expand Up @@ -345,7 +343,7 @@ async def _sim_runsync(self, job_request: DefaultRequest) -> JobOutput:
# ---------------------------------- stream ---------------------------------- #
async def _sim_stream(self, job_id: str) -> StreamOutput:
"""Development endpoint to simulate stream behavior."""
stashed_job = job_list.get(job_id)
stashed_job = get_jobs_progress().get(job_id)
if stashed_job is None:
return jsonable_encoder(
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
Expand All @@ -367,7 +365,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
}
)

job_list.remove(job.id)
get_jobs_progress().remove(job.id)

if stashed_job.webhook:
thread = threading.Thread(
Expand All @@ -384,7 +382,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
# ---------------------------------- status ---------------------------------- #
async def _sim_status(self, job_id: str) -> JobOutput:
"""Development endpoint to simulate status behavior."""
stashed_job = job_list.get(job_id)
stashed_job = get_jobs_progress().get(job_id)
if stashed_job is None:
return jsonable_encoder(
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
Expand All @@ -400,7 +398,7 @@ async def _sim_status(self, job_id: str) -> JobOutput:
else:
job_output = await run_job(self.config["handler"], job.__dict__)

job_list.remove(job.id)
get_jobs_progress().remove(job.id)

if job_output.get("error", None):
return jsonable_encoder(
Expand Down
5 changes: 2 additions & 3 deletions runpod/serverless/modules/rp_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from .rp_handler import is_generator
from .rp_http import send_result, stream_result
from .rp_tips import check_return_size
from .worker_state import WORKER_ID, REF_COUNT_ZERO, JobsProgress
from .worker_state import WORKER_ID, REF_COUNT_ZERO, get_jobs_progress

JOB_GET_URL = str(os.environ.get("RUNPOD_WEBHOOK_GET_JOB")).replace("$ID", WORKER_ID)

log = RunPodLogger()
job_progress = JobsProgress()


def _job_get_url(batch_size: int = 1):
Expand All @@ -43,7 +42,7 @@ def _job_get_url(batch_size: int = 1):
else:
job_take_url = JOB_GET_URL

job_in_progress = "1" if job_progress.get_job_list() else "0"
job_in_progress = "1" if get_jobs_progress().get_job_list() else "0"
job_take_url += f"&job_in_progress={job_in_progress}"

log.debug(f"rp_job | get_job: {job_take_url}")
Expand Down
22 changes: 19 additions & 3 deletions runpod/serverless/modules/rp_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,27 @@

from runpod.http_client import SyncClientSession
from runpod.serverless.modules.rp_logger import RunPodLogger
from runpod.serverless.modules.worker_state import WORKER_ID, JobsProgress
from runpod.serverless.modules.worker_state import WORKER_ID, get_jobs_progress
from runpod.version import __version__ as runpod_version

log = RunPodLogger()
jobs = JobsProgress() # Contains the list of jobs that are currently running.

# Lazy loading for Heartbeat instance
_heartbeat_instance = None


def get_heartbeat():
"""Get the global Heartbeat instance with lazy initialization."""
global _heartbeat_instance
if _heartbeat_instance is None:
_heartbeat_instance = Heartbeat()
return _heartbeat_instance


def reset_heartbeat():
"""Reset the lazy-loaded Heartbeat instance (useful for testing)."""
global _heartbeat_instance
_heartbeat_instance = None


class Heartbeat:
Expand Down Expand Up @@ -97,7 +113,7 @@ def _send_ping(self):
"""
Sends a heartbeat to the Runpod server.
"""
job_ids = jobs.get_job_list()
job_ids = get_jobs_progress().get_job_list()
ping_params = {"job_id": job_ids, "runpod_version": runpod_version}

try:
Expand Down
11 changes: 5 additions & 6 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
from ...http_client import AsyncClientSession, ClientSession, TooManyRequests
from .rp_job import get_job, handle_job
from .rp_logger import RunPodLogger
from .worker_state import JobsProgress, IS_LOCAL_TEST
from .worker_state import IS_LOCAL_TEST, get_jobs_progress

log = RunPodLogger()
job_progress = JobsProgress()


def _handle_uncaught_exception(exc_type, exc_value, exc_traceback):
Expand Down Expand Up @@ -101,7 +100,7 @@ def start(self):
signal.signal(signal.SIGTERM, self.handle_shutdown)
signal.signal(signal.SIGINT, self.handle_shutdown)
except ValueError:
log.warning("Signal handling is only supported in the main thread.")
log.warn("Signal handling is only supported in the main thread.")

# Start the main loop
# Run forever until the worker is signalled to shut down.
Expand Down Expand Up @@ -149,7 +148,7 @@ def kill_worker(self):

def current_occupancy(self) -> int:
current_queue_count = self.jobs_queue.qsize()
current_progress_count = job_progress.get_job_count()
current_progress_count = get_jobs_progress().get_job_count()

log.debug(
f"JobScaler.status | concurrency: {self.current_concurrency}; queue: {current_queue_count}; progress: {current_progress_count}"
Expand Down Expand Up @@ -188,7 +187,7 @@ async def get_jobs(self, session: ClientSession):

for job in acquired_jobs:
await self.jobs_queue.put(job)
job_progress.add(job)
get_jobs_progress().add(job)
log.debug("Job Queued", job["id"])

log.info(f"Jobs in queue: {self.jobs_queue.qsize()}")
Expand Down Expand Up @@ -268,6 +267,6 @@ async def handle_job(self, session: ClientSession, job: dict):
self.jobs_queue.task_done()

# Job is no longer in progress
job_progress.remove(job)
get_jobs_progress().remove(job)

log.debug("Finished Job", job["id"])
Loading
Loading