diff --git a/MatterDotNet/Protocol/Connection/IConnection.cs b/MatterDotNet/Protocol/Connection/IConnection.cs index f56c585..3dccd5c 100644 --- a/MatterDotNet/Protocol/Connection/IConnection.cs +++ b/MatterDotNet/Protocol/Connection/IConnection.cs @@ -17,7 +17,6 @@ namespace MatterDotNet.Protocol.Connection { internal interface IConnection : IDisposable { - Task Read(); Task SendFrame(Exchange exchange, Frame frame, bool reliable); Task CloseExchange(Exchange exchange); } diff --git a/MatterDotNet/Protocol/Connection/MRPConnection.cs b/MatterDotNet/Protocol/Connection/MRPConnection.cs index 652717c..5513c91 100644 --- a/MatterDotNet/Protocol/Connection/MRPConnection.cs +++ b/MatterDotNet/Protocol/Connection/MRPConnection.cs @@ -32,7 +32,6 @@ internal class MRPConnection : IConnection ConcurrentDictionary<(ushort, ushort), Retransmission> Retransmissions = new ConcurrentDictionary<(ushort, ushort), Retransmission>(); ConcurrentDictionary AckTable = new ConcurrentDictionary(); - Channel channel = Channel.CreateUnbounded(); CancellationTokenSource cts = new CancellationTokenSource(); UdpClient client; @@ -70,7 +69,7 @@ public async Task SendFrame(Exchange exchange, Frame frame, bool reliable) } } } - Console.WriteLine("SENT: " + frame.ToString()); + Console.WriteLine(DateTime.Now.ToString("h:mm:ss") + " SENT: " + frame.ToString()); await client.SendAsync(writer.GetPayload()); exchange.Session.Timestamp = DateTime.Now; while (reliable) @@ -111,22 +110,22 @@ public async Task SendAck(SessionContext? session, ushort exchange, uint counter ack.Message.ExchangeID = exchange; ack.Message.Flags = ExchangeFlags.Acknowledgement; if (initiator) + { + ack.Flags |= MessageFlags.SourceNodeID; ack.Message.Flags |= ExchangeFlags.Initiator; + } + else + ack.Flags |= MessageFlags.DestinationNodeID; ack.Message.AckCounter = counter; ack.Message.Protocol = Payloads.ProtocolType.SecureChannel; PayloadWriter writer = new PayloadWriter(Frame.MAX_SIZE + 4); ack.Serialize(writer, session!); if (AckTable.TryGetValue(exchange, out uint ctr) && ctr == counter) AckTable.TryRemove(exchange, out _); - Console.WriteLine("Sent standalone ack: " + ack.ToString()); + Console.WriteLine(DateTime.Now.ToString("h:mm:ss") + " Sent standalone ack: " + ack.ToString()); await client.SendAsync(writer.GetPayload()); } - public async Task Read() - { - return await channel.Reader.ReadAsync(); - } - public async Task Run() { try @@ -135,10 +134,20 @@ public async Task Run() { UdpReceiveResult result = await client.ReceiveAsync(); Frame frame = new Frame(result.Buffer); + if (!frame.Valid) + { + Console.WriteLine("Invalid frame received"); + continue; + } + SessionContext? session = SessionManager.GetSession(frame.SessionID); + bool ack = false; if ((frame.Message.Flags & ExchangeFlags.Reliability) == ExchangeFlags.Reliability) { if (!AckTable.TryAdd(frame.Message.ExchangeID, frame.Counter)) - await SendAck(SessionManager.GetSession(frame.SessionID), frame.Message.ExchangeID, frame.Counter, (frame.Message.Flags & ExchangeFlags.Initiator) == 0); + { + ack = true; + await SendAck(session, frame.Message.ExchangeID, frame.Counter, (frame.Message.Flags & ExchangeFlags.Initiator) == 0); + } } if ((frame.Message.Flags & ExchangeFlags.Acknowledgement) == ExchangeFlags.Acknowledgement) { @@ -148,20 +157,23 @@ public async Task Run() transmission.Ack.Release(); } } - if (frame.SessionID != 0) - SessionManager.SessionActive(frame.SessionID); - Console.WriteLine("Received: " + frame.ToString()); - channel.Writer.TryWrite(frame); + Console.WriteLine(DateTime.Now.ToString("h:mm:ss") + " Received: " + frame.ToString()); + if (session == null) + { + Console.WriteLine("Unknown Session: " + frame.SessionID); + continue; + } + if (!session.ProcessFrame(frame) && !ack) + await SendAck(session, frame.Message.ExchangeID, frame.Counter, (frame.Message.Flags & ExchangeFlags.Initiator) == 0); + + session.Timestamp = DateTime.Now; + session.LastActive = DateTime.Now; } } catch (Exception e) { Console.WriteLine(e.ToString()); } - finally - { - channel.Writer.Complete(); - } } public async Task CloseExchange(Exchange exchange) diff --git a/MatterDotNet/Protocol/Connection/TCPConnection.cs b/MatterDotNet/Protocol/Connection/TCPConnection.cs index e09e3a1..bed8bc0 100644 --- a/MatterDotNet/Protocol/Connection/TCPConnection.cs +++ b/MatterDotNet/Protocol/Connection/TCPConnection.cs @@ -24,7 +24,6 @@ internal class TCPConnection : IConnection TcpClient client; NetworkStream stream; CancellationTokenSource cts = new CancellationTokenSource(); - Channel channel = Channel.CreateUnbounded(); public TCPConnection(IPEndPoint destination) { client = new TcpClient(); @@ -43,11 +42,6 @@ public async Task SendFrame(Exchange exchange, Frame frame, bool reliable) exchange.Session.Timestamp = DateTime.Now; } - public async Task Read() - { - return await channel.Reader.ReadAsync(); - } - public async Task Run() { byte[] len = new byte[4]; @@ -58,13 +52,18 @@ public async Task Run() await stream.ReadExactlyAsync(len); frameLen = BinaryPrimitives.ReadInt32LittleEndian(len); await stream.ReadExactlyAsync(data.Slice(0, frameLen)); - Console.WriteLine("READ: " + Convert.ToHexString(data.Slice(0, frameLen).Span)); Frame frame = new Frame(data.Slice(0, frameLen).Span); - channel.Writer.TryWrite(frame); - if (frame.SessionID != 0) - SessionManager.SessionActive(frame.SessionID); + Console.WriteLine(DateTime.Now.ToString("h:mm:ss") + " Received: " + frame.ToString()); + SessionContext? session = SessionManager.GetSession(frame.SessionID); + if (session == null) + { + Console.WriteLine("Unknown Session: " + frame.SessionID); + continue; + } + session.ProcessFrame(frame); + session.Timestamp = DateTime.Now; + session.LastActive = DateTime.Now; } - channel.Writer.Complete(); } public void Dispose() diff --git a/MatterDotNet/Protocol/Sessions/Exchange.cs b/MatterDotNet/Protocol/Sessions/Exchange.cs index 9ee8822..5bcc0f6 100644 --- a/MatterDotNet/Protocol/Sessions/Exchange.cs +++ b/MatterDotNet/Protocol/Sessions/Exchange.cs @@ -12,20 +12,21 @@ using MatterDotNet.Protocol.Payloads; using MatterDotNet.Protocol.Payloads.Flags; +using System.Threading.Channels; namespace MatterDotNet.Protocol.Sessions { internal class Exchange : IDisposable { - private const int MSG_COUNTER_WINDOW_SIZE = 32; - public ushort ID { get; init; } public SessionContext Session {get; init;} + internal Channel Messages { get; init;} internal Exchange(SessionContext session, ushort id) { Session = session; ID = id; + Messages = Channel.CreateBounded(10); } public async Task SendFrame(Frame frame, bool reliable = true) @@ -38,62 +39,9 @@ public async Task SendFrame(Frame frame, bool reliable = true) await Session.Connection.SendFrame(this, frame, reliable); } - public async Task Read() + public async Task Read(CancellationToken token = default) { - Frame? frame = null; - while (frame == null) - { - frame = await Session.Connection.Read(); - MessageState state = Session.PeerMessageCtr; - if (!state.Initialized) - { - state.Initialized = true; - state.CounterWindow = uint.MaxValue; - state.MaxMessageCounter = frame.Counter; - } - else if (frame.Counter > state.MaxMessageCounter) - { - int offset = (int)Math.Min(frame.Counter - state.MaxMessageCounter, MSG_COUNTER_WINDOW_SIZE); - state.MaxMessageCounter = frame.Counter; - state.CounterWindow <<= offset; - if (offset < MSG_COUNTER_WINDOW_SIZE) - state.CounterWindow |= (uint)(1 << (int)offset - 1); - } - else if (frame.Counter == state.MaxMessageCounter) - { - Console.WriteLine("DROPPED DUPLICATE: " + frame); - frame = null; - } - else - { - uint offset = (state.MaxMessageCounter - frame.Counter); - if (offset > MSG_COUNTER_WINDOW_SIZE) - { - if (Session is SecureSession) - { - Console.WriteLine("DROPPED DUPLICATE : " + frame); - frame = null; - } - else - { - state.MaxMessageCounter = frame.Counter; - state.CounterWindow = uint.MaxValue; - } - } - else - { - if ((state.CounterWindow & (uint)(1 << (int)offset - 1)) != 0x0) - { - Console.WriteLine("DROPPED DUPLICATE : " + frame); - frame = null; - } - else - state.CounterWindow |= (uint)(1 << (int)offset - 1); - } - } - Session.PeerMessageCtr = state; - } - return frame; + return await Messages.Reader.ReadAsync(token); } /// diff --git a/MatterDotNet/Protocol/Sessions/SecureSession.cs b/MatterDotNet/Protocol/Sessions/SecureSession.cs index 8c99f7e..8c4b6d4 100644 --- a/MatterDotNet/Protocol/Sessions/SecureSession.cs +++ b/MatterDotNet/Protocol/Sessions/SecureSession.cs @@ -12,6 +12,7 @@ using MatterDotNet.Protocol.Connection; using MatterDotNet.Protocol.Cryptography; +using MatterDotNet.Protocol.Payloads; using System.Buffers.Binary; using System.Security.Cryptography; @@ -49,6 +50,12 @@ internal SecureSession(IConnection connection, bool PASE, bool initiator, ushort } } + internal override bool HandleBehindWindow(ref MessageState state, Frame frame) + { + Console.WriteLine("DROPPED DUPLICATE : " + frame); + return true; + } + internal override uint GetSessionCounter() { return LocalMessageCtr; diff --git a/MatterDotNet/Protocol/Sessions/SessionContext.cs b/MatterDotNet/Protocol/Sessions/SessionContext.cs index 735274f..e2de52c 100644 --- a/MatterDotNet/Protocol/Sessions/SessionContext.cs +++ b/MatterDotNet/Protocol/Sessions/SessionContext.cs @@ -20,6 +20,8 @@ namespace MatterDotNet.Protocol.Sessions { public class SessionContext : IDisposable { + private const int MSG_COUNTER_WINDOW_SIZE = 32; + public bool Initiator { get; init; } public ulong InitiatorNodeID { get; init; } public ushort LocalSessionID { get; init; } @@ -66,6 +68,64 @@ internal Exchange CreateExchange() return ret; } + internal bool ProcessFrame(Frame frame) + { + MessageState state = PeerMessageCtr; + if (!state.Initialized) + { + state.Initialized = true; + state.CounterWindow = uint.MaxValue; + state.MaxMessageCounter = frame.Counter; + } + else if (frame.Counter > state.MaxMessageCounter) + { + int offset = (int)Math.Min(frame.Counter - state.MaxMessageCounter, MSG_COUNTER_WINDOW_SIZE); + state.MaxMessageCounter = frame.Counter; + state.CounterWindow <<= offset; + if (offset < MSG_COUNTER_WINDOW_SIZE) + state.CounterWindow |= (uint)(1 << (int)offset - 1); + } + else if (frame.Counter == state.MaxMessageCounter) + { + Console.WriteLine("DROPPED DUPLICATE : " + frame); + return false; + } + else + { + uint offset = (state.MaxMessageCounter - frame.Counter); + if (offset > MSG_COUNTER_WINDOW_SIZE) + { + if (HandleBehindWindow(ref state, frame)) + return false; + } + else + { + if ((state.CounterWindow & (uint)(1 << (int)offset - 1)) != 0x0) + { + Console.WriteLine("DROPPED DUPLICATE : " + frame); + return false; + } + else + state.CounterWindow |= (uint)(1 << (int)offset - 1); + } + } + PeerMessageCtr = state; + if (frame.Message.Protocol == ProtocolType.SecureChannel && (SecureOpCodes?)frame.Message.OpCode == SecureOpCodes.MRPStandaloneAcknowledgement) + return true; //Standalone Ack + if (exchanges.TryGetValue(frame.Message.ExchangeID, out Exchange? exchange)) + exchange.Messages.Writer.TryWrite(frame); + else + Console.WriteLine("Unknown Exchange " + frame.Message.ExchangeID); + return true; + } + + internal virtual bool HandleBehindWindow(ref MessageState state, Frame frame) + { + state.MaxMessageCounter = frame.Counter; + state.CounterWindow = uint.MaxValue; + return false; + } + internal async Task DeleteExchange(Exchange exchange) { await Connection.CloseExchange(exchange); diff --git a/MatterDotNet/Protocol/Sessions/SessionManager.cs b/MatterDotNet/Protocol/Sessions/SessionManager.cs index 3889afb..319ff98 100644 --- a/MatterDotNet/Protocol/Sessions/SessionManager.cs +++ b/MatterDotNet/Protocol/Sessions/SessionManager.cs @@ -107,14 +107,5 @@ public static SessionParameter GetDefaultSessionParams() param.SpecificationVersion = 0; return param; } - - internal static void SessionActive(ushort sessionID) - { - if (sessions.TryGetValue(sessionID, out SessionContext? context)) - { - context.Timestamp = DateTime.Now; - context.LastActive = DateTime.Now; - } - } } }