Skip to content

Commit 048a72d

Browse files
gustavogaldinoogustavogaldinoo
andauthored
Adding local implementation for queue based measuring (#1998)
Adding local implementation for queue based measuring and tests --------- Co-authored-by: gustavogaldinoo <gustavogaldino@google.com>
1 parent 22673aa commit 048a72d

File tree

5 files changed

+432
-29
lines changed

5 files changed

+432
-29
lines changed

experiment/measurer/datatypes.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Module for common data types shared under the measurer module."""
15+
import collections
16+
17+
SnapshotMeasureRequest = collections.namedtuple(
18+
'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle'])
19+
20+
RetryRequest = collections.namedtuple(
21+
'RetryRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle'])

experiment/measurer/measure_manager.py

Lines changed: 144 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,20 @@
4444
from database import models
4545
from experiment.build import build_utils
4646
from experiment.measurer import coverage_utils
47+
from experiment.measurer import measure_worker
4748
from experiment.measurer import run_coverage
4849
from experiment.measurer import run_crashes
4950
from experiment import scheduler
51+
import experiment.measurer.datatypes as measurer_datatypes
5052

5153
logger = logs.Logger()
5254

53-
SnapshotMeasureRequest = collections.namedtuple(
54-
'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id', 'cycle'])
55-
5655
NUM_RETRIES = 3
5756
RETRY_DELAY = 3
5857
FAIL_WAIT_SECONDS = 30
5958
SNAPSHOT_QUEUE_GET_TIMEOUT = 1
6059
SNAPSHOTS_BATCH_SAVE_SIZE = 100
60+
MEASUREMENT_LOOP_WAIT = 10
6161

6262

6363
def exists_in_experiment_filestore(path: pathlib.Path) -> bool:
@@ -75,10 +75,9 @@ def measure_main(experiment_config):
7575
experiment = experiment_config['experiment']
7676
max_total_time = experiment_config['max_total_time']
7777
measurers_cpus = experiment_config['measurers_cpus']
78-
runners_cpus = experiment_config['runners_cpus']
7978
region_coverage = experiment_config['region_coverage']
80-
measure_loop(experiment, max_total_time, measurers_cpus, runners_cpus,
81-
region_coverage)
79+
measure_manager_loop(experiment, max_total_time, measurers_cpus,
80+
region_coverage)
8281

8382
# Clean up resources.
8483
gc.collect()
@@ -104,18 +103,7 @@ def measure_loop(experiment: str,
104103
"""Continuously measure trials for |experiment|."""
105104
logger.info('Start measure_loop.')
106105

107-
pool_args = ()
108-
if measurers_cpus is not None and runners_cpus is not None:
109-
local_experiment = experiment_utils.is_local_experiment()
110-
if local_experiment:
111-
cores_queue = multiprocessing.Queue()
112-
logger.info('Scheduling measurers from core %d to %d.',
113-
runners_cpus, runners_cpus + measurers_cpus - 1)
114-
for cpu in range(runners_cpus, runners_cpus + measurers_cpus):
115-
cores_queue.put(cpu)
116-
pool_args = (measurers_cpus, _process_init, (cores_queue,))
117-
else:
118-
pool_args = (measurers_cpus,)
106+
pool_args = get_pool_args(measurers_cpus, runners_cpus)
119107

120108
with multiprocessing.Pool(
121109
*pool_args) as pool, multiprocessing.Manager() as manager:
@@ -256,12 +244,13 @@ def _query_unmeasured_trials(experiment: str):
256244

257245

258246
def _get_unmeasured_first_snapshots(
259-
experiment: str) -> List[SnapshotMeasureRequest]:
247+
experiment: str) -> List[measurer_datatypes.SnapshotMeasureRequest]:
260248
"""Returns a list of unmeasured SnapshotMeasureRequests that are the first
261249
snapshot for their trial. The trials are trials in |experiment|."""
262250
trials_without_snapshots = _query_unmeasured_trials(experiment)
263251
return [
264-
SnapshotMeasureRequest(trial.fuzzer, trial.benchmark, trial.id, 0)
252+
measurer_datatypes.SnapshotMeasureRequest(trial.fuzzer, trial.benchmark,
253+
trial.id, 0)
265254
for trial in trials_without_snapshots
266255
]
267256

@@ -289,7 +278,8 @@ def _query_measured_latest_snapshots(experiment: str):
289278

290279

291280
def _get_unmeasured_next_snapshots(
292-
experiment: str, max_cycle: int) -> List[SnapshotMeasureRequest]:
281+
experiment: str,
282+
max_cycle: int) -> List[measurer_datatypes.SnapshotMeasureRequest]:
293283
"""Returns a list of the latest unmeasured SnapshotMeasureRequests of
294284
trials in |experiment| that have been measured at least once in
295285
|experiment|. |max_total_time| is used to determine if a trial has another
@@ -305,16 +295,15 @@ def _get_unmeasured_next_snapshots(
305295
if next_cycle > max_cycle:
306296
continue
307297

308-
snapshot_with_cycle = SnapshotMeasureRequest(snapshot.fuzzer,
309-
snapshot.benchmark,
310-
snapshot.trial_id,
311-
next_cycle)
298+
snapshot_with_cycle = measurer_datatypes.SnapshotMeasureRequest(
299+
snapshot.fuzzer, snapshot.benchmark, snapshot.trial_id, next_cycle)
312300
next_snapshots.append(snapshot_with_cycle)
313301
return next_snapshots
314302

315303

316-
def get_unmeasured_snapshots(experiment: str,
317-
max_cycle: int) -> List[SnapshotMeasureRequest]:
304+
def get_unmeasured_snapshots(
305+
experiment: str,
306+
max_cycle: int) -> List[measurer_datatypes.SnapshotMeasureRequest]:
318307
"""Returns a list of SnapshotMeasureRequests that need to be measured
319308
(assuming they have been saved already)."""
320309
# Measure the first snapshot of every started trial without any measured
@@ -683,6 +672,134 @@ def initialize_logs():
683672
})
684673

685674

675+
def consume_snapshots_from_response_queue(
676+
response_queue, queued_snapshots) -> List[models.Snapshot]:
677+
"""Consume response_queue, allows retry objects to retried, and
678+
return all measured snapshots in a list."""
679+
measured_snapshots = []
680+
while True:
681+
try:
682+
response_object = response_queue.get_nowait()
683+
if isinstance(response_object, measurer_datatypes.RetryRequest):
684+
# Need to retry measurement task, will remove identifier from
685+
# the set so task can be retried in next loop iteration.
686+
snapshot_identifier = (response_object.trial_id,
687+
response_object.cycle)
688+
queued_snapshots.remove(snapshot_identifier)
689+
logger.info('Reescheduling task for trial %s and cycle %s',
690+
response_object.trial_id, response_object.cycle)
691+
elif isinstance(response_object, models.Snapshot):
692+
measured_snapshots.append(response_object)
693+
else:
694+
logger.error('Type of response object not mapped! %s',
695+
type(response_object))
696+
except queue.Empty:
697+
break
698+
return measured_snapshots
699+
700+
701+
def measure_manager_inner_loop(experiment: str, max_cycle: int, request_queue,
702+
response_queue, queued_snapshots):
703+
"""Reads from database to determine which snapshots needs measuring. Write
704+
measurements tasks to request queue, get results from response queue, and
705+
write measured snapshots to database. Returns False if there's no more
706+
snapshots left to be measured"""
707+
initialize_logs()
708+
# Read database to determine which snapshots needs measuring.
709+
unmeasured_snapshots = get_unmeasured_snapshots(experiment, max_cycle)
710+
logger.info('Retrieved %d unmeasured snapshots from measure manager',
711+
len(unmeasured_snapshots))
712+
# When there are no more snapshots left to be measured, should break loop.
713+
if not unmeasured_snapshots:
714+
return False
715+
716+
# Write measurements requests to request queue
717+
for unmeasured_snapshot in unmeasured_snapshots:
718+
# No need to insert fuzzer and benchmark info here as it's redundant
719+
# (Can be retrieved through trial_id).
720+
unmeasured_snapshot_identifier = (unmeasured_snapshot.trial_id,
721+
unmeasured_snapshot.cycle)
722+
# Checking if snapshot already was queued so workers will not repeat
723+
# measurement for same snapshot
724+
if unmeasured_snapshot_identifier not in queued_snapshots:
725+
request_queue.put(unmeasured_snapshot)
726+
queued_snapshots.add(unmeasured_snapshot_identifier)
727+
728+
# Read results from response queue.
729+
measured_snapshots = consume_snapshots_from_response_queue(
730+
response_queue, queued_snapshots)
731+
logger.info('Retrieved %d measured snapshots from response queue',
732+
len(measured_snapshots))
733+
734+
# Save measured snapshots to database.
735+
if measured_snapshots:
736+
db_utils.add_all(measured_snapshots)
737+
738+
return True
739+
740+
741+
def get_pool_args(measurers_cpus, runners_cpus):
742+
"""Return pool args based on measurer cpus and runner cpus arguments."""
743+
if measurers_cpus is None or runners_cpus is None:
744+
return ()
745+
746+
local_experiment = experiment_utils.is_local_experiment()
747+
if not local_experiment:
748+
return (measurers_cpus,)
749+
750+
cores_queue = multiprocessing.Queue()
751+
logger.info('Scheduling measurers from core %d to %d.', runners_cpus,
752+
runners_cpus + measurers_cpus - 1)
753+
for cpu in range(runners_cpus, runners_cpus + measurers_cpus):
754+
cores_queue.put(cpu)
755+
return (measurers_cpus, _process_init, (cores_queue,))
756+
757+
758+
def measure_manager_loop(experiment: str,
759+
max_total_time: int,
760+
measurers_cpus=None,
761+
region_coverage=False): # pylint: disable=too-many-locals
762+
"""Measure manager loop. Creates request and response queues, request
763+
measurements tasks from workers, retrieve measurement results from response
764+
queue and writes measured snapshots in database."""
765+
logger.info('Starting measure manager loop.')
766+
if not measurers_cpus:
767+
measurers_cpus = multiprocessing.cpu_count()
768+
logger.info('Number of measurer CPUs not passed as argument. using %d',
769+
measurers_cpus)
770+
with multiprocessing.Pool() as pool, multiprocessing.Manager() as manager:
771+
logger.info('Setting up coverage binaries')
772+
set_up_coverage_binaries(pool, experiment)
773+
request_queue = manager.Queue()
774+
response_queue = manager.Queue()
775+
776+
config = {
777+
'request_queue': request_queue,
778+
'response_queue': response_queue,
779+
'region_coverage': region_coverage,
780+
}
781+
local_measure_worker = measure_worker.LocalMeasureWorker(config)
782+
783+
# Since each worker is going to be in an infinite loop, we dont need
784+
# result return. Workers' life scope will end automatically when there
785+
# are no more snapshots left to measure.
786+
logger.info('Starting measure worker loop for %d workers',
787+
measurers_cpus)
788+
for _ in range(measurers_cpus):
789+
_result = pool.apply_async(local_measure_worker.measure_worker_loop)
790+
791+
max_cycle = _time_to_cycle(max_total_time)
792+
queued_snapshots = set()
793+
while not scheduler.all_trials_ended(experiment):
794+
continue_inner_loop = measure_manager_inner_loop(
795+
experiment, max_cycle, request_queue, response_queue,
796+
queued_snapshots)
797+
if not continue_inner_loop:
798+
break
799+
time.sleep(MEASUREMENT_LOOP_WAIT)
800+
logger.info('All trials ended. Ending measure manager loop')
801+
802+
686803
def main():
687804
"""Measure the experiment."""
688805
initialize_logs()

experiment/measurer/measure_worker.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Module for measurer workers logic."""
15+
import time
16+
from typing import Dict, Optional
17+
from common import logs
18+
from database.models import Snapshot
19+
import experiment.measurer.datatypes as measurer_datatypes
20+
from experiment.measurer import measure_manager
21+
22+
MEASUREMENT_TIMEOUT = 1
23+
logger = logs.Logger() # pylint: disable=invalid-name
24+
25+
26+
class BaseMeasureWorker:
27+
"""Base class for measure worker. Encapsulates core methods that will be
28+
implemented for Local and Google Cloud measure workers."""
29+
30+
def __init__(self, config: Dict):
31+
self.request_queue = config['request_queue']
32+
self.response_queue = config['response_queue']
33+
self.region_coverage = config['region_coverage']
34+
35+
def get_task_from_request_queue(self):
36+
""""Get task from request queue"""
37+
raise NotImplementedError
38+
39+
def put_result_in_response_queue(self, measured_snapshot, request):
40+
"""Save measurement result in response queue, for the measure manager to
41+
retrieve"""
42+
raise NotImplementedError
43+
44+
def measure_worker_loop(self):
45+
"""Periodically retrieves request from request queue, measure it, and
46+
put result in response queue"""
47+
logs.initialize(default_extras={
48+
'component': 'measurer',
49+
'subcomponent': 'worker',
50+
})
51+
logger.info('Starting one measure worker loop')
52+
while True:
53+
# 'SnapshotMeasureRequest', ['fuzzer', 'benchmark', 'trial_id',
54+
# 'cycle']
55+
request = self.get_task_from_request_queue()
56+
logger.info(
57+
'Measurer worker: Got request %s %s %d %d from request queue',
58+
request.fuzzer, request.benchmark, request.trial_id,
59+
request.cycle)
60+
measured_snapshot = measure_manager.measure_snapshot_coverage(
61+
request.fuzzer, request.benchmark, request.trial_id,
62+
request.cycle, self.region_coverage)
63+
self.put_result_in_response_queue(measured_snapshot, request)
64+
time.sleep(MEASUREMENT_TIMEOUT)
65+
66+
67+
class LocalMeasureWorker(BaseMeasureWorker):
68+
"""Class that holds implementations of core methods for running a measure
69+
worker locally."""
70+
71+
def get_task_from_request_queue(
72+
self) -> measurer_datatypes.SnapshotMeasureRequest:
73+
"""Get item from request multiprocessing queue, block if necessary until
74+
an item is available"""
75+
request = self.request_queue.get(block=True)
76+
return request
77+
78+
def put_result_in_response_queue(
79+
self, measured_snapshot: Optional[Snapshot],
80+
request: measurer_datatypes.SnapshotMeasureRequest):
81+
if measured_snapshot:
82+
logger.info('Put measured snapshot in response_queue')
83+
self.response_queue.put(measured_snapshot)
84+
else:
85+
retry_request = measurer_datatypes.RetryRequest(
86+
request.fuzzer, request.benchmark, request.trial_id,
87+
request.cycle)
88+
self.response_queue.put(retry_request)

0 commit comments

Comments
 (0)