Skip to content

Commit

Permalink
Merge branch 'main' into skyzh/force-switch-handler
Browse files Browse the repository at this point in the history
  • Loading branch information
skyzh authored May 22, 2024
2 parents e02b160 + 014f822 commit d6cb086
Show file tree
Hide file tree
Showing 22 changed files with 844 additions and 107 deletions.
20 changes: 9 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ license = "Apache-2.0"

## All dependency versions, used in the project
[workspace.dependencies]
ahash = "0.8"
anyhow = { version = "1.0", features = ["backtrace"] }
arc-swap = "1.6"
async-compression = { version = "0.4.0", features = ["tokio", "gzip", "zstd"] }
Expand Down Expand Up @@ -74,6 +75,7 @@ clap = { version = "4.0", features = ["derive"] }
comfy-table = "6.1"
const_format = "0.2"
crc32c = "0.6"
crossbeam-deque = "0.8.5"
crossbeam-utils = "0.8.5"
dashmap = { version = "5.5.0", features = ["raw-api"] }
either = "1.8"
Expand Down
45 changes: 26 additions & 19 deletions libs/pageserver_api/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use std::{
collections::HashMap,
io::{BufRead, Read},
num::{NonZeroU64, NonZeroUsize},
str::FromStr,
sync::atomic::AtomicUsize,
time::{Duration, SystemTime},
};
Expand Down Expand Up @@ -334,14 +333,28 @@ pub struct TenantConfig {
/// Unset -> V1
/// -> V2
/// -> CrossValidation -> V2
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(
Eq,
PartialEq,
Debug,
Copy,
Clone,
strum_macros::EnumString,
strum_macros::Display,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
)]
#[strum(serialize_all = "kebab-case")]
pub enum AuxFilePolicy {
/// V1 aux file policy: store everything in AUX_FILE_KEY
#[strum(ascii_case_insensitive)]
V1,
/// V2 aux file policy: store in the AUX_FILE keyspace
#[strum(ascii_case_insensitive)]
V2,
/// Cross validation runs both formats on the write path and does validation
/// on the read path.
#[strum(ascii_case_insensitive)]
CrossValidation,
}

Expand Down Expand Up @@ -407,23 +420,6 @@ impl AuxFilePolicy {
}
}

impl FromStr for AuxFilePolicy {
type Err = anyhow::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.to_lowercase();
if s == "v1" {
Ok(Self::V1)
} else if s == "v2" {
Ok(Self::V2)
} else if s == "crossvalidation" || s == "cross_validation" {
Ok(Self::CrossValidation)
} else {
anyhow::bail!("cannot parse {} to aux file policy", s)
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind")]
pub enum EvictionPolicy {
Expand Down Expand Up @@ -1405,6 +1401,7 @@ impl PagestreamBeMessage {
#[cfg(test)]
mod tests {
use serde_json::json;
use std::str::FromStr;

use super::*;

Expand Down Expand Up @@ -1667,4 +1664,14 @@ mod tests {
AuxFilePolicy::V2
));
}

#[test]
fn test_aux_parse() {
assert_eq!(AuxFilePolicy::from_str("V2").unwrap(), AuxFilePolicy::V2);
assert_eq!(AuxFilePolicy::from_str("v2").unwrap(), AuxFilePolicy::V2);
assert_eq!(
AuxFilePolicy::from_str("cross-validation").unwrap(),
AuxFilePolicy::CrossValidation
);
}
}
4 changes: 3 additions & 1 deletion proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ default = []
testing = []

[dependencies]
ahash.workspace = true
anyhow.workspace = true
async-compression.workspace = true
async-trait.workspace = true
Expand All @@ -24,6 +25,7 @@ camino.workspace = true
chrono.workspace = true
clap.workspace = true
consumption_metrics.workspace = true
crossbeam-deque.workspace = true
dashmap.workspace = true
env_logger.workspace = true
framed-websockets.workspace = true
Expand Down Expand Up @@ -52,7 +54,6 @@ opentelemetry.workspace = true
parking_lot.workspace = true
parquet.workspace = true
parquet_derive.workspace = true
pbkdf2 = { workspace = true, features = ["simple", "std"] }
pin-project-lite.workspace = true
postgres_backend.workspace = true
pq_proto.workspace = true
Expand Down Expand Up @@ -106,6 +107,7 @@ workspace_hack.workspace = true
camino-tempfile.workspace = true
fallible-iterator.workspace = true
tokio-tungstenite.workspace = true
pbkdf2 = { workspace = true, features = ["simple", "std"] }
rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true
Expand Down
10 changes: 7 additions & 3 deletions proxy/src/auth/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,10 @@ async fn authenticate_with_secret(
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let auth_outcome = validate_password_and_exchange(&password, secret).await?;
let ep = EndpointIdInt::from(&info.endpoint);

let auth_outcome =
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
Expand All @@ -386,7 +389,7 @@ async fn authenticate_with_secret(
// Currently, we use it for websocket connections (latency).
if allow_cleartext {
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
return hacks::authenticate_cleartext(ctx, info, client, secret).await;
return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
}

// Finally, proceed with the main auth flow (SCRAM-based).
Expand Down Expand Up @@ -554,7 +557,7 @@ mod tests {
context::RequestMonitoring,
proxy::NeonOptions,
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
scram::ServerSecret,
scram::{threadpool::ThreadPool, ServerSecret},
stream::{PqStream, Stream},
};

Expand Down Expand Up @@ -596,6 +599,7 @@ mod tests {
}

static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
Expand Down
11 changes: 10 additions & 1 deletion proxy/src/auth/backend/hacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use super::{
};
use crate::{
auth::{self, AuthFlow},
config::AuthenticationConfig,
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
sasl,
stream::{self, Stream},
};
Expand All @@ -20,15 +22,22 @@ pub async fn authenticate_cleartext(
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
secret: AuthSecret,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
warn!("cleartext auth flow override is enabled, proceeding");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);

// pause the timer while we communicate with the client
let paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);

let ep = EndpointIdInt::from(&info.endpoint);

let auth_flow = AuthFlow::new(client)
.begin(auth::CleartextPassword(secret))
.begin(auth::CleartextPassword {
secret,
endpoint: ep,
pool: config.thread_pool.clone(),
})
.await?;
drop(paused);
// cleartext auth is only allowed to the ws/http protocol.
Expand Down
24 changes: 19 additions & 5 deletions proxy/src/auth/flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ use crate::{
config::TlsServerEndPoint,
console::AuthSecret,
context::RequestMonitoring,
sasl, scram,
intern::EndpointIdInt,
sasl,
scram::{self, threadpool::ThreadPool},
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::io;
use std::{io, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;

Expand Down Expand Up @@ -53,7 +55,11 @@ impl AuthMethod for PasswordHack {

/// Use clear-text password auth called `password` in docs
/// <https://www.postgresql.org/docs/current/auth-password.html>
pub struct CleartextPassword(pub AuthSecret);
pub struct CleartextPassword {
pub pool: Arc<ThreadPool>,
pub endpoint: EndpointIdInt,
pub secret: AuthSecret,
}

impl AuthMethod for CleartextPassword {
#[inline(always)]
Expand Down Expand Up @@ -126,7 +132,13 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;

let outcome = validate_password_and_exchange(password, self.state.0).await?;
let outcome = validate_password_and_exchange(
&self.state.pool,
self.state.endpoint,
password,
self.state.secret,
)
.await?;

if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
Expand Down Expand Up @@ -181,6 +193,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}

pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
Expand All @@ -194,7 +208,7 @@ pub(crate) async fn validate_password_and_exchange(
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(&scram_secret, password).await?;
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;

let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,
Expand Down
8 changes: 8 additions & 0 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use proxy::redis::cancellation_publisher::RedisPublisherClient;
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use proxy::redis::elasticache;
use proxy::redis::notifications;
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;
Expand Down Expand Up @@ -132,6 +133,9 @@ struct ProxyCliArgs {
/// timeout for scram authentication protocol
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
scram_protocol_timeout: tokio::time::Duration,
/// size of the threadpool for password hashing
#[clap(long, default_value_t = 4)]
scram_thread_pool_size: u8,
/// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
require_client_ip: bool,
Expand Down Expand Up @@ -489,6 +493,9 @@ async fn main() -> anyhow::Result<()> {

/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());

let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
key_path,
Expand Down Expand Up @@ -624,6 +631,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
};
let authentication_config = AuthenticationConfig {
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
Expand Down
2 changes: 2 additions & 0 deletions proxy/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
auth::{self, backend::AuthRateLimiter},
console::locks::ApiLocks,
rate_limiter::RateBucketInfo,
scram::threadpool::ThreadPool,
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
Host,
};
Expand Down Expand Up @@ -61,6 +62,7 @@ pub struct HttpConfig {
}

pub struct AuthenticationConfig {
pub thread_pool: Arc<ThreadPool>,
pub scram_protocol_timeout: tokio::time::Duration,
pub rate_limiter_enabled: bool,
pub rate_limiter: AuthRateLimiter,
Expand Down
Loading

0 comments on commit d6cb086

Please sign in to comment.