Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add stream flag #53

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
amplitude_analytics~=1.1.1
dataclasses-json~=0.6.7
sseclient-py~=1.8.0
96 changes: 27 additions & 69 deletions src/amplitude_experiment/deployment/deployment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,25 @@
from typing import Optional
import threading

from ..flag.flag_config_updater import FlagConfigPoller, FlagConfigStreamer, FlagConfigUpdaterFallbackRetryWrapper
from ..local.config import LocalEvaluationConfig
from ..cohort.cohort_loader import CohortLoader
from ..cohort.cohort_storage import CohortStorage
from ..flag.flag_config_api import FlagConfigApi
from ..flag.flag_config_api import FlagConfigApi, FlagConfigStreamApi
from ..flag.flag_config_storage import FlagConfigStorage
from ..local.poller import Poller
from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags
from ..util.flag_config import get_all_cohort_ids_from_flags

DEFAULT_STREAM_UPDATER_RETRY_DELAY_MILLIS = 15000
DEFAULT_STREAM_UPDATER_RETRY_DELAY_MAX_JITTER_MILLIS = 1000


class DeploymentRunner:
def __init__(
self,
config: LocalEvaluationConfig,
flag_config_api: FlagConfigApi,
flag_config_stream_api: Optional[FlagConfigStreamApi],
flag_config_storage: FlagConfigStorage,
cohort_storage: CohortStorage,
logger: logging.Logger,
Expand All @@ -27,88 +32,41 @@ def __init__(
self.cohort_storage = cohort_storage
self.cohort_loader = cohort_loader
self.lock = threading.Lock()
self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update)
self.flag_updater = FlagConfigUpdaterFallbackRetryWrapper(
FlagConfigPoller(flag_config_api, flag_config_storage, cohort_loader, cohort_storage, config, logger),
None,
0, 0, config.flag_config_polling_interval_millis, 0,
logger
)
if flag_config_stream_api:
self.flag_updater = FlagConfigUpdaterFallbackRetryWrapper(
FlagConfigStreamer(flag_config_stream_api, flag_config_storage, cohort_loader, cohort_storage, logger),
self.flag_updater,
DEFAULT_STREAM_UPDATER_RETRY_DELAY_MILLIS, DEFAULT_STREAM_UPDATER_RETRY_DELAY_MAX_JITTER_MILLIS,
config.flag_config_polling_interval_millis, 0,
logger
)

self.cohort_poller = None
if self.cohort_loader:
self.cohort_poller = Poller(self.config.cohort_sync_config.cohort_polling_interval_millis / 1000,
self.__update_cohorts)
self.logger = logger

def start(self):
with self.lock:
self.__update_flag_configs()
self.flag_poller.start()
self.flag_updater.start(None)
if self.cohort_loader:
self.cohort_poller.start()

def stop(self):
self.flag_poller.stop()

def __periodic_flag_update(self):
try:
self.__update_flag_configs()
except Exception as e:
self.logger.warning(f"Error while updating flags: {e}")

def __update_flag_configs(self):
try:
flag_configs = self.flag_config_api.get_flag_configs()
except Exception as e:
self.logger.warning(f'Failed to fetch flag configs: {e}')
raise e

flag_keys = {flag.key for flag in flag_configs}
self.flag_config_storage.remove_if(lambda f: f.key not in flag_keys)

if not self.cohort_loader:
for flag_config in flag_configs:
self.logger.debug(f"Putting non-cohort flag {flag_config.key}")
self.flag_config_storage.put_flag_config(flag_config)
return

new_cohort_ids = set()
for flag_config in flag_configs:
new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config))

existing_cohort_ids = self.cohort_storage.get_cohort_ids()
cohort_ids_to_download = new_cohort_ids - existing_cohort_ids

# download all new cohorts
try:
self.cohort_loader.download_cohorts(cohort_ids_to_download).result()
except Exception as e:
self.logger.warning(f"Error while downloading cohorts: {e}")

# get updated set of cohort ids
updated_cohort_ids = self.cohort_storage.get_cohort_ids()
# iterate through new flag configs and check if their required cohorts exist
for flag_config in flag_configs:
cohort_ids = get_all_cohort_ids_from_flag(flag_config)
self.logger.debug(f"Storing flag {flag_config.key}")
self.flag_config_storage.put_flag_config(flag_config)
missing_cohorts = cohort_ids - updated_cohort_ids
if missing_cohorts:
self.logger.warning(f"Flag {flag_config.key} - failed to load cohorts: {missing_cohorts}")

# delete unused cohorts
self._delete_unused_cohorts()
self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.")
self.flag_updater.stop()
if self.cohort_poller:
self.cohort_poller.stop()

def __update_cohorts(self):
cohort_ids = get_all_cohort_ids_from_flags(list(self.flag_config_storage.get_flag_configs().values()))
try:
self.cohort_loader.download_cohorts(cohort_ids).result()
except Exception as e:
self.logger.warning(f"Error while updating cohorts: {e}")

def _delete_unused_cohorts(self):
flag_cohort_ids = set()
for flag in self.flag_config_storage.get_flag_configs().values():
flag_cohort_ids.update(get_all_cohort_ids_from_flag(flag))

storage_cohorts = self.cohort_storage.get_cohorts()
deleted_cohort_ids = set(storage_cohorts.keys()) - flag_cohort_ids

for deleted_cohort_id in deleted_cohort_ids:
deleted_cohort = storage_cohorts.get(deleted_cohort_id)
if deleted_cohort is not None:
self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id)
2 changes: 2 additions & 0 deletions src/amplitude_experiment/flag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .flag_config_api import FlagConfigStreamApi
from .flag_config_updater import FlagConfigStreamer
186 changes: 182 additions & 4 deletions src/amplitude_experiment/flag/flag_config_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
from typing import List
import threading
from http.client import HTTPResponse, HTTPConnection, HTTPSConnection
from typing import List, Optional, Callable, Mapping, Union, Tuple

from ..evaluation.types import EvaluationFlag
from ..version import __version__
import sseclient

from ..connection_pool import HTTPConnectionPool

from ..util.updater import get_duration_with_jitter
from ..evaluation.types import EvaluationFlag
from ..version import __version__

class FlagConfigApi:
def get_flag_configs(self) -> List[EvaluationFlag]:
Expand Down Expand Up @@ -46,3 +49,178 @@ def __setup_connection_pool(self):
timeout = self.flag_config_poller_request_timeout_millis / 1000
self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30,
read_timeout=timeout, scheme=scheme)


DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS = 17000
DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS = 15 * 60 * 1000
DEFAULT_STREAM_MAX_JITTER_MILLIS = 5000


class EventSource:
def __init__(self, server_url: str, path: str, headers: Mapping[str, str], conn_timeout_millis: int,
max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS,
max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS,
keep_alive_timeout_millis: int = DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS):
self.keep_alive_timer: Optional[threading.Timer] = None
self.server_url = server_url
self.path = path
self.headers = headers
self.conn_timeout_millis = conn_timeout_millis
self.max_conn_duration_millis = max_conn_duration_millis
self.max_jitter_millis = max_jitter_millis
self.keep_alive_timeout_millis = keep_alive_timeout_millis

self.sse: Optional[sseclient.SSEClient] = None
self.conn: Optional[HTTPConnection | HTTPSConnection] = None
self.thread: Optional[threading.Thread] = None
self._stopped = False
self.lock = threading.RLock()

def start(self, on_update: Callable[[str], None], on_error: Callable[[str], None]):
with self.lock:
if self.sse is not None:
self.sse.close()
if self.conn is not None:
self.conn.close()

self.conn, response = self._get_conn()
if response.status != 200:
on_error(f"[Experiment] Stream flagConfigs - received error response: ${response.status}: ${response.read().decode('utf-8')}")
return

self.sse = sseclient.SSEClient(response, char_enc='utf-8')
self._stopped = False
self.thread = threading.Thread(target=self._run, args=[on_update, on_error])
self.thread.start()
self.reset_keep_alive_timer(on_error)

def stop(self):
with self.lock:
self._stopped = True
if self.sse:
self.sse.close()
if self.conn:
self.conn.close()
if self.keep_alive_timer:
self.keep_alive_timer.cancel()
self.sse = None
self.conn = None
# No way to stop self.thread, on self.conn.close(),
# the loop in thread will raise exception, which will terminate the thread.

def reset_keep_alive_timer(self, on_error: Callable[[str], None]):
with self.lock:
if self.keep_alive_timer:
self.keep_alive_timer.cancel()
self.keep_alive_timer = threading.Timer(self.keep_alive_timeout_millis / 1000, self.keep_alive_timed_out,
args=[on_error])
self.keep_alive_timer.start()

def keep_alive_timed_out(self, on_error: Callable[[str], None]):
with self.lock:
if not self._stopped:
self.stop()
on_error("[Experiment] Stream flagConfigs - Keep alive timed out")

def _run(self, on_update: Callable[[str], None], on_error: Callable[[str], None]):
try:
for event in self.sse.events():
with self.lock:
if self._stopped:
return
self.reset_keep_alive_timer(on_error)
if event.data == ' ':
continue
on_update(event.data)
except TimeoutError:
# Due to connection max time reached, open another one.
with self.lock:
if self._stopped:
return
self.stop()
self.start(on_update, on_error)
except Exception as e:
# Closing connection can result in exception here as a way to stop generator.
with self.lock:
if self._stopped:
return
on_error("[Experiment] Stream flagConfigs - Unexpected exception" + str(e))

def _get_conn(self) -> Tuple[Union[HTTPConnection, HTTPSConnection], HTTPResponse]:
scheme, _, host = self.server_url.split('/', 3)
connection = HTTPConnection if scheme == 'http:' else HTTPSConnection

body = None

conn = connection(host, timeout=get_duration_with_jitter(self.max_conn_duration_millis, self.max_jitter_millis) / 1000)
try:
conn.request('GET', self.path, body, self.headers)
response = conn.getresponse()
except Exception as e:
conn.close()
raise e

return conn, response


class FlagConfigStreamApi:
def __init__(self,
deployment_key: str,
server_url: str,
conn_timeout_millis: int,
max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS,
max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS):
self.deployment_key = deployment_key
self.server_url = server_url
self.conn_timeout_millis = conn_timeout_millis
self.max_conn_duration_millis = max_conn_duration_millis
self.max_jitter_millis = max_jitter_millis

self.lock = threading.RLock()

headers = {
'Authorization': f"Api-Key {self.deployment_key}",
'Content-Type': 'application/json;charset=utf-8',
'X-Amp-Exp-Library': f"experiment-python-server/{__version__}"
}

self.eventsource = EventSource(self.server_url, "/sdk/stream/v1/flags", headers, conn_timeout_millis)

def start(self, on_update: Callable[[List[EvaluationFlag]], None], on_error: Callable[[str], None]):
with self.lock:
init_finished_event = threading.Event()
init_error_event = threading.Event()
init_updated_event = threading.Event()

def _on_update(data):
response_json = json.loads(data)
flags = EvaluationFlag.schema().load(response_json, many=True)
if init_finished_event.is_set():
on_update(flags)
else:
init_finished_event.set()
on_update(flags)
init_updated_event.set()

def _on_error(data):
if init_finished_event.is_set():
on_error(data)
else:
init_error_event.set()
init_finished_event.set()
on_error(data)

t = threading.Thread(target=self.eventsource.start, args=[_on_update, _on_error])
t.start()
init_finished_event.wait(self.conn_timeout_millis / 1000)
if t.is_alive() or not init_finished_event.is_set() or init_error_event.is_set():
self.stop()
on_error("stream connection timeout error")
return

# Wait for first update callback to finish before returning.
init_updated_event.wait()

def stop(self):
with self.lock:
threading.Thread(target=lambda: self.eventsource.stop()).start()
Loading
Loading