Skip to content

Commit

Permalink
Improvements to secure session establishment (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdomnitz committed Dec 28, 2024
1 parent 0981e0c commit 5c9f405
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 42 deletions.
20 changes: 20 additions & 0 deletions MatterDotNet/MatterDotNet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,24 @@
<GenerateDocumentationFile>True</GenerateDocumentationFile>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(TargetFramework)|$(Platform)'=='Debug|net8.0|AnyCPU'">
<IsTrimmable>True</IsTrimmable>
<IsAotCompatible>True</IsAotCompatible>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(TargetFramework)|$(Platform)'=='Debug|net9.0|AnyCPU'">
<IsTrimmable>True</IsTrimmable>
<IsAotCompatible>True</IsAotCompatible>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(TargetFramework)|$(Platform)'=='Release|net8.0|AnyCPU'">
<IsTrimmable>True</IsTrimmable>
<IsAotCompatible>True</IsAotCompatible>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(TargetFramework)|$(Platform)'=='Release|net9.0|AnyCPU'">
<IsTrimmable>True</IsTrimmable>
<IsAotCompatible>True</IsAotCompatible>
</PropertyGroup>

</Project>
44 changes: 23 additions & 21 deletions MatterDotNet/Protocol/Cryptography/CASE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class CASE(SessionContext unsecureSession)
InitiatorSessionParams = SessionManager.GetDefaultSessionParams()
};

Frame sigma1 = new Frame(Msg1, (byte)SecureOpCodes.CASESigma1) { Flags = MessageFlags.SourceNodeID };
Frame sigma1 = new Frame(Msg1, (byte)SecureOpCodes.CASESigma1);
await exchange.SendFrame(sigma1);
Console.WriteLine("Sent SIGMA 1");
resp = await exchange.Read();
Expand Down Expand Up @@ -111,12 +111,9 @@ public class CASE(SessionContext unsecureSession)
return null;
}

PayloadWriter noc = new PayloadWriter(512);
fabric.Commissioner!.ToMatterCertificate().Serialize(noc);

Sigma3Tbsdata s3sig = new Sigma3Tbsdata()
{
InitiatorNOC = noc.GetPayload().ToArray(),
InitiatorNOC = fabric.Commissioner!.GetMatterCertBytes(),
InitiatorEphPubKey = ephKeys.Public,
ResponderEphPubKey = Msg2.ResponderEphPubKey
};
Expand All @@ -129,20 +126,22 @@ public class CASE(SessionContext unsecureSession)
Signature = Crypto.Sign(fabric.Commissioner.GetPrivateKey()!, s3sigBytes.GetPayload().ToArray())
};
PayloadWriter s3tbeBytes = new PayloadWriter(512);
s3sig.Serialize(s3tbeBytes);
s3tbe.Serialize(s3tbeBytes);

byte[] Msg1Msg2 = new byte[Msg1Bytes.Length + Msg2Bytes.Length];
Msg1Bytes.GetPayload().CopyTo(Msg1Msg2);
Msg2Bytes.GetPayload().CopyTo(Msg1Msg2.AsMemory(Msg1Bytes.Length));
byte[] transcript3 = Crypto.Hash(Msg1Msg2);
byte[] Msg1Msg2Bytes = new byte[Msg1Bytes.Length + Msg2Bytes.Length];
Msg1Bytes.GetPayload().CopyTo(Msg1Msg2Bytes);
Msg2Bytes.GetPayload().CopyTo(Msg1Msg2Bytes.AsMemory(Msg1Bytes.Length));
byte[] transcript3 = Crypto.Hash(Msg1Msg2Bytes);
byte[] salt3 = new byte[Crypto.SYMMETRIC_KEY_LENGTH_BYTES + Crypto.HASH_LEN_BYTES];
Array.Copy(fabric.OperationalIdentityProtectionKey, salt3, Crypto.SYMMETRIC_KEY_LENGTH_BYTES);
Array.Copy(transcript3, 0, salt3, Crypto.SYMMETRIC_KEY_LENGTH_BYTES, Crypto.HASH_LEN_BYTES);
byte[] S3K = Crypto.KDF(SharedSecret, salt3, S3K_Info, Crypto.SYMMETRIC_KEY_LENGTH_BITS);

Memory<byte> mic = Crypto.AEAD_GenerateEncrypt(S3K, s3tbeBytes.GetPayload().Span, [], TBEData3_Nonce).ToArray();
s3tbeBytes.Write(mic);
Sigma3 Msg3 = new Sigma3()
{
Encrypted3 = Crypto.AEAD_GenerateEncrypt(S3K, s3tbeBytes.GetPayload().Span, [], TBEData3_Nonce).ToArray()
Encrypted3 = s3tbeBytes.GetPayload().ToArray()
};
PayloadWriter Msg3Bytes = new PayloadWriter(1024);
Msg3.Serialize(Msg3Bytes);
Expand All @@ -153,24 +152,27 @@ public class CASE(SessionContext unsecureSession)

StatusPayload s3resp = (StatusPayload)resp.Message.Payload!;
if (s3resp.GeneralCode != GeneralCode.SUCCESS)
throw new IOException("CASE step 3 failed with status: " + (SecureStatusCodes)s3resp.ProtocolCode);

byte[] salt = new byte[Crypto.SYMMETRIC_KEY_LENGTH_BYTES + Msg1Msg2.Length + Msg3Bytes.Length];
fabric.OperationalIdentityProtectionKey.CopyTo(salt, 0);
Msg1Msg2.CopyTo(salt, Crypto.SYMMETRIC_KEY_LENGTH_BYTES);
Msg3Bytes.GetPayload().CopyTo(salt.AsMemory(Crypto.SYMMETRIC_KEY_LENGTH_BYTES + Msg1Msg2.Length));
byte[] sessionKeys = Crypto.KDF(SharedSecret, salt, SEKeys_Info, Crypto.SYMMETRIC_KEY_LENGTH_BITS * 3);
throw new IOException("CASE step 3 failed with status: " + (SecureStatusCodes)s3resp.ProtocolCode);;

byte[] Msg1Msg2Msg3 = new byte[Msg1Msg2Bytes.Length + Msg3Bytes.Length];
Msg1Msg2Bytes.CopyTo((Memory<byte>)Msg1Msg2Msg3);
Msg3Bytes.GetPayload().CopyTo(Msg1Msg2Msg3.AsMemory(Msg1Msg2Bytes.Length));
byte[] transcript4 = Crypto.Hash(Msg1Msg2Msg3);
byte[] salt4 = new byte[Crypto.SYMMETRIC_KEY_LENGTH_BYTES + Crypto.HASH_LEN_BYTES];
Array.Copy(fabric.OperationalIdentityProtectionKey, salt4, Crypto.SYMMETRIC_KEY_LENGTH_BYTES);
Array.Copy(transcript4, 0, salt4, Crypto.SYMMETRIC_KEY_LENGTH_BYTES, Crypto.HASH_LEN_BYTES);
byte[] sessionKeys = Crypto.KDF(SharedSecret, salt4, SEKeys_Info, Crypto.SYMMETRIC_KEY_LENGTH_BITS * 3);
attestationChallenge = sessionKeys.AsSpan(2 * Crypto.SYMMETRIC_KEY_LENGTH_BYTES, Crypto.SYMMETRIC_KEY_LENGTH_BYTES).ToArray();

uint activeInterval = Msg2.ResponderSessionParams?.SessionActiveInterval ?? SessionManager.GetDefaultSessionParams().SessionActiveInterval!.Value;
uint activeThreshold = Msg2.ResponderSessionParams?.SessionActiveThreshold ?? SessionManager.GetDefaultSessionParams().SessionActiveThreshold!.Value;
uint idleInterval = Msg2.ResponderSessionParams?.SessionIdleInterval ?? SessionManager.GetDefaultSessionParams().SessionIdleInterval!.Value;

Console.WriteLine("Created CASE session");
SecureSession? session = SessionManager.CreateSession(unsecureSession.Connection, true, Msg1.InitiatorSessionId, Msg2.ResponderSessionId,
sessionKeys.AsSpan(0, Crypto.SYMMETRIC_KEY_LENGTH_BYTES).ToArray(),
SecureSession? session = SessionManager.CreateSession(unsecureSession.Connection, false, true, Msg1.InitiatorSessionId, Msg2.ResponderSessionId,
sessionKeys.AsSpan(0, Crypto.SYMMETRIC_KEY_LENGTH_BYTES).ToArray(),
sessionKeys.AsSpan(Crypto.SYMMETRIC_KEY_LENGTH_BYTES, Crypto.SYMMETRIC_KEY_LENGTH_BYTES).ToArray(),
false, idleInterval, activeInterval, activeThreshold);
fabric.Commissioner.NodeID, nodeId, SharedSecret, false, idleInterval, activeInterval, activeThreshold);
unsecureSession.Dispose();
return session;
}
Expand Down
11 changes: 5 additions & 6 deletions MatterDotNet/Protocol/Cryptography/PASE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,14 @@ public class PASE(SessionContext unsecureSession)
StatusPayload status = (StatusPayload)resp.Message.Payload!;
if (status.GeneralCode != GeneralCode.SUCCESS)
throw new IOException("PASE failed with status: " + (SecureStatusCodes)status.ProtocolCode);
exchange.Dispose();
ushort localSessionID = ((PBKDFParamReq)paramReq.Message.Payload!).InitiatorSessionId;

uint activeInterval = paramResp.ResponderSessionParams?.SessionActiveInterval ?? SessionManager.GetDefaultSessionParams().SessionActiveInterval!.Value;
uint activeThreshold = paramResp.ResponderSessionParams?.SessionActiveThreshold ?? SessionManager.GetDefaultSessionParams().SessionActiveThreshold!.Value;
uint idleInterval = paramResp.ResponderSessionParams?.SessionIdleInterval ?? SessionManager.GetDefaultSessionParams().SessionIdleInterval!.Value;

SecureSession ? session = SessionManager.CreateSession(unsecureSession.Connection, true, localSessionID, paramResp.ResponderSessionId, SessionKeys.I2RKey, SessionKeys.R2IKey, false, idleInterval, activeInterval, activeThreshold);
unsecureSession.Dispose();
return session;
return SessionManager.CreateSession(unsecureSession.Connection, true, true, localSessionID, paramResp.ResponderSessionId, SessionKeys.I2RKey, SessionKeys.R2IKey, 0, 0, [], false, idleInterval, activeInterval, activeThreshold);
}

public byte[] GetAttestationChallenge()
Expand All @@ -71,7 +70,7 @@ private Frame GeneratePake1(PBKDFParamResp paramResp)
Console.WriteLine("Iterations: " + (int)paramResp.Pbkdf_parameters!.Iterations);
BigIntegerPoint pA = spake.PAKEValues_Initiator(36331256, (int)paramResp.Pbkdf_parameters!.Iterations, paramResp.Pbkdf_parameters!.Salt);
Pake1 pk1 = new Pake1() { PA = pA.ToBytes(false) };
Frame frame = new Frame(pk1, (byte)SecureOpCodes.PASEPake1) { Flags = MessageFlags.SourceNodeID };
Frame frame = new Frame(pk1, (byte)SecureOpCodes.PASEPake1);
return frame;
}

Expand All @@ -84,7 +83,7 @@ private Frame GeneratePake3(Pake1 pake1, Pake2 pake2, PBKDFParamReq req, PBKDFPa
Pake3 pake3 = new Pake3() {
CA = SessionKeys.cA
};
Frame frame = new Frame(pake3, (byte)SecureOpCodes.PASEPake3) { Flags = MessageFlags.SourceNodeID };
Frame frame = new Frame(pake3, (byte)SecureOpCodes.PASEPake3);
return frame;
}

Expand All @@ -98,7 +97,7 @@ private Frame GenerateParamRequest(bool hasOnboardingPayload = false)
HasPBKDFParameters = hasOnboardingPayload
};

Frame frame = new Frame(req, (byte)SecureOpCodes.PBKDFParamRequest) { Flags = MessageFlags.SourceNodeID };
Frame frame = new Frame(req, (byte)SecureOpCodes.PBKDFParamRequest);
return frame;
}
}
Expand Down
16 changes: 9 additions & 7 deletions MatterDotNet/Protocol/InteractionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ public static async Task<List<AttributeReportIB>> GetAttributes(SecureSession se
AttributeRequests = paths,
};
Frame readFrame = new Frame(read, (byte)IMOpCodes.ReadRequest);
readFrame.Flags |= MessageFlags.SourceNodeID;
readFrame.Message.Protocol = ProtocolType.InteractionModel;
readFrame.SourceNodeID = session.InitiatorNodeID;
readFrame.DestinationNodeID = session.ResponderNodeID;
await secExchange.SendFrame(readFrame);
List<AttributeReportIB> results = new List<AttributeReportIB>();
bool more = false;
Expand All @@ -52,7 +53,6 @@ public static async Task<List<AttributeReportIB>> GetAttributes(SecureSession se
{
var status = new StatusResponseMessage() { InteractionModelRevision = Constants.MATTER_13_REVISION, Status = (byte)IMStatusCode.SUCCESS };
Frame statusFrame = new Frame(status, (byte)IMOpCodes.StatusResponse);
readFrame.Flags |= MessageFlags.SourceNodeID;
readFrame.Message.Protocol = ProtocolType.InteractionModel;
await secExchange.SendFrame(statusFrame);
}
Expand All @@ -73,8 +73,9 @@ public static async Task<AttributeReportIB> GetAttribute(SecureSession session,
AttributeRequests = [new AttributePathIB() { Node = session.InitiatorNodeID, Endpoint = endpoint, Cluster = cluster, Attribute = attribute }]
};
Frame readFrame = new Frame(read, (byte)IMOpCodes.ReadRequest);
readFrame.Flags |= MessageFlags.SourceNodeID;
readFrame.Message.Protocol = ProtocolType.InteractionModel;
readFrame.SourceNodeID = session.InitiatorNodeID;
readFrame.DestinationNodeID = session.ResponderNodeID;
await secExchange.SendFrame(readFrame);
while (true)
{
Expand Down Expand Up @@ -103,10 +104,11 @@ public static async Task SendCommand(Exchange exchange, ushort endpoint, uint cl
InteractionModelRevision = Constants.MATTER_13_REVISION,
InvokeRequests = [new CommandDataIB() { CommandFields = payload, CommandPath = new CommandPathIB() { Endpoint = endpoint, Cluster = cluster, Command = command } }]
};
Frame readFrame = new Frame(run, (byte)IMOpCodes.InvokeRequest);
readFrame.Flags |= MessageFlags.SourceNodeID;
readFrame.Message.Protocol = ProtocolType.InteractionModel;
await exchange.SendFrame(readFrame);
Frame invokeFrame = new Frame(run, (byte)IMOpCodes.InvokeRequest);
invokeFrame.Message.Protocol = ProtocolType.InteractionModel;
invokeFrame.SourceNodeID = exchange.Session.InitiatorNodeID;
invokeFrame.DestinationNodeID = exchange.Session.ResponderNodeID;
await exchange.SendFrame(invokeFrame);
}

public static async Task<InvokeResponseIB> ExecCommand(SecureSession secSession, ushort endpoint, uint cluster, uint command, TLVPayload? payload = null)
Expand Down
16 changes: 8 additions & 8 deletions MatterDotNet/Protocol/Sessions/SessionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,28 @@ public static class SessionManager
private static ConcurrentDictionary<IPEndPoint, IConnection> connections = new ConcurrentDictionary<IPEndPoint, IConnection>();
private static ConcurrentDictionary<ushort, SessionContext> sessions = new ConcurrentDictionary<ushort, SessionContext>();

public static SessionContext GetUnsecureSession(IPEndPoint ep, bool initiator, uint initiatorNodeId, uint responderNodeId)
public static SessionContext GetUnsecureSession(IPEndPoint ep, bool initiator)
{
return GetUnsecureSession(GetConnection(ep), initiator, initiatorNodeId, responderNodeId);
return GetUnsecureSession(GetConnection(ep), initiator);
}

internal static SessionContext GetUnsecureSession(IConnection connection, bool initiator, uint initiatorNodeId, uint responderNodeId)
internal static SessionContext GetUnsecureSession(IConnection connection, bool initiator)
{
SessionContext ctx = new SessionContext(connection, initiator, initiatorNodeId, responderNodeId, 0, 0, new MessageState());
SessionContext ctx = new SessionContext(connection, initiator, 0, 0, 0, 0, new MessageState());
sessions.TryAdd(0, ctx);
return ctx;
}

public static SecureSession? CreateSession(IPEndPoint ep, bool initiator, ushort initiatorSessionId, ushort responderSessionId, byte[] i2r, byte[] r2i, bool group, uint idleInterval, uint activeInterval, uint activeThreshold)
public static SecureSession? CreateSession(IPEndPoint ep, bool PASE, bool initiator, ushort initiatorSessionId, ushort responderSessionId, byte[] i2r, byte[] r2i, ulong localNodeId, ulong peerNodeId, byte[] sharedSecret, bool group, uint idleInterval, uint activeInterval, uint activeThreshold)
{
return CreateSession(GetConnection(ep), initiator, initiatorSessionId, responderSessionId, i2r, r2i, group, idleInterval, activeInterval, activeThreshold);
return CreateSession(GetConnection(ep), PASE, initiator, initiatorSessionId, responderSessionId, i2r, r2i, localNodeId, peerNodeId, sharedSecret, group, idleInterval, activeInterval, activeThreshold);
}

internal static SecureSession? CreateSession(IConnection connection, bool initiator, ushort initiatorSessionId, ushort responderSessionId, byte[] i2r, byte[] r2i, bool group, uint idleInterval, uint activeInterval, uint activeThreshold)
internal static SecureSession? CreateSession(IConnection connection, bool PASE, bool initiator, ushort initiatorSessionId, ushort responderSessionId, byte[] i2r, byte[] r2i, ulong localNodeId, ulong peerNodeId, byte[] sharedSecret, bool group, uint idleInterval, uint activeInterval, uint activeThreshold)
{
if (group == false && initiatorSessionId == 0)
return null; //Unsecured session
SecureSession ctx = new SecureSession(connection, false, initiator, initiator ? initiatorSessionId : responderSessionId, initiator ? responderSessionId : initiatorSessionId, i2r, r2i, [], 0, new MessageState(), 0, 0, idleInterval, activeInterval, activeThreshold);
SecureSession ctx = new SecureSession(connection, PASE, initiator, initiator ? initiatorSessionId : responderSessionId, initiator ? responderSessionId : initiatorSessionId, i2r, r2i, sharedSecret, 0, new MessageState(), localNodeId, peerNodeId, idleInterval, activeInterval, activeThreshold);
Console.WriteLine("Secure Session Created: " + ctx.LocalSessionID);
sessions.TryAdd(ctx.LocalSessionID, ctx);
return ctx;
Expand Down

0 comments on commit 5c9f405

Please sign in to comment.