diff --git a/.gitignore b/.gitignore index e6b95300..2557451c 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,4 @@ dmypy.json # Pyre type checker .pyre/ runpod/_version.py +.runpod_jobs.pkl diff --git a/runpod/serverless/__init__.py b/runpod/serverless/__init__.py index fe428d65..5529dd2b 100644 --- a/runpod/serverless/__init__.py +++ b/runpod/serverless/__init__.py @@ -16,7 +16,6 @@ from . import worker from .modules import rp_fastapi from .modules.rp_logger import RunPodLogger -from .modules.rp_progress import progress_update log = RunPodLogger() diff --git a/runpod/serverless/modules/rp_job.py b/runpod/serverless/modules/rp_job.py index f7f8feba..233c34fd 100644 --- a/runpod/serverless/modules/rp_job.py +++ b/runpod/serverless/modules/rp_job.py @@ -149,7 +149,7 @@ async def handle_job(session: ClientSession, config: Dict[str, Any], job) -> dic job_result["stopPod"] = True # If rp_debugger is set, debugger output will be returned. - if config["rp_args"].get("rp_debugger", False) and isinstance(job_result, dict): + if config.get("rp_args", {}).get("rp_debugger", False) and isinstance(job_result, dict): job_result["output"]["rp_debugger"] = rp_debugger.get_debugger_output() log.debug("rp_debugger | Flag set, returning debugger output.", job["id"]) diff --git a/runpod/serverless/modules/rp_ping.py b/runpod/serverless/modules/rp_ping.py index ae1499f7..28517835 100644 --- a/runpod/serverless/modules/rp_ping.py +++ b/runpod/serverless/modules/rp_ping.py @@ -15,7 +15,6 @@ from runpod.version import __version__ as runpod_version log = RunPodLogger() -jobs = JobsProgress() # Contains the list of jobs that are currently running. class Heartbeat: @@ -97,6 +96,7 @@ def _send_ping(self): """ Sends a heartbeat to the Runpod server. """ + jobs = JobsProgress() # Get the singleton instance job_ids = jobs.get_job_list() ping_params = {"job_id": job_ids, "runpod_version": runpod_version} diff --git a/runpod/serverless/modules/rp_scale.py b/runpod/serverless/modules/rp_scale.py index 7c05ef9c..5c7d79cc 100644 --- a/runpod/serverless/modules/rp_scale.py +++ b/runpod/serverless/modules/rp_scale.py @@ -15,7 +15,6 @@ from .worker_state import JobsProgress, IS_LOCAL_TEST log = RunPodLogger() -job_progress = JobsProgress() def _handle_uncaught_exception(exc_type, exc_value, exc_traceback): @@ -47,6 +46,7 @@ def __init__(self, config: Dict[str, Any]): self._shutdown_event = asyncio.Event() self.current_concurrency = 1 self.config = config + self.job_progress = JobsProgress() # Cache the singleton instance self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency) @@ -149,7 +149,7 @@ def kill_worker(self): def current_occupancy(self) -> int: current_queue_count = self.jobs_queue.qsize() - current_progress_count = job_progress.get_job_count() + current_progress_count = self.job_progress.get_job_count() log.debug( f"JobScaler.status | concurrency: {self.current_concurrency}; queue: {current_queue_count}; progress: {current_progress_count}" @@ -188,7 +188,7 @@ async def get_jobs(self, session: ClientSession): for job in acquired_jobs: await self.jobs_queue.put(job) - job_progress.add(job) + self.job_progress.add(job) log.debug("Job Queued", job["id"]) log.info(f"Jobs in queue: {self.jobs_queue.qsize()}") @@ -268,6 +268,6 @@ async def handle_job(self, session: ClientSession, job: dict): self.jobs_queue.task_done() # Job is no longer in progress - job_progress.remove(job) + self.job_progress.remove(job) log.debug("Finished Job", job["id"]) diff --git a/runpod/serverless/modules/worker_state.py b/runpod/serverless/modules/worker_state.py index be5dc9db..2dad0a07 100644 --- a/runpod/serverless/modules/worker_state.py +++ b/runpod/serverless/modules/worker_state.py @@ -5,9 +5,10 @@ import os import time import uuid -from multiprocessing import Manager -from multiprocessing.managers import SyncManager -from typing import Any, Dict, Optional +import pickle +import fcntl +import tempfile +from typing import Any, Dict, Optional, Set from .rp_logger import RunPodLogger @@ -63,149 +64,150 @@ def __str__(self) -> str: # ---------------------------------------------------------------------------- # # Tracker # # ---------------------------------------------------------------------------- # -class JobsProgress: - """Track the state of current jobs in progress using shared memory.""" - - _instance: Optional['JobsProgress'] = None - _manager: SyncManager - _shared_data: Any - _lock: Any +class JobsProgress(Set[Job]): + """Track the state of current jobs in progress with persistent state.""" + + _instance = None + _STATE_DIR = os.getcwd() + _STATE_FILE = os.path.join(_STATE_DIR, ".runpod_jobs.pkl") def __new__(cls): - if cls._instance is None: - instance = object.__new__(cls) - # Initialize instance variables - instance._manager = Manager() - instance._shared_data = instance._manager.dict() - instance._shared_data['jobs'] = instance._manager.list() - instance._lock = instance._manager.Lock() - cls._instance = instance - return cls._instance + if JobsProgress._instance is None: + os.makedirs(cls._STATE_DIR, exist_ok=True) + JobsProgress._instance = set.__new__(cls) + # Initialize as empty set before loading state + set.__init__(JobsProgress._instance) + JobsProgress._instance._load_state() + return JobsProgress._instance def __init__(self): - # Everything is already initialized in __new__ + # This should never clear data in a singleton + # Don't call parent __init__ as it would clear the set pass - + def __repr__(self) -> str: return f"<{self.__class__.__name__}>: {self.get_job_list()}" + def _load_state(self): + """Load jobs state from pickle file with file locking.""" + try: + if ( + os.path.exists(self._STATE_FILE) + and os.path.getsize(self._STATE_FILE) > 0 + ): + with open(self._STATE_FILE, "rb") as f: + fcntl.flock(f, fcntl.LOCK_SH) + try: + loaded_jobs = pickle.load(f) + # Clear current state and add loaded jobs + super().clear() + for job in loaded_jobs: + set.add( + self, job + ) # Use set.add to avoid triggering _save_state + + except (EOFError, pickle.UnpicklingError): + # Handle empty or corrupted file + log.debug( + "JobsProgress: Failed to load state file, starting with empty state" + ) + pass + finally: + fcntl.flock(f, fcntl.LOCK_UN) + + except FileNotFoundError: + log.debug("JobsProgress: No state file found, starting with empty state") + pass + + def _save_state(self): + """Save jobs state to pickle file with atomic write and file locking.""" + try: + # Use temporary file for atomic write + with tempfile.NamedTemporaryFile( + dir=self._STATE_DIR, delete=False, mode="wb" + ) as temp_f: + fcntl.flock(temp_f, fcntl.LOCK_EX) + try: + pickle.dump(set(self), temp_f) + finally: + fcntl.flock(temp_f, fcntl.LOCK_UN) + + # Atomically replace the state file + os.replace(temp_f.name, self._STATE_FILE) + except Exception as e: + log.error(f"Failed to save job state: {e}") + def clear(self) -> None: - with self._lock: - self._shared_data['jobs'][:] = [] + super().clear() + self._save_state() def add(self, element: Any): """ Adds a Job object to the set. - """ - if isinstance(element, str): - job_dict = {'id': element} - elif isinstance(element, dict): - job_dict = element - elif hasattr(element, 'id'): - job_dict = {'id': element.id} - else: - raise TypeError("Only Job objects can be added to JobsProgress.") - with self._lock: - # Check if job already exists - job_list = self._shared_data['jobs'] - for existing_job in job_list: - if existing_job['id'] == job_dict['id']: - return # Job already exists - - # Add new job - job_list.append(job_dict) - log.debug(f"JobsProgress | Added job: {job_dict['id']}") - - def get(self, element: Any) -> Optional[Job]: - """ - Retrieves a Job object from the set. + If the added element is a string, then `Job(id=element)` is added - If the element is a string, searches for Job with that id. + If the added element is a dict, that `Job(**element)` is added """ if isinstance(element, str): - search_id = element - elif isinstance(element, Job): - search_id = element.id - else: - raise TypeError("Only Job objects can be retrieved from JobsProgress.") + element = Job(id=element) - with self._lock: - for job_dict in self._shared_data['jobs']: - if job_dict['id'] == search_id: - log.debug(f"JobsProgress | Retrieved job: {job_dict['id']}") - return Job(**job_dict) - - return None + if isinstance(element, dict): + element = Job(**element) + + if not isinstance(element, Job): + raise TypeError("Only Job objects can be added to JobsProgress.") + + result = super().add(element) + self._save_state() + return result def remove(self, element: Any): """ Removes a Job object from the set. + + If the element is a string, then `Job(id=element)` is removed + + If the element is a dict, then `Job(**element)` is removed """ if isinstance(element, str): - job_id = element - elif isinstance(element, dict): - job_id = element.get('id') - elif hasattr(element, 'id'): - job_id = element.id - else: + element = Job(id=element) + + if isinstance(element, dict): + element = Job(**element) + + if not isinstance(element, Job): raise TypeError("Only Job objects can be removed from JobsProgress.") - with self._lock: - job_list = self._shared_data['jobs'] - # Find and remove the job - for i, job_dict in enumerate(job_list): - if job_dict['id'] == job_id: - del job_list[i] - log.debug(f"JobsProgress | Removed job: {job_dict['id']}") - break + result = super().discard(element) + self._save_state() + return result + + def get(self, element: Any) -> Optional[Job]: + if isinstance(element, str): + element = Job(id=element) + + if not isinstance(element, Job): + raise TypeError("Only Job objects can be retrieved from JobsProgress.") + + for job in self: + if job == element: + return job + return None def get_job_list(self) -> Optional[str]: """ Returns the list of job IDs as comma-separated string. """ - with self._lock: - job_list = list(self._shared_data['jobs']) - - if not job_list: + self._load_state() + + if not len(self): return None - log.debug(f"JobsProgress | Jobs in progress: {job_list}") - return ",".join(str(job_dict['id']) for job_dict in job_list) + return ",".join(str(job) for job in self) def get_job_count(self) -> int: """ Returns the number of jobs. """ - with self._lock: - return len(self._shared_data['jobs']) - - def __iter__(self): - """Make the class iterable - returns Job objects""" - with self._lock: - # Create a snapshot of jobs to avoid holding lock during iteration - job_dicts = list(self._shared_data['jobs']) - - # Return an iterator of Job objects - return iter(Job(**job_dict) for job_dict in job_dicts) - - def __len__(self): - """Support len() operation""" - return self.get_job_count() - - def __contains__(self, element: Any) -> bool: - """Support 'in' operator""" - if isinstance(element, str): - search_id = element - elif isinstance(element, Job): - search_id = element.id - elif isinstance(element, dict): - search_id = element.get('id') - else: - return False - - with self._lock: - for job_dict in self._shared_data['jobs']: - if job_dict['id'] == search_id: - return True - return False + return len(self) diff --git a/tests/test_serverless/local_sim/.env_example b/tests/test_serverless/local_sim/.env_example new file mode 100644 index 00000000..b5a0f6a6 --- /dev/null +++ b/tests/test_serverless/local_sim/.env_example @@ -0,0 +1,10 @@ +RUNPOD_AI_API_KEY=XXX +RUNPOD_API_URL=http://localhost:8080/graphql +RUNPOD_DEBUG_LEVEL=INFO +RUNPOD_ENDPOINT_ID=test-endpoint +RUNPOD_PING_INTERVAL=1000 +RUNPOD_POD_ID=test-worker +RUNPOD_WEBHOOK_GET_JOB=http://localhost:8080/v2/test-endpoint/job-take/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090 +RUNPOD_WEBHOOK_PING=http://localhost:8080/v2/test-endpoint/ping/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090 +RUNPOD_WEBHOOK_POST_OUTPUT=http://localhost:8080/v2/test-endpoint/job-done/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090 +RUNPOD_WEBHOOK_JOB_STREAM=http://localhost:8080/v2/test-endpoint/job-stream/$RUNPOD_POD_ID?gpu=NVIDIA+GeForce+RTX+4090 diff --git a/tests/test_serverless/local_sim/Makefile b/tests/test_serverless/local_sim/Makefile new file mode 100644 index 00000000..72475872 --- /dev/null +++ b/tests/test_serverless/local_sim/Makefile @@ -0,0 +1,14 @@ +.PHONY: localhost worker all + +all: localhost worker + +localhost: + python localhost.py & + +worker: + python worker.py + +clean: + find . -type f -name ".runpod_jobs.pkl" -delete + find . -type f -name "*.pyc" -delete + find . -type d -name "__pycache__" -delete diff --git a/tests/test_serverless/local_sim/localhost.py b/tests/test_serverless/local_sim/localhost.py new file mode 100644 index 00000000..78833000 --- /dev/null +++ b/tests/test_serverless/local_sim/localhost.py @@ -0,0 +1,49 @@ +import random +import uvicorn +from fastapi import FastAPI, Request +from typing import Dict, Any +from faker import Faker + + +fake = Faker() +app = FastAPI() + + +def generate_fake_job() -> Dict[str, Any]: + delay = fake.random_digit_above_two() + return { + "id": fake.uuid4(), + "input": fake.sentence(), + "mock_delay": delay, + } + + +@app.get("/v2/{endpoint_id}/job-take/{worker_id}") +async def job_take(endpoint_id: str, worker_id: str): + """Accept GET request and return a random fake job as a dict""" + return generate_fake_job() + + +@app.get("/v2/{endpoint_id}/job-take-batch/{worker_id}") +async def job_take_batch(endpoint_id: str, worker_id: str, batch_size: int = 5): + """Accept GET request and return a random fake list of jobs""" + return [generate_fake_job() for _ in range(random.randint(1, batch_size))] + + +@app.post("/v2/{endpoint_id}/job-done/{worker_id}") +async def job_done(request: Request, endpoint_id: str, worker_id: str): + """Accept POST request and return the payload posted""" + payload = await request.json() + return payload + + +@app.get("/v2/{endpoint_id}/ping/{worker_id}") +async def ping_worker(endpoint_id: str, worker_id: str): + """Accept GET request and return ping response with extracted path values""" + return {"status": "pong"} + + +if __name__ == "__main__": + # Run with: python filename.py + # Or use: uvicorn filename:app --reload + uvicorn.run(app, host="0.0.0.0", port=8080) diff --git a/tests/test_serverless/local_sim/worker.py b/tests/test_serverless/local_sim/worker.py new file mode 100644 index 00000000..37df815e --- /dev/null +++ b/tests/test_serverless/local_sim/worker.py @@ -0,0 +1,43 @@ +import asyncio +import math +from faker import Faker +from runpod.serverless.modules.rp_scale import JobScaler, RunPodLogger +from runpod.serverless.modules.rp_ping import Heartbeat +from runpod.serverless.modules.worker_state import JobsProgress + + +fake = Faker() +logger = RunPodLogger() +heartbeat = Heartbeat() + +start = 3 + + +# sample concurrency modifier that loops +def collatz_conjecture(current_concurrency): + if current_concurrency == 1: + return start + + if current_concurrency % 2 == 0: + return math.floor(current_concurrency / 2) + else: + return current_concurrency * 3 + 1 + + +async def fake_handle_job(job): + await asyncio.sleep(job["mock_delay"]) # Simulates a blocking process + logger.info(f"Job handled ({job['mock_delay']}s): `{job['input']}`", job["id"]) + return job["input"] + + +job_scaler = JobScaler( + { + "concurrency_modifier": collatz_conjecture, + "handler": fake_handle_job, + } +) + +if __name__ == "__main__": + JobsProgress().clear() + heartbeat.start_ping() + job_scaler.start() diff --git a/tests/test_serverless/test_integration_worker_state.py b/tests/test_serverless/test_integration_worker_state.py new file mode 100644 index 00000000..11a7cd53 --- /dev/null +++ b/tests/test_serverless/test_integration_worker_state.py @@ -0,0 +1,249 @@ +""" +Integration test for worker state persistence between job_scaler and heartbeat. +This test mimics the runpod.serverless.worker.run_worker path. +""" + +import asyncio +import os +import tempfile +from unittest.mock import AsyncMock, MagicMock, patch +import pytest + +from runpod.serverless.modules.rp_ping import Heartbeat +from runpod.serverless.modules.rp_scale import JobScaler +from runpod.serverless.modules.worker_state import JobsProgress + + +class TestWorkerStateIntegration: + """Test the integration between JobScaler and Heartbeat for state persistence.""" + + def setup_method(self): + """Setup test environment.""" + # Clear any existing singleton instance + JobsProgress._instance = None + + # Create a temporary directory for state files + self.temp_dir = tempfile.mkdtemp() + + # Mock environment variables for testing + self.env_patcher = patch.dict(os.environ, { + 'RUNPOD_AI_API_KEY': 'test_key', + 'RUNPOD_POD_ID': 'test_pod_id', + 'RUNPOD_WEBHOOK_PING': 'http://test.com/ping', + 'RUNPOD_PING_INTERVAL': '5000' + }) + self.env_patcher.start() + + # Mock the state directory to use our temp directory + self.state_dir_patcher = patch.object(JobsProgress, '_STATE_DIR', self.temp_dir) + self.state_file_patcher = patch.object(JobsProgress, '_STATE_FILE', + os.path.join(self.temp_dir, '.runpod_jobs.pkl')) + self.state_dir_patcher.start() + self.state_file_patcher.start() + + def teardown_method(self): + """Cleanup test environment.""" + self.env_patcher.stop() + self.state_dir_patcher.stop() + self.state_file_patcher.stop() + + # Clean up temp directory + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + # Reset singleton + JobsProgress._instance = None + + def test_jobs_progress_singleton_persistence(self): + """Test that JobsProgress maintains singleton behavior across processes.""" + jobs1 = JobsProgress() + jobs2 = JobsProgress() + + assert jobs1 is jobs2 + + # Add a job and verify it's visible in both instances + jobs1.add("test_job_1") + assert "test_job_1" in jobs2.get_job_list() + + def test_file_based_state_persistence(self): + """Test that job state persists to file and can be loaded.""" + # Create initial instance and add jobs + jobs1 = JobsProgress() + jobs1.add("job_1") + jobs1.add("job_2") + + # Verify state is saved to file + assert os.path.exists(jobs1._STATE_FILE) + + # Reset singleton to simulate new process + JobsProgress._instance = None + + # Create new instance and verify state is loaded + jobs2 = JobsProgress() + job_list = jobs2.get_job_list() + + assert "job_1" in job_list + assert "job_2" in job_list + + def test_jobs_progress_add_and_remove_jobs(self): + """Test JobsProgress job tracking functionality.""" + # Reset JobsProgress singleton + JobsProgress._instance = None + + # Get JobsProgress instance + jobs_progress = JobsProgress() + + # Test adding jobs + test_jobs = [ + {"id": "job_1", "input": {"test": "data1"}}, + {"id": "job_2", "input": {"test": "data2"}} + ] + + # Add jobs + for job in test_jobs: + jobs_progress.add(job) + + # Verify initial state + assert len(jobs_progress) == 2 + job_list = jobs_progress.get_job_list() + assert job_list is not None + assert "job_1" in job_list + assert "job_2" in job_list + + # Test removing jobs + jobs_progress.remove(test_jobs[0]) + + # Verify removal + assert len(jobs_progress) == 1 + + # Get remaining job + remaining_list = jobs_progress.get_job_list() + assert remaining_list is not None + assert "job_1" not in remaining_list + assert "job_2" in remaining_list + + # Test clearing jobs + jobs_progress.clear() + + # Verify clearing + assert len(jobs_progress) == 0 + assert jobs_progress.get_job_list() is None + + def test_heartbeat_reads_job_progress(self): + """Test that Heartbeat can read jobs from JobsProgress.""" + # Add jobs to progress + jobs_progress = JobsProgress() + jobs_progress.add("job_1") + jobs_progress.add("job_2") + + # Create heartbeat instance + heartbeat = Heartbeat() + + # Mock the session.get method to capture the ping parameters + with patch.object(heartbeat, '_session') as mock_session: + mock_response = MagicMock() + mock_response.url = "http://test.com/ping" + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + # Send a ping + heartbeat._send_ping() + + # Verify the ping was sent with job_ids + mock_session.get.assert_called_once() + call_args = mock_session.get.call_args + + # Check that job_id parameter contains our jobs + params = call_args[1]['params'] + assert 'job_id' in params + job_ids = params['job_id'] + assert "job_1" in job_ids + assert "job_2" in job_ids + + def test_multiprocess_heartbeat_state_access(self): + """Test that heartbeat process can access job state from main process.""" + # Add jobs in main process + main_jobs = JobsProgress() + main_jobs.add("main_job_1") + main_jobs.add("main_job_2") + + # Simulate what happens in the heartbeat process + # The process_loop creates a new Heartbeat instance + heartbeat = Heartbeat() + + # Mock the session to capture ping data + with patch.object(heartbeat, '_session') as mock_session: + mock_response = MagicMock() + mock_response.url = "http://test.com/ping" + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + # Send ping - this should read from the persisted state + heartbeat._send_ping() + + # Verify the ping includes jobs from main process + call_args = mock_session.get.call_args + params = call_args[1]['params'] + job_ids = params['job_id'] + + assert "main_job_1" in job_ids + assert "main_job_2" in job_ids + + @pytest.mark.asyncio + async def test_end_to_end_job_lifecycle(self): + """Test complete job lifecycle: add -> process -> remove -> ping.""" + # Mock job data + test_jobs = [{"id": "lifecycle_job", "input": {"test": "data"}}] + + async def mock_jobs_fetcher(session, count): + return test_jobs[:count] + + async def mock_job_handler(session, config, job): + await asyncio.sleep(0.1) # Simulate processing + + config = { + "handler": lambda x: x, + "jobs_fetcher": mock_jobs_fetcher, + "jobs_handler": mock_job_handler, + "jobs_fetcher_timeout": 1 + } + + # Create instances + job_scaler = JobScaler(config) + heartbeat = Heartbeat() + jobs_progress = JobsProgress() + + # Track ping calls + ping_calls = [] + + def capture_ping(*args, **kwargs): + job_ids = kwargs.get('params', {}).get('job_id', '') + ping_calls.append(job_ids) + mock_response = MagicMock() + mock_response.url = "http://test.com/ping" + mock_response.status_code = 200 + return mock_response + + with patch.object(heartbeat, '_session') as mock_session: + mock_session.get.side_effect = capture_ping + + # Start job processing + session = AsyncMock() + + # Add job + await job_scaler.jobs_queue.put(test_jobs[0]) + jobs_progress.add(test_jobs[0]["id"]) + + # Send ping with job active + heartbeat._send_ping() + + # Process job (this should remove it from progress) + await job_scaler.handle_job(session, test_jobs[0]) + + # Send ping after job completion + heartbeat._send_ping() + + # Verify ping behavior + assert len(ping_calls) == 2 + assert "lifecycle_job" in ping_calls[0] # Job was active + assert ping_calls[1] is None # Job completed, no active jobs \ No newline at end of file diff --git a/tests/test_serverless/test_modules/run_scale.py b/tests/test_serverless/test_modules/run_scale.py deleted file mode 100644 index 1310e463..00000000 --- a/tests/test_serverless/test_modules/run_scale.py +++ /dev/null @@ -1,63 +0,0 @@ -import asyncio -import math -from faker import Faker -from typing import Any, Dict, Optional, List - - -def main(start=1): - """Main function to run the job scaler""" - from runpod.serverless.modules.rp_scale import JobScaler, RunPodLogger - - fake = Faker() - log = RunPodLogger() - - # sample concurrency modifier that loops - def collatz_conjecture(current_concurrency): - if current_concurrency == 1: - return start - - if current_concurrency % 2 == 0: - return math.floor(current_concurrency / 2) - else: - return current_concurrency * 3 + 1 - - def fake_job(): - # Change this number to your desired delay - delay = fake.random_digit_above_two() - return { - "id": fake.uuid4(), - "input": fake.sentence(), - "mock_delay": delay, - } - - async def fake_get_job(session, num_jobs: int = 1) -> Optional[List[Dict[str, Any]]]: - # Change this number to your desired delay - delay = fake.random_digit_above_two() - 1 - - log.info(f"... artificial delay ({delay}s)") - await asyncio.sleep(delay) # Simulates a blocking process - - jobs = [fake_job() for _ in range(num_jobs)] - log.info(f"... Generated # jobs: {len(jobs)}") - return jobs - - async def fake_handle_job(session, config, job) -> dict: - await asyncio.sleep(job["mock_delay"]) # Simulates a blocking process - log.info(f"... Job handled ({job['mock_delay']}s)", job["id"]) - - job_scaler = JobScaler( - { - "concurrency_modifier": collatz_conjecture, - # "jobs_fetcher_timeout": 5, - "jobs_fetcher": fake_get_job, - "jobs_handler": fake_handle_job, - } - ) - job_scaler.start() - - -if __name__ == '__main__': - # This is required for multiprocessing on macOS/Windows - import multiprocessing - multiprocessing.set_start_method('spawn', force=True) - main(start=10) diff --git a/tests/test_serverless/test_modules/test_ping.py b/tests/test_serverless/test_modules/test_ping.py index 3a5447d3..a10b43df 100644 --- a/tests/test_serverless/test_modules/test_ping.py +++ b/tests/test_serverless/test_modules/test_ping.py @@ -52,8 +52,9 @@ def mock_session(self): @pytest.fixture def mock_jobs(self): """Mock the JobsProgress instance""" - with patch("runpod.serverless.modules.rp_ping.jobs") as mock: - mock.get_job_list.return_value = "job1,job2,job3" + with patch("runpod.serverless.modules.rp_ping.JobsProgress") as mock: + instance = mock.return_value + instance.get_job_list.return_value = "job1,job2,job3" yield mock @pytest.fixture @@ -242,7 +243,7 @@ def test_send_ping_no_jobs(self, mock_env, mock_worker_id, mock_session, mock_lo heartbeat = Heartbeat() # Mock no jobs - with patch("runpod.serverless.modules.rp_ping.jobs.get_job_list", return_value=None): + with patch("runpod.serverless.modules.rp_ping.JobsProgress.get_job_list", return_value=None): mock_response = MagicMock() mock_response.url = "https://test.com/ping/test_worker_123" mock_response.status_code = 200 diff --git a/tests/test_serverless/test_modules/test_state.py b/tests/test_serverless/test_modules/test_state.py index f3bb3372..94772bde 100644 --- a/tests/test_serverless/test_modules/test_state.py +++ b/tests/test_serverless/test_modules/test_state.py @@ -145,7 +145,7 @@ async def test_get_job(self): assert job1 in self.jobs async def test_get_job_list(self): - assert self.jobs.get_job_list() is None + assert not self.jobs.get_job_list() job1 = {"id": "123"} self.jobs.add(job1) @@ -158,4 +158,68 @@ async def test_get_job_list(self): async def test_get_job_count(self): # test job count contention when adding and removing jobs in parallel - pass \ No newline at end of file + pass + + async def test_state_persistence(self): + """Test state persistence across multiple JobsProgress instances""" + # First instance: add some jobs + jobs1 = JobsProgress() + jobs1.clear() # Ensure clean state + + job1 = {"id": "test_persistent_1"} + job2 = {"id": "test_persistent_2"} + + jobs1.add(job1) + jobs1.add(job2) + + # Reset singleton to simulate process restart + JobsProgress._instance = None + jobs2 = JobsProgress() + + # Debug: check jobs2 right after creation + print(f"DEBUG: jobs2 length right after creation: {len(jobs2)}") + print(f"DEBUG: jobs2 contents right after creation: {list(jobs2)}") + + # Check that jobs were persisted + assert jobs2.get_job_count() == 2, "Jobs should be persisted across instances" + + # Verify specific jobs are present + assert jobs2.get("test_persistent_1") is not None, "First job should be retrievable" + assert jobs2.get("test_persistent_2") is not None, "Second job should be retrievable" + + async def test_state_persistence_empty(self): + """Test state persistence when no jobs are present""" + # Clear any existing state + jobs1 = JobsProgress() + jobs1.clear() + + # Reset singleton to simulate process restart + JobsProgress._instance = None + jobs2 = JobsProgress() + + # Check that no jobs are present + assert jobs2.get_job_count() == 0, "No jobs should be present after clear" + assert jobs2.get_job_list() is None, "Job list should be None when empty" + + async def test_file_persistence_after_clear(self): + """Verify that clearing the jobs results in an empty persistent state""" + # Add some jobs + jobs1 = JobsProgress() + jobs1.clear() # Ensure clean state + + job1 = {"id": "to_be_cleared_1"} + job2 = {"id": "to_be_cleared_2"} + + jobs1.add(job1) + jobs1.add(job2) + + # Clear the jobs + jobs1.clear() + + # Reset singleton to simulate process restart + JobsProgress._instance = None + jobs2 = JobsProgress() + + # Verify that no jobs remain + assert jobs2.get_job_count() == 0, "Jobs should be cleared in persistent state" + assert jobs2.get_job_list() is None, "Job list should be None after clear" \ No newline at end of file diff --git a/tests/test_serverless/test_utils/test_download.py b/tests/test_serverless/test_utils/test_download.py index f0ab6d98..a4085a20 100644 --- a/tests/test_serverless/test_utils/test_download.py +++ b/tests/test_serverless/test_utils/test_download.py @@ -87,7 +87,7 @@ def test_download_files_from_urls(self, mock_open_file, mock_get, mock_makedirs) """ Tests download_files_from_urls """ - urls = ["https://example.com/picture.jpg", "https://example.com/file_without_extension"] + urls = ("https://example.com/picture.jpg", "https://example.com/file_without_extension",) downloaded_files = download_files_from_urls( JOB_ID, urls,