Skip to content

Commit

Permalink
Properly multiplex exchanges and sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
jdomnitz committed Jan 3, 2025
1 parent 014ae24 commit f637332
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 95 deletions.
1 change: 0 additions & 1 deletion MatterDotNet/Protocol/Connection/IConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ namespace MatterDotNet.Protocol.Connection
{
internal interface IConnection : IDisposable
{
Task<Frame> Read();
Task SendFrame(Exchange exchange, Frame frame, bool reliable);
Task CloseExchange(Exchange exchange);
}
Expand Down
46 changes: 29 additions & 17 deletions MatterDotNet/Protocol/Connection/MRPConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ internal class MRPConnection : IConnection

ConcurrentDictionary<(ushort, ushort), Retransmission> Retransmissions = new ConcurrentDictionary<(ushort, ushort), Retransmission>();
ConcurrentDictionary<ushort, uint> AckTable = new ConcurrentDictionary<ushort, uint>();
Channel<Frame> channel = Channel.CreateUnbounded<Frame>();
CancellationTokenSource cts = new CancellationTokenSource();

UdpClient client;
Expand Down Expand Up @@ -70,7 +69,7 @@ public async Task SendFrame(Exchange exchange, Frame frame, bool reliable)
}
}
}
Console.WriteLine("SENT: " + frame.ToString());
Console.WriteLine(DateTime.Now.ToString("h:mm:ss") + " SENT: " + frame.ToString());
await client.SendAsync(writer.GetPayload());
exchange.Session.Timestamp = DateTime.Now;
while (reliable)
Expand Down Expand Up @@ -111,22 +110,22 @@ public async Task SendAck(SessionContext? session, ushort exchange, uint counter
ack.Message.ExchangeID = exchange;
ack.Message.Flags = ExchangeFlags.Acknowledgement;
if (initiator)
{
ack.Flags |= MessageFlags.SourceNodeID;
ack.Message.Flags |= ExchangeFlags.Initiator;
}
else
ack.Flags |= MessageFlags.DestinationNodeID;
ack.Message.AckCounter = counter;
ack.Message.Protocol = Payloads.ProtocolType.SecureChannel;
PayloadWriter writer = new PayloadWriter(Frame.MAX_SIZE + 4);
ack.Serialize(writer, session!);
if (AckTable.TryGetValue(exchange, out uint ctr) && ctr == counter)
AckTable.TryRemove(exchange, out _);
Console.WriteLine("Sent standalone ack: " + ack.ToString());
Console.WriteLine(DateTime.Now.ToString("h:mm:ss") + " Sent standalone ack: " + ack.ToString());
await client.SendAsync(writer.GetPayload());
}

public async Task<Frame> Read()
{
return await channel.Reader.ReadAsync();
}

public async Task Run()
{
try
Expand All @@ -135,10 +134,20 @@ public async Task Run()
{
UdpReceiveResult result = await client.ReceiveAsync();
Frame frame = new Frame(result.Buffer);
if (!frame.Valid)
{
Console.WriteLine("Invalid frame received");
continue;
}
SessionContext? session = SessionManager.GetSession(frame.SessionID);
bool ack = false;
if ((frame.Message.Flags & ExchangeFlags.Reliability) == ExchangeFlags.Reliability)
{
if (!AckTable.TryAdd(frame.Message.ExchangeID, frame.Counter))
await SendAck(SessionManager.GetSession(frame.SessionID), frame.Message.ExchangeID, frame.Counter, (frame.Message.Flags & ExchangeFlags.Initiator) == 0);
{
ack = true;
await SendAck(session, frame.Message.ExchangeID, frame.Counter, (frame.Message.Flags & ExchangeFlags.Initiator) == 0);
}
}
if ((frame.Message.Flags & ExchangeFlags.Acknowledgement) == ExchangeFlags.Acknowledgement)
{
Expand All @@ -148,20 +157,23 @@ public async Task Run()
transmission.Ack.Release();
}
}
if (frame.SessionID != 0)
SessionManager.SessionActive(frame.SessionID);
Console.WriteLine("Received: " + frame.ToString());
channel.Writer.TryWrite(frame);
Console.WriteLine(DateTime.Now.ToString("h:mm:ss") + " Received: " + frame.ToString());
if (session == null)
{
Console.WriteLine("Unknown Session: " + frame.SessionID);
continue;
}
if (!session.ProcessFrame(frame) && !ack)
await SendAck(session, frame.Message.ExchangeID, frame.Counter, (frame.Message.Flags & ExchangeFlags.Initiator) == 0);

session.Timestamp = DateTime.Now;
session.LastActive = DateTime.Now;
}
}
catch (Exception e)
{
Console.WriteLine(e.ToString());
}
finally
{
channel.Writer.Complete();
}
}

public async Task CloseExchange(Exchange exchange)
Expand Down
21 changes: 10 additions & 11 deletions MatterDotNet/Protocol/Connection/TCPConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ internal class TCPConnection : IConnection
TcpClient client;
NetworkStream stream;
CancellationTokenSource cts = new CancellationTokenSource();
Channel<Frame> channel = Channel.CreateUnbounded<Frame>();
public TCPConnection(IPEndPoint destination)
{
client = new TcpClient();
Expand All @@ -43,11 +42,6 @@ public async Task SendFrame(Exchange exchange, Frame frame, bool reliable)
exchange.Session.Timestamp = DateTime.Now;
}

public async Task<Frame> Read()
{
return await channel.Reader.ReadAsync();
}

public async Task Run()
{
byte[] len = new byte[4];
Expand All @@ -58,13 +52,18 @@ public async Task Run()
await stream.ReadExactlyAsync(len);
frameLen = BinaryPrimitives.ReadInt32LittleEndian(len);
await stream.ReadExactlyAsync(data.Slice(0, frameLen));
Console.WriteLine("READ: " + Convert.ToHexString(data.Slice(0, frameLen).Span));
Frame frame = new Frame(data.Slice(0, frameLen).Span);
channel.Writer.TryWrite(frame);
if (frame.SessionID != 0)
SessionManager.SessionActive(frame.SessionID);
Console.WriteLine(DateTime.Now.ToString("h:mm:ss") + " Received: " + frame.ToString());
SessionContext? session = SessionManager.GetSession(frame.SessionID);
if (session == null)
{
Console.WriteLine("Unknown Session: " + frame.SessionID);
continue;
}
session.ProcessFrame(frame);
session.Timestamp = DateTime.Now;
session.LastActive = DateTime.Now;
}
channel.Writer.Complete();
}

public void Dispose()
Expand Down
62 changes: 5 additions & 57 deletions MatterDotNet/Protocol/Sessions/Exchange.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@

using MatterDotNet.Protocol.Payloads;
using MatterDotNet.Protocol.Payloads.Flags;
using System.Threading.Channels;

namespace MatterDotNet.Protocol.Sessions
{
internal class Exchange : IDisposable
{
private const int MSG_COUNTER_WINDOW_SIZE = 32;

public ushort ID { get; init; }
public SessionContext Session {get; init;}
internal Channel<Frame> Messages { get; init;}

internal Exchange(SessionContext session, ushort id)
{
Session = session;
ID = id;
Messages = Channel.CreateBounded<Frame>(10);
}

public async Task SendFrame(Frame frame, bool reliable = true)
Expand All @@ -38,62 +39,9 @@ public async Task SendFrame(Frame frame, bool reliable = true)
await Session.Connection.SendFrame(this, frame, reliable);
}

public async Task<Frame> Read()
public async Task<Frame> Read(CancellationToken token = default)
{
Frame? frame = null;
while (frame == null)
{
frame = await Session.Connection.Read();
MessageState state = Session.PeerMessageCtr;
if (!state.Initialized)
{
state.Initialized = true;
state.CounterWindow = uint.MaxValue;
state.MaxMessageCounter = frame.Counter;
}
else if (frame.Counter > state.MaxMessageCounter)
{
int offset = (int)Math.Min(frame.Counter - state.MaxMessageCounter, MSG_COUNTER_WINDOW_SIZE);
state.MaxMessageCounter = frame.Counter;
state.CounterWindow <<= offset;
if (offset < MSG_COUNTER_WINDOW_SIZE)
state.CounterWindow |= (uint)(1 << (int)offset - 1);
}
else if (frame.Counter == state.MaxMessageCounter)
{
Console.WriteLine("DROPPED DUPLICATE: " + frame);
frame = null;
}
else
{
uint offset = (state.MaxMessageCounter - frame.Counter);
if (offset > MSG_COUNTER_WINDOW_SIZE)
{
if (Session is SecureSession)
{
Console.WriteLine("DROPPED DUPLICATE <behind window>: " + frame);
frame = null;
}
else
{
state.MaxMessageCounter = frame.Counter;
state.CounterWindow = uint.MaxValue;
}
}
else
{
if ((state.CounterWindow & (uint)(1 << (int)offset - 1)) != 0x0)
{
Console.WriteLine("DROPPED DUPLICATE <within window>: " + frame);
frame = null;
}
else
state.CounterWindow |= (uint)(1 << (int)offset - 1);
}
}
Session.PeerMessageCtr = state;
}
return frame;
return await Messages.Reader.ReadAsync(token);
}

/// <inheritdoc />
Expand Down
7 changes: 7 additions & 0 deletions MatterDotNet/Protocol/Sessions/SecureSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

using MatterDotNet.Protocol.Connection;
using MatterDotNet.Protocol.Cryptography;
using MatterDotNet.Protocol.Payloads;
using System.Buffers.Binary;
using System.Security.Cryptography;

Expand Down Expand Up @@ -49,6 +50,12 @@ internal SecureSession(IConnection connection, bool PASE, bool initiator, ushort
}
}

internal override bool HandleBehindWindow(ref MessageState state, Frame frame)
{
Console.WriteLine("DROPPED DUPLICATE <behind window>: " + frame);
return true;
}

internal override uint GetSessionCounter()
{
return LocalMessageCtr;
Expand Down
60 changes: 60 additions & 0 deletions MatterDotNet/Protocol/Sessions/SessionContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ namespace MatterDotNet.Protocol.Sessions
{
public class SessionContext : IDisposable
{
private const int MSG_COUNTER_WINDOW_SIZE = 32;

public bool Initiator { get; init; }
public ulong InitiatorNodeID { get; init; }
public ushort LocalSessionID { get; init; }
Expand Down Expand Up @@ -66,6 +68,64 @@ internal Exchange CreateExchange()
return ret;
}

internal bool ProcessFrame(Frame frame)
{
MessageState state = PeerMessageCtr;
if (!state.Initialized)
{
state.Initialized = true;
state.CounterWindow = uint.MaxValue;
state.MaxMessageCounter = frame.Counter;
}
else if (frame.Counter > state.MaxMessageCounter)
{
int offset = (int)Math.Min(frame.Counter - state.MaxMessageCounter, MSG_COUNTER_WINDOW_SIZE);
state.MaxMessageCounter = frame.Counter;
state.CounterWindow <<= offset;
if (offset < MSG_COUNTER_WINDOW_SIZE)
state.CounterWindow |= (uint)(1 << (int)offset - 1);
}
else if (frame.Counter == state.MaxMessageCounter)
{
Console.WriteLine("DROPPED DUPLICATE <repeated last>: " + frame);
return false;
}
else
{
uint offset = (state.MaxMessageCounter - frame.Counter);
if (offset > MSG_COUNTER_WINDOW_SIZE)
{
if (HandleBehindWindow(ref state, frame))
return false;
}
else
{
if ((state.CounterWindow & (uint)(1 << (int)offset - 1)) != 0x0)
{
Console.WriteLine("DROPPED DUPLICATE <within window>: " + frame);
return false;
}
else
state.CounterWindow |= (uint)(1 << (int)offset - 1);
}
}
PeerMessageCtr = state;
if (frame.Message.Protocol == ProtocolType.SecureChannel && (SecureOpCodes?)frame.Message.OpCode == SecureOpCodes.MRPStandaloneAcknowledgement)
return true; //Standalone Ack
if (exchanges.TryGetValue(frame.Message.ExchangeID, out Exchange? exchange))
exchange.Messages.Writer.TryWrite(frame);
else
Console.WriteLine("Unknown Exchange " + frame.Message.ExchangeID);
return true;
}

internal virtual bool HandleBehindWindow(ref MessageState state, Frame frame)
{
state.MaxMessageCounter = frame.Counter;
state.CounterWindow = uint.MaxValue;
return false;
}

internal async Task DeleteExchange(Exchange exchange)
{
await Connection.CloseExchange(exchange);
Expand Down
9 changes: 0 additions & 9 deletions MatterDotNet/Protocol/Sessions/SessionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,5 @@ public static SessionParameter GetDefaultSessionParams()
param.SpecificationVersion = 0;
return param;
}

internal static void SessionActive(ushort sessionID)
{
if (sessions.TryGetValue(sessionID, out SessionContext? context))
{
context.Timestamp = DateTime.Now;
context.LastActive = DateTime.Now;
}
}
}
}

0 comments on commit f637332

Please sign in to comment.