Skip to content

Commit 1d974f1

Browse files
committed
refresh cohorts based on flag configs in storage
1 parent 7043732 commit 1d974f1

File tree

3 files changed

+24
-26
lines changed

3 files changed

+24
-26
lines changed

src/amplitude_experiment/cohort/cohort_loader.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .cohort import Cohort
77
from .cohort_download_api import CohortDownloadApi
88
from .cohort_storage import CohortStorage
9-
from ..exception import CohortUpdateException
9+
from ..exception import CohortsDownloadException
1010

1111

1212
class CohortLoader:
@@ -30,39 +30,37 @@ def load_cohort(self, cohort_id: str) -> Future:
3030

3131
def _remove_job(self, cohort_id: str):
3232
if cohort_id in self.jobs:
33-
del self.jobs[cohort_id]
33+
with self.lock_jobs:
34+
self.jobs.pop(cohort_id, None)
3435

3536
def download_cohort(self, cohort_id: str) -> Cohort:
3637
cohort = self.cohort_storage.get_cohort(cohort_id)
3738
return self.cohort_download_api.get_cohort(cohort_id, cohort)
3839

39-
def update_stored_cohorts(self) -> Future:
40-
def update_task():
40+
def download_cohorts(self, cohort_ids: Set[str]) -> Future:
41+
def update_task(task_cohort_ids):
4142
errors = []
42-
cohort_ids = self.cohort_storage.get_cohort_ids()
43-
4443
futures = []
45-
with self.lock_jobs:
46-
for cohort_id in cohort_ids:
47-
future = self.load_cohort(cohort_id)
48-
futures.append(future)
44+
for cohort_id in task_cohort_ids:
45+
future = self.load_cohort(cohort_id)
46+
futures.append(future)
4947

5048
for future in as_completed(futures):
51-
cohort_id = next(c_id for c_id, f in self.jobs.items() if f == future)
5249
try:
5350
future.result()
5451
except Exception as e:
55-
errors.append((cohort_id, e))
52+
cohort_id = next((c_id for c_id, f in self.jobs.items() if f == future), None)
53+
if cohort_id:
54+
errors.append((cohort_id, e))
5655

5756
if errors:
58-
raise CohortUpdateException(errors)
57+
raise CohortsDownloadException(errors)
5958

60-
return self.executor.submit(update_task)
59+
return self.executor.submit(update_task, cohort_ids)
6160

6261
def __load_cohort_internal(self, cohort_id):
6362
try:
6463
cohort = self.download_cohort(cohort_id)
65-
# None is returned when cohort is not modified
6664
if cohort is not None:
6765
self.cohort_storage.put_cohort(cohort)
6866
except Exception as e:

src/amplitude_experiment/deployment/deployment_runner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from ..flag.flag_config_api import FlagConfigApi
99
from ..flag.flag_config_storage import FlagConfigStorage
1010
from ..local.poller import Poller
11-
from ..util.flag_config import get_all_cohort_ids_from_flag
11+
from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags
12+
13+
COHORT_POLLING_INTERVAL_MILLIS = 60000
1214

1315

1416
class DeploymentRunner:
@@ -29,7 +31,7 @@ def __init__(
2931
self.lock = threading.Lock()
3032
self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update)
3133
if self.cohort_loader:
32-
self.cohort_poller = Poller(self.config.flag_config_polling_interval_millis / 1000,
34+
self.cohort_poller = Poller(COHORT_POLLING_INTERVAL_MILLIS / 1000,
3335
self.__update_cohorts)
3436
self.logger = logger
3537

@@ -71,15 +73,12 @@ def __update_flag_configs(self):
7173

7274
existing_cohort_ids = self.cohort_storage.get_cohort_ids()
7375
cohort_ids_to_download = new_cohort_ids - existing_cohort_ids
74-
cohort_download_errors = []
7576

7677
# download all new cohorts
77-
for cohort_id in cohort_ids_to_download:
78-
try:
79-
self.cohort_loader.load_cohort(cohort_id).result()
80-
except Exception as e:
81-
cohort_download_errors.append((cohort_id, str(e)))
82-
self.logger.warning(f"Download cohort {cohort_id} failed: {e}")
78+
try:
79+
self.cohort_loader.download_cohorts(cohort_ids_to_download).result()
80+
except Exception as e:
81+
self.logger.warning(f"Error while downloading cohorts: {e}")
8382

8483
# get updated set of cohort ids
8584
updated_cohort_ids = self.cohort_storage.get_cohort_ids()
@@ -97,8 +96,9 @@ def __update_flag_configs(self):
9796
self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.")
9897

9998
def __update_cohorts(self):
99+
cohort_ids = get_all_cohort_ids_from_flags(list(self.flag_config_storage.get_flag_configs().values()))
100100
try:
101-
self.cohort_loader.update_stored_cohorts().result()
101+
self.cohort_loader.download_cohorts(cohort_ids).result()
102102
except Exception as e:
103103
self.logger.warning(f"Error while updating cohorts: {e}")
104104

src/amplitude_experiment/exception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, status_code, message):
1515
self.status_code = status_code
1616

1717

18-
class CohortUpdateException(Exception):
18+
class CohortsDownloadException(Exception):
1919
def __init__(self, errors):
2020
self.errors = errors
2121
super().__init__(self.__str__())

0 commit comments

Comments
 (0)