Skip to content

Commit

Permalink
Seperate IPv4 and IPv6 resolution. Try both where possible due to pos…
Browse files Browse the repository at this point in the history
…sible routing issues. Scope sessions by EndPoint
jdomnitz committed Jan 12, 2025
1 parent e575cbe commit 675906f
Showing 12 changed files with 120 additions and 65 deletions.
2 changes: 1 addition & 1 deletion MatterDotNet/Entities/EndPoint.cs
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ namespace MatterDotNet.Entities
/// </summary>
public class EndPoint
{
private Node? node;
private Node? node; //TODO - Remove if not used when groups are implemented
private Dictionary<uint, ClusterBase> clusters;
private ushort index;
private Dictionary<ushort, EndPoint> children;
9 changes: 7 additions & 2 deletions MatterDotNet/Entities/Node.cs
Original file line number Diff line number Diff line change
@@ -57,9 +57,14 @@ private Node(ODNode connection, Fabric fabric, OperationalCertificate noc)
/// <exception cref="IOException"></exception>
public async Task<SecureSession> GetCASESession()
{
if (connection.IPAddress != null)
if (connection.IP6Address != null)
{
using (SessionContext session = SessionManager.GetUnsecureSession(new IPEndPoint(connection.IPAddress!, connection.Port), true))
using (SessionContext session = SessionManager.GetUnsecureSession(new IPEndPoint(connection.IP6Address!, connection.Port), true))
return await GetCASESession(session);
}
else if(connection.IP4Address != null)
{
using (SessionContext session = SessionManager.GetUnsecureSession(new IPEndPoint(connection.IP4Address!, connection.Port), true))
return await GetCASESession(session);
}
else
20 changes: 14 additions & 6 deletions MatterDotNet/OperationalDiscovery/IPDiscoveryService.cs
Original file line number Diff line number Diff line change
@@ -110,9 +110,15 @@ public async Task<List<ODNode>> Find(uint discriminator, bool fullLen)
Console.WriteLine("Looking for " + operationalInstanceName);
List<ODNode> results;
int length = extendedSearch ? 20 : 10; // 60 / 30 seconds
DNSRecordType[] records;
if (extendedSearch)
records = [DNSRecordType.ANY, DNSRecordType.SRV, DNSRecordType.TXT, DNSRecordType.AAAA];
else
records = [DNSRecordType.SRV, DNSRecordType.TXT, DNSRecordType.A, DNSRecordType.AAAA];
string domain = operationalInstanceName + "._matter._tcp.local";
for (int i = 0; i < length; i++)
{
results = Parse(await mdns.ResolveServiceInstance(operationalInstanceName, "_matter._tcp", "local"));
results = Parse(await mdns.ResolveQuery(domain, false, records));
if (results.Count > 0)
return results[0];
}
@@ -154,19 +160,21 @@ private List<ODNode> Parse(List<Message> msgs)
node.Port = service.Port;
else if (answer is TxtRecord txt)
PopulateText(txt, ref node);
else if (answer is AAAARecord AAAA)
node.IP6Address = AAAA.Address;
}
foreach (ResourceRecord additional in msg.Additionals)
{
if (node.Port == 0 && additional is SRVRecord service)
node.Port = service.Port;
else if (node.IPAddress == null && additional is ARecord A)
node.IPAddress = A.Address;
else if (node.IPAddress == null && additional is AAAARecord AAAA)
node.IPAddress = AAAA.Address;
else if (node.IP4Address == null && additional is ARecord A)
node.IP4Address = A.Address;
else if (node.IP6Address == null && additional is AAAARecord AAAA)
node.IP6Address = AAAA.Address;
else if (additional is TxtRecord txt)
PopulateText(txt, ref node);
}
if (node.IPAddress == null || node.Port == 0)
if ((node.IP4Address == null && node.IP6Address == null) || node.Port == 0)
continue;
ret.Add(node);
}
12 changes: 8 additions & 4 deletions MatterDotNet/OperationalDiscovery/ODNode.cs
Original file line number Diff line number Diff line change
@@ -22,17 +22,21 @@ namespace MatterDotNet.OperationalDiscovery
public record ODNode
{
/// <summary>
/// Discovered IP Address
/// Discovered IPv6 Address
/// </summary>
public IPAddress? IPAddress { get; set; }
public IPAddress? IP6Address { get; set; }
/// <summary>
/// Discovered IPv4 Address
/// </summary>
public IPAddress? IP4Address { get; set; }
/// <summary>
/// Discovered Port
/// </summary>
public ushort Port { get; set; }
/// <summary>
/// Discovered BT LE Address
/// </summary>
public string BTAddress { get; set; }
public string? BTAddress { get; set; }
/// <summary>
/// Idle Session Interval
/// </summary>
@@ -80,7 +84,7 @@ public record ODNode

public override string ToString()
{
return $"Vendor: {Vendor}, Product: {Product}, Discriminator: {Discriminator:X3}, Name: {DeviceName}, Address: {(BTAddress != null ? BTAddress : $"{IPAddress}:{Port}")}, Type: {Type}, Mode: {CommissioningMode}";
return $"Vendor: {Vendor}, Product: {Product}, Discriminator: {Discriminator:X3}, Name: {DeviceName}, Address: {(BTAddress != null ? BTAddress : $"{IP6Address}:{Port}")}, Type: {Type}, Mode: {CommissioningMode}";
}
}
}
53 changes: 26 additions & 27 deletions MatterDotNet/Protocol/Connection/BTPConnection.cs
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
using MatterDotNet.Protocol.Payloads.Flags;
using MatterDotNet.Protocol.Payloads.OpCodes;
using MatterDotNet.Protocol.Sessions;
using System.Net;
using System.Threading.Channels;

namespace MatterDotNet.Protocol.Connection
@@ -43,9 +44,11 @@ internal class BTPConnection : IConnection
SemaphoreSlim WriteLock = new SemaphoreSlim(1, 1);
bool connected;
BluetoothDevice? device;
BLEEndPoint destination;

public BTPConnection(BLEEndPoint bleDevice)
{
destination = bleDevice;
Connect(bleDevice.Address).Wait();
AckTimer = new Timer(SendAck, null, ACK_TIME, ACK_TIME);
if (Read == null || Write == null)
@@ -60,7 +63,7 @@ private async Task Connect(string deviceID)
}

private async Task Connect()
{
{
if (!device!.Gatt.IsConnected)
await device.Gatt.ConnectAsync();
MTU = (ushort)Math.Min(device.Gatt.Mtu, 244);
@@ -83,29 +86,22 @@ private void Device_GattServerDisconnected(object? sender, EventArgs e)

private async Task SendHandshake()
{
try
{
Console.WriteLine("Send Handshake Request");
BTPFrame handshake = new BTPFrame(BTPFlags.Handshake | BTPFlags.Management | BTPFlags.Beginning | BTPFlags.Ending);
handshake.OpCode = BTPManagementOpcode.Handshake;
handshake.WindowSize = 8;
handshake.ATT_MTU = MTU;
await Write.WriteValueWithResponseAsync(handshake.Serialize(9));
Read.CharacteristicValueChanged += Read_CharacteristicValueChanged;
await Read.StartNotificationsAsync();

BTPFrame frame = await instream.Reader.ReadAsync();
MTU = frame.ATT_MTU;
ServerWindow = frame.WindowSize;
if (frame.Version != BTPFrame.MATTER_BT_VERSION1)
throw new NotSupportedException($"Version {frame.Version} not supported");
connected = true;
Console.WriteLine($"MTU: {MTU}, Window: {ServerWindow}");
}
catch(Exception ex)
{
Console.WriteLine(ex.Message);
}
Console.WriteLine("Send Handshake Request");
BTPFrame handshake = new BTPFrame(BTPFlags.Handshake | BTPFlags.Management | BTPFlags.Beginning | BTPFlags.Ending);
handshake.OpCode = BTPManagementOpcode.Handshake;
handshake.WindowSize = 8;
handshake.ATT_MTU = MTU;
await Write.WriteValueWithResponseAsync(handshake.Serialize(9));
Read.CharacteristicValueChanged += Read_CharacteristicValueChanged;
await Read.StartNotificationsAsync();

BTPFrame frame = await instream.Reader.ReadAsync();
MTU = frame.ATT_MTU;
ServerWindow = frame.WindowSize;
if (frame.Version != BTPFrame.MATTER_BT_VERSION1)
throw new NotSupportedException($"Version {frame.Version} not supported");
connected = true;
Console.WriteLine($"MTU: {MTU}, Window: {ServerWindow}");
}

private async void SendAck(object? state)
@@ -125,6 +121,7 @@ private async void SendAck(object? state)
Console.WriteLine("[StandaloneAck] Wrote Segment: " + segment);
await Write.WriteValueWithResponseAsync(segment.Serialize(MTU));
}
catch (OperationCanceledException) { }
finally
{
WriteLock.Release();
@@ -201,13 +198,13 @@ public async Task Run()
foreach (BTPFrame part in segments)
buffer.Write(part.Payload);
segments.Clear();
Frame frame = new Frame(buffer.GetPayload().Span);
Frame frame = new Frame(buffer.GetPayload().Span, destination);
if (!frame.Valid)
{
Console.WriteLine("Invalid frame received");
continue;
}
SessionContext? session = SessionManager.GetSession(frame.SessionID);
SessionContext? session = SessionManager.GetSession(frame.SessionID, destination);
Console.WriteLine(DateTime.Now.ToString("h:mm:ss") + " Received: " + frame.ToString());
if (session == null)
{
@@ -231,7 +228,9 @@ public Task CloseExchange(Exchange exchange)
return Task.CompletedTask;
}

public bool Connected { get { return connected; } }
public bool Connected { get { return connected; } }

public EndPoint EndPoint { get { return destination; } }

/// <inheritdoc />
public void Dispose()
2 changes: 2 additions & 0 deletions MatterDotNet/Protocol/Connection/IConnection.cs
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@

using MatterDotNet.Protocol.Payloads;
using MatterDotNet.Protocol.Sessions;
using System.Net;

namespace MatterDotNet.Protocol.Connection
{
@@ -20,5 +21,6 @@ internal interface IConnection : IDisposable
Task SendFrame(Exchange exchange, Frame frame, bool reliable);
Task CloseExchange(Exchange exchange);
bool Connected { get; }
EndPoint EndPoint { get; }
}
}
10 changes: 7 additions & 3 deletions MatterDotNet/Protocol/Connection/MRPConnection.cs
Original file line number Diff line number Diff line change
@@ -34,9 +34,11 @@ internal class MRPConnection : IConnection
CancellationTokenSource cts = new CancellationTokenSource();

UdpClient client;
IPEndPoint destination;

public MRPConnection(IPEndPoint ep)
{
destination = ep;
client = new UdpClient(ep.AddressFamily);
client.Connect(ep);
Task.Factory.StartNew(Run);
@@ -68,7 +70,7 @@ public async Task SendFrame(Exchange exchange, Frame frame, bool reliable)
}
}
}
Console.WriteLine(DateTime.Now.ToString("h:mm:ss") + " SENT: " + frame.ToString());
Console.WriteLine("MRP SENT: " + frame.ToString());
await client.SendAsync(writer.GetPayload());
exchange.Session.Timestamp = DateTime.Now;
while (reliable)
@@ -132,13 +134,13 @@ public async Task Run()
while (!cts.IsCancellationRequested)
{
UdpReceiveResult result = await client.ReceiveAsync();
Frame frame = new Frame(result.Buffer);
Frame frame = new Frame(result.Buffer, destination);
if (!frame.Valid)
{
Console.WriteLine("Invalid frame received");
continue;
}
SessionContext? session = SessionManager.GetSession(frame.SessionID);
SessionContext? session = SessionManager.GetSession(frame.SessionID, destination);
bool ack = false;
if ((frame.Message.Flags & ExchangeFlags.Reliability) == ExchangeFlags.Reliability)
{
@@ -183,6 +185,8 @@ public async Task CloseExchange(Exchange exchange)

public bool Connected { get { return !cts.IsCancellationRequested; } }

public EndPoint EndPoint { get { return destination; } }

/// <inheritdoc />
public void Dispose()
{
8 changes: 6 additions & 2 deletions MatterDotNet/Protocol/Connection/TCPConnection.cs
Original file line number Diff line number Diff line change
@@ -21,11 +21,13 @@ namespace MatterDotNet.Protocol.Connection
{
internal class TCPConnection : IConnection
{
IPEndPoint destination;
TcpClient client;
NetworkStream stream;
CancellationTokenSource cts = new CancellationTokenSource();
public TCPConnection(IPEndPoint destination)
{
this.destination = destination;
client = new TcpClient();
client.Connect(destination);
stream = client.GetStream();
@@ -52,9 +54,9 @@ public async Task Run()
await stream.ReadExactlyAsync(len);
frameLen = BinaryPrimitives.ReadInt32LittleEndian(len);
await stream.ReadExactlyAsync(data.Slice(0, frameLen));
Frame frame = new Frame(data.Slice(0, frameLen).Span);
Frame frame = new Frame(data.Slice(0, frameLen).Span, destination);
Console.WriteLine(DateTime.Now.ToString("h:mm:ss") + " Received: " + frame.ToString());
SessionContext? session = SessionManager.GetSession(frame.SessionID);
SessionContext? session = SessionManager.GetSession(frame.SessionID, destination);
if (session == null)
{
Console.WriteLine("Unknown Session: " + frame.SessionID);
@@ -68,6 +70,8 @@ public async Task Run()

public bool Connected { get { return client.Connected; } }

public EndPoint EndPoint { get { return destination; } }

public void Dispose()
{
cts.Cancel();
5 changes: 3 additions & 2 deletions MatterDotNet/Protocol/Payloads/Frame.cs
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
using MatterDotNet.Protocol.Sessions;
using System.Buffers;
using System.Buffers.Binary;
using System.Net;
using System.Text;

namespace MatterDotNet.Protocol.Payloads
@@ -130,14 +131,14 @@ public Frame(IPayload? payload, byte opCode)
Message = new Version1Payload(payload, opCode);
}

public Frame(Span<byte> payload)
public Frame(Span<byte> payload, EndPoint endPoint)
{
Valid = true;
Flags = (MessageFlags)payload[0];
SessionID = BinaryPrimitives.ReadUInt16LittleEndian(payload.Slice(1, 2));
Security = (SecurityFlags)payload[3];

SecureSession? session = SessionManager.GetSession(SessionID) as SecureSession;
SecureSession? session = SessionManager.GetSession(SessionID, endPoint) as SecureSession;

if ((Security & SecurityFlags.Privacy) == SecurityFlags.Privacy)
{
19 changes: 13 additions & 6 deletions MatterDotNet/Protocol/Sessions/Exchange.cs
Original file line number Diff line number Diff line change
@@ -31,12 +31,19 @@ internal Exchange(SessionContext session, ushort id)

public async Task SendFrame(Frame frame, bool reliable = true)
{
frame.SessionID = Session.RemoteSessionID;
if (Session.Initiator)
frame.Message.Flags |= ExchangeFlags.Initiator;
frame.Message.ExchangeID = ID;
frame.Counter = Session.GetSessionCounter();
await Session.Connection.SendFrame(this, frame, reliable);
try
{
frame.SessionID = Session.RemoteSessionID;
if (Session.Initiator)
frame.Message.Flags |= ExchangeFlags.Initiator;
frame.Message.ExchangeID = ID;
frame.Counter = Session.GetSessionCounter();
await Session.Connection.SendFrame(this, frame, reliable);
}
catch(OperationCanceledException e)
{
Console.WriteLine("Failed to send frame: " + e.ToString());
}
}

public async Task<Frame> Read(CancellationToken token = default)
2 changes: 1 addition & 1 deletion MatterDotNet/Protocol/Sessions/SessionContext.cs
Original file line number Diff line number Diff line change
@@ -141,7 +141,7 @@ public void Dispose()
var keys = exchanges.Keys;
foreach (var key in keys)
exchanges[key].Dispose();
SessionManager.RemoveSession(LocalSessionID);
SessionManager.RemoveSession(LocalSessionID, Connection.EndPoint);
GC.SuppressFinalize(this);
}
}
43 changes: 32 additions & 11 deletions MatterDotNet/Protocol/Sessions/SessionManager.cs
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ public static class SessionManager
{
private static uint globalCtr;
private static ConcurrentDictionary<EndPoint, IConnection> connections = new ConcurrentDictionary<EndPoint, IConnection>();
private static ConcurrentDictionary<ushort, SessionContext> sessions = new ConcurrentDictionary<ushort, SessionContext>();
private static ConcurrentDictionary<EndPoint, ConcurrentDictionary<ushort, SessionContext>> sessions = new ConcurrentDictionary<EndPoint, ConcurrentDictionary<ushort, SessionContext>>();

public static SessionContext GetUnsecureSession(EndPoint ep, bool initiator)
{
@@ -33,10 +33,16 @@ public static SessionContext GetUnsecureSession(EndPoint ep, bool initiator)

internal static SessionContext GetUnsecureSession(IConnection connection, bool initiator)
{
if (sessions.TryGetValue(0, out SessionContext? existing))
return existing;
ConcurrentDictionary<ushort, SessionContext>? existing;
if (!sessions.TryGetValue(connection.EndPoint, out existing))
{
existing = new ConcurrentDictionary<ushort, SessionContext>();
sessions.TryAdd(connection.EndPoint, existing);
}
if (existing.TryGetValue(0, out SessionContext? existingSession))
return existingSession;
SessionContext ctx = new SessionContext(connection, initiator, 0, 0, 0, 0, new MessageState());
sessions.TryAdd(0, ctx);
existing.TryAdd(0, ctx);
return ctx;
}

@@ -49,27 +55,42 @@ internal static SessionContext GetUnsecureSession(IConnection connection, bool i
{
if (group == false && initiatorSessionId == 0)
return null; //Unsecured session
ConcurrentDictionary<ushort, SessionContext>? existing;
if (!sessions.TryGetValue(connection.EndPoint, out existing))
{
existing = new ConcurrentDictionary<ushort, SessionContext>();
sessions.TryAdd(connection.EndPoint, existing);
}
if (existing.TryGetValue(initiator ? initiatorSessionId : responderSessionId, out SessionContext? existingSession) && existingSession is SecureSession secure && secure.PASE == PASE)
return secure;

SecureSession ctx = new SecureSession(connection, PASE, initiator, initiator ? initiatorSessionId : responderSessionId, initiator ? responderSessionId : initiatorSessionId, i2r, r2i, sharedSecret, resumptionId, 0, new MessageState(), localNodeId, peerNodeId, idleInterval, activeInterval, activeThreshold);
Console.WriteLine("Secure Session Created: " + ctx.LocalSessionID);
sessions.TryAdd(ctx.LocalSessionID, ctx);
existing.TryAdd(ctx.LocalSessionID, ctx);
return ctx;
}

public static SessionContext? GetSession(ushort sessionId)
public static SessionContext? GetSession(ushort sessionId, EndPoint endPoint)
{
if (sessions.TryGetValue(sessionId, out SessionContext? ctx))
return ctx;
ConcurrentDictionary<ushort, SessionContext>? existing;
if (!sessions.TryGetValue(endPoint, out existing))
return null;
if (existing.TryGetValue(sessionId, out SessionContext? existingSession))
return existingSession;
return null;
}

internal static void RemoveSession(ushort sessionId)
internal static void RemoveSession(ushort sessionId, EndPoint endPoint)
{
sessions.TryRemove(sessionId, out _);
ConcurrentDictionary<ushort, SessionContext>? existing;
if (!sessions.TryGetValue(endPoint, out existing))
return;
existing.TryRemove(sessionId, out _);
}

internal static ushort GetAvailableSessionID()
{
return (ushort)Random.Shared.Next(0, ushort.MaxValue);
return (ushort)Random.Shared.Next(1, ushort.MaxValue);
}

public static uint GlobalUnencryptedCounter

0 comments on commit 675906f

Please sign in to comment.