diff --git a/Cargo.lock b/Cargo.lock index 9e01084ed6..fc5c708a4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1364,6 +1364,7 @@ version = "0.1.0" dependencies = [ "futures", "linkerd-conditional", + "linkerd-dns-name", "linkerd-error", "linkerd-identity", "linkerd-io", @@ -1403,6 +1404,7 @@ name = "linkerd-meshtls-rustls" version = "0.1.0" dependencies = [ "futures", + "linkerd-dns-name", "linkerd-error", "linkerd-identity", "linkerd-io", @@ -1577,6 +1579,7 @@ version = "0.1.0" dependencies = [ "futures", "http-body", + "linkerd-dns-name", "linkerd-error", "linkerd-identity", "linkerd-metrics", diff --git a/linkerd/app/admin/src/stack.rs b/linkerd/app/admin/src/stack.rs index f703155e63..15b657c24b 100644 --- a/linkerd/app/admin/src/stack.rs +++ b/linkerd/app/admin/src/stack.rs @@ -36,7 +36,7 @@ struct NonHttpClient(Remote); #[derive(Debug, Error)] #[error("Unexpected TLS connection to {} from {}", self.0, self.1)] -struct UnexpectedSni(tls::ServerId, Remote); +struct UnexpectedSni(tls::ServerName, Remote); #[derive(Clone, Debug)] struct Tcp { diff --git a/linkerd/app/gateway/src/http.rs b/linkerd/app/gateway/src/http.rs index 275d92bda7..f5c7f571c0 100644 --- a/linkerd/app/gateway/src/http.rs +++ b/linkerd/app/gateway/src/http.rs @@ -1,7 +1,6 @@ use super::Gateway; use inbound::{GatewayAddr, GatewayDomainInvalid}; use linkerd_app_core::{ - identity, metrics::ServerLabel, profiles, proxy::{ @@ -104,9 +103,9 @@ impl Gateway { // Discard `T` and its associated client-specific metadata. .push_map_target(Target::discard_parent) // Add headers to prevent loops. - .push(NewHttpGateway::layer(identity::LocalId( - self.inbound.identity().name().clone(), - ))) + .push(NewHttpGateway::layer( + self.inbound.identity().server_name().clone().into(), + )) .push_on_service(svc::LoadShed::layer()) .lift_new() // After protocol-downgrade, we need to build an inner stack for diff --git a/linkerd/app/gateway/src/http/gateway.rs b/linkerd/app/gateway/src/http/gateway.rs index a7e579b4b9..266a440966 100644 --- a/linkerd/app/gateway/src/http/gateway.rs +++ b/linkerd/app/gateway/src/http/gateway.rs @@ -1,5 +1,6 @@ use futures::{future, TryFutureExt}; use linkerd_app_core::{ + identity as id, proxy::http, svc::{self, layer}, tls, Error, @@ -15,7 +16,7 @@ use std::{ #[derive(Clone, Debug)] pub(crate) struct NewHttpGateway { inner: N, - local_id: tls::LocalId, + local_id: id::Id, } /// A `Service` middleware that fails requests that would loop. It reads and @@ -25,7 +26,7 @@ pub(crate) struct HttpGateway { inner: S, host: String, client_id: tls::ClientId, - local_id: tls::LocalId, + local_id: id::Id, } type ResponseFuture = Pin> + Send + 'static>>; @@ -33,11 +34,11 @@ type ResponseFuture = Pin> + Send + // === impl NewHttpGateway === impl NewHttpGateway { - pub fn new(inner: N, local_id: tls::LocalId) -> Self { + pub fn new(inner: N, local_id: id::Id) -> Self { Self { inner, local_id } } - pub fn layer(local_id: tls::LocalId) -> impl layer::Layer + Clone { + pub fn layer(local_id: id::Id) -> impl layer::Layer + Clone { layer::mk(move |inner| Self::new(inner, local_id.clone())) } } @@ -77,7 +78,7 @@ where #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { // If the client ID is the same as the gateway's, then we're in a loop. - if *self.client_id == *self.local_id { + if *self.client_id == self.local_id { return Poll::Ready(Err(GatewayLoop.into())); } @@ -95,7 +96,7 @@ where { if let Some(by) = fwd_by(forwarded) { tracing::info!(%forwarded); - if by == self.local_id.as_str() { + if by == self.local_id.to_str() { return Box::pin(future::err(GatewayLoop.into())); } } diff --git a/linkerd/app/gateway/src/http/tests.rs b/linkerd/app/gateway/src/http/tests.rs index 1a9423b842..17a2a9e684 100644 --- a/linkerd/app/gateway/src/http/tests.rs +++ b/linkerd/app/gateway/src/http/tests.rs @@ -219,7 +219,7 @@ impl Test { let new = NewHttpGateway::new( move |_: _| outbound.clone(), - tls::LocalId("gateway.id.test".parse().unwrap()), + "gateway.id.test".parse().unwrap(), ); #[derive(Clone, Debug)] diff --git a/linkerd/app/inbound/src/detect.rs b/linkerd/app/inbound/src/detect.rs index 0341d81b08..bc87d07be9 100644 --- a/linkerd/app/inbound/src/detect.rs +++ b/linkerd/app/inbound/src/detect.rs @@ -396,8 +396,8 @@ impl svc::Param for Http { } } -impl svc::Param> for Http { - fn param(&self) -> Option { +impl svc::Param> for Http { + fn param(&self) -> Option { self.tls .status .value() diff --git a/linkerd/app/inbound/src/http.rs b/linkerd/app/inbound/src/http.rs index 9471cb1d94..0cf11764fc 100644 --- a/linkerd/app/inbound/src/http.rs +++ b/linkerd/app/inbound/src/http.rs @@ -252,8 +252,8 @@ pub mod fuzz { } } - impl svc::Param> for Target { - fn param(&self) -> Option { + impl svc::Param> for Target { + fn param(&self) -> Option { None } } diff --git a/linkerd/app/inbound/src/http/set_identity_header.rs b/linkerd/app/inbound/src/http/set_identity_header.rs index 9db9f3014b..0a60eb53f2 100644 --- a/linkerd/app/inbound/src/http/set_identity_header.rs +++ b/linkerd/app/inbound/src/http/set_identity_header.rs @@ -43,7 +43,7 @@ where .and_then(|tls| match tls { tls::ServerTls::Established { client_id, .. } => { client_id.as_ref().and_then(|id| { - match http::HeaderValue::from_str(id.as_str()) { + match http::HeaderValue::from_str(&id.to_str()) { Ok(v) => Some(v), Err(error) => { tracing::warn!(%error, "identity not a valid header value"); diff --git a/linkerd/app/inbound/src/http/tests.rs b/linkerd/app/inbound/src/http/tests.rs index d21f005af5..f5c9351240 100644 --- a/linkerd/app/inbound/src/http/tests.rs +++ b/linkerd/app/inbound/src/http/tests.rs @@ -770,8 +770,8 @@ impl svc::Param for Target { } } -impl svc::Param> for Target { - fn param(&self) -> Option { +impl svc::Param> for Target { + fn param(&self) -> Option { None } } diff --git a/linkerd/app/inbound/src/policy.rs b/linkerd/app/inbound/src/policy.rs index d267418204..be55a18f3b 100644 --- a/linkerd/app/inbound/src/policy.rs +++ b/linkerd/app/inbound/src/policy.rs @@ -177,7 +177,8 @@ fn is_authorized( client_id: Some(tls::server::ClientId(ref id)), .. }) => { - identities.contains(id.as_str()) || suffixes.iter().any(|s| s.contains(id.as_str())) + identities.contains(&*id.to_str()) + || suffixes.iter().any(|s| s.contains(&id.to_str())) } _ => false, }, diff --git a/linkerd/app/outbound/src/http/concrete.rs b/linkerd/app/outbound/src/http/concrete.rs index e690530f6e..c2fc6c75a6 100644 --- a/linkerd/app/outbound/src/http/concrete.rs +++ b/linkerd/app/outbound/src/http/concrete.rs @@ -375,15 +375,13 @@ impl svc::Param for Endpoint { .identity() .cloned() .map(move |server_id| { - tls::ConditionalClientTls::Some(tls::ClientTls { - server_id, - alpn: if use_transport_header { - use linkerd_app_core::transport_header::PROTOCOL; - Some(tls::client::AlpnProtocols(vec![PROTOCOL.into()])) - } else { - None - }, - }) + let alpn = if use_transport_header { + use linkerd_app_core::transport_header::PROTOCOL; + Some(tls::client::AlpnProtocols(vec![PROTOCOL.into()])) + } else { + None + }; + tls::ConditionalClientTls::Some(tls::ClientTls::new(server_id, alpn)) }) .unwrap_or(tls::ConditionalClientTls::None( tls::NoClientTls::NotProvidedByServiceDiscovery, diff --git a/linkerd/app/outbound/src/http/require_id_header.rs b/linkerd/app/outbound/src/http/require_id_header.rs index bc9c431039..8ac1cdc5c3 100644 --- a/linkerd/app/outbound/src/http/require_id_header.rs +++ b/linkerd/app/outbound/src/http/require_id_header.rs @@ -1,5 +1,5 @@ use futures::{future, TryFutureExt}; -use linkerd_app_core::{identity, svc, tls, Conditional, Error}; +use linkerd_app_core::{dns, identity, svc, tls, Conditional, Error}; use std::task::{Context, Poll}; use thiserror::Error; use tracing::{debug, trace}; @@ -57,9 +57,10 @@ type ResponseFuture = impl RequireIdentity { #[inline] - fn extract_id(req: &mut http::Request) -> Option { + fn extract_id(req: &mut http::Request) -> Option { let v = req.headers_mut().remove(HEADER_NAME)?; - v.to_str().ok()?.parse().ok() + let n = v.to_str().ok()?.parse::().ok()?; + Some(n.into()) } } diff --git a/linkerd/app/outbound/src/opaq/concrete.rs b/linkerd/app/outbound/src/opaq/concrete.rs index c71bc67b85..566390afea 100644 --- a/linkerd/app/outbound/src/opaq/concrete.rs +++ b/linkerd/app/outbound/src/opaq/concrete.rs @@ -284,15 +284,13 @@ impl svc::Param for Endpoint { .identity() .cloned() .map(move |server_id| { - tls::ConditionalClientTls::Some(tls::ClientTls { - server_id, - alpn: if use_transport_header { - use linkerd_app_core::transport_header::PROTOCOL; - Some(tls::client::AlpnProtocols(vec![PROTOCOL.into()])) - } else { - None - }, - }) + let alpn = if use_transport_header { + use linkerd_app_core::transport_header::PROTOCOL; + Some(tls::client::AlpnProtocols(vec![PROTOCOL.into()])) + } else { + None + }; + tls::ConditionalClientTls::Some(tls::ClientTls::new(server_id, alpn)) }) .unwrap_or(tls::ConditionalClientTls::None( tls::NoClientTls::NotProvidedByServiceDiscovery, diff --git a/linkerd/app/outbound/src/tcp/tagged_transport.rs b/linkerd/app/outbound/src/tcp/tagged_transport.rs index fc053592af..d10d0a7ba2 100644 --- a/linkerd/app/outbound/src/tcp/tagged_transport.rs +++ b/linkerd/app/outbound/src/tcp/tagged_transport.rs @@ -142,7 +142,6 @@ mod test { use super::*; use futures::future; use linkerd_app_core::{ - identity, io::{self, AsyncWriteExt}, tls, transport::{ClientAddr, Local}, @@ -163,12 +162,10 @@ mod test { self.server_id .clone() .map(|server_id| { - tls::ConditionalClientTls::Some(tls::ClientTls { - server_id, - alpn: Some(tls::client::AlpnProtocols(vec![ - transport_header::PROTOCOL.into() - ])), - }) + let alpn = Some(tls::client::AlpnProtocols(vec![ + transport_header::PROTOCOL.into() + ])); + tls::ConditionalClientTls::Some(tls::ClientTls::new(server_id, alpn)) }) .unwrap_or(tls::ConditionalClientTls::None( tls::NoClientTls::NotProvidedByServiceDiscovery, @@ -261,9 +258,7 @@ mod test { let e = Endpoint { port_override: Some(4143), - server_id: Some(tls::ServerId( - identity::Name::from_str("server.id").unwrap(), - )), + server_id: Some(tls::ServerId("server.id".parse().unwrap())), authority: None, proto: None, }; @@ -285,9 +280,7 @@ mod test { let e = Endpoint { port_override: Some(4143), - server_id: Some(tls::ServerId( - identity::Name::from_str("server.id").unwrap(), - )), + server_id: Some(tls::ServerId("server.id".parse().unwrap())), authority: Some(http::uri::Authority::from_str("foo.bar.example.com:5555").unwrap()), proto: None, }; @@ -309,9 +302,7 @@ mod test { let e = Endpoint { port_override: Some(4143), - server_id: Some(tls::ServerId( - identity::Name::from_str("server.id").unwrap(), - )), + server_id: Some(tls::ServerId("server.id".parse().unwrap())), authority: None, proto: None, }; @@ -333,9 +324,7 @@ mod test { let e = Endpoint { port_override: Some(4143), - server_id: Some(tls::ServerId( - identity::Name::from_str("server.id").unwrap(), - )), + server_id: Some(tls::ServerId("server.id".parse().unwrap())), authority: None, proto: Some(SessionProtocol::Http1), }; @@ -357,9 +346,7 @@ mod test { let e = Endpoint { port_override: Some(4143), - server_id: Some(tls::ServerId( - identity::Name::from_str("server.id").unwrap(), - )), + server_id: Some(tls::ServerId("server.id".parse().unwrap())), authority: Some(http::uri::Authority::from_str("foo.bar.example.com:5555").unwrap()), proto: Some(SessionProtocol::Http1), }; @@ -381,9 +368,7 @@ mod test { let e = Endpoint { port_override: Some(4143), - server_id: Some(tls::ServerId( - identity::Name::from_str("server.id").unwrap(), - )), + server_id: Some(tls::ServerId("server.id".parse().unwrap())), authority: None, proto: Some(SessionProtocol::Http1), }; diff --git a/linkerd/app/src/env.rs b/linkerd/app/src/env.rs index 479d6243db..969d3c656b 100644 --- a/linkerd/app/src/env.rs +++ b/linkerd/app/src/env.rs @@ -982,8 +982,15 @@ fn parse_port_range_set(s: &str) -> Result, ParseError> { Ok(set) } -pub(super) fn parse_identity(s: &str) -> Result { - identity::Name::from_str(s).map_err(|identity::InvalidName| { +pub(super) fn parse_dns_name(s: &str) -> Result { + s.parse().map_err(|_| { + error!("Not a valid identity name: {}", s); + ParseError::NameError + }) +} + +pub(super) fn parse_identity(s: &str) -> Result { + s.parse().map_err(|_| { error!("Not a valid identity name: {}", s); ParseError::NameError }) @@ -1130,7 +1137,7 @@ pub fn parse_control_addr( base: &str, ) -> Result, EnvError> { let a = parse(strings, &format!("{}_ADDR", base), parse_addr)?; - let n = parse(strings, &format!("{}_NAME", base), parse_identity)?; + let n = parse(strings, &format!("{}_NAME", base), parse_dns_name)?; match (a, n) { (None, None) => Ok(None), (Some(ref addr), _) if addr.is_loopback() => Ok(Some(ControlAddr { @@ -1139,7 +1146,7 @@ pub fn parse_control_addr( })), (Some(addr), Some(name)) => Ok(Some(ControlAddr { addr, - identity: Conditional::Some(tls::ServerId(name).into()), + identity: Conditional::Some(tls::ClientTls::new(tls::ServerId(name.into()), None)), })), _ => { error!("{}_ADDR and {}_NAME must be specified together", base, base); @@ -1165,7 +1172,7 @@ pub fn parse_identity_config( ParseError::InvalidTokenSource }) }); - let li = parse(strings, ENV_IDENTITY_IDENTITY_LOCAL_NAME, parse_identity); + let li = parse(strings, ENV_IDENTITY_IDENTITY_LOCAL_NAME, parse_dns_name); let min_refresh = parse(strings, ENV_IDENTITY_MIN_REFRESH, parse_duration); let max_refresh = parse(strings, ENV_IDENTITY_MAX_REFRESH, parse_duration); @@ -1235,7 +1242,7 @@ pub fn parse_identity_config( max_refresh: max_refresh.unwrap_or(DEFAULT_IDENTITY_MAX_REFRESH), }; let docs = identity::Documents { - id: identity::LocalId(local_name), + server_name: local_name, trust_anchors_pem, key_pkcs8: key?, csr_der: csr?, diff --git a/linkerd/app/src/identity.rs b/linkerd/app/src/identity.rs index 2d476faa1a..3b4052040b 100644 --- a/linkerd/app/src/identity.rs +++ b/linkerd/app/src/identity.rs @@ -1,6 +1,6 @@ pub use linkerd_app_core::identity::{ client::{certify, TokenSource}, - InvalidName, LocalId, Name, + Id, }; use linkerd_app_core::{ control, dns, @@ -25,7 +25,7 @@ pub struct Config { #[derive(Clone)] pub struct Documents { - pub id: LocalId, + pub server_name: dns::Name, pub trust_anchors_pem: String, pub key_pkcs8: Vec, pub csr_der: Vec, @@ -55,7 +55,7 @@ struct NotifyReady { impl Config { pub fn build(self, dns: dns::Resolver, client_metrics: ClientMetrics) -> Result { - let name = (*self.documents.id).clone(); + let name = self.documents.server_name.clone(); let (store, receiver) = Mode::default().watch( name.clone(), &self.documents.trust_anchors_pem, @@ -116,7 +116,7 @@ impl Credentials for NotifyReady { impl std::fmt::Debug for Documents { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Documents") - .field("id", &self.id) + .field("server_name", &self.server_name) .field("trust_anchors_pem", &self.trust_anchors_pem) .finish() } diff --git a/linkerd/app/src/lib.rs b/linkerd/app/src/lib.rs index 1ccde86449..fe9719c4fc 100644 --- a/linkerd/app/src/lib.rs +++ b/linkerd/app/src/lib.rs @@ -316,8 +316,8 @@ impl App { &self.dst } - pub fn local_identity(&self) -> identity::Name { - self.identity.receiver().name().clone() + pub fn local_server_name(&self) -> dns::Name { + self.identity.receiver().server_name().clone() } pub fn identity_addr(&self) -> ControlAddr { @@ -369,7 +369,7 @@ impl App { // Kick off the identity so that the process can become ready. let local = identity.receiver(); - let local_id = local.name().clone(); + let local_name = local.server_name().clone(); let ready = identity.ready(); tokio::spawn( identity @@ -382,7 +382,7 @@ impl App { ready .map(move |()| { latch.release(); - info!(id = %local_id, "Certified identity"); + info!(id = %local_name, "Certified identity"); }) .instrument(info_span!("identity").or_current()), ); diff --git a/linkerd/http-access-log/src/lib.rs b/linkerd/http-access-log/src/lib.rs index 8a785d64fd..1383fd8ebc 100644 --- a/linkerd/http-access-log/src/lib.rs +++ b/linkerd/http-access-log/src/lib.rs @@ -28,7 +28,7 @@ pub struct NewAccessLog { pub struct AccessLogContext { inner: S, client_addr: SocketAddr, - client_id: Option, + client_id: Option, } struct ResponseFutureInner { @@ -118,9 +118,15 @@ where .unwrap_or_default() }; + let client_id: std::borrow::Cow<'_, str> = self + .client_id + .as_ref() + .map(|n| n.to_str()) + .unwrap_or(std::borrow::Cow::Borrowed("-")); + let span = span!(target: TRACE_TARGET, Level::INFO, "http", client.addr = %self.client_addr, - client.id = self.client_id.as_ref().map(|n| n.as_str()).unwrap_or("-"), + client.id = %client_id, timestamp = %now(), method = request.method().as_str(), uri = %request.uri(), diff --git a/linkerd/identity/src/lib.rs b/linkerd/identity/src/lib.rs index 4f65c8f9c8..b86ce23774 100644 --- a/linkerd/identity/src/lib.rs +++ b/linkerd/identity/src/lib.rs @@ -2,12 +2,49 @@ #![forbid(unsafe_code)] mod credentials; -mod local; -mod name; - -pub use self::{ - credentials::{Credentials, DerX509}, - local::LocalId, - name::Name, -}; -pub use linkerd_dns_name::InvalidName; + +pub use self::credentials::{Credentials, DerX509}; + +/// An authenticated endpoint identity descriptor. +/// +/// Practically speaking, this is a DNS-like identity string that matches an +/// x509 DNS SAN. This will in the future be updated to support SPIFFE IDs as +/// well. +/// +/// This isn't restricted to TLS or x509 uses. An authenticated Id could be +/// provided by, e.g., a JWT. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct Id(pub linkerd_dns_name::Name); + +// === impl Id === + +impl std::str::FromStr for Id { + type Err = linkerd_error::Error; + + fn from_str(s: &str) -> Result { + // TODO support spiffe:// URIs. + let n = linkerd_dns_name::Name::from_str(s)?; + if n.ends_with('.') { + return Err(linkerd_dns_name::InvalidName.into()); + } + Ok(Self(n)) + } +} + +impl Id { + pub fn to_str(&self) -> std::borrow::Cow<'_, str> { + self.0.as_str().into() + } +} + +impl From for Id { + fn from(n: linkerd_dns_name::Name) -> Self { + Self(n) + } +} + +impl std::fmt::Display for Id { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.without_trailing_dot().fmt(f) + } +} diff --git a/linkerd/identity/src/local.rs b/linkerd/identity/src/local.rs deleted file mode 100644 index 520e4f0e43..0000000000 --- a/linkerd/identity/src/local.rs +++ /dev/null @@ -1,34 +0,0 @@ -use crate::Name; -use std::{fmt, ops::Deref}; - -/// A newtype for local server identities. -#[derive(Clone, Debug, Eq, PartialEq, Hash)] -pub struct LocalId(pub Name); - -// === impl LocalId === - -impl From for LocalId { - fn from(n: Name) -> Self { - Self(n) - } -} - -impl Deref for LocalId { - type Target = Name; - - fn deref(&self) -> &Name { - &self.0 - } -} - -impl fmt::Display for LocalId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -impl From for Name { - fn from(LocalId(name): LocalId) -> Name { - name - } -} diff --git a/linkerd/identity/src/name.rs b/linkerd/identity/src/name.rs deleted file mode 100644 index 591d0f6af3..0000000000 --- a/linkerd/identity/src/name.rs +++ /dev/null @@ -1,46 +0,0 @@ -use linkerd_dns_name::InvalidName; -use std::{fmt, ops::Deref, str::FromStr, sync::Arc}; - -/// An endpoint's identity. -#[derive(Clone, Eq, PartialEq, Hash)] -pub struct Name(Arc); - -// === impl Name === - -impl From for Name { - fn from(n: linkerd_dns_name::Name) -> Self { - Name(Arc::new(n)) - } -} - -impl FromStr for Name { - type Err = InvalidName; - - fn from_str(s: &str) -> Result { - if s.ends_with('.') { - return Err(InvalidName); // SNI hostnames are implicitly absolute. - } - - linkerd_dns_name::Name::from_str(s).map(|n| Name(Arc::new(n))) - } -} - -impl Deref for Name { - type Target = linkerd_dns_name::Name; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl fmt::Debug for Name { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - fmt::Debug::fmt(&self.0, f) - } -} - -impl fmt::Display for Name { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - fmt::Display::fmt(&self.0, f) - } -} diff --git a/linkerd/meshtls/Cargo.toml b/linkerd/meshtls/Cargo.toml index eba1a9c2ae..a087bd1841 100644 --- a/linkerd/meshtls/Cargo.toml +++ b/linkerd/meshtls/Cargo.toml @@ -15,6 +15,9 @@ __has_any_tls_impls = [] [dependencies] futures = { version = "0.3", default-features = false } +pin-project = "1" + +linkerd-dns-name = { path = "../dns/name" } linkerd-error = { path = "../error" } linkerd-identity = { path = "../identity" } linkerd-io = { path = "../io" } @@ -22,12 +25,12 @@ linkerd-meshtls-boring = { path = "boring", optional = true } linkerd-meshtls-rustls = { path = "rustls", optional = true } linkerd-stack = { path = "../stack" } linkerd-tls = { path = "../tls" } -pin-project = "1" [dev-dependencies] +tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] } +tracing = "0.1" + linkerd-conditional = { path = "../conditional" } linkerd-proxy-transport = { path = "../proxy/transport" } linkerd-tls-test-util = { path = "../tls/test-util" } linkerd-tracing = { path = "../tracing", features = ["ansi"] } -tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] } -tracing = "0.1" diff --git a/linkerd/meshtls/boring/src/client.rs b/linkerd/meshtls/boring/src/client.rs index 2806e34ce4..9415f9af2d 100644 --- a/linkerd/meshtls/boring/src/client.rs +++ b/linkerd/meshtls/boring/src/client.rs @@ -1,8 +1,7 @@ use crate::creds::CredsRx; -use linkerd_identity::Name; use linkerd_io as io; use linkerd_stack::{NewService, Service}; -use linkerd_tls::{client::AlpnProtocols, ClientTls, NegotiatedProtocolRef, ServerId}; +use linkerd_tls::{client::AlpnProtocols, ClientTls, NegotiatedProtocolRef, ServerName}; use std::{future::Future, pin::Pin, sync::Arc, task::Context}; use tracing::debug; @@ -13,7 +12,7 @@ pub struct NewClient(CredsRx); pub struct Connect { rx: CredsRx, alpn: Option]>>, - server_id: Name, + server: ServerName, } pub type ConnectFuture = Pin>> + Send>>; @@ -47,12 +46,10 @@ impl std::fmt::Debug for NewClient { impl Connect { pub(crate) fn new(client_tls: ClientTls, rx: CredsRx) -> Self { - let ServerId(server_id) = client_tls.server_id; - let alpn = client_tls.alpn.map(|AlpnProtocols(ps)| ps.into()); Self { rx, - alpn, - server_id, + alpn: client_tls.alpn.map(|AlpnProtocols(ps)| ps.into()), + server: client_tls.server_name, } } } @@ -70,7 +67,7 @@ where } fn call(&mut self, io: I) -> Self::Future { - let id = self.server_id.clone(); + let server_name = self.server.clone(); let connector = self .rx .borrow() @@ -80,7 +77,7 @@ where let config = conn .configure() .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - let io = tokio_boring::connect(config, id.as_str(), io) + let io = tokio_boring::connect(config, server_name.as_str(), io) .await .map_err(|e| match e.as_io_error() { // TODO(ver) boring should let us take ownership of the error directly. diff --git a/linkerd/meshtls/boring/src/creds.rs b/linkerd/meshtls/boring/src/creds.rs index 5178dba126..ac19be44e4 100644 --- a/linkerd/meshtls/boring/src/creds.rs +++ b/linkerd/meshtls/boring/src/creds.rs @@ -7,13 +7,13 @@ use boring::{ ssl, x509::{store::X509StoreBuilder, X509}, }; +use linkerd_dns_name as dns; use linkerd_error::Result; -use linkerd_identity as id; use std::sync::Arc; use tokio::sync::watch; pub fn watch( - identity: id::Name, + identity: dns::Name, roots_pem: &str, key_pkcs8: &[u8], csr: &[u8], diff --git a/linkerd/meshtls/boring/src/creds/receiver.rs b/linkerd/meshtls/boring/src/creds/receiver.rs index 9e78dea160..953e081bf7 100644 --- a/linkerd/meshtls/boring/src/creds/receiver.rs +++ b/linkerd/meshtls/boring/src/creds/receiver.rs @@ -1,20 +1,20 @@ use super::CredsRx; use crate::{NewClient, Server}; -use linkerd_identity::Name; +use linkerd_dns_name as dns; #[derive(Clone)] pub struct Receiver { - name: Name, + name: dns::Name, rx: CredsRx, } impl Receiver { - pub(crate) fn new(name: Name, rx: CredsRx) -> Self { + pub(crate) fn new(name: dns::Name, rx: CredsRx) -> Self { Self { name, rx } } /// Returns the local identity. - pub fn name(&self) -> &Name { + pub fn server_name(&self) -> &dns::Name { &self.name } diff --git a/linkerd/meshtls/boring/src/creds/store.rs b/linkerd/meshtls/boring/src/creds/store.rs index 80d05b29ba..a9223ea662 100644 --- a/linkerd/meshtls/boring/src/creds/store.rs +++ b/linkerd/meshtls/boring/src/creds/store.rs @@ -1,5 +1,6 @@ use super::{BaseCreds, Certs, Creds, CredsTx}; use boring::x509::{X509StoreContext, X509}; +use linkerd_dns_name as dns; use linkerd_error::Result; use linkerd_identity as id; use std::sync::Arc; @@ -7,14 +8,14 @@ use std::sync::Arc; pub struct Store { creds: Arc, csr: Vec, - name: id::Name, + name: dns::Name, tx: CredsTx, } // === impl Store === impl Store { - pub(super) fn new(creds: Arc, csr: &[u8], name: id::Name, tx: CredsTx) -> Self { + pub(super) fn new(creds: Arc, csr: &[u8], name: dns::Name, tx: CredsTx) -> Self { Self { creds, csr: csr.into(), @@ -26,8 +27,8 @@ impl Store { fn cert_matches_name(&self, cert: &X509) -> bool { for san in cert.subject_alt_names().into_iter().flatten() { if let Some(n) = san.dnsname() { - if let Ok(name) = n.parse::() { - if name == *self.name { + if let Ok(name) = n.parse::() { + if name == self.name { return true; } } diff --git a/linkerd/meshtls/boring/src/server.rs b/linkerd/meshtls/boring/src/server.rs index 623daf849b..0203106667 100644 --- a/linkerd/meshtls/boring/src/server.rs +++ b/linkerd/meshtls/boring/src/server.rs @@ -1,14 +1,14 @@ use crate::creds::CredsRx; -use linkerd_identity::Name; +use linkerd_dns_name as dns; use linkerd_io as io; use linkerd_stack::{Param, Service}; -use linkerd_tls::{ClientId, LocalId, NegotiatedProtocol, ServerTls}; +use linkerd_tls::{ClientId, NegotiatedProtocol, ServerName, ServerTls}; use std::{future::Future, pin::Pin, sync::Arc, task::Context}; use tracing::debug; #[derive(Clone)] pub struct Server { - name: Name, + name: dns::Name, rx: CredsRx, alpn: Option]>>, } @@ -22,7 +22,7 @@ pub struct ServerIo(tokio_boring::SslStream); // === impl Server === impl Server { - pub(crate) fn new(name: Name, rx: CredsRx) -> Self { + pub(crate) fn new(name: dns::Name, rx: CredsRx) -> Self { Self { name, rx, @@ -41,9 +41,9 @@ impl Server { } } -impl Param for Server { - fn param(&self) -> LocalId { - LocalId(self.name.clone()) +impl Param for Server { + fn param(&self) -> ServerName { + ServerName(self.name.clone()) } } @@ -119,7 +119,11 @@ impl ServerIo { None })?; sans.into_iter() - .filter_map(|san| san.dnsname()?.parse().ok()) + .filter_map(|san| { + let dns = san.dnsname()?; + let name = dns.parse::().ok()?; + Some(ClientId(name.into())) + }) .next() .or_else(|| { debug!("Peer certificate missing DNS SANs"); diff --git a/linkerd/meshtls/rustls/Cargo.toml b/linkerd/meshtls/rustls/Cargo.toml index 51e56a1826..1378f158f1 100644 --- a/linkerd/meshtls/rustls/Cargo.toml +++ b/linkerd/meshtls/rustls/Cargo.toml @@ -11,19 +11,21 @@ test-util = ["linkerd-tls-test-util"] [dependencies] futures = { version = "0.3", default-features = false } -linkerd-error = { path = "../../error" } -linkerd-io = { path = "../../io" } -linkerd-identity = { path = "../../identity" } -linkerd-stack = { path = "../../stack" } -linkerd-tls = { path = "../../tls" } -linkerd-tls-test-util = { path = "../../tls/test-util", optional = true } ring = { version = "0.16", features = ["std"] } rustls-pemfile = "1.0" -rustls-webpki = { version = "0.101.5", features = [ "std"] } +rustls-webpki = { version = "0.101.5", features = ["std"] } thiserror = "1" tokio = { version = "1", features = ["macros", "rt", "sync"] } tokio-rustls = { version = "0.24", features = ["dangerous_configuration"] } tracing = "0.1" +linkerd-dns-name = { path = "../../dns/name" } +linkerd-error = { path = "../../error" } +linkerd-io = { path = "../../io" } +linkerd-identity = { path = "../../identity" } +linkerd-stack = { path = "../../stack" } +linkerd-tls = { path = "../../tls" } +linkerd-tls-test-util = { path = "../../tls/test-util", optional = true } + [dev-dependencies] linkerd-tls-test-util = { path = "../../tls/test-util" } diff --git a/linkerd/meshtls/rustls/src/client.rs b/linkerd/meshtls/rustls/src/client.rs index a5eef085fc..7a6d0470e8 100644 --- a/linkerd/meshtls/rustls/src/client.rs +++ b/linkerd/meshtls/rustls/src/client.rs @@ -68,7 +68,7 @@ impl Connect { } }; - let server_id = rustls::ServerName::try_from(client_tls.server_id.as_str()) + let server_id = rustls::ServerName::try_from(&*client_tls.server_id.to_str()) .expect("identity must be a valid DNS name"); Self { server_id, config } diff --git a/linkerd/meshtls/rustls/src/creds.rs b/linkerd/meshtls/rustls/src/creds.rs index 084efecae3..322f0264f1 100644 --- a/linkerd/meshtls/rustls/src/creds.rs +++ b/linkerd/meshtls/rustls/src/creds.rs @@ -2,8 +2,8 @@ mod receiver; mod store; pub use self::{receiver::Receiver, store::Store}; +use linkerd_dns_name as dns; use linkerd_error::Result; -use linkerd_identity as id; use ring::{error::KeyRejected, signature::EcdsaKeyPair}; use std::sync::Arc; use thiserror::Error; @@ -20,7 +20,7 @@ pub struct InvalidKey(KeyRejected); pub struct InvalidTrustRoots(()); pub fn watch( - identity: id::Name, + server_name: dns::Name, roots_pem: &str, key_pkcs8: &[u8], csr: &[u8], @@ -80,13 +80,13 @@ pub fn watch( watch::channel(store::server_config(roots.clone(), empty_resolver)) }; - let rx = Receiver::new(identity.clone(), client_rx, server_rx); + let rx = Receiver::new(server_name.clone(), client_rx, server_rx); let store = Store::new( roots, server_cert_verifier, key, csr, - identity, + server_name, client_tx, server_tx, ); diff --git a/linkerd/meshtls/rustls/src/creds/receiver.rs b/linkerd/meshtls/rustls/src/creds/receiver.rs index 51c1a1f558..f8a0e8d43b 100644 --- a/linkerd/meshtls/rustls/src/creds/receiver.rs +++ b/linkerd/meshtls/rustls/src/creds/receiver.rs @@ -1,5 +1,5 @@ use crate::{NewClient, Server}; -use linkerd_identity::Name; +use linkerd_dns_name as dns; use std::sync::Arc; use tokio::sync::watch; use tokio_rustls::rustls; @@ -7,7 +7,7 @@ use tokio_rustls::rustls; /// Receives TLS config updates to build `NewClient` and `Server` types. #[derive(Clone)] pub struct Receiver { - name: Name, + name: dns::Name, client_rx: watch::Receiver>, server_rx: watch::Receiver>, } @@ -16,7 +16,7 @@ pub struct Receiver { impl Receiver { pub(super) fn new( - name: Name, + name: dns::Name, client_rx: watch::Receiver>, server_rx: watch::Receiver>, ) -> Self { @@ -27,8 +27,8 @@ impl Receiver { } } - /// Returns the local identity. - pub fn name(&self) -> &Name { + /// Returns the local server name (i.e. used in mTLS). + pub fn server_name(&self) -> &dns::Name { &self.name } diff --git a/linkerd/meshtls/rustls/src/creds/store.rs b/linkerd/meshtls/rustls/src/creds/store.rs index 887b61599e..533e1d770a 100644 --- a/linkerd/meshtls/rustls/src/creds/store.rs +++ b/linkerd/meshtls/rustls/src/creds/store.rs @@ -1,4 +1,5 @@ use super::params::*; +use linkerd_dns_name as dns; use linkerd_error::Result; use linkerd_identity as id; use ring::{rand, signature::EcdsaKeyPair}; @@ -12,7 +13,7 @@ pub struct Store { server_cert_verifier: Arc, key: Arc, csr: Arc<[u8]>, - name: id::Name, + server_name: dns::Name, client_tx: watch::Sender>, server_tx: watch::Sender>, } @@ -75,7 +76,7 @@ impl Store { server_cert_verifier: Arc, key: EcdsaKeyPair, csr: &[u8], - name: id::Name, + server_name: dns::Name, client_tx: watch::Sender>, server_tx: watch::Sender>, ) -> Self { @@ -84,7 +85,7 @@ impl Store { key: Arc::new(key), server_cert_verifier, csr: csr.into(), - name, + server_name, client_tx, server_tx, } @@ -105,7 +106,7 @@ impl Store { /// Ensures the certificate is valid for the services we terminate for TLS. This assumes that /// server cert validation does the same or more validation than client cert validation. fn validate(&self, certs: &[rustls::Certificate]) -> Result<()> { - let name = rustls::ServerName::try_from(self.name.as_str()) + let name = rustls::ServerName::try_from(self.server_name.as_str()) .expect("server name must be a valid DNS name"); static NO_OCSP: &[u8] = &[]; let end_entity = &certs[0]; diff --git a/linkerd/meshtls/rustls/src/server.rs b/linkerd/meshtls/rustls/src/server.rs index 1b98216634..60f7b456da 100644 --- a/linkerd/meshtls/rustls/src/server.rs +++ b/linkerd/meshtls/rustls/src/server.rs @@ -1,8 +1,8 @@ use futures::prelude::*; -use linkerd_identity::{LocalId, Name}; +use linkerd_dns_name as dns; use linkerd_io as io; use linkerd_stack::{Param, Service}; -use linkerd_tls::{ClientId, NegotiatedProtocol, NegotiatedProtocolRef, ServerTls}; +use linkerd_tls::{ClientId, NegotiatedProtocol, NegotiatedProtocolRef, ServerName, ServerTls}; use std::{convert::TryFrom, pin::Pin, sync::Arc, task::Context}; use thiserror::Error; use tokio::sync::watch; @@ -12,7 +12,7 @@ use tracing::debug; /// A Service that terminates TLS connections using a dynamically updated server configuration. #[derive(Clone)] pub struct Server { - name: Name, + name: dns::Name, rx: watch::Receiver>, } @@ -29,7 +29,7 @@ pub struct ServerIo(tokio_rustls::server::TlsStream); pub struct LostStore(()); impl Server { - pub(crate) fn new(name: Name, rx: watch::Receiver>) -> Self { + pub(crate) fn new(name: dns::Name, rx: watch::Receiver>) -> Self { Self { name, rx } } @@ -82,9 +82,9 @@ impl Server { } } -impl Param for Server { - fn param(&self) -> LocalId { - LocalId(self.name.clone()) +impl Param for Server { + fn param(&self) -> ServerName { + ServerName(self.name.clone()) } } @@ -144,11 +144,15 @@ fn client_identity(tls: &tokio_rustls::server::TlsStream) -> Option() .map_err(|error| tracing::warn!(%error, "Client certificate contained an invalid DNS name")) - .ok() - .map(ClientId) + .ok()?; + Some(ClientId(n.into())) } // === impl ServerIo === diff --git a/linkerd/meshtls/src/creds.rs b/linkerd/meshtls/src/creds.rs index 5759163b5b..803573f634 100644 --- a/linkerd/meshtls/src/creds.rs +++ b/linkerd/meshtls/src/creds.rs @@ -1,6 +1,7 @@ use crate::{NewClient, Server}; +use linkerd_dns_name as dns; use linkerd_error::Result; -use linkerd_identity::{Credentials, DerX509, Name}; +use linkerd_identity::{Credentials, DerX509}; #[cfg(feature = "boring")] pub use crate::boring; @@ -79,13 +80,13 @@ impl From for Receiver { } impl Receiver { - pub fn name(&self) -> &Name { + pub fn server_name(&self) -> &dns::Name { match self { #[cfg(feature = "boring")] - Self::Boring(receiver) => receiver.name(), + Self::Boring(receiver) => receiver.server_name(), #[cfg(feature = "rustls")] - Self::Rustls(receiver) => receiver.name(), + Self::Rustls(receiver) => receiver.server_name(), #[cfg(not(feature = "__has_any_tls_impls"))] _ => crate::no_tls!(), } diff --git a/linkerd/meshtls/src/lib.rs b/linkerd/meshtls/src/lib.rs index 1ee2c3d754..e7b76b5e75 100644 --- a/linkerd/meshtls/src/lib.rs +++ b/linkerd/meshtls/src/lib.rs @@ -20,8 +20,8 @@ pub use self::{ client::{ClientIo, Connect, ConnectFuture, NewClient}, server::{Server, ServerIo, TerminateFuture}, }; +use linkerd_dns_name as dns; use linkerd_error::{Error, Result}; -use linkerd_identity::Name; use std::str::FromStr; #[cfg(feature = "boring")] @@ -82,7 +82,7 @@ impl Default for Mode { impl Mode { pub fn watch( self, - identity: Name, + server_name: dns::Name, roots_pem: &str, key_pkcs8: &[u8], csr: &[u8], @@ -90,7 +90,8 @@ impl Mode { match self { #[cfg(feature = "boring")] Self::Boring => { - let (store, receiver) = boring::creds::watch(identity, roots_pem, key_pkcs8, csr)?; + let (store, receiver) = + boring::creds::watch(server_name, roots_pem, key_pkcs8, csr)?; Ok(( creds::Store::Boring(store), creds::Receiver::Boring(receiver), @@ -99,7 +100,8 @@ impl Mode { #[cfg(feature = "rustls")] Self::Rustls => { - let (store, receiver) = rustls::creds::watch(identity, roots_pem, key_pkcs8, csr)?; + let (store, receiver) = + rustls::creds::watch(server_name, roots_pem, key_pkcs8, csr)?; Ok(( creds::Store::Rustls(store), creds::Receiver::Rustls(receiver), @@ -107,7 +109,7 @@ impl Mode { } #[cfg(not(feature = "__has_any_tls_impls"))] - _ => no_tls!(identity, roots_pem, key_pkcs8, csr), + _ => no_tls!(server_name, roots_pem, key_pkcs8, csr), } } } diff --git a/linkerd/meshtls/src/server.rs b/linkerd/meshtls/src/server.rs index cceb9c570a..f4725ee7fa 100644 --- a/linkerd/meshtls/src/server.rs +++ b/linkerd/meshtls/src/server.rs @@ -1,8 +1,7 @@ use linkerd_error::Result; -use linkerd_identity::LocalId; use linkerd_io as io; use linkerd_stack::{Param, Service}; -use linkerd_tls::ServerTls; +use linkerd_tls::{ServerName, ServerTls}; use std::{ future::Future, pin::Pin, @@ -57,9 +56,9 @@ pub enum ServerIo { // === impl Server === -impl Param for Server { +impl Param for Server { #[inline] - fn param(&self) -> LocalId { + fn param(&self) -> ServerName { match self { #[cfg(feature = "boring")] Self::Boring(srv) => srv.param(), diff --git a/linkerd/meshtls/tests/util.rs b/linkerd/meshtls/tests/util.rs index 0152da2aa4..0032ac253e 100644 --- a/linkerd/meshtls/tests/util.rs +++ b/linkerd/meshtls/tests/util.rs @@ -3,8 +3,9 @@ use futures::prelude::*; use linkerd_conditional::Conditional; +use linkerd_dns_name as dns; use linkerd_error::Infallible; -use linkerd_identity::{Credentials, DerX509, Name}; +use linkerd_identity::{Credentials, DerX509}; use linkerd_io::{self as io, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use linkerd_meshtls as meshtls; use linkerd_proxy_transport::{ @@ -60,10 +61,7 @@ pub async fn proxy_to_proxy_tls_works(mode: meshtls::Mode) { .await; assert_eq!( client_result.tls, - Some(Conditional::Some(tls::ClientTls { - server_id, - alpn: None, - })) + Some(Conditional::Some(tls::ClientTls::new(server_id, None))) ); assert_eq!(&client_result.result.expect("pong")[..], PONG); assert_eq!( @@ -82,11 +80,11 @@ pub async fn proxy_to_proxy_tls_pass_through_when_identity_does_not_match(mode: // Misuse the client's identity instead of the server's identity. Any // identity other than `server_tls.server_identity` would work. let (_bar, client_tls, _) = load(mode, &test_util::BAR_NS1); - let sni = test_util::BAR_NS1.name.parse::().unwrap(); + let server_name = test_util::BAR_NS1.name.parse::().unwrap(); let (client_result, server_result) = run_test( client_tls, - Conditional::Some(tls::ServerId(sni.clone())), + Conditional::Some(tls::ServerId(server_name.clone().into())), |conn| write_then_read(conn, PING), server_tls, |(_, conn)| read_then_write(conn, START_OF_TLS.len(), PONG), @@ -100,7 +98,7 @@ pub async fn proxy_to_proxy_tls_pass_through_when_identity_does_not_match(mode: assert_eq!( server_result.tls, Some(Conditional::Some(tls::ServerTls::Passthru { - sni: tls::ServerId(sni) + sni: tls::ServerName(server_name) })) ); assert_eq!(&server_result.result.unwrap()[..], START_OF_TLS); @@ -230,11 +228,18 @@ where // parallels the server side. let (sender, receiver) = mpsc::channel::>(); - let tls = Some(client_server_id.clone().map(Into::into)); + let tls = Some( + client_server_id + .clone() + .map(|s| tls::ClientTls::new(s, None)), + ); let client = async move { let conn = tls::Client::layer(client_tls) .layer(ConnectTcp::new(Keepalive(None))) - .oneshot(Target(server_addr.into(), client_server_id.map(Into::into))) + .oneshot(Target( + server_addr.into(), + client_server_id.map(|s| tls::ClientTls::new(s, None)), + )) .await; match conn { Err(e) => { diff --git a/linkerd/proxy/identity-client/Cargo.toml b/linkerd/proxy/identity-client/Cargo.toml index bf6b8f04c6..e46ee0b4b8 100644 --- a/linkerd/proxy/identity-client/Cargo.toml +++ b/linkerd/proxy/identity-client/Cargo.toml @@ -9,6 +9,7 @@ publish = false [dependencies] futures = { version = "0.3", default-features = false } linkerd2-proxy-api = { version = "0.11", features = ["identity"] } +linkerd-dns-name = { path = "../../dns/name" } linkerd-error = { path = "../../error" } linkerd-identity = { path = "../../identity" } linkerd-metrics = { path = "../../metrics" } diff --git a/linkerd/proxy/identity-client/src/certify.rs b/linkerd/proxy/identity-client/src/certify.rs index 92f11c4232..0577ef796c 100644 --- a/linkerd/proxy/identity-client/src/certify.rs +++ b/linkerd/proxy/identity-client/src/certify.rs @@ -1,8 +1,9 @@ use crate::{Metrics, TokenSource}; use http_body::Body; use linkerd2_proxy_api::identity::{self as api, identity_client::IdentityClient}; +use linkerd_dns_name::Name; use linkerd_error::{Error, Result}; -use linkerd_identity::{Credentials, DerX509, Name}; +use linkerd_identity::{Credentials, DerX509}; use linkerd_stack::NewService; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use thiserror::Error; diff --git a/linkerd/tls/src/client.rs b/linkerd/tls/src/client.rs index 3243393d1e..ca18c84e37 100644 --- a/linkerd/tls/src/client.rs +++ b/linkerd/tls/src/client.rs @@ -1,6 +1,7 @@ -use crate::NegotiatedProtocol; +use crate::{NegotiatedProtocol, ServerName}; use futures::prelude::*; use linkerd_conditional::Conditional; +use linkerd_dns_name as dns; use linkerd_identity as id; use linkerd_io as io; use linkerd_stack::{layer, MakeConnection, NewService, Oneshot, Param, Service, ServiceExt}; @@ -14,13 +15,14 @@ use std::{ }; use tracing::debug; -/// A newtype for target server identities. +/// Describes the authenticated identity of a remote server. #[derive(Clone, Debug, Eq, PartialEq, Hash)] -pub struct ServerId(pub id::Name); +pub struct ServerId(pub id::Id); /// A stack parameter that configures a `Client` to establish a TLS connection. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct ClientTls { + pub server_name: ServerName, pub server_id: ServerId, pub alpn: Option, } @@ -75,11 +77,14 @@ pub struct ConnectMeta { // === impl ClientTls === -impl From for ClientTls { - fn from(server_id: ServerId) -> Self { +impl ClientTls { + // XXX(ver) We'll have to change this when ServerIds are not necessarily DNS names. + pub fn new(server_id: ServerId, alpn: Option) -> Self { + let ServerId(linkerd_identity::Id(name)) = server_id.clone(); Self { + server_name: ServerName(name), server_id, - alpn: None, + alpn, } } } @@ -180,30 +185,37 @@ where // === impl ServerId === -impl From for ServerId { - fn from(n: id::Name) -> Self { - Self(n) +impl From for ServerId { + fn from(id: id::Id) -> Self { + Self(id) } } -impl From for id::Name { - fn from(ServerId(name): ServerId) -> id::Name { - name +impl From for id::Id { + fn from(ServerId(id): ServerId) -> id::Id { + id } } -impl Deref for ServerId { - type Target = id::Name; +impl ServerId { + pub fn to_str(&self) -> std::borrow::Cow<'_, str> { + self.0.to_str() + } +} - fn deref(&self) -> &id::Name { +impl Deref for ServerId { + type Target = id::Id; + fn deref(&self) -> &Self::Target { &self.0 } } impl FromStr for ServerId { - type Err = id::InvalidName; + type Err = dns::InvalidName; fn from_str(s: &str) -> Result { - id::Name::from_str(s).map(ServerId) + // TODO Handle SPIFFE IDs. + let n = dns::Name::from_str(s)?; + Ok(Self(n.into())) } } diff --git a/linkerd/tls/src/lib.rs b/linkerd/tls/src/lib.rs index 9aad11f13d..0e54d86442 100755 --- a/linkerd/tls/src/lib.rs +++ b/linkerd/tls/src/lib.rs @@ -4,13 +4,23 @@ pub mod client; pub mod server; -pub use linkerd_identity::LocalId; - pub use self::{ client::{Client, ClientTls, ConditionalClientTls, ConnectMeta, NoClientTls, ServerId}, server::{ClientId, ConditionalServerTls, NewDetectTls, NoServerTls, ServerTls}, }; +use linkerd_dns_name as dns; + +/// Describes the the Server Name Indication (SNI) value used by both clients +/// and servers. +/// +/// Clients use this type to describe the SNI value to be sent to a server. +/// +/// Servers use this type to describe the SNI value that they expect from +/// clients. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct ServerName(pub dns::Name); + #[derive(Clone, Eq, PartialEq, Hash)] pub struct NegotiatedProtocol(pub Vec); @@ -18,6 +28,8 @@ pub struct NegotiatedProtocol(pub Vec); #[derive(Copy, Clone, Eq, PartialEq, Hash)] pub struct NegotiatedProtocolRef<'t>(pub &'t [u8]); +// === impl NegotiatedProtocol === + impl NegotiatedProtocol { pub fn as_ref(&self) -> NegotiatedProtocolRef<'_> { NegotiatedProtocolRef(&self.0) @@ -50,3 +62,38 @@ impl std::fmt::Debug for NegotiatedProtocolRef<'_> { } } } + +// === impl ServerName === + +impl From for ServerName { + fn from(n: dns::Name) -> Self { + Self(n) + } +} + +impl From for dns::Name { + fn from(ServerName(name): ServerName) -> dns::Name { + name + } +} + +impl std::ops::Deref for ServerName { + type Target = dns::Name; + + fn deref(&self) -> &dns::Name { + &self.0 + } +} + +impl std::str::FromStr for ServerName { + type Err = dns::InvalidName; + fn from_str(s: &str) -> Result { + dns::Name::from_str(s).map(ServerName) + } +} + +impl std::fmt::Display for ServerName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} diff --git a/linkerd/tls/src/server.rs b/linkerd/tls/src/server.rs index e0e0ea6a5e..04862401f9 100644 --- a/linkerd/tls/src/server.rs +++ b/linkerd/tls/src/server.rs @@ -1,6 +1,6 @@ mod client_hello; -use crate::{NegotiatedProtocol, ServerId}; +use crate::{NegotiatedProtocol, ServerName}; use bytes::BytesMut; use futures::prelude::*; use linkerd_conditional::Conditional; @@ -12,16 +12,15 @@ use std::{ fmt, ops::Deref, pin::Pin, - str::FromStr, task::{Context, Poll}, }; use thiserror::Error; use tokio::time::{self, Duration}; use tracing::{debug, trace, warn}; -/// A newtype for remote client idenities. +/// Describes the authenticated identity of a remote client. #[derive(Clone, Debug, Eq, PartialEq, Hash)] -pub struct ClientId(pub id::Name); +pub struct ClientId(pub id::Id); /// Indicates a server-side connection's TLS status. #[derive(Clone, Debug, Eq, PartialEq, Hash)] @@ -31,7 +30,7 @@ pub enum ServerTls { negotiated_protocol: Option, }, Passthru { - sni: ServerId, + sni: ServerName, }, } @@ -134,7 +133,7 @@ where T: Clone + Send + 'static, P: InsertParam + Clone + Send + Sync + 'static, P::Target: Send + 'static, - L: Param + Clone + Send + 'static, + L: Param + Clone + Send + 'static, L: Service, Response = (ServerTls, LIo), Error = io::Error>, L::Future: Send, LIo: io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin + 'static, @@ -164,10 +163,10 @@ where Box::pin(async move { let (sni, io) = detect.await.map_err(|_| ServerTlsTimeoutError(()))??; - let id::LocalId(id) = tls.param(); + let local_server_name = tls.param(); let (peer, io) = match sni { // If we detected an SNI matching this proxy, terminate TLS. - Some(ServerId(sni)) if sni == id => { + Some(sni) if sni == local_server_name => { trace!("Identified local SNI"); let (peer, io) = tls.oneshot(io).await?; (Conditional::Some(peer), EitherIo::Left(io)) @@ -193,7 +192,7 @@ where } /// Peek or buffer the provided stream to determine an SNI value. -async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> +async fn detect_sni(mut io: I) -> io::Result<(Option, DetectIo)> where I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin, { @@ -249,22 +248,22 @@ where // === impl ClientId === -impl From for ClientId { - fn from(n: id::Name) -> Self { +impl From for ClientId { + fn from(n: id::Id) -> Self { Self(n) } } -impl From for id::Name { - fn from(ClientId(name): ClientId) -> id::Name { - name +impl From for id::Id { + fn from(ClientId(id): ClientId) -> id::Id { + id } } impl Deref for ClientId { - type Target = id::Name; + type Target = id::Id; - fn deref(&self) -> &id::Name { + fn deref(&self) -> &id::Id { &self.0 } } @@ -275,10 +274,10 @@ impl fmt::Display for ClientId { } } -impl FromStr for ClientId { - type Err = id::InvalidName; +impl std::str::FromStr for ClientId { + type Err = linkerd_error::Error; fn from_str(s: &str) -> Result { - id::Name::from_str(s).map(Self) + id::Id::from_str(s).map(Self) } } @@ -310,7 +309,6 @@ impl ServerTls { mod tests { use super::*; use linkerd_io::AsyncWriteExt; - use std::str::FromStr; #[tokio::test(flavor = "current_thread")] async fn detect_buffered() { @@ -330,8 +328,7 @@ mod tests { .await .expect("SNI detection must not fail"); - let identity = id::Name::from_str("example.com").unwrap(); - assert_eq!(sni, Some(ServerId(identity))); + assert_eq!(sni, Some(ServerName("example.com".parse().unwrap()))); match io { EitherIo::Left(_) => panic!("Detected IO should be buffered"), diff --git a/linkerd/tls/src/server/client_hello.rs b/linkerd/tls/src/server/client_hello.rs index bc77aa2988..af7491b102 100644 --- a/linkerd/tls/src/server/client_hello.rs +++ b/linkerd/tls/src/server/client_hello.rs @@ -1,5 +1,5 @@ -use crate::ServerId; -use linkerd_identity as id; +use crate::ServerName; +use linkerd_dns_name as dns; use tracing::trace; #[derive(Debug, Eq, PartialEq)] @@ -18,7 +18,7 @@ pub struct Incomplete; /// This assumes that the ClientHello is small and is sent in a single TLS record, which is what all /// reasonable implementations do. (If they were not to, they wouldn't interoperate with picky /// servers.) -pub fn parse_sni(input: &[u8]) -> Result, Incomplete> { +pub fn parse_sni(input: &[u8]) -> Result, Incomplete> { let r = untrusted::Input::from(input).read_all(untrusted::EndOfInput, |input| { let r = extract_sni(input); input.skip_to_end(); // Ignore anything after what we parsed. @@ -28,13 +28,13 @@ pub fn parse_sni(input: &[u8]) -> Result, Incomplete> { Ok(Some(sni)) => { let sni = match std::str::from_utf8(sni.as_slice_less_safe()) .ok() - .and_then(|n| n.parse::().ok()) + .and_then(|n| n.parse::().ok()) { Some(sni) => sni, None => return Ok(None), }; trace!(?sni, "parse_sni: parsed correctly up to SNI"); - Ok(Some(ServerId(sni))) + Ok(Some(ServerName(sni))) } Ok(None) => { trace!("parse_sni: failed to parse up to SNI"); @@ -202,7 +202,7 @@ mod tests { #[test] fn check_all_prefixes() { let input = include_bytes!("testdata/example-com-client-hello.bin"); - let identity = id::Name::from_str("example.com").unwrap(); + let identity = dns::Name::from_str("example.com").unwrap(); let mut i = 0; while let Err(Incomplete) = parse_sni(&input[..i]) { @@ -211,7 +211,10 @@ mod tests { // The same result will be returned for all longer prefixes. for i in i..input.len() { - assert_eq!(Ok(Some(ServerId(identity.clone()))), parse_sni(&input[..i])) + assert_eq!( + Ok(Some(ServerName(identity.clone()))), + parse_sni(&input[..i]) + ) } } } diff --git a/linkerd2-proxy/src/main.rs b/linkerd2-proxy/src/main.rs index 0733cd48c8..2a2a0c7aa2 100644 --- a/linkerd2-proxy/src/main.rs +++ b/linkerd2-proxy/src/main.rs @@ -93,7 +93,8 @@ fn main() { Some(addr) => info!("Tap interface on {}", addr), } - info!("Local identity is {}", app.local_identity()); + // TODO distinguish ServerName and Identity. + info!("Local identity is {}", app.local_server_name()); let addr = app.identity_addr(); match addr.identity.value() { None => info!("Identity verified via {}", addr.addr),