Skip to content

Commit

Permalink
proto: Pass ConnectionId by value internally
Browse files Browse the repository at this point in the history
  • Loading branch information
gretchenfrage authored and djc committed Dec 23, 2024
1 parent f99ca19 commit 7caa30b
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 32 deletions.
8 changes: 4 additions & 4 deletions quinn-proto/src/crypto/rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl TlsSession {

impl crypto::Session for TlsSession {
fn initial_keys(&self, dst_cid: &ConnectionId, side: Side) -> Keys {
initial_keys(self.version, dst_cid, side, &self.suite)
initial_keys(self.version, *dst_cid, side, &self.suite)
}

fn handshake_data(&self) -> Option<Box<dyn Any>> {
Expand Down Expand Up @@ -504,7 +504,7 @@ impl crypto::ServerConfig for QuicServerConfig {
dst_cid: &ConnectionId,
) -> Result<Keys, UnsupportedVersion> {
let version = interpret_version(version)?;
Ok(initial_keys(version, dst_cid, Side::Server, &self.initial))
Ok(initial_keys(version, *dst_cid, Side::Server, &self.initial))
}

fn retry_tag(&self, version: u32, orig_dst_cid: &ConnectionId, packet: &[u8]) -> [u8; 16] {
Expand Down Expand Up @@ -564,11 +564,11 @@ fn to_vec(params: &TransportParameters) -> Vec<u8> {

pub(crate) fn initial_keys(
version: Version,
dst_cid: &ConnectionId,
dst_cid: ConnectionId,
side: Side,
suite: &Suite,
) -> Keys {
let keys = suite.keys(dst_cid, side.into(), version);
let keys = suite.keys(&dst_cid, side.into(), version);
Keys {
header: KeyPair {
local: Box::new(keys.local.header),
Expand Down
20 changes: 10 additions & 10 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl Endpoint {
RetireConnectionId(now, seq, allow_more_cids) => {
if let Some(cid) = self.connections[ch].loc_cids.remove(&seq) {
trace!("peer retired CID {}: {}", seq, cid);
self.index.retire(&cid);
self.index.retire(cid);
if allow_more_cids {
return Some(self.send_new_identifiers(now, ch, 1));
}
Expand Down Expand Up @@ -243,7 +243,7 @@ impl Endpoint {
let Some(server_config) = &self.server_config else {
debug!("packet for unrecognized connection {}", dst_cid);
return self
.stateless_reset(now, datagram_len, addresses, dst_cid, buf)
.stateless_reset(now, datagram_len, addresses, *dst_cid, buf)
.map(DatagramEvent::Response);
};

Expand Down Expand Up @@ -306,7 +306,7 @@ impl Endpoint {

if !dst_cid.is_empty() {
return self
.stateless_reset(now, datagram_len, addresses, dst_cid, buf)
.stateless_reset(now, datagram_len, addresses, *dst_cid, buf)
.map(DatagramEvent::Response);
}

Expand All @@ -319,7 +319,7 @@ impl Endpoint {
now: Instant,
inciting_dgram_len: usize,
addresses: FourTuple,
dst_cid: &ConnectionId,
dst_cid: ConnectionId,
buf: &mut Vec<u8>,
) -> Option<Transmit> {
if self
Expand Down Expand Up @@ -441,7 +441,7 @@ impl Endpoint {
ids.push(IssuedCid {
sequence,
id,
reset_token: ResetToken::new(&*self.config.reset_key, &id),
reset_token: ResetToken::new(&*self.config.reset_key, id),
});
}
ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids, now))
Expand Down Expand Up @@ -603,7 +603,7 @@ impl Endpoint {
Some(&server_config),
&mut self.rng,
);
params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, &loc_cid));
params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, loc_cid));
params.original_dst_cid = Some(incoming.token.orig_dst_cid);
params.retry_src_cid = incoming.token.retry_src_cid;
let mut pref_addr_cid = None;
Expand All @@ -616,7 +616,7 @@ impl Endpoint {
address_v4: server_config.preferred_address_v4,
address_v6: server_config.preferred_address_v6,
connection_id: cid,
stateless_reset_token: ResetToken::new(&*self.config.reset_key, &cid),
stateless_reset_token: ResetToken::new(&*self.config.reset_key, cid),
});
}

Expand Down Expand Up @@ -749,7 +749,7 @@ impl Endpoint {
.encode(
&*server_config.token_key,
incoming.addresses.remote,
&loc_cid,
loc_cid,
);

let header = Header::Retry {
Expand Down Expand Up @@ -1056,8 +1056,8 @@ impl ConnectionIndex {
}

/// Discard a connection ID
fn retire(&mut self, dst_cid: &ConnectionId) {
self.connection_ids.remove(dst_cid);
fn retire(&mut self, dst_cid: ConnectionId) {
self.connection_ids.remove(&dst_cid);
}

/// Remove all references to a connection
Expand Down
16 changes: 8 additions & 8 deletions quinn-proto/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,14 +435,14 @@ impl Header {
)
}

pub(crate) fn dst_cid(&self) -> &ConnectionId {
pub(crate) fn dst_cid(&self) -> ConnectionId {
use Header::*;
match *self {
Initial(InitialHeader { ref dst_cid, .. }) => dst_cid,
Long { ref dst_cid, .. } => dst_cid,
Retry { ref dst_cid, .. } => dst_cid,
Short { ref dst_cid, .. } => dst_cid,
VersionNegotiate { ref dst_cid, .. } => dst_cid,
Initial(InitialHeader { dst_cid, .. }) => dst_cid,
Long { dst_cid, .. } => dst_cid,
Retry { dst_cid, .. } => dst_cid,
Short { dst_cid, .. } => dst_cid,
VersionNegotiate { dst_cid, .. } => dst_cid,
}
}

Expand Down Expand Up @@ -949,7 +949,7 @@ mod tests {
let provider = default_provider();

let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap();
let client = initial_keys(Version::V1, &dcid, Side::Client, &suite);
let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
let mut buf = Vec::new();
let header = Header::Initial(InitialHeader {
number: PacketNumber::U8(0),
Expand Down Expand Up @@ -979,7 +979,7 @@ mod tests {
)[..]
);

let server = initial_keys(Version::V1, &dcid, Side::Server, &suite);
let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
let decode = PartialDecode::new(
buf.as_slice().into(),
Expand Down
20 changes: 10 additions & 10 deletions quinn-proto/src/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl IncomingToken {
let result = RetryToken::from_bytes(
&*server_config.token_key,
remote_address,
&header.dst_cid,
header.dst_cid,
&header.token,
);

Expand Down Expand Up @@ -78,9 +78,9 @@ impl RetryToken {
&self,
key: &dyn HandshakeTokenKey,
address: SocketAddr,
retry_src_cid: &ConnectionId,
retry_src_cid: ConnectionId,
) -> Vec<u8> {
let aead_key = key.aead_from_hkdf(retry_src_cid);
let aead_key = key.aead_from_hkdf(&retry_src_cid);

let mut buf = Vec::new();
encode_addr(&mut buf, address);
Expand All @@ -100,10 +100,10 @@ impl RetryToken {
fn from_bytes(
key: &dyn HandshakeTokenKey,
address: SocketAddr,
retry_src_cid: &ConnectionId,
retry_src_cid: ConnectionId,
raw_token_bytes: &[u8],
) -> Result<Self, ValidationError> {
let aead_key = key.aead_from_hkdf(retry_src_cid);
let aead_key = key.aead_from_hkdf(&retry_src_cid);
let mut sealed_token = raw_token_bytes.to_vec();

let data = aead_key.open(&mut sealed_token, &[])?;
Expand Down Expand Up @@ -192,9 +192,9 @@ impl From<CryptoError> for ValidationError {
pub(crate) struct ResetToken([u8; RESET_TOKEN_SIZE]);

impl ResetToken {
pub(crate) fn new(key: &dyn HmacKey, id: &ConnectionId) -> Self {
pub(crate) fn new(key: &dyn HmacKey, id: ConnectionId) -> Self {
let mut signature = vec![0; key.signature_len()];
key.sign(id, &mut signature);
key.sign(&id, &mut signature);
// TODO: Server ID??
let mut result = [0; RESET_TOKEN_SIZE];
result.copy_from_slice(&signature[..RESET_TOKEN_SIZE]);
Expand Down Expand Up @@ -262,9 +262,9 @@ mod test {
orig_dst_cid: RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid(),
issued: UNIX_EPOCH + Duration::new(42, 0), // Fractional seconds would be lost
};
let encoded = token.encode(&prk, addr, &retry_src_cid);
let encoded = token.encode(&prk, addr, retry_src_cid);

let decoded = RetryToken::from_bytes(&prk, addr, &retry_src_cid, &encoded)
let decoded = RetryToken::from_bytes(&prk, addr, retry_src_cid, &encoded)
.expect("token didn't validate");
assert_eq!(token.orig_dst_cid, decoded.orig_dst_cid);
assert_eq!(token.issued, decoded.issued);
Expand Down Expand Up @@ -295,6 +295,6 @@ mod test {
invalid_token.put_slice(&random_data);

// Assert: garbage sealed data returns err
assert!(RetryToken::from_bytes(&prk, addr, &retry_src_cid, &invalid_token).is_err());
assert!(RetryToken::from_bytes(&prk, addr, retry_src_cid, &invalid_token).is_err());
}
}

0 comments on commit 7caa30b

Please sign in to comment.