From cab13c47568bb94ef23a59e6a8b16cd189108d34 Mon Sep 17 00:00:00 2001 From: gustavogaldinoo Date: Thu, 25 Jul 2024 16:56:10 +0000 Subject: [PATCH] Adding code for uexperiments ssing pubsub queues --- experiment/measurer/measure_manager.py | 400 ++++++++++++++------ experiment/measurer/measure_worker.py | 78 +++- experiment/measurer/test_measure_manager.py | 62 ++- requirements.txt | 1 + service/experiment-config.yaml | 4 +- service/gcbrun_experiment.py | 1 + 6 files changed, 415 insertions(+), 131 deletions(-) diff --git a/experiment/measurer/measure_manager.py b/experiment/measurer/measure_manager.py index 288148401..b1a974247 100644 --- a/experiment/measurer/measure_manager.py +++ b/experiment/measurer/measure_manager.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=too-many-lines """Module for measuring snapshots from trial runners.""" import collections import gc import glob import multiprocessing +import multiprocessing.queues import json import os import pathlib @@ -25,13 +27,13 @@ import tempfile import tarfile import time -from typing import List +from typing import List, Tuple import queue import psutil from sqlalchemy import func from sqlalchemy import orm - +from google.cloud import pubsub_v1 from common import benchmark_utils from common import experiment_utils from common import experiment_path as exp_path @@ -58,6 +60,7 @@ SNAPSHOT_QUEUE_GET_TIMEOUT = 1 SNAPSHOTS_BATCH_SAVE_SIZE = 100 MEASUREMENT_LOOP_WAIT = 10 +GET_FROM_PUB_SUB_QUEUE_TIMEOUT = 5 def exists_in_experiment_filestore(path: pathlib.Path) -> bool: @@ -76,8 +79,16 @@ def measure_main(experiment_config): max_total_time = experiment_config['max_total_time'] measurers_cpus = experiment_config['measurers_cpus'] region_coverage = experiment_config['region_coverage'] - measure_manager_loop(experiment, max_total_time, measurers_cpus, - region_coverage) + cloud_project = experiment_config['cloud_project'] + local_experiment = experiment_config['local_experiment'] + + measure_manager = LocalMeasureManager(experiment, region_coverage, + measurers_cpus) + if not local_experiment: + measure_manager = GoogleCloudMeasureManager(experiment, cloud_project, + region_coverage, + measurers_cpus) + measure_manager.measure_manager_loop(max_total_time) # Clean up resources. gc.collect() @@ -672,72 +683,6 @@ def initialize_logs(): }) -def consume_snapshots_from_response_queue( - response_queue, queued_snapshots) -> List[models.Snapshot]: - """Consume response_queue, allows retry objects to retried, and - return all measured snapshots in a list.""" - measured_snapshots = [] - while True: - try: - response_object = response_queue.get_nowait() - if isinstance(response_object, measurer_datatypes.RetryRequest): - # Need to retry measurement task, will remove identifier from - # the set so task can be retried in next loop iteration. - snapshot_identifier = (response_object.trial_id, - response_object.cycle) - queued_snapshots.remove(snapshot_identifier) - logger.info('Reescheduling task for trial %s and cycle %s', - response_object.trial_id, response_object.cycle) - elif isinstance(response_object, models.Snapshot): - measured_snapshots.append(response_object) - else: - logger.error('Type of response object not mapped! %s', - type(response_object)) - except queue.Empty: - break - return measured_snapshots - - -def measure_manager_inner_loop(experiment: str, max_cycle: int, request_queue, - response_queue, queued_snapshots): - """Reads from database to determine which snapshots needs measuring. Write - measurements tasks to request queue, get results from response queue, and - write measured snapshots to database. Returns False if there's no more - snapshots left to be measured""" - initialize_logs() - # Read database to determine which snapshots needs measuring. - unmeasured_snapshots = get_unmeasured_snapshots(experiment, max_cycle) - logger.info('Retrieved %d unmeasured snapshots from measure manager', - len(unmeasured_snapshots)) - # When there are no more snapshots left to be measured, should break loop. - if not unmeasured_snapshots: - return False - - # Write measurements requests to request queue - for unmeasured_snapshot in unmeasured_snapshots: - # No need to insert fuzzer and benchmark info here as it's redundant - # (Can be retrieved through trial_id). - unmeasured_snapshot_identifier = (unmeasured_snapshot.trial_id, - unmeasured_snapshot.cycle) - # Checking if snapshot already was queued so workers will not repeat - # measurement for same snapshot - if unmeasured_snapshot_identifier not in queued_snapshots: - request_queue.put(unmeasured_snapshot) - queued_snapshots.add(unmeasured_snapshot_identifier) - - # Read results from response queue. - measured_snapshots = consume_snapshots_from_response_queue( - response_queue, queued_snapshots) - logger.info('Retrieved %d measured snapshots from response queue', - len(measured_snapshots)) - - # Save measured snapshots to database. - if measured_snapshots: - db_utils.add_all(measured_snapshots) - - return True - - def get_pool_args(measurers_cpus, runners_cpus): """Return pool args based on measurer cpus and runner cpus arguments.""" if measurers_cpus is None or runners_cpus is None: @@ -755,51 +700,294 @@ def get_pool_args(measurers_cpus, runners_cpus): return (measurers_cpus, _process_init, (cores_queue,)) -def measure_manager_loop(experiment: str, - max_total_time: int, - measurers_cpus=None, - region_coverage=False): # pylint: disable=too-many-locals - """Measure manager loop. Creates request and response queues, request - measurements tasks from workers, retrieve measurement results from response - queue and writes measured snapshots in database.""" - logger.info('Starting measure manager loop.') - if not measurers_cpus: - measurers_cpus = multiprocessing.cpu_count() - logger.info('Number of measurer CPUs not passed as argument. using %d', - measurers_cpus) - with multiprocessing.Pool() as pool, multiprocessing.Manager() as manager: - logger.info('Setting up coverage binaries') - set_up_coverage_binaries(pool, experiment) - request_queue = manager.Queue() - response_queue = manager.Queue() - - config = { - 'request_queue': request_queue, - 'response_queue': response_queue, - 'region_coverage': region_coverage, - } - local_measure_worker = measure_worker.LocalMeasureWorker(config) - - # Since each worker is going to be in an infinite loop, we dont need - # result return. Workers' life scope will end automatically when there - # are no more snapshots left to measure. - logger.info('Starting measure worker loop for %d workers', - measurers_cpus) - for _ in range(measurers_cpus): - _result = pool.apply_async(local_measure_worker.measure_worker_loop) +class BaseMeasureManager: + """Base class for measure manager. Encapsulates core methods that will be + implemented for Local and Google Cloud measure managers.""" + + def __init__(self, experiment: str, region_coverage=False): + self.region_coverage = region_coverage + self.experiment = experiment + + def initialize_queues(self): + """Initialize and return request and response queues, respectively.""" + raise NotImplementedError + + def start_workers(self, request_queue, response_queue): + """Initialize measure workers.""" + raise NotImplementedError + + def put_task_in_request_queue(self, task, request_queue): + """Put task in request queue. The request queue can be a pub sub queue + or a multiprocessing in-memory queue, depending on the + implementation.""" + raise NotImplementedError + + def get_result_from_response_queue(self, response_queue): + """Get result from request queue. Can be a pub sub queue or a + multiprocessing in-memory queue, depending on the implementation.""" + raise NotImplementedError + + def consume_snapshots_from_response_queue( + self, response_queue, queued_snapshots) -> List[models.Snapshot]: + """Consume response_queue, allows retry objects to retried, and + return all measured snapshots in a list.""" + measured_snapshots = [] + while True: + try: + response_object = self.get_result_from_response_queue( + response_queue) + if isinstance(response_object, measurer_datatypes.RetryRequest): + # Need to retry measurement task, will remove identifier + # from the set so task can be retried in next loop + # iteration. + snapshot_identifier = (response_object.trial_id, + response_object.cycle) + queued_snapshots.remove(snapshot_identifier) + logger.info('Reescheduling task for trial %s and cycle %s', + response_object.trial_id, response_object.cycle) + elif isinstance(response_object, models.Snapshot): + measured_snapshots.append(response_object) + else: + logger.error('Type of response object not mapped! %s', + type(response_object)) + except queue.Empty: + break + return measured_snapshots + + def measure_manager_inner_loop(self, max_cycle: int, request_queue, + response_queue, queued_snapshots): + """Reads from database to determine which snapshots needs measuring. + Write measurements tasks to request queue, get results from response + queue, and write measured snapshots to database. Returns False if + there's no more snapshots left to be measured.""" + initialize_logs() + # Read database to determine which snapshots needs measuring. + unmeasured_snapshots = get_unmeasured_snapshots(self.experiment, + max_cycle) + logger.info('Retrieved %d unmeasured snapshots from measure manager.', + len(unmeasured_snapshots)) + # When there are no more snapshots left to be measured, should break + # loop. + if not unmeasured_snapshots: + return False + + # Write measurements requests to request queue + for unmeasured_snapshot in unmeasured_snapshots: + # No need to insert fuzzer and benchmark info here as it's redundant + # (Can be retrieved through trial_id). + unmeasured_snapshot_identifier = (unmeasured_snapshot.trial_id, + unmeasured_snapshot.cycle) + # Checking if snapshot already was queued so workers will not repeat + # measurement for same snapshot + if unmeasured_snapshot_identifier not in queued_snapshots: + self.put_task_in_request_queue(unmeasured_snapshot, + request_queue) + queued_snapshots.add(unmeasured_snapshot_identifier) + + # Read results from response queue. + measured_snapshots = self.consume_snapshots_from_response_queue( + response_queue, queued_snapshots) + logger.info('Retrieved %d measured snapshots from response queue.', + len(measured_snapshots)) + + # Save measured snapshots to database. + if measured_snapshots: + db_utils.add_all(measured_snapshots) + return True + + def measure_manager_loop(self, max_total_time: int): + """Measure manager loop. Creates request and response queues, request + measurements tasks from workers, retrieve measurement results from + response queue and writes measured snapshots in database.""" + logger.info('Starting measure manager loop.') + with multiprocessing.Pool() as pool: + set_up_coverage_binaries(pool, self.experiment) + (request_queue, response_queue) = self.initialize_queues() + self.start_workers(request_queue, response_queue) max_cycle = _time_to_cycle(max_total_time) queued_snapshots = set() - while not scheduler.all_trials_ended(experiment): - continue_inner_loop = measure_manager_inner_loop( - experiment, max_cycle, request_queue, response_queue, - queued_snapshots) + while not scheduler.all_trials_ended(self.experiment): + continue_inner_loop = self.measure_manager_inner_loop( + max_cycle, request_queue, response_queue, queued_snapshots) if not continue_inner_loop: break time.sleep(MEASUREMENT_LOOP_WAIT) logger.info('All trials ended. Ending measure manager loop') +class LocalMeasureManager(BaseMeasureManager): + """Class that holds implementations of core methods for running a measure + worker locally.""" + + def __init__(self, + experiment: str, + region_coverage=False, + measurers_cpus=None): + super().__init__(experiment, region_coverage) + self.measurers_cpus = measurers_cpus + + def initialize_queues( + self + ) -> Tuple[multiprocessing.queues.Queue, multiprocessing.queues.Queue]: + return (multiprocessing.Queue(), multiprocessing.Queue()) + + def start_workers(self, request_queue: multiprocessing.queues.Queue, + response_queue: multiprocessing.queues.Queue): + if not self.measurers_cpus: + self.measurers_cpus = multiprocessing.cpu_count() + logger.info( + 'Number of measurer CPUs not passed as argument. using %d', + self.measurers_cpus) + with multiprocessing.Pool(processes=self.measurers_cpus) as pool: + config = { + 'request_queue': request_queue, + 'response_queue': response_queue, + 'region_coverage': self.region_coverage, + } + local_measure_worker = measure_worker.LocalMeasureWorker(config) + + # Since each worker is going to be in an infinite loop, we dont need + # result return. Workers' life scope will end automatically when + # there are no more snapshots left to measure. + logger.info('Starting measure worker loop for %d workers', + self.measurers_cpus) + for _ in range(self.measurers_cpus): + _result = pool.apply_async( + local_measure_worker.measure_worker_loop) + + def get_result_from_response_queue( + self, response_queue: multiprocessing.queues.Queue): + return response_queue.get_nowait() + + def put_task_in_request_queue( + self, task: measurer_datatypes.SnapshotMeasureRequest, + request_queue: multiprocessing.queues.Queue): + request_queue.put_nowait(task) + + +class GoogleCloudMeasureManager(BaseMeasureManager): # pylint: disable=too-many-instance-attributes + """Measurer manager implementation that subscribe and publishes from a + Google Cloud Pub/Sub Queue, instead of multiprocessing queue.""" + + def __init__(self, + experiment: str, + cloud_project: str, + region_coverage=False, + measurers_cpus=None): + super().__init__(experiment, region_coverage) + self.project_id = cloud_project + self.request_queue_topic_id = f'request-queue-topic-{self.experiment}' + self.response_queue_topic_id = f'response-queue-topic-{self.experiment}' + self.response_queue_subscription_id = f"""response-queue-subscription- + {self.experiment}""" + self.subscriber_client = pubsub_v1.SubscriberClient() + self.publisher_client = pubsub_v1.PublisherClient() + self.subscription_path = self.subscriber_client.subscription_path( + self.project_id, self.response_queue_subscription_id) + self.measurers_cpus = measurers_cpus + + def initialize_queues(self) -> Tuple[str, str]: + request_queue_topic_path = self.publisher_client.topic_path( + self.project_id, self.request_queue_topic_id) + request_queue_topic = self.publisher_client.create_topic( + request={'name': request_queue_topic_path}) + response_queue_topic_path = self.subscriber_client.topic_path( + self.project_id, self.response_queue_topic_id) + response_queue_topic = self.publisher_client.create_topic( + request={'name': response_queue_topic_path}) + return request_queue_topic.name, response_queue_topic.name + + def _create_response_queue_subscription(self): + """Creates a new Pub/Sub subscription for the response queue.""" + topic_path = self.response_queue_topic_id + subscription = self.subscriber_client.create_subscription(request={ + 'name': self.subscription_path, + 'topic': topic_path + }) + logger.info(f'Subscription {subscription.name} created successfully.') + + return self.subscription_path + + def start_workers(self, request_queue, response_queue): + self._create_response_queue_subscription() + if not self.measurers_cpus: + self.measurers_cpus = multiprocessing.cpu_count() + logger.info( + 'Number of measurer CPUs not passed as argument. using %d', + self.measurers_cpus) + with multiprocessing.Pool(processes=self.measurers_cpus) as pool: + config = { + 'request_queue_topic_id': self.request_queue_topic_id, + 'response_queue_topic_id': self.response_queue_topic_id, + 'region_coverage': self.region_coverage, + 'project_id': self.project_id, + 'experiment': self.experiment, + } + google_cloud_worker = measure_worker.GoogleCloudMeasureWorker( + config) + + # Since each worker is going to be in an infinite loop, we dont need + # result return. Workers' life scope will end automatically when + # there are no more snapshots left to measure. + logger.info('Starting measure worker loop for %d workers', + self.measurers_cpus) + for _ in range(self.measurers_cpus): + _result = pool.apply_async( + google_cloud_worker.measure_worker_loop) + + def get_result_from_response_queue(self, response_queue: str): + + response = self.subscriber_client.pull( + request={ + 'subscription': self.subscription_path, + 'max_messages': 1 + }, + timeout=GET_FROM_PUB_SUB_QUEUE_TIMEOUT) + + if response.received_messages: + message = response.received_messages[0] + ack_ids = [message.ack_id] + + # Acknowledge the received message to remove it from the queue. + self.subscriber_client.acknowledge(request={ + 'subscription': self.subscription_path, + 'ack_ids': ack_ids + }) + + return message.message.data + + return None + + def _task_to_bytes( + self, task: measurer_datatypes.SnapshotMeasureRequest) -> bytes: + """Takes a snapshot measure request task and transform it into bytes, so + it can be published in a pub sub queue.""" + task_as_dict = task.__dict__ + return json.dumps(task_as_dict).encode('utf-8') + + def put_task_in_request_queue( + self, task: measurer_datatypes.SnapshotMeasureRequest, + request_queue: str): + topic_path = self.publisher_client.topic_path(self.project_id, + request_queue) + try: + # Convert message data to bytes + message_as_bytes = self._task_to_bytes(task) + # Build the Pub/Sub message object + future = self.publisher_client.publish(topic_path, + message_as_bytes, + ordering_key=task.cycle) + message_id = future.result() # Get the published message ID + logger.info( + 'Manager successfully published task with message ID %s to %s.', + message_id, topic_path) + except Exception as error: # pylint: disable=broad-except + logger.error( + 'An error occurred when publishing task to request queue: %s.', + error) + + def main(): """Measure the experiment.""" initialize_logs() diff --git a/experiment/measurer/measure_worker.py b/experiment/measurer/measure_worker.py index cfa033d06..50f9c7e31 100644 --- a/experiment/measurer/measure_worker.py +++ b/experiment/measurer/measure_worker.py @@ -13,13 +13,16 @@ # limitations under the License. """Module for measurer workers logic.""" import time +import json from typing import Dict, Optional +from google.cloud import pubsub_v1 from common import logs from database.models import Snapshot import experiment.measurer.datatypes as measurer_datatypes from experiment.measurer import measure_manager MEASUREMENT_TIMEOUT = 1 +GET_FROM_PUB_SUB_QUEUE_TIMEOUT = 3 logger = logs.Logger() # pylint: disable=invalid-name @@ -28,8 +31,6 @@ class BaseMeasureWorker: implemented for Local and Google Cloud measure workers.""" def __init__(self, config: Dict): - self.request_queue = config['request_queue'] - self.response_queue = config['response_queue'] self.region_coverage = config['region_coverage'] def get_task_from_request_queue(self): @@ -68,6 +69,11 @@ class LocalMeasureWorker(BaseMeasureWorker): """Class that holds implementations of core methods for running a measure worker locally.""" + def __init__(self, config: Dict): + self.request_queue = config['request_queue'] + self.response_queue = config['response_queue'] + super().__init__(config) + def get_task_from_request_queue( self) -> measurer_datatypes.SnapshotMeasureRequest: """Get item from request multiprocessing queue, block if necessary until @@ -86,3 +92,71 @@ def put_result_in_response_queue( request.fuzzer, request.benchmark, request.trial_id, request.cycle) self.response_queue.put(retry_request) + + +class GoogleCloudMeasureWorker(BaseMeasureWorker): # pylint: disable=too-many-instance-attributes + """Worker that consumes from a Google Cloud Pub/Sub Queue, instead of a + multiprocessing queue""" + + def __init__(self, config: Dict): + super().__init__(config) + self.request_queue_topic_id = config['request_queue_topic_id'] + self.response_queue_topic_id = config['response_queue_topic_id'] + self.project_id = config['project_id'] + self.experiment = config['experiment'] + self.request_queue_subscription = f"""request-queue-subscription- + {self.experiment}""" + self.publisher_client = pubsub_v1.PublisherClient() + self.subscriber_client = pubsub_v1.SubscriberClient() + self.subscription_path = self.subscriber_client.subscription_path( + self.project_id, self.request_queue_subscription) + self._create_request_queue_subscription() + + def _create_request_queue_subscription(self): + """Creates a new Pub/Sub subscription for the request queue.""" + topic_path = self.response_queue_topic_id + subscription = self.subscriber_client.create_subscription(request={ + 'name': self.subscription_path, + 'topic': topic_path + }) + logger.info(f'Subscription {subscription.name} created successfully.') + + return self.subscription_path + + def get_task_from_request_queue( + self) -> measurer_datatypes.SnapshotMeasureRequest: + while True: + response = self.subscriber_client.pull( + request={ + 'subscription': self.subscription_path, + 'max_messages': 1 + }, + timeout=GET_FROM_PUB_SUB_QUEUE_TIMEOUT) + + if response.received_messages: + message = response.received_messages[0] + ack_ids = [message.ack_id] + + # Acknowledge the received message to remove it from the queue. + self.subscriber_client.acknowledge(request={ + 'subscription': self.subscription_path, + 'ack_ids': ack_ids + }) + + return message.message.data + + def put_result_in_response_queue(self, measured_snapshot, request): + topic_path = self.publisher_client.topic_path( + self.project_id, self.response_queue_topic_id) + if measured_snapshot: + logger.info('Put measured snapshot in response_queue') + measured_snapshot_encoded = json.dumps( + measured_snapshot.__dict__).encode('utf-8') + self.publisher_client.publish(topic_path, measured_snapshot_encoded) + else: + retry_request = measurer_datatypes.SnapshotMeasureRequest( + request.fuzzer, request.benchmark, request.trial_id, + request.cycle) + retry_request_encoded = json.dumps( + retry_request.__dict__).encode('utf-8') + self.publisher_client.publish(topic_path, retry_request_encoded) diff --git a/experiment/measurer/test_measure_manager.py b/experiment/measurer/test_measure_manager.py index 7b6521869..0e7d220a4 100644 --- a/experiment/measurer/test_measure_manager.py +++ b/experiment/measurer/test_measure_manager.py @@ -411,19 +411,27 @@ def test_path_exists_in_experiment_filestore(mocked_execute, environ): expect_zero=False) -def test_consume_unmapped_type_from_response_queue(): +@pytest.fixture +def local_measure_manager(): + """Fixture for instantiating a local measure manager object""" + local_measure_manager = measure_manager.LocalMeasureManager( + 'experiment', False, None) + return local_measure_manager + + +def test_consume_unmapped_type_from_response_queue(local_measure_manager): """Tests the scenario where an unmapped type is retrieved from the response queue. This scenario is not expected to happen, so in this case no snapshots are returned.""" # Use normal queue here as multiprocessing queue gives flaky tests. response_queue = queue.Queue() response_queue.put('unexpected string') - snapshots = measure_manager.consume_snapshots_from_response_queue( + snapshots = local_measure_manager.consume_snapshots_from_response_queue( response_queue, set()) assert not snapshots -def test_consume_retry_type_from_response_queue(): +def test_consume_retry_type_from_response_queue(local_measure_manager): """Tests the scenario where a retry object is retrieved from the response queue. In this scenario, we want to remove the snapshot identifier from the queued_snapshots set, as this allows the measurement task to be @@ -435,13 +443,13 @@ def test_consume_retry_type_from_response_queue(): snapshot_identifier = (TRIAL_NUM, CYCLE) response_queue.put(retry_request_object) queued_snapshots_set = set([snapshot_identifier]) - snapshots = measure_manager.consume_snapshots_from_response_queue( + snapshots = local_measure_manager.consume_snapshots_from_response_queue( response_queue, queued_snapshots_set) assert not snapshots assert len(queued_snapshots_set) == 0 -def test_consume_snapshot_type_from_response_queue(): +def test_consume_snapshot_type_from_response_queue(local_measure_manager): """Tests the scenario where a measured snapshot is retrieved from the response queue. In this scenario, we want to return the snapshot in the function.""" @@ -452,31 +460,32 @@ def test_consume_snapshot_type_from_response_queue(): measured_snapshot = models.Snapshot(trial_id=TRIAL_NUM) response_queue.put(measured_snapshot) assert response_queue.qsize() == 1 - snapshots = measure_manager.consume_snapshots_from_response_queue( + snapshots = local_measure_manager.consume_snapshots_from_response_queue( response_queue, queued_snapshots_set) assert len(snapshots) == 1 @mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') def test_measure_manager_inner_loop_break_condition( - mocked_get_unmeasured_snapshots): + mocked_get_unmeasured_snapshots, local_measure_manager): """Tests that the measure manager inner loop returns False when there's no more snapshots left to be measured.""" # Empty list means no more snapshots left to be measured. mocked_get_unmeasured_snapshots.return_value = [] request_queue = queue.Queue() response_queue = queue.Queue() - continue_inner_loop = measure_manager.measure_manager_inner_loop( - 'experiment', 1, request_queue, response_queue, set()) + continue_inner_loop = local_measure_manager.measure_manager_inner_loop( + 1, request_queue, response_queue, set()) assert not continue_inner_loop @mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') @mock.patch( - 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue') + 'experiment.measurer.measure_manager.BaseMeasureManager.consume_snapshots_from_response_queue' # pylint: disable=line-too-long +) def test_measure_manager_inner_loop_writes_to_request_queue( mocked_consume_snapshots_from_response_queue, - mocked_get_unmeasured_snapshots): + mocked_get_unmeasured_snapshots, local_measure_manager): """Tests that the measure manager inner loop is writing measurement tasks to request queue.""" mocked_get_unmeasured_snapshots.return_value = [ @@ -485,18 +494,19 @@ def test_measure_manager_inner_loop_writes_to_request_queue( mocked_consume_snapshots_from_response_queue.return_value = [] request_queue = queue.Queue() response_queue = queue.Queue() - measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, - response_queue, set()) + local_measure_manager.measure_manager_inner_loop(1, request_queue, + response_queue, set()) assert request_queue.qsize() == 1 @mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') @mock.patch( - 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue') + 'experiment.measurer.measure_manager.BaseMeasureManager.consume_snapshots_from_response_queue' # pylint: disable=line-too-long +) @mock.patch('database.utils.add_all') def test_measure_manager_inner_loop_dont_write_to_db( mocked_add_all, mocked_consume_snapshots_from_response_queue, - mocked_get_unmeasured_snapshots): + mocked_get_unmeasured_snapshots, local_measure_manager): """Tests that the measure manager inner loop does not call add_all to write to the database, when there are no measured snapshots to be written.""" mocked_get_unmeasured_snapshots.return_value = [ @@ -505,18 +515,19 @@ def test_measure_manager_inner_loop_dont_write_to_db( request_queue = queue.Queue() response_queue = queue.Queue() mocked_consume_snapshots_from_response_queue.return_value = [] - measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, - response_queue, set()) + local_measure_manager.measure_manager_inner_loop(1, request_queue, + response_queue, set()) mocked_add_all.not_called() @mock.patch('experiment.measurer.measure_manager.get_unmeasured_snapshots') @mock.patch( - 'experiment.measurer.measure_manager.consume_snapshots_from_response_queue') + 'experiment.measurer.measure_manager.BaseMeasureManager.consume_snapshots_from_response_queue' # pylint: disable=line-too-long +) @mock.patch('database.utils.add_all') def test_measure_manager_inner_loop_writes_to_db( mocked_add_all, mocked_consume_snapshots_from_response_queue, - mocked_get_unmeasured_snapshots): + mocked_get_unmeasured_snapshots, local_measure_manager): """Tests that the measure manager inner loop calls add_all to write to the database, when there are measured snapshots to be written.""" mocked_get_unmeasured_snapshots.return_value = [ @@ -526,6 +537,15 @@ def test_measure_manager_inner_loop_writes_to_db( response_queue = queue.Queue() snapshot_model = models.Snapshot(trial_id=1) mocked_consume_snapshots_from_response_queue.return_value = [snapshot_model] - measure_manager.measure_manager_inner_loop('experiment', 1, request_queue, - response_queue, set()) + local_measure_manager.measure_manager_inner_loop(1, request_queue, + response_queue, set()) mocked_add_all.assert_called_with([snapshot_model]) + + +def test_google_cloud_measure_manager_init_clients(): + """Tests that when we instantiante a GoogleCloudMeasureManager object, its + publisher and subscriber clients are initialized""" + google_cloud_measure_manager = measure_manager.GoogleCloudMeasureManager( + 'experiment', False, None) + assert google_cloud_measure_manager.publisher_client + assert google_cloud_measure_manager.subscriber_client diff --git a/requirements.txt b/requirements.txt index 56b835357..a6280ea19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ google-auth==2.12.0 google-cloud-error-reporting==1.6.3 google-cloud-logging==3.1.2 google-cloud-secret-manager==2.12.6 +google-cloud-pubsub==2.19.2 clusterfuzz==2.6.0 Jinja2==3.1.2 numpy==1.23.4 diff --git a/service/experiment-config.yaml b/service/experiment-config.yaml index b9acb09f8..877c853e7 100644 --- a/service/experiment-config.yaml +++ b/service/experiment-config.yaml @@ -2,8 +2,8 @@ # Unless you are a fuzzbench maintainer running this service, this # will not work with your setup. -trials: 20 -max_total_time: 82800 # 23 hours, the default time for preemptible experiments. +trials: 3 +max_total_time: 3660 cloud_project: fuzzbench docker_registry: gcr.io/fuzzbench cloud_compute_zone: us-central1-c diff --git a/service/gcbrun_experiment.py b/service/gcbrun_experiment.py index f19ab493d..b30f12f28 100644 --- a/service/gcbrun_experiment.py +++ b/service/gcbrun_experiment.py @@ -16,6 +16,7 @@ """Entrypoint for gcbrun into run_experiment. This script will get the command from the last PR comment containing "/gcbrun" and pass it to run_experiment.py which will run an experiment.""" +# Dummy comment to trigger run experiment action import logging import os