diff --git a/src/WebSocketTunnel.Client/Program.cs b/src/WebSocketTunnel.Client/Program.cs index e352ed1..e688235 100644 --- a/src/WebSocketTunnel.Client/Program.cs +++ b/src/WebSocketTunnel.Client/Program.cs @@ -7,26 +7,33 @@ namespace WebSocketTunnel.Client; public class Program { - private static readonly HttpClient HttpClient = new(); private static HubConnection? Connection; - private static readonly Guid InstanceId = Guid.NewGuid(); - private static readonly string Server = "https://localhost:7193"; + private static readonly Guid ClientId = Guid.NewGuid(); + private static readonly string Server = "https://tunnelite.azurewebsites.net/"; + //private static readonly string Server = "https://localhost:7193"; private static readonly int ChunkSize = 512 * 1024; // 512KB + private static readonly HttpClientHandler LocalHttpClientHandler = new() + { + ServerCertificateCustomValidationCallback = (message, cert, chain, sslPolicyErrors) => true, + }; + private static readonly HttpClient ServerHttpClient = new(); + private static readonly HttpClient LocalHttpClient = new(LocalHttpClientHandler); + public static async Task Main(string[] args) { var localUrlArgument = new Argument("localUrl", "The local URL to tunnel to."); var rootCommand = new RootCommand { - localUrlArgument + localUrlArgument, }; rootCommand.Description = "CLI tool to create a tunnel to a local server."; rootCommand.SetHandler(async (string localUrl) => { - await ConnectToServerAsync(localUrl, Server, InstanceId); + await ConnectToServerAsync(localUrl, Server, ClientId); }, localUrlArgument); @@ -35,14 +42,14 @@ public static async Task Main(string[] args) Console.ReadLine(); } - private static async Task RegisterTunnelAsync(string localUrl, string publicUrl, Guid instanceId) + private static async Task RegisterTunnelAsync(string localUrl, string publicUrl, Guid clientId) { - var response = await HttpClient.PostAsJsonAsync( + var response = await ServerHttpClient.PostAsJsonAsync( $"{publicUrl}/register-tunnel", new Tunnel { LocalUrl = localUrl, - InstanceId = instanceId, + ClientId = clientId, }); var message = await response.Content.ReadAsStringAsync(); @@ -61,10 +68,10 @@ private static async Task RegisterTunnelAsync(string localUrl, string publ } } - private static async Task ConnectToServerAsync(string localUrl, string publicUrl, Guid instanceId) + private static async Task ConnectToServerAsync(string localUrl, string publicUrl, Guid clientId) { Connection = new HubConnectionBuilder() - .WithUrl($"{publicUrl}/wstunnel?instanceId={instanceId}", options => + .WithUrl($"{publicUrl}/wstunnel?clientId={clientId}", options => { options.TransportMaxBufferSize = ChunkSize; options.ApplicationMaxBufferSize = ChunkSize; @@ -83,7 +90,7 @@ private static async Task ConnectToServerAsync(string localUrl, string publicUrl { Console.WriteLine($"Reconnected. New ConnectionId {connectionId}"); - await RegisterTunnelAsync(localUrl, publicUrl, instanceId); + await RegisterTunnelAsync(localUrl, publicUrl, clientId); }; Connection.Closed += async (error) => @@ -94,13 +101,13 @@ private static async Task ConnectToServerAsync(string localUrl, string publicUrl if (await ConnectWithRetryAsync(Connection, CancellationToken.None)) { - await RegisterTunnelAsync(localUrl, publicUrl, instanceId); + await RegisterTunnelAsync(localUrl, publicUrl, clientId); } }; if (await ConnectWithRetryAsync(Connection, CancellationToken.None)) { - await RegisterTunnelAsync(localUrl, publicUrl, instanceId); + await RegisterTunnelAsync(localUrl, publicUrl, clientId); } } @@ -137,7 +144,7 @@ private static async Task TunnelRequestAsync(RequestMetadata requestMetadata) requestMessage.Content.Headers.ContentType = new MediaTypeHeaderValue(requestMetadata.ContentType); } - var response = await HttpClient.SendAsync(requestMessage); + var response = await LocalHttpClient.SendAsync(requestMessage); var responseMetadata = new ResponseMetadata { diff --git a/src/WebSocketTunnel.Client/Tunnel.cs b/src/WebSocketTunnel.Client/Tunnel.cs index 9ca453d..893027b 100644 --- a/src/WebSocketTunnel.Client/Tunnel.cs +++ b/src/WebSocketTunnel.Client/Tunnel.cs @@ -3,7 +3,7 @@ namespace WebSocketTunnel.Client { public class Tunnel { - public Guid? InstanceId { get; set; } + public Guid? ClientId { get; set; } public string LocalUrl { get; set; } } diff --git a/src/WebSocketTunnel.Server/Program.cs b/src/WebSocketTunnel.Server/Program.cs index ec0dd98..8bc6033 100644 --- a/src/WebSocketTunnel.Server/Program.cs +++ b/src/WebSocketTunnel.Server/Program.cs @@ -32,10 +32,10 @@ return; } - if (payload.InstanceId == Guid.Empty) + if (payload.ClientId == Guid.Empty) { context.Response.StatusCode = StatusCodes.Status400BadRequest; - await context.Response.WriteAsync("Missing or invalid 'InstanceId' property."); + await context.Response.WriteAsync("Missing or invalid 'ClientId' property."); return; } @@ -44,6 +44,7 @@ payload.LocalUrl = payload.LocalUrl.TrimEnd(['/']); tunnelStore.Tunnels.AddOrUpdate(subdomain, payload, (key, oldValue) => payload); + tunnelStore.Clients.AddOrUpdate(payload.ClientId, subdomain, (key, oldValue) => subdomain); var baseUrl = $"{context.Request.Scheme}://{context.Request.Host}{context.Request.PathBase}"; @@ -83,7 +84,8 @@ static async Task ProxyRequestAsync(HttpContext context, IHubContext var subdomain = context.Request.Host.Host.Split('.')[0]; - if (subdomain.Equals("tunnelite", StringComparison.OrdinalIgnoreCase)) + if (subdomain.Equals("tunnelite", StringComparison.OrdinalIgnoreCase) || + subdomain.Equals("localhost", StringComparison.OrdinalIgnoreCase)) { tunnel = tunnelStore.Tunnels.FirstOrDefault().Value; } @@ -100,7 +102,7 @@ static async Task ProxyRequestAsync(HttpContext context, IHubContext return; } - if (!tunnelStore.Connections.TryGetValue(tunnel!.InstanceId, out var connectionId)) + if (!tunnelStore.Connections.TryGetValue(tunnel!.ClientId, out var connectionId)) { context.Response.StatusCode = StatusCodes.Status404NotFound; await context.Response.WriteAsync("Client disconnected!"); diff --git a/src/WebSocketTunnel.Server/TunnelHub.cs b/src/WebSocketTunnel.Server/TunnelHub.cs index 0930689..19577a6 100644 --- a/src/WebSocketTunnel.Server/TunnelHub.cs +++ b/src/WebSocketTunnel.Server/TunnelHub.cs @@ -11,16 +11,26 @@ public class TunnelHub(RequestsQueue requestsQueue, TunnelStore tunnelStore) : H public override Task OnConnectedAsync() { - var instanceId = Context.GetHttpContext()!.Request.Query["instanceId"].ToString(); + var clientId = Context.GetHttpContext()!.Request.Query["clientId"].ToString(); - _tunnelStore.Connections.AddOrUpdate(Guid.Parse(instanceId), Context.ConnectionId, (key, oldValue) => Context.ConnectionId); + _tunnelStore.Connections.AddOrUpdate(Guid.Parse(clientId), Context.ConnectionId, (key, oldValue) => Context.ConnectionId); return base.OnConnectedAsync(); } public override Task OnDisconnectedAsync(Exception? exception) { - // todo + var clientIdQuery = Context.GetHttpContext()!.Request.Query["clientId"].ToString(); + + var clientId = Guid.Parse(clientIdQuery); + + if (_tunnelStore.Clients.TryGetValue(clientId, out var subdomain)) + { + _tunnelStore.Tunnels.Remove(subdomain, out var _); + _tunnelStore.Connections.Remove(clientId, out var _); + _tunnelStore.Clients.Remove(clientId, out _); + } + return base.OnDisconnectedAsync(exception); } diff --git a/src/WebSocketTunnel.Server/TunnelStore.cs b/src/WebSocketTunnel.Server/TunnelStore.cs index aba74fb..5f9dd28 100644 --- a/src/WebSocketTunnel.Server/TunnelStore.cs +++ b/src/WebSocketTunnel.Server/TunnelStore.cs @@ -4,14 +4,19 @@ namespace WebSocketTunnel.Server { public class TunnelStore { + // subdomain, [clientId, localUrl] public ConcurrentDictionary Tunnels = new(); + // clientId, connectionId public ConcurrentDictionary Connections = new(); + + // clientId, subdomain + public ConcurrentDictionary Clients = new(); } public class Tunnel { - public Guid InstanceId { get; set; } + public Guid ClientId { get; set; } public string? LocalUrl { get; set; } }