Skip to content

Commit

Permalink
Handle client disconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
cristipufu committed Jul 20, 2024
1 parent 738e738 commit fe9db85
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 23 deletions.
35 changes: 21 additions & 14 deletions src/WebSocketTunnel.Client/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,33 @@ namespace WebSocketTunnel.Client;

public class Program
{
private static readonly HttpClient HttpClient = new();
private static HubConnection? Connection;
private static readonly Guid InstanceId = Guid.NewGuid();
private static readonly string Server = "https://localhost:7193";
private static readonly Guid ClientId = Guid.NewGuid();
private static readonly string Server = "https://tunnelite.azurewebsites.net/";
//private static readonly string Server = "https://localhost:7193";
private static readonly int ChunkSize = 512 * 1024; // 512KB

private static readonly HttpClientHandler LocalHttpClientHandler = new()
{
ServerCertificateCustomValidationCallback = (message, cert, chain, sslPolicyErrors) => true,
};
private static readonly HttpClient ServerHttpClient = new();
private static readonly HttpClient LocalHttpClient = new(LocalHttpClientHandler);

public static async Task Main(string[] args)
{
var localUrlArgument = new Argument<string>("localUrl", "The local URL to tunnel to.");

var rootCommand = new RootCommand
{
localUrlArgument
localUrlArgument,
};

rootCommand.Description = "CLI tool to create a tunnel to a local server.";

rootCommand.SetHandler(async (string localUrl) =>
{
await ConnectToServerAsync(localUrl, Server, InstanceId);
await ConnectToServerAsync(localUrl, Server, ClientId);

}, localUrlArgument);

Expand All @@ -35,14 +42,14 @@ public static async Task Main(string[] args)
Console.ReadLine();
}

private static async Task<bool> RegisterTunnelAsync(string localUrl, string publicUrl, Guid instanceId)
private static async Task<bool> RegisterTunnelAsync(string localUrl, string publicUrl, Guid clientId)
{
var response = await HttpClient.PostAsJsonAsync(
var response = await ServerHttpClient.PostAsJsonAsync(
$"{publicUrl}/register-tunnel",
new Tunnel
{
LocalUrl = localUrl,
InstanceId = instanceId,
ClientId = clientId,
});

var message = await response.Content.ReadAsStringAsync();
Expand All @@ -61,10 +68,10 @@ private static async Task<bool> RegisterTunnelAsync(string localUrl, string publ
}
}

private static async Task ConnectToServerAsync(string localUrl, string publicUrl, Guid instanceId)
private static async Task ConnectToServerAsync(string localUrl, string publicUrl, Guid clientId)
{
Connection = new HubConnectionBuilder()
.WithUrl($"{publicUrl}/wstunnel?instanceId={instanceId}", options =>
.WithUrl($"{publicUrl}/wstunnel?clientId={clientId}", options =>
{
options.TransportMaxBufferSize = ChunkSize;
options.ApplicationMaxBufferSize = ChunkSize;
Expand All @@ -83,7 +90,7 @@ private static async Task ConnectToServerAsync(string localUrl, string publicUrl
{
Console.WriteLine($"Reconnected. New ConnectionId {connectionId}");

await RegisterTunnelAsync(localUrl, publicUrl, instanceId);
await RegisterTunnelAsync(localUrl, publicUrl, clientId);
};

Connection.Closed += async (error) =>
Expand All @@ -94,13 +101,13 @@ private static async Task ConnectToServerAsync(string localUrl, string publicUrl

if (await ConnectWithRetryAsync(Connection, CancellationToken.None))
{
await RegisterTunnelAsync(localUrl, publicUrl, instanceId);
await RegisterTunnelAsync(localUrl, publicUrl, clientId);
}
};

if (await ConnectWithRetryAsync(Connection, CancellationToken.None))
{
await RegisterTunnelAsync(localUrl, publicUrl, instanceId);
await RegisterTunnelAsync(localUrl, publicUrl, clientId);
}
}

Expand Down Expand Up @@ -137,7 +144,7 @@ private static async Task TunnelRequestAsync(RequestMetadata requestMetadata)
requestMessage.Content.Headers.ContentType = new MediaTypeHeaderValue(requestMetadata.ContentType);
}

var response = await HttpClient.SendAsync(requestMessage);
var response = await LocalHttpClient.SendAsync(requestMessage);

var responseMetadata = new ResponseMetadata
{
Expand Down
2 changes: 1 addition & 1 deletion src/WebSocketTunnel.Client/Tunnel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ namespace WebSocketTunnel.Client
{
public class Tunnel
{
public Guid? InstanceId { get; set; }
public Guid? ClientId { get; set; }

public string LocalUrl { get; set; }
}
Expand Down
10 changes: 6 additions & 4 deletions src/WebSocketTunnel.Server/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
return;
}

if (payload.InstanceId == Guid.Empty)
if (payload.ClientId == Guid.Empty)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
await context.Response.WriteAsync("Missing or invalid 'InstanceId' property.");
await context.Response.WriteAsync("Missing or invalid 'ClientId' property.");
return;
}

Expand All @@ -44,6 +44,7 @@
payload.LocalUrl = payload.LocalUrl.TrimEnd(['/']);

tunnelStore.Tunnels.AddOrUpdate(subdomain, payload, (key, oldValue) => payload);
tunnelStore.Clients.AddOrUpdate(payload.ClientId, subdomain, (key, oldValue) => subdomain);

var baseUrl = $"{context.Request.Scheme}://{context.Request.Host}{context.Request.PathBase}";

Expand Down Expand Up @@ -83,7 +84,8 @@ static async Task ProxyRequestAsync(HttpContext context, IHubContext<TunnelHub>

var subdomain = context.Request.Host.Host.Split('.')[0];

if (subdomain.Equals("tunnelite", StringComparison.OrdinalIgnoreCase))
if (subdomain.Equals("tunnelite", StringComparison.OrdinalIgnoreCase) ||
subdomain.Equals("localhost", StringComparison.OrdinalIgnoreCase))
{
tunnel = tunnelStore.Tunnels.FirstOrDefault().Value;
}
Expand All @@ -100,7 +102,7 @@ static async Task ProxyRequestAsync(HttpContext context, IHubContext<TunnelHub>
return;
}

if (!tunnelStore.Connections.TryGetValue(tunnel!.InstanceId, out var connectionId))
if (!tunnelStore.Connections.TryGetValue(tunnel!.ClientId, out var connectionId))
{
context.Response.StatusCode = StatusCodes.Status404NotFound;
await context.Response.WriteAsync("Client disconnected!");
Expand Down
16 changes: 13 additions & 3 deletions src/WebSocketTunnel.Server/TunnelHub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,26 @@ public class TunnelHub(RequestsQueue requestsQueue, TunnelStore tunnelStore) : H

public override Task OnConnectedAsync()
{
var instanceId = Context.GetHttpContext()!.Request.Query["instanceId"].ToString();
var clientId = Context.GetHttpContext()!.Request.Query["clientId"].ToString();

_tunnelStore.Connections.AddOrUpdate(Guid.Parse(instanceId), Context.ConnectionId, (key, oldValue) => Context.ConnectionId);
_tunnelStore.Connections.AddOrUpdate(Guid.Parse(clientId), Context.ConnectionId, (key, oldValue) => Context.ConnectionId);

return base.OnConnectedAsync();
}

public override Task OnDisconnectedAsync(Exception? exception)
{
// todo
var clientIdQuery = Context.GetHttpContext()!.Request.Query["clientId"].ToString();

var clientId = Guid.Parse(clientIdQuery);

if (_tunnelStore.Clients.TryGetValue(clientId, out var subdomain))
{
_tunnelStore.Tunnels.Remove(subdomain, out var _);
_tunnelStore.Connections.Remove(clientId, out var _);
_tunnelStore.Clients.Remove(clientId, out _);
}

return base.OnDisconnectedAsync(exception);
}

Expand Down
7 changes: 6 additions & 1 deletion src/WebSocketTunnel.Server/TunnelStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@ namespace WebSocketTunnel.Server
{
public class TunnelStore
{
// subdomain, [clientId, localUrl]
public ConcurrentDictionary<string, Tunnel> Tunnels = new();

// clientId, connectionId
public ConcurrentDictionary<Guid, string> Connections = new();

// clientId, subdomain
public ConcurrentDictionary<Guid, string> Clients = new();
}

public class Tunnel
{
public Guid InstanceId { get; set; }
public Guid ClientId { get; set; }

public string? LocalUrl { get; set; }
}
Expand Down

0 comments on commit fe9db85

Please sign in to comment.