Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,4 @@ dmypy.json
# Pyre type checker
.pyre/
runpod/_version.py
.runpod_jobs.pkl
1 change: 0 additions & 1 deletion runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from . import worker
from .modules import rp_fastapi
from .modules.rp_logger import RunPodLogger
from .modules.rp_progress import progress_update

log = RunPodLogger()

Expand Down
2 changes: 1 addition & 1 deletion runpod/serverless/modules/rp_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def handle_job(session: ClientSession, config: Dict[str, Any], job) -> dic
job_result["stopPod"] = True

# If rp_debugger is set, debugger output will be returned.
if config["rp_args"].get("rp_debugger", False) and isinstance(job_result, dict):
if config.get("rp_args", {}).get("rp_debugger", False) and isinstance(job_result, dict):
job_result["output"]["rp_debugger"] = rp_debugger.get_debugger_output()
log.debug("rp_debugger | Flag set, returning debugger output.", job["id"])

Expand Down
2 changes: 1 addition & 1 deletion runpod/serverless/modules/rp_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from runpod.version import __version__ as runpod_version

log = RunPodLogger()
jobs = JobsProgress() # Contains the list of jobs that are currently running.


class Heartbeat:
Expand Down Expand Up @@ -97,6 +96,7 @@ def _send_ping(self):
"""
Sends a heartbeat to the Runpod server.
"""
jobs = JobsProgress() # Get the singleton instance
job_ids = jobs.get_job_list()
ping_params = {"job_id": job_ids, "runpod_version": runpod_version}

Expand Down
8 changes: 4 additions & 4 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from .worker_state import JobsProgress, IS_LOCAL_TEST

log = RunPodLogger()
job_progress = JobsProgress()


def _handle_uncaught_exception(exc_type, exc_value, exc_traceback):
Expand Down Expand Up @@ -47,6 +46,7 @@ def __init__(self, config: Dict[str, Any]):
self._shutdown_event = asyncio.Event()
self.current_concurrency = 1
self.config = config
self.job_progress = JobsProgress() # Cache the singleton instance

self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency)

Expand Down Expand Up @@ -149,7 +149,7 @@ def kill_worker(self):

def current_occupancy(self) -> int:
current_queue_count = self.jobs_queue.qsize()
current_progress_count = job_progress.get_job_count()
current_progress_count = self.job_progress.get_job_count()

log.debug(
f"JobScaler.status | concurrency: {self.current_concurrency}; queue: {current_queue_count}; progress: {current_progress_count}"
Expand Down Expand Up @@ -188,7 +188,7 @@ async def get_jobs(self, session: ClientSession):

for job in acquired_jobs:
await self.jobs_queue.put(job)
job_progress.add(job)
self.job_progress.add(job)
log.debug("Job Queued", job["id"])

log.info(f"Jobs in queue: {self.jobs_queue.qsize()}")
Expand Down Expand Up @@ -268,6 +268,6 @@ async def handle_job(self, session: ClientSession, job: dict):
self.jobs_queue.task_done()

# Job is no longer in progress
job_progress.remove(job)
self.job_progress.remove(job)

log.debug("Finished Job", job["id"])
224 changes: 113 additions & 111 deletions runpod/serverless/modules/worker_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import os
import time
import uuid
from multiprocessing import Manager
from multiprocessing.managers import SyncManager
from typing import Any, Dict, Optional
import pickle
import fcntl
import tempfile
from typing import Any, Dict, Optional, Set

from .rp_logger import RunPodLogger

Expand Down Expand Up @@ -63,149 +64,150 @@ def __str__(self) -> str:
# ---------------------------------------------------------------------------- #
# Tracker #
# ---------------------------------------------------------------------------- #
class JobsProgress:
"""Track the state of current jobs in progress using shared memory."""

_instance: Optional['JobsProgress'] = None
_manager: SyncManager
_shared_data: Any
_lock: Any
class JobsProgress(Set[Job]):
"""Track the state of current jobs in progress with persistent state."""

_instance = None
_STATE_DIR = os.getcwd()
_STATE_FILE = os.path.join(_STATE_DIR, ".runpod_jobs.pkl")

def __new__(cls):
if cls._instance is None:
instance = object.__new__(cls)
# Initialize instance variables
instance._manager = Manager()
instance._shared_data = instance._manager.dict()
instance._shared_data['jobs'] = instance._manager.list()
instance._lock = instance._manager.Lock()
cls._instance = instance
return cls._instance
if JobsProgress._instance is None:
os.makedirs(cls._STATE_DIR, exist_ok=True)
JobsProgress._instance = set.__new__(cls)
# Initialize as empty set before loading state
set.__init__(JobsProgress._instance)
JobsProgress._instance._load_state()
return JobsProgress._instance

def __init__(self):
# Everything is already initialized in __new__
# This should never clear data in a singleton
# Don't call parent __init__ as it would clear the set
pass

def __repr__(self) -> str:
return f"<{self.__class__.__name__}>: {self.get_job_list()}"

def _load_state(self):
"""Load jobs state from pickle file with file locking."""
try:
if (
os.path.exists(self._STATE_FILE)
and os.path.getsize(self._STATE_FILE) > 0
):
with open(self._STATE_FILE, "rb") as f:
fcntl.flock(f, fcntl.LOCK_SH)
try:
loaded_jobs = pickle.load(f)
# Clear current state and add loaded jobs
super().clear()
for job in loaded_jobs:
set.add(
self, job
) # Use set.add to avoid triggering _save_state

except (EOFError, pickle.UnpicklingError):
# Handle empty or corrupted file
log.debug(
"JobsProgress: Failed to load state file, starting with empty state"
)
pass
finally:
fcntl.flock(f, fcntl.LOCK_UN)

except FileNotFoundError:
log.debug("JobsProgress: No state file found, starting with empty state")
pass

def _save_state(self):
"""Save jobs state to pickle file with atomic write and file locking."""
try:
# Use temporary file for atomic write
with tempfile.NamedTemporaryFile(
dir=self._STATE_DIR, delete=False, mode="wb"
) as temp_f:
fcntl.flock(temp_f, fcntl.LOCK_EX)
try:
pickle.dump(set(self), temp_f)
finally:
fcntl.flock(temp_f, fcntl.LOCK_UN)

# Atomically replace the state file
os.replace(temp_f.name, self._STATE_FILE)
except Exception as e:
log.error(f"Failed to save job state: {e}")

def clear(self) -> None:
with self._lock:
self._shared_data['jobs'][:] = []
super().clear()
self._save_state()

def add(self, element: Any):
"""
Adds a Job object to the set.
"""
if isinstance(element, str):
job_dict = {'id': element}
elif isinstance(element, dict):
job_dict = element
elif hasattr(element, 'id'):
job_dict = {'id': element.id}
else:
raise TypeError("Only Job objects can be added to JobsProgress.")

with self._lock:
# Check if job already exists
job_list = self._shared_data['jobs']
for existing_job in job_list:
if existing_job['id'] == job_dict['id']:
return # Job already exists

# Add new job
job_list.append(job_dict)
log.debug(f"JobsProgress | Added job: {job_dict['id']}")

def get(self, element: Any) -> Optional[Job]:
"""
Retrieves a Job object from the set.
If the added element is a string, then `Job(id=element)` is added

If the element is a string, searches for Job with that id.
If the added element is a dict, that `Job(**element)` is added
"""
if isinstance(element, str):
search_id = element
elif isinstance(element, Job):
search_id = element.id
else:
raise TypeError("Only Job objects can be retrieved from JobsProgress.")
element = Job(id=element)

with self._lock:
for job_dict in self._shared_data['jobs']:
if job_dict['id'] == search_id:
log.debug(f"JobsProgress | Retrieved job: {job_dict['id']}")
return Job(**job_dict)

return None
if isinstance(element, dict):
element = Job(**element)

if not isinstance(element, Job):
raise TypeError("Only Job objects can be added to JobsProgress.")

result = super().add(element)
self._save_state()
return result

def remove(self, element: Any):
"""
Removes a Job object from the set.

If the element is a string, then `Job(id=element)` is removed

If the element is a dict, then `Job(**element)` is removed
"""
if isinstance(element, str):
job_id = element
elif isinstance(element, dict):
job_id = element.get('id')
elif hasattr(element, 'id'):
job_id = element.id
else:
element = Job(id=element)

if isinstance(element, dict):
element = Job(**element)

if not isinstance(element, Job):
raise TypeError("Only Job objects can be removed from JobsProgress.")

with self._lock:
job_list = self._shared_data['jobs']
# Find and remove the job
for i, job_dict in enumerate(job_list):
if job_dict['id'] == job_id:
del job_list[i]
log.debug(f"JobsProgress | Removed job: {job_dict['id']}")
break
result = super().discard(element)
self._save_state()
return result

def get(self, element: Any) -> Optional[Job]:
if isinstance(element, str):
element = Job(id=element)

if not isinstance(element, Job):
raise TypeError("Only Job objects can be retrieved from JobsProgress.")

for job in self:
if job == element:
return job
return None

def get_job_list(self) -> Optional[str]:
"""
Returns the list of job IDs as comma-separated string.
"""
with self._lock:
job_list = list(self._shared_data['jobs'])

if not job_list:
self._load_state()

if not len(self):
return None

log.debug(f"JobsProgress | Jobs in progress: {job_list}")
return ",".join(str(job_dict['id']) for job_dict in job_list)
return ",".join(str(job) for job in self)

def get_job_count(self) -> int:
"""
Returns the number of jobs.
"""
with self._lock:
return len(self._shared_data['jobs'])

def __iter__(self):
"""Make the class iterable - returns Job objects"""
with self._lock:
# Create a snapshot of jobs to avoid holding lock during iteration
job_dicts = list(self._shared_data['jobs'])

# Return an iterator of Job objects
return iter(Job(**job_dict) for job_dict in job_dicts)

def __len__(self):
"""Support len() operation"""
return self.get_job_count()

def __contains__(self, element: Any) -> bool:
"""Support 'in' operator"""
if isinstance(element, str):
search_id = element
elif isinstance(element, Job):
search_id = element.id
elif isinstance(element, dict):
search_id = element.get('id')
else:
return False

with self._lock:
for job_dict in self._shared_data['jobs']:
if job_dict['id'] == search_id:
return True
return False
return len(self)
10 changes: 10 additions & 0 deletions tests/test_serverless/local_sim/.env_example
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
RUNPOD_AI_API_KEY=XXX
RUNPOD_API_URL=http://localhost:8080/graphql
RUNPOD_DEBUG_LEVEL=INFO
RUNPOD_ENDPOINT_ID=test-endpoint
RUNPOD_PING_INTERVAL=1000
RUNPOD_POD_ID=test-worker
RUNPOD_WEBHOOK_GET_JOB=http://localhost:8080/v2/test-endpoint/job-take/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090
RUNPOD_WEBHOOK_PING=http://localhost:8080/v2/test-endpoint/ping/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090
RUNPOD_WEBHOOK_POST_OUTPUT=http://localhost:8080/v2/test-endpoint/job-done/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090
RUNPOD_WEBHOOK_JOB_STREAM=http://localhost:8080/v2/test-endpoint/job-stream/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090
14 changes: 14 additions & 0 deletions tests/test_serverless/local_sim/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.PHONY: localhost worker all

all: localhost worker

localhost:
python localhost.py &

worker:
python worker.py

clean:
find . -type f -name ".runpod_jobs.pkl" -delete
find . -type f -name "*.pyc" -delete
find . -type d -name "__pycache__" -delete
Loading
Loading