Skip to content

Commit

Permalink
Merge pull request #88 from brnbs/master
Browse files Browse the repository at this point in the history
Fix buffer handling for large WebSocket messages
  • Loading branch information
twitchax authored Jan 21, 2022
2 parents 572fc3e + 5d1ef86 commit 804525f
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 16 deletions.
26 changes: 21 additions & 5 deletions src/Core/Extensions/Ws.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.IO;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
Expand Down Expand Up @@ -70,37 +71,52 @@ internal static async Task ExecuteWsProxyOperationAsync(this HttpContext context

private static async Task PumpWebSocket(WebSocket source, WebSocket destination, int bufferSize, CancellationToken cancellationToken)
{
var buffer = new byte[bufferSize];
using var ms = new MemoryStream();
var receiveBuffer = WebSocket.CreateServerBuffer(bufferSize);

while (true)
{
WebSocketReceiveResult result;

try
{
result = await source.ReceiveAsync(new ArraySegment<byte>(buffer), cancellationToken).ConfigureAwait(false);
ms.SetLength(0);

do
{
result = await source.ReceiveAsync(receiveBuffer, cancellationToken).ConfigureAwait(false);
ms.Write(receiveBuffer.Array!, receiveBuffer.Offset, result.Count);
}
while (!result.EndOfMessage);
}
catch (Exception e)
{
var closeMessageBytes = Encoding.UTF8.GetBytes($"WebSocket failure.\n\n{e.Message}\n\n{e.StackTrace}");
var closeMessage = Encoding.UTF8.GetString(closeMessageBytes, 0, Math.Min(closeMessageBytes.Length, CloseMessageMaxSize));
await destination.CloseOutputAsync(WebSocketCloseStatus.EndpointUnavailable, closeMessage, cancellationToken).ConfigureAwait(false);

return;
}

if(destination.State != WebSocketState.Open && destination.State != WebSocketState.CloseReceived)
if (destination.State != WebSocketState.Open && destination.State != WebSocketState.CloseReceived)
{
return;
}

if (result.MessageType == WebSocketMessageType.Close)
{
await destination.CloseOutputAsync(source.CloseStatus.Value, source.CloseStatusDescription, cancellationToken).ConfigureAwait(false);
var closeStatus = source.CloseStatus ?? WebSocketCloseStatus.Empty;
await destination.CloseOutputAsync(closeStatus, source.CloseStatusDescription, cancellationToken).ConfigureAwait(false);

return;
}

var sendBuffer = new ArraySegment<byte>(ms.GetBuffer(), 0, (int)ms.Length);

// TODO: Add handlers here to allow the developer to edit message before forwarding, and vice versa?
// Possibly in the future, if deemed useful.

await destination.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, cancellationToken).ConfigureAwait(false);
await destination.SendAsync(sendBuffer, result.MessageType, result.EndOfMessage, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down
29 changes: 26 additions & 3 deletions src/Test/Extensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Diagnostics;
using System.IO;
using System.Net.WebSockets;
using System.Text;
Expand Down Expand Up @@ -37,13 +36,37 @@ internal static async Task<string> ReceiveShortMessageAsync(this WebSocket socke
return Encoding.UTF8.GetString(buffer, 0, result.Count);
}

internal static Task SendMessageAsync(this WebSocket socket, string message)
{
return socket.SendAsync(new ArraySegment<byte>(Encoding.UTF8.GetBytes(message)), WebSocketMessageType.Text, true, CancellationToken.None);
}

internal static async Task<string> ReceiveMessageAsync(this WebSocket socket)
{
var buffer = new ArraySegment<byte>(new byte[BUFFER_SIZE]);
WebSocketReceiveResult result;

using var ms = new MemoryStream();
do
{
result = await socket.ReceiveAsync(buffer, CancellationToken.None);
ms.Write(buffer.Array!, buffer.Offset, result.Count);
}
while (!result.EndOfMessage);

ms.Seek(0, SeekOrigin.Begin);

using var reader = new StreamReader(ms, Encoding.UTF8);
return await reader.ReadToEndAsync();
}

internal static async Task SocketBoomerang(this HttpContext context)
{
var socket = await context.WebSockets.AcceptWebSocketAsync(SupportedProtocol);

while(true)
{
var message = await socket.ReceiveShortMessageAsync();
var message = await socket.ReceiveMessageAsync();

if(message == CloseMessage)
{
Expand All @@ -57,7 +80,7 @@ internal static async Task SocketBoomerang(this HttpContext context)
}

// Basically, this server just always sends back a message that is the message it received wrapped with "[]".
await socket.SendShortMessageAsync($"[{message}]");
await socket.SendMessageAsync($"[{message}]");
}
}

Expand Down
2 changes: 0 additions & 2 deletions src/Test/Http/HttpHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Routing;
using Microsoft.Extensions.DependencyInjection;
using AspNetCore.Proxy;
using AspNetCore.Proxy.Options;
using System.Diagnostics.CodeAnalysis;

Expand Down
3 changes: 0 additions & 3 deletions src/Test/RunProxy/RunProxyHelpers.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using AspNetCore.Proxy;

namespace AspNetCore.Proxy.Tests
{
Expand Down
1 change: 0 additions & 1 deletion src/Test/Unit/Endpoints.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using AspNetCore.Proxy.Endpoints;
using Xunit;

Expand Down
11 changes: 10 additions & 1 deletion src/Test/Ws/WsHelpers.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore;
Expand All @@ -6,7 +7,6 @@
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using AspNetCore.Proxy;
using Microsoft.AspNetCore.Http;

namespace AspNetCore.Proxy.Tests
Expand Down Expand Up @@ -78,5 +78,14 @@ internal static Task RunWsServers(CancellationToken token)

return Task.WhenAll(proxiedServerTask, proxyServerTask);
}

internal static string GetRandomBase64String(int sizeInKb)
{
var rnd = new Random();
var b = new byte[sizeInKb * 1024];
rnd.NextBytes(b);

return Convert.ToBase64String(b);
}
}
}
36 changes: 35 additions & 1 deletion src/Test/Ws/WsIntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using System.Net;
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
Expand Down Expand Up @@ -71,6 +70,41 @@ public async Task CanDoWebSockets(string server)
await _client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, Extensions.CloseDescription, CancellationToken.None);
}

[Theory]
[InlineData("ws://localhost:5001/ws")]
[InlineData("ws://localhost:5001/api/ws")]
[InlineData("ws://localhost:5001/api/ws2")]
public async Task CanDoWebSocketsWithLargeDataChunks(string server)
{
var send1 = WsHelpers.GetRandomBase64String(10);
var expected1 = $"[{send1}]";

var send2 = WsHelpers.GetRandomBase64String(500);
var expected2 = $"[{send2}]";

await _client.ConnectAsync(new Uri(server), CancellationToken.None);
Assert.Equal(Extensions.SupportedProtocol, _client.SubProtocol);

// Send a message.
await _client.SendMessageAsync(send1);
await _client.SendMessageAsync(send2);
await _client.SendShortMessageAsync(Extensions.CloseMessage);

// Receive responses.
var response1 = await _client.ReceiveMessageAsync();
Assert.Equal(expected1, response1);
var response2 = await _client.ReceiveMessageAsync();
Assert.Equal(expected2, response2);

// Receive close.
var result = await _client.ReceiveAsync(new ArraySegment<byte>(new byte[4096]), CancellationToken.None);
Assert.Equal(WebSocketMessageType.Close, result.MessageType);
Assert.Equal(WebSocketCloseStatus.NormalClosure, result.CloseStatus);
Assert.Equal(Extensions.CloseDescription, result.CloseStatusDescription);

await _client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, Extensions.CloseDescription, CancellationToken.None);
}

[Fact]
public async Task CanCatchAbruptClose()
{
Expand Down

0 comments on commit 804525f

Please sign in to comment.