From aad0f2c92da6e0338126c1810d48f27b0dfbcec7 Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Wed, 1 May 2024 17:21:44 +0100 Subject: [PATCH 1/7] feat(flags): Basic flags service --- Cargo.lock | 26 ++++++ Cargo.toml | 1 + feature-flags/Cargo.toml | 35 ++++++++ feature-flags/src/api.rs | 59 +++++++++++++ feature-flags/src/config.rs | 27 ++++++ feature-flags/src/lib.rs | 7 ++ feature-flags/src/main.rs | 42 ++++++++++ feature-flags/src/redis.rs | 78 +++++++++++++++++ feature-flags/src/router.rs | 50 +++++++++++ feature-flags/src/server.rs | 34 ++++++++ feature-flags/src/v0_endpoint.rs | 88 ++++++++++++++++++++ feature-flags/src/v0_request.rs | 138 +++++++++++++++++++++++++++++++ 12 files changed, 585 insertions(+) create mode 100644 feature-flags/Cargo.toml create mode 100644 feature-flags/src/api.rs create mode 100644 feature-flags/src/config.rs create mode 100644 feature-flags/src/lib.rs create mode 100644 feature-flags/src/main.rs create mode 100644 feature-flags/src/redis.rs create mode 100644 feature-flags/src/router.rs create mode 100644 feature-flags/src/server.rs create mode 100644 feature-flags/src/v0_endpoint.rs create mode 100644 feature-flags/src/v0_request.rs diff --git a/Cargo.lock b/Cargo.lock index 82c787a..c0139c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -706,6 +706,32 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +[[package]] +name = "feature-flags" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "axum 0.7.5", + "axum-client-ip", + "base64 0.22.0", + "bytes", + "envconfig", + "flate2", + "governor", + "metrics", + "rand", + "rdkafka", + "redis", + "serde", + "serde_json", + "serde_urlencoded", + "thiserror", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "finl_unicode" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index d34cd0a..265fe5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "capture", "capture-server", "common/health", + "feature-flags", "hook-api", "hook-common", "hook-janitor", diff --git a/feature-flags/Cargo.toml b/feature-flags/Cargo.toml new file mode 100644 index 0000000..088b069 --- /dev/null +++ b/feature-flags/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "feature-flags" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } +axum = { workspace = true } +axum-client-ip = { workspace = true } +envconfig = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } +base64 = { workspace = true } +bytes = { workspace = true } +flate2 = { workspace = true } +governor = { workspace = true } +metrics = { workspace = true } +rand = { workspace = true } +rdkafka = { workspace = true } +redis = { version = "0.23.3", features = [ + "tokio-comp", + "cluster", + "cluster-async", +] } +serde = { workspace = true } +serde_json = { workspace = true } +serde_urlencoded = { workspace = true } +thiserror = { workspace = true } + +[lints] +workspace = true diff --git a/feature-flags/src/api.rs b/feature-flags/src/api.rs new file mode 100644 index 0000000..dc016b2 --- /dev/null +++ b/feature-flags/src/api.rs @@ -0,0 +1,59 @@ +use std::collections::HashMap; + +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] +pub enum FlagsResponseCode { + Ok = 1, +} + +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct FlagsResponse { + pub error_while_computing_flags: bool, + // TODO: better typing here, support bool responses + pub feature_flags: HashMap +} + +#[derive(Error, Debug)] +pub enum FlagError { + #[error("failed to decode request: {0}")] + RequestDecodingError(String), + #[error("failed to parse request: {0}")] + RequestParsingError(#[from] serde_json::Error), + + #[error("Empty distinct_id in request")] + EmptyDistinctId, + #[error("No distinct_id in request")] + MissingDistinctId, + + #[error("No api_key in request")] + NoTokenError, + #[error("API key is not valid")] + TokenValidationError, + + #[error("rate limited")] + RateLimited, +} + +impl IntoResponse for FlagError { + fn into_response(self) -> Response { + match self { + FlagError::RequestDecodingError(_) + | FlagError::RequestParsingError(_) + | FlagError::EmptyDistinctId + | FlagError::MissingDistinctId => (StatusCode::BAD_REQUEST, self.to_string()), + + FlagError::NoTokenError + | FlagError::TokenValidationError => (StatusCode::UNAUTHORIZED, self.to_string()), + + FlagError::RateLimited => { + (StatusCode::TOO_MANY_REQUESTS, self.to_string()) + } + } + .into_response() + } +} \ No newline at end of file diff --git a/feature-flags/src/config.rs b/feature-flags/src/config.rs new file mode 100644 index 0000000..46d2983 --- /dev/null +++ b/feature-flags/src/config.rs @@ -0,0 +1,27 @@ +use std::{net::SocketAddr, num::NonZeroU32}; + +use envconfig::Envconfig; + +#[derive(Envconfig, Clone)] +pub struct Config { + #[envconfig(default = "false")] + pub print_sink: bool, + + #[envconfig(default = "127.0.0.1:3001")] + pub address: SocketAddr, + + #[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")] + pub write_database_url: String, + + #[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")] + pub read_database_url: String, + + #[envconfig(default = "1024")] + pub max_concurrent_jobs: usize, + + #[envconfig(default = "100")] + pub max_pg_connections: u32, + + #[envconfig(default = "redis://localhost:6379/")] + pub redis_url: String, +} diff --git a/feature-flags/src/lib.rs b/feature-flags/src/lib.rs new file mode 100644 index 0000000..ebda41c --- /dev/null +++ b/feature-flags/src/lib.rs @@ -0,0 +1,7 @@ +pub mod config; +pub mod server; +pub mod router; +pub mod v0_endpoint; +pub mod v0_request; +pub mod api; +pub mod redis; \ No newline at end of file diff --git a/feature-flags/src/main.rs b/feature-flags/src/main.rs new file mode 100644 index 0000000..716f464 --- /dev/null +++ b/feature-flags/src/main.rs @@ -0,0 +1,42 @@ + +use envconfig::Envconfig; +use tokio::signal; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, Layer}; + +use feature_flags::config::Config; +use feature_flags::server::serve; + +async fn shutdown() { + let mut term = signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to register SIGTERM handler"); + + let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt()) + .expect("failed to register SIGINT handler"); + + tokio::select! { + _ = term.recv() => {}, + _ = interrupt.recv() => {}, + }; + + tracing::info!("Shutting down gracefully..."); +} + +#[tokio::main] +async fn main() { + let config = Config::init_from_env().expect("Invalid configuration:"); + + // Basic logging for now: + // - stdout with a level configured by the RUST_LOG envvar (default=ERROR) + let log_layer = tracing_subscriber::fmt::layer().with_filter(EnvFilter::from_default_env()); + tracing_subscriber::registry() + .with(log_layer) + .init(); + + // Open the TCP port and start the server + let listener = tokio::net::TcpListener::bind(config.address) + .await + .expect("could not bind port"); + serve(config, listener, shutdown()).await +} diff --git a/feature-flags/src/redis.rs b/feature-flags/src/redis.rs new file mode 100644 index 0000000..457c008 --- /dev/null +++ b/feature-flags/src/redis.rs @@ -0,0 +1,78 @@ +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use redis::AsyncCommands; +use tokio::time::timeout; + +// average for all commands is <10ms, check grafana +const REDIS_TIMEOUT_MILLISECS: u64 = 10; + +/// A simple redis wrapper +/// Copied from capture/src/redis.rs. +/// TODO: Modify this to support hincrby, get, and set commands. + +#[async_trait] +pub trait Client { + // A very simplified wrapper, but works for our usage + async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result>; +} + +pub struct RedisClient { + client: redis::Client, +} + +impl RedisClient { + pub fn new(addr: String) -> Result { + let client = redis::Client::open(addr)?; + + Ok(RedisClient { client }) + } +} + +#[async_trait] +impl Client for RedisClient { + async fn zrangebyscore(&self, k: String, min: String, max: String) -> Result> { + let mut conn = self.client.get_async_connection().await?; + + let results = conn.zrangebyscore(k, min, max); + let fut = timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?; + + Ok(fut?) + } +} + +// TODO: Find if there's a better way around this. +// mockall got really annoying with async and results so I'm just gonna do my own +#[derive(Clone)] +pub struct MockRedisClient { + zrangebyscore_ret: Vec, +} + +impl MockRedisClient { + pub fn new() -> MockRedisClient { + MockRedisClient { + zrangebyscore_ret: Vec::new(), + } + } + + pub fn zrangebyscore_ret(&mut self, ret: Vec) -> Self { + self.zrangebyscore_ret = ret; + + self.clone() + } +} + +impl Default for MockRedisClient { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl Client for MockRedisClient { + // A very simplified wrapper, but works for our usage + async fn zrangebyscore(&self, _k: String, _min: String, _max: String) -> Result> { + Ok(self.zrangebyscore_ret.clone()) + } +} diff --git a/feature-flags/src/router.rs b/feature-flags/src/router.rs new file mode 100644 index 0000000..0088968 --- /dev/null +++ b/feature-flags/src/router.rs @@ -0,0 +1,50 @@ +use std::future::ready; +use std::sync::Arc; + +use axum::http::Method; +use axum::{ + routing::{get, post}, + Router, +}; + +use crate::{ + redis::Client, v0_endpoint, +}; + + +#[derive(Clone)] +pub struct State { + pub redis: Arc, + // TODO: Add pgClient when ready +} + + + +pub fn router< + R: Client + Send + Sync + 'static, +>( + redis: Arc, +) -> Router { + let state = State { + redis, + }; + + // // Very permissive CORS policy, as old SDK versions + // // and reverse proxies might send funky headers. + // let cors = CorsLayer::new() + // .allow_methods([Method::GET, Method::POST, Method::OPTIONS]) + // .allow_headers(AllowHeaders::mirror_request()) + // .allow_credentials(true) + // .allow_origin(AllowOrigin::mirror_request()); + + let router = Router::new() + .route( + "/flags", + post(v0_endpoint::flags) + .get(v0_endpoint::flags) + ) + + .with_state(state); + + router +} diff --git a/feature-flags/src/server.rs b/feature-flags/src/server.rs new file mode 100644 index 0000000..fcebce8 --- /dev/null +++ b/feature-flags/src/server.rs @@ -0,0 +1,34 @@ +use std::future::Future; +use std::net::SocketAddr; +use std::sync::Arc; + +use tokio::net::TcpListener; + +use crate::config::Config; + +use crate::redis::RedisClient; +use crate::router; + +pub async fn serve(config: Config, listener: TcpListener, shutdown: F) +where + F: Future + Send + 'static, +{ + + let redis_client = + Arc::new(RedisClient::new(config.redis_url).expect("failed to create redis client")); + + let app = router::router( + redis_client, + ); + + // run our app with hyper + // `axum::Server` is a re-export of `hyper::Server` + tracing::info!("listening on {:?}", listener.local_addr().unwrap()); + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .with_graceful_shutdown(shutdown) + .await + .unwrap() +} diff --git a/feature-flags/src/v0_endpoint.rs b/feature-flags/src/v0_endpoint.rs new file mode 100644 index 0000000..837621d --- /dev/null +++ b/feature-flags/src/v0_endpoint.rs @@ -0,0 +1,88 @@ +use std::collections::HashMap; +use std::ops::Deref; +use std::sync::Arc; + +use axum::{debug_handler, Json}; +use bytes::Bytes; +// TODO: stream this instead +use axum::extract::{MatchedPath, Query, State}; +use axum::http::{HeaderMap, Method}; +use axum_client_ip::InsecureClientIp; +use base64::Engine; +use tracing::instrument; + + +use crate::{ + api::{FlagError, FlagsResponse}, + router, + v0_request::{FlagsQueryParams, FlagRequest}, +}; + +/// Feature flag evaluation endpoint. +/// Only supports a specific shape of data, and rejects any malformed data. + +#[instrument( + skip_all, + fields( + path, + token, + batch_size, + user_agent, + content_encoding, + content_type, + version, + compression, + historical_migration + ) +)] +#[debug_handler] +pub async fn flags( + state: State, + InsecureClientIp(ip): InsecureClientIp, + meta: Query, + headers: HeaderMap, + method: Method, + path: MatchedPath, + body: Bytes, +) -> Result, FlagError> { + let user_agent = headers + .get("user-agent") + .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); + let content_encoding = headers + .get("content-encoding") + .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); + + + tracing::Span::current().record("user_agent", user_agent); + tracing::Span::current().record("content_encoding", content_encoding); + // tracing::Span::current().record("version", meta.lib_version.clone()); + tracing::Span::current().record("method", method.as_str()); + tracing::Span::current().record("path", path.as_str().trim_end_matches('/')); + + let request = match headers + .get("content-type") + .map_or("", |v| v.to_str().unwrap_or("")) + { + "application/x-www-form-urlencoded" => { + return Err(FlagError::RequestDecodingError(String::from("invalid form data"))); + } + ct => { + tracing::Span::current().record("content_type", ct); + + FlagRequest::from_bytes(body) + } + }?; + + let token = request.extract_and_verify_token()?; + + tracing::Span::current().record("token", &token); + + tracing::debug!("request: {:?}", request); + + // TODO: Some actual processing for evaluating the feature flag + + Ok(Json(FlagsResponse { + error_while_computing_flags: false, + feature_flags: HashMap::from([("beta-feature".to_string(), "variant-1".to_string()), ("rollout-flag".to_string(), true.to_string())]), + })) +} \ No newline at end of file diff --git a/feature-flags/src/v0_request.rs b/feature-flags/src/v0_request.rs new file mode 100644 index 0000000..bdcd35d --- /dev/null +++ b/feature-flags/src/v0_request.rs @@ -0,0 +1,138 @@ +use std::collections::{HashMap}; +// use std::io::prelude::*; + +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +// use time::format_description::well_known::Iso8601; +// use time::OffsetDateTime; +use tracing::instrument; +// use uuid::Uuid; + +use crate::api::FlagError; + + +#[derive(Deserialize, Default)] +pub struct FlagsQueryParams { + + #[serde(alias = "v")] + pub version: Option, + + #[serde(alias = "_")] + sent_at: Option, +} + +#[derive(Default, Debug, Deserialize, Serialize)] +pub struct FlagRequest { + #[serde( + alias = "$token", + alias = "api_key", + skip_serializing_if = "Option::is_none" + )] + pub token: Option, + #[serde(alias = "$distinct_id", skip_serializing_if = "Option::is_none")] + pub distinct_id: Option, + pub geoip_disable: Option, + #[serde(default)] + pub person_properties: Option>, + #[serde(default)] + pub groups: Option>, + // TODO: better type this since we know its going to be a nested json + #[serde(default)] + pub group_properties: Option>, + #[serde(alias = "$anon_distinct_id", skip_serializing_if = "Option::is_none")] + pub anon_distinct_id: Option, +} + +impl FlagRequest { + /// Takes a request payload and tries to decompress and unmarshall it. + /// While posthog-js sends a compression query param, a sizable portion of requests + /// fail due to it being missing when the body is compressed. + /// Instead of trusting the parameter, we peek at the payload's first three bytes to + /// detect gzip, fallback to uncompressed utf8 otherwise. + #[instrument(skip_all)] + pub fn from_bytes(bytes: Bytes) -> Result { + tracing::debug!(len = bytes.len(), "decoding new request"); + // TODO: Add base64 decoding + let payload = String::from_utf8(bytes.into()).map_err(|e| { + tracing::error!("failed to decode body: {}", e); + FlagError::RequestDecodingError(String::from("invalid body encoding")) + })?; + + tracing::debug!(json = payload, "decoded event data"); + Ok(serde_json::from_str::(&payload)?) + } + + pub fn extract_and_verify_token(&self) -> Result { + let token = match self { + FlagRequest { token: Some(token), .. } => token.to_string(), + _ => return Err(FlagError::NoTokenError), + }; + // TODO: Get tokens from redis, confirm this one is valid + // validate_token(&token)?; + Ok(token) + } + +} + + +#[cfg(test)] +mod tests { + use base64::Engine as _; + use bytes::Bytes; + use rand::distributions::Alphanumeric; + use rand::Rng; + use serde_json::json; + + use super::FlagError; + use super::FlagRequest; + + #[test] + fn extract_and_verify_token() { + let parse_and_extract = |input: &'static str| -> Result { + FlagRequest::from_bytes(input.into()) + .expect("failed to parse") + .extract_and_verify_token() + }; + + let assert_extracted_token = |input: &'static str, expected: &str| { + let id = parse_and_extract(input).expect("failed to extract"); + assert_eq!(id, expected); + }; + + // Return NoTokenError if not found + assert!(matches!( + parse_and_extract(r#"{"distinct_id": "xyz"}"#), + Err(FlagError::NoTokenError) + )); + + // Return TokenValidationError if token empty + assert!(matches!( + parse_and_extract(r#"{"api_key": "", "batch":[{"event": "e"}]}"#), + Err(FlagError::TokenValidationError) + )); + + // Return TokenValidationError if personal apikey + assert!(matches!( + parse_and_extract(r#"[{"event": "e", "token": "phx_hellothere"}]"#), + Ok(_) + )); + + + // Return token from array if consistent + assert_extracted_token( + r#"[{"event":"e","token":"token1"},{"event":"e","token":"token1"}]"#, + "token1", + ); + + // Return token from batch if present + assert_extracted_token( + r#"{"batch":[{"event":"e","token":"token1"}],"api_key":"batched"}"#, + "batched", + ); + + // Return token from single event if present + assert_extracted_token(r#"{"event":"e","$token":"single_token"}"#, "single_token"); + assert_extracted_token(r#"{"event":"e","api_key":"single_token"}"#, "single_token"); + } +} From b5591f5630b317028b3e24203dda2c9c4881c3e9 Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Wed, 1 May 2024 17:25:55 +0100 Subject: [PATCH 2/7] format --- feature-flags/src/api.rs | 13 ++++++------- feature-flags/src/lib.rs | 6 +++--- feature-flags/src/main.rs | 5 +---- feature-flags/src/router.rs | 24 ++++-------------------- feature-flags/src/server.rs | 5 +---- feature-flags/src/v0_endpoint.rs | 15 +++++++++------ feature-flags/src/v0_request.rs | 17 +++++++---------- 7 files changed, 31 insertions(+), 54 deletions(-) diff --git a/feature-flags/src/api.rs b/feature-flags/src/api.rs index dc016b2..c94eed6 100644 --- a/feature-flags/src/api.rs +++ b/feature-flags/src/api.rs @@ -15,7 +15,7 @@ pub enum FlagsResponseCode { pub struct FlagsResponse { pub error_while_computing_flags: bool, // TODO: better typing here, support bool responses - pub feature_flags: HashMap + pub feature_flags: HashMap, } #[derive(Error, Debug)] @@ -47,13 +47,12 @@ impl IntoResponse for FlagError { | FlagError::EmptyDistinctId | FlagError::MissingDistinctId => (StatusCode::BAD_REQUEST, self.to_string()), - FlagError::NoTokenError - | FlagError::TokenValidationError => (StatusCode::UNAUTHORIZED, self.to_string()), - - FlagError::RateLimited => { - (StatusCode::TOO_MANY_REQUESTS, self.to_string()) + FlagError::NoTokenError | FlagError::TokenValidationError => { + (StatusCode::UNAUTHORIZED, self.to_string()) } + + FlagError::RateLimited => (StatusCode::TOO_MANY_REQUESTS, self.to_string()), } .into_response() } -} \ No newline at end of file +} diff --git a/feature-flags/src/lib.rs b/feature-flags/src/lib.rs index ebda41c..9175b5c 100644 --- a/feature-flags/src/lib.rs +++ b/feature-flags/src/lib.rs @@ -1,7 +1,7 @@ +pub mod api; pub mod config; -pub mod server; +pub mod redis; pub mod router; +pub mod server; pub mod v0_endpoint; pub mod v0_request; -pub mod api; -pub mod redis; \ No newline at end of file diff --git a/feature-flags/src/main.rs b/feature-flags/src/main.rs index 716f464..980db69 100644 --- a/feature-flags/src/main.rs +++ b/feature-flags/src/main.rs @@ -1,4 +1,3 @@ - use envconfig::Envconfig; use tokio::signal; use tracing_subscriber::layer::SubscriberExt; @@ -30,9 +29,7 @@ async fn main() { // Basic logging for now: // - stdout with a level configured by the RUST_LOG envvar (default=ERROR) let log_layer = tracing_subscriber::fmt::layer().with_filter(EnvFilter::from_default_env()); - tracing_subscriber::registry() - .with(log_layer) - .init(); + tracing_subscriber::registry().with(log_layer).init(); // Open the TCP port and start the server let listener = tokio::net::TcpListener::bind(config.address) diff --git a/feature-flags/src/router.rs b/feature-flags/src/router.rs index 0088968..33c7a36 100644 --- a/feature-flags/src/router.rs +++ b/feature-flags/src/router.rs @@ -7,10 +7,7 @@ use axum::{ Router, }; -use crate::{ - redis::Client, v0_endpoint, -}; - +use crate::{redis::Client, v0_endpoint}; #[derive(Clone)] pub struct State { @@ -18,16 +15,8 @@ pub struct State { // TODO: Add pgClient when ready } - - -pub fn router< - R: Client + Send + Sync + 'static, ->( - redis: Arc, -) -> Router { - let state = State { - redis, - }; +pub fn router(redis: Arc) -> Router { + let state = State { redis }; // // Very permissive CORS policy, as old SDK versions // // and reverse proxies might send funky headers. @@ -38,12 +27,7 @@ pub fn router< // .allow_origin(AllowOrigin::mirror_request()); let router = Router::new() - .route( - "/flags", - post(v0_endpoint::flags) - .get(v0_endpoint::flags) - ) - + .route("/flags", post(v0_endpoint::flags).get(v0_endpoint::flags)) .with_state(state); router diff --git a/feature-flags/src/server.rs b/feature-flags/src/server.rs index fcebce8..ffe6b0e 100644 --- a/feature-flags/src/server.rs +++ b/feature-flags/src/server.rs @@ -13,13 +13,10 @@ pub async fn serve(config: Config, listener: TcpListener, shutdown: F) where F: Future + Send + 'static, { - let redis_client = Arc::new(RedisClient::new(config.redis_url).expect("failed to create redis client")); - let app = router::router( - redis_client, - ); + let app = router::router(redis_client); // run our app with hyper // `axum::Server` is a re-export of `hyper::Server` diff --git a/feature-flags/src/v0_endpoint.rs b/feature-flags/src/v0_endpoint.rs index 837621d..2d71e9e 100644 --- a/feature-flags/src/v0_endpoint.rs +++ b/feature-flags/src/v0_endpoint.rs @@ -11,11 +11,10 @@ use axum_client_ip::InsecureClientIp; use base64::Engine; use tracing::instrument; - use crate::{ api::{FlagError, FlagsResponse}, router, - v0_request::{FlagsQueryParams, FlagRequest}, + v0_request::{FlagRequest, FlagsQueryParams}, }; /// Feature flag evaluation endpoint. @@ -52,7 +51,6 @@ pub async fn flags( .get("content-encoding") .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); - tracing::Span::current().record("user_agent", user_agent); tracing::Span::current().record("content_encoding", content_encoding); // tracing::Span::current().record("version", meta.lib_version.clone()); @@ -64,7 +62,9 @@ pub async fn flags( .map_or("", |v| v.to_str().unwrap_or("")) { "application/x-www-form-urlencoded" => { - return Err(FlagError::RequestDecodingError(String::from("invalid form data"))); + return Err(FlagError::RequestDecodingError(String::from( + "invalid form data", + ))); } ct => { tracing::Span::current().record("content_type", ct); @@ -83,6 +83,9 @@ pub async fn flags( Ok(Json(FlagsResponse { error_while_computing_flags: false, - feature_flags: HashMap::from([("beta-feature".to_string(), "variant-1".to_string()), ("rollout-flag".to_string(), true.to_string())]), + feature_flags: HashMap::from([ + ("beta-feature".to_string(), "variant-1".to_string()), + ("rollout-flag".to_string(), true.to_string()), + ]), })) -} \ No newline at end of file +} diff --git a/feature-flags/src/v0_request.rs b/feature-flags/src/v0_request.rs index bdcd35d..36177a7 100644 --- a/feature-flags/src/v0_request.rs +++ b/feature-flags/src/v0_request.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap}; +use std::collections::HashMap; // use std::io::prelude::*; use bytes::Bytes; @@ -11,10 +11,8 @@ use tracing::instrument; use crate::api::FlagError; - #[derive(Deserialize, Default)] pub struct FlagsQueryParams { - #[serde(alias = "v")] pub version: Option, @@ -55,9 +53,9 @@ impl FlagRequest { tracing::debug!(len = bytes.len(), "decoding new request"); // TODO: Add base64 decoding let payload = String::from_utf8(bytes.into()).map_err(|e| { - tracing::error!("failed to decode body: {}", e); - FlagError::RequestDecodingError(String::from("invalid body encoding")) - })?; + tracing::error!("failed to decode body: {}", e); + FlagError::RequestDecodingError(String::from("invalid body encoding")) + })?; tracing::debug!(json = payload, "decoded event data"); Ok(serde_json::from_str::(&payload)?) @@ -65,17 +63,17 @@ impl FlagRequest { pub fn extract_and_verify_token(&self) -> Result { let token = match self { - FlagRequest { token: Some(token), .. } => token.to_string(), + FlagRequest { + token: Some(token), .. + } => token.to_string(), _ => return Err(FlagError::NoTokenError), }; // TODO: Get tokens from redis, confirm this one is valid // validate_token(&token)?; Ok(token) } - } - #[cfg(test)] mod tests { use base64::Engine as _; @@ -118,7 +116,6 @@ mod tests { Ok(_) )); - // Return token from array if consistent assert_extracted_token( r#"[{"event":"e","token":"token1"},{"event":"e","token":"token1"}]"#, From b3308d35260ae822249ce32f984d7e6d568cdf4a Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Wed, 1 May 2024 17:37:10 +0100 Subject: [PATCH 3/7] rm old test --- feature-flags/src/v0_endpoint.rs | 2 +- feature-flags/src/v0_request.rs | 60 -------------------------------- 2 files changed, 1 insertion(+), 61 deletions(-) diff --git a/feature-flags/src/v0_endpoint.rs b/feature-flags/src/v0_endpoint.rs index 2d71e9e..a677278 100644 --- a/feature-flags/src/v0_endpoint.rs +++ b/feature-flags/src/v0_endpoint.rs @@ -53,7 +53,7 @@ pub async fn flags( tracing::Span::current().record("user_agent", user_agent); tracing::Span::current().record("content_encoding", content_encoding); - // tracing::Span::current().record("version", meta.lib_version.clone()); + tracing::Span::current().record("version", meta.version.clone()); tracing::Span::current().record("method", method.as_str()); tracing::Span::current().record("path", path.as_str().trim_end_matches('/')); diff --git a/feature-flags/src/v0_request.rs b/feature-flags/src/v0_request.rs index 36177a7..7a869d0 100644 --- a/feature-flags/src/v0_request.rs +++ b/feature-flags/src/v0_request.rs @@ -73,63 +73,3 @@ impl FlagRequest { Ok(token) } } - -#[cfg(test)] -mod tests { - use base64::Engine as _; - use bytes::Bytes; - use rand::distributions::Alphanumeric; - use rand::Rng; - use serde_json::json; - - use super::FlagError; - use super::FlagRequest; - - #[test] - fn extract_and_verify_token() { - let parse_and_extract = |input: &'static str| -> Result { - FlagRequest::from_bytes(input.into()) - .expect("failed to parse") - .extract_and_verify_token() - }; - - let assert_extracted_token = |input: &'static str, expected: &str| { - let id = parse_and_extract(input).expect("failed to extract"); - assert_eq!(id, expected); - }; - - // Return NoTokenError if not found - assert!(matches!( - parse_and_extract(r#"{"distinct_id": "xyz"}"#), - Err(FlagError::NoTokenError) - )); - - // Return TokenValidationError if token empty - assert!(matches!( - parse_and_extract(r#"{"api_key": "", "batch":[{"event": "e"}]}"#), - Err(FlagError::TokenValidationError) - )); - - // Return TokenValidationError if personal apikey - assert!(matches!( - parse_and_extract(r#"[{"event": "e", "token": "phx_hellothere"}]"#), - Ok(_) - )); - - // Return token from array if consistent - assert_extracted_token( - r#"[{"event":"e","token":"token1"},{"event":"e","token":"token1"}]"#, - "token1", - ); - - // Return token from batch if present - assert_extracted_token( - r#"{"batch":[{"event":"e","token":"token1"}],"api_key":"batched"}"#, - "batched", - ); - - // Return token from single event if present - assert_extracted_token(r#"{"event":"e","$token":"single_token"}"#, "single_token"); - assert_extracted_token(r#"{"event":"e","api_key":"single_token"}"#, "single_token"); - } -} From 68fe2bbfd2e191b39825c6449c3c49ef1c849dee Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Thu, 2 May 2024 18:00:48 +0100 Subject: [PATCH 4/7] clean add test --- Cargo.lock | 33 ++++++++-------- Cargo.toml | 2 +- feature-flags/Cargo.toml | 11 +++--- feature-flags/src/config.rs | 7 +--- feature-flags/src/router.rs | 7 +--- feature-flags/src/v0_endpoint.rs | 6 +-- feature-flags/src/v0_request.rs | 7 ---- feature-flags/tests/common.rs | 66 +++++++++++++++++++++++++++++++ feature-flags/tests/test_flags.rs | 43 ++++++++++++++++++++ 9 files changed, 137 insertions(+), 45 deletions(-) create mode 100644 feature-flags/tests/common.rs create mode 100644 feature-flags/tests/test_flags.rs diff --git a/Cargo.lock b/Cargo.lock index 4de877a..5fb1287 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -182,7 +182,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.1.0", + "hyper 1.3.1", "hyper-util", "itoa", "matchit", @@ -273,7 +273,7 @@ dependencies = [ "bytes", "http 1.1.0", "http-body 1.0.0", - "hyper 1.1.0", + "hyper 1.3.1", "reqwest 0.11.24", "serde", "tokio", @@ -352,9 +352,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "capture" @@ -711,21 +711,19 @@ name = "feature-flags" version = "0.1.0" dependencies = [ "anyhow", + "assert-json-diff", "async-trait", "axum 0.7.5", "axum-client-ip", "base64 0.22.0", "bytes", "envconfig", - "flate2", - "governor", - "metrics", + "once_cell", "rand", - "rdkafka", "redis", + "reqwest 0.12.3", "serde", "serde_json", - "serde_urlencoded", "thiserror", "tokio", "tracing", @@ -1266,9 +1264,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.1.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5aa53871fc917b1a9ed87b683a5d86db645e23acb32c2e0785a353e522fb75" +checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" dependencies = [ "bytes", "futures-channel", @@ -1280,6 +1278,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", + "smallvec", "tokio", "want", ] @@ -1318,7 +1317,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper 1.1.0", + "hyper 1.3.1", "hyper-util", "native-tls", "tokio", @@ -1337,7 +1336,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.0", - "hyper 1.1.0", + "hyper 1.3.1", "pin-project-lite", "socket2 0.5.5", "tokio", @@ -1559,7 +1558,7 @@ checksum = "5d58e362dc7206e9456ddbcdbd53c71ba441020e62104703075a69151e38d85f" dependencies = [ "base64 0.22.0", "http-body-util", - "hyper 1.1.0", + "hyper 1.3.1", "hyper-tls", "hyper-util", "indexmap 2.2.2", @@ -2350,7 +2349,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.1.0", + "hyper 1.3.1", "hyper-tls", "hyper-util", "ipnet", @@ -3092,9 +3091,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.36.0" +version = "1.37.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" +checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" dependencies = [ "backtrace", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 265fe5f..54e216a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ metrics = "0.22.0" metrics-exporter-prometheus = "0.14.0" rand = "0.8.5" rdkafka = { version = "0.36.0", features = ["cmake-build", "ssl", "tracing"] } -reqwest = { version = "0.12.3" } +reqwest = { version = "0.12.3", features = ["json"] } serde = { version = "1.0", features = ["derive"] } serde_derive = { version = "1.0" } serde_json = { version = "1.0" } diff --git a/feature-flags/Cargo.toml b/feature-flags/Cargo.toml index 088b069..f1f03ac 100644 --- a/feature-flags/Cargo.toml +++ b/feature-flags/Cargo.toml @@ -16,11 +16,7 @@ tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } base64 = { workspace = true } bytes = { workspace = true } -flate2 = { workspace = true } -governor = { workspace = true } -metrics = { workspace = true } rand = { workspace = true } -rdkafka = { workspace = true } redis = { version = "0.23.3", features = [ "tokio-comp", "cluster", @@ -28,8 +24,13 @@ redis = { version = "0.23.3", features = [ ] } serde = { workspace = true } serde_json = { workspace = true } -serde_urlencoded = { workspace = true } thiserror = { workspace = true } [lints] workspace = true + +[dev-dependencies] +assert-json-diff = { workspace = true } +once_cell = "1.18.0" +reqwest = { workspace = true } + diff --git a/feature-flags/src/config.rs b/feature-flags/src/config.rs index 46d2983..3fa6f50 100644 --- a/feature-flags/src/config.rs +++ b/feature-flags/src/config.rs @@ -1,13 +1,10 @@ -use std::{net::SocketAddr, num::NonZeroU32}; +use std::net::SocketAddr; use envconfig::Envconfig; #[derive(Envconfig, Clone)] pub struct Config { - #[envconfig(default = "false")] - pub print_sink: bool, - - #[envconfig(default = "127.0.0.1:3001")] + #[envconfig(default = "127.0.0.1:0")] pub address: SocketAddr, #[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")] diff --git a/feature-flags/src/router.rs b/feature-flags/src/router.rs index 33c7a36..58dc8cb 100644 --- a/feature-flags/src/router.rs +++ b/feature-flags/src/router.rs @@ -1,11 +1,6 @@ -use std::future::ready; use std::sync::Arc; -use axum::http::Method; -use axum::{ - routing::{get, post}, - Router, -}; +use axum::{routing::post, Router}; use crate::{redis::Client, v0_endpoint}; diff --git a/feature-flags/src/v0_endpoint.rs b/feature-flags/src/v0_endpoint.rs index a677278..8f77611 100644 --- a/feature-flags/src/v0_endpoint.rs +++ b/feature-flags/src/v0_endpoint.rs @@ -1,6 +1,4 @@ use std::collections::HashMap; -use std::ops::Deref; -use std::sync::Arc; use axum::{debug_handler, Json}; use bytes::Bytes; @@ -8,7 +6,6 @@ use bytes::Bytes; use axum::extract::{MatchedPath, Query, State}; use axum::http::{HeaderMap, Method}; use axum_client_ip::InsecureClientIp; -use base64::Engine; use tracing::instrument; use crate::{ @@ -36,7 +33,7 @@ use crate::{ )] #[debug_handler] pub async fn flags( - state: State, + _state: State, InsecureClientIp(ip): InsecureClientIp, meta: Query, headers: HeaderMap, @@ -56,6 +53,7 @@ pub async fn flags( tracing::Span::current().record("version", meta.version.clone()); tracing::Span::current().record("method", method.as_str()); tracing::Span::current().record("path", path.as_str().trim_end_matches('/')); + tracing::Span::current().record("ip", ip.to_string()); let request = match headers .get("content-type") diff --git a/feature-flags/src/v0_request.rs b/feature-flags/src/v0_request.rs index 7a869d0..f2269df 100644 --- a/feature-flags/src/v0_request.rs +++ b/feature-flags/src/v0_request.rs @@ -1,13 +1,9 @@ use std::collections::HashMap; -// use std::io::prelude::*; use bytes::Bytes; use serde::{Deserialize, Serialize}; use serde_json::Value; -// use time::format_description::well_known::Iso8601; -// use time::OffsetDateTime; use tracing::instrument; -// use uuid::Uuid; use crate::api::FlagError; @@ -15,9 +11,6 @@ use crate::api::FlagError; pub struct FlagsQueryParams { #[serde(alias = "v")] pub version: Option, - - #[serde(alias = "_")] - sent_at: Option, } #[derive(Default, Debug, Deserialize, Serialize)] diff --git a/feature-flags/tests/common.rs b/feature-flags/tests/common.rs new file mode 100644 index 0000000..f66a11f --- /dev/null +++ b/feature-flags/tests/common.rs @@ -0,0 +1,66 @@ +use std::net::SocketAddr; +use std::str::FromStr; +use std::string::ToString; +use std::sync::Arc; + +use once_cell::sync::Lazy; +use rand::distributions::Alphanumeric; +use rand::Rng; +use tokio::net::TcpListener; +use tokio::sync::Notify; + +use feature_flags::config::Config; +use feature_flags::server::serve; + +pub static DEFAULT_CONFIG: Lazy = Lazy::new(|| Config { + address: SocketAddr::from_str("127.0.0.1:0").unwrap(), + redis_url: "redis://localhost:6379/".to_string(), + write_database_url: "postgres://posthog:posthog@localhost:15432/test_database".to_string(), + read_database_url: "postgres://posthog:posthog@localhost:15432/test_database".to_string(), + max_concurrent_jobs: 1024, + max_pg_connections: 100, +}); + +pub struct ServerHandle { + pub addr: SocketAddr, + shutdown: Arc, +} + +impl ServerHandle { + pub async fn for_config(config: Config) -> ServerHandle { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let notify = Arc::new(Notify::new()); + let shutdown = notify.clone(); + + tokio::spawn(async move { + serve(config, listener, async move { notify.notified().await }).await + }); + ServerHandle { addr, shutdown } + } + + pub async fn send_flags_request>(&self, body: T) -> reqwest::Response { + let client = reqwest::Client::new(); + client + .post(format!("http://{:?}/flags", self.addr)) + .body(body) + .send() + .await + .expect("failed to send request") + } +} + +impl Drop for ServerHandle { + fn drop(&mut self) { + self.shutdown.notify_one() + } +} + +pub fn random_string(prefix: &str, length: usize) -> String { + let suffix: String = rand::thread_rng() + .sample_iter(Alphanumeric) + .take(length) + .map(char::from) + .collect(); + format!("{}_{}", prefix, suffix) +} diff --git a/feature-flags/tests/test_flags.rs b/feature-flags/tests/test_flags.rs new file mode 100644 index 0000000..82f41f0 --- /dev/null +++ b/feature-flags/tests/test_flags.rs @@ -0,0 +1,43 @@ +use anyhow::Result; +use assert_json_diff::assert_json_include; + +use reqwest::StatusCode; +use serde_json::{json, Value}; + +use crate::common::*; +mod common; + +#[tokio::test] +async fn it_sends_flag_request() -> Result<()> { + let token = random_string("token", 16); + let distinct_id = "user_distinct_id".to_string(); + + let config = DEFAULT_CONFIG.clone(); + + let server = ServerHandle::for_config(config).await; + + let payload = json!({ + "token": token, + "distinct_id": distinct_id, + "groups": {"group1": "group1"} + }); + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::OK, res.status()); + + // We don't want to deserialize the data into a flagResponse struct here, + // because we want to assert the shape of the raw json data. + let json_data = res.json::().await?; + + assert_json_include!( + actual: json_data, + expected: json!({ + "errorWhileComputingFlags": false, + "featureFlags": { + "beta-feature": "variant-1", + "rollout-flag": "true", + } + }) + ); + + Ok(()) +} From 877e4df335f9f5e2eaf42a14b5ac4cfef55e5716 Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Thu, 2 May 2024 18:07:47 +0100 Subject: [PATCH 5/7] rm not needed dep --- Cargo.lock | 1 - feature-flags/Cargo.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5fb1287..4a12b5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -715,7 +715,6 @@ dependencies = [ "async-trait", "axum 0.7.5", "axum-client-ip", - "base64 0.22.0", "bytes", "envconfig", "once_cell", diff --git a/feature-flags/Cargo.toml b/feature-flags/Cargo.toml index f1f03ac..ddfe070 100644 --- a/feature-flags/Cargo.toml +++ b/feature-flags/Cargo.toml @@ -14,7 +14,6 @@ envconfig = { workspace = true } tokio = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } -base64 = { workspace = true } bytes = { workspace = true } rand = { workspace = true } redis = { version = "0.23.3", features = [ From 777788f96ff42f850b48f2136c29ade756e142e5 Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Fri, 3 May 2024 17:02:37 +0100 Subject: [PATCH 6/7] fix lint --- feature-flags/src/router.rs | 6 ++---- feature-flags/tests/common.rs | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/feature-flags/src/router.rs b/feature-flags/src/router.rs index 58dc8cb..1f5941e 100644 --- a/feature-flags/src/router.rs +++ b/feature-flags/src/router.rs @@ -21,9 +21,7 @@ pub fn router(redis: Arc) -> Router { // .allow_credentials(true) // .allow_origin(AllowOrigin::mirror_request()); - let router = Router::new() + Router::new() .route("/flags", post(v0_endpoint::flags).get(v0_endpoint::flags)) - .with_state(state); - - router + .with_state(state) } diff --git a/feature-flags/tests/common.rs b/feature-flags/tests/common.rs index f66a11f..b9afb87 100644 --- a/feature-flags/tests/common.rs +++ b/feature-flags/tests/common.rs @@ -28,7 +28,7 @@ pub struct ServerHandle { impl ServerHandle { pub async fn for_config(config: Config) -> ServerHandle { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener = TcpListener::bind("127.0.0.1:3001").await.unwrap(); let addr = listener.local_addr().unwrap(); let notify = Arc::new(Notify::new()); let shutdown = notify.clone(); From 1c3481f4b1f79e01d19d454af6be2e488e3a3967 Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Tue, 7 May 2024 15:58:54 +0100 Subject: [PATCH 7/7] address comment --- feature-flags/src/redis.rs | 1 - feature-flags/src/router.rs | 8 -------- feature-flags/tests/common.rs | 2 +- 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/feature-flags/src/redis.rs b/feature-flags/src/redis.rs index 457c008..8c03820 100644 --- a/feature-flags/src/redis.rs +++ b/feature-flags/src/redis.rs @@ -43,7 +43,6 @@ impl Client for RedisClient { } // TODO: Find if there's a better way around this. -// mockall got really annoying with async and results so I'm just gonna do my own #[derive(Clone)] pub struct MockRedisClient { zrangebyscore_ret: Vec, diff --git a/feature-flags/src/router.rs b/feature-flags/src/router.rs index 1f5941e..8824d44 100644 --- a/feature-flags/src/router.rs +++ b/feature-flags/src/router.rs @@ -13,14 +13,6 @@ pub struct State { pub fn router(redis: Arc) -> Router { let state = State { redis }; - // // Very permissive CORS policy, as old SDK versions - // // and reverse proxies might send funky headers. - // let cors = CorsLayer::new() - // .allow_methods([Method::GET, Method::POST, Method::OPTIONS]) - // .allow_headers(AllowHeaders::mirror_request()) - // .allow_credentials(true) - // .allow_origin(AllowOrigin::mirror_request()); - Router::new() .route("/flags", post(v0_endpoint::flags).get(v0_endpoint::flags)) .with_state(state) diff --git a/feature-flags/tests/common.rs b/feature-flags/tests/common.rs index b9afb87..f66a11f 100644 --- a/feature-flags/tests/common.rs +++ b/feature-flags/tests/common.rs @@ -28,7 +28,7 @@ pub struct ServerHandle { impl ServerHandle { pub async fn for_config(config: Config) -> ServerHandle { - let listener = TcpListener::bind("127.0.0.1:3001").await.unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let notify = Arc::new(Notify::new()); let shutdown = notify.clone();