From ed6ad60fa9f95fc7e957d82f04f4f981bfb266ac Mon Sep 17 00:00:00 2001 From: Alexey Akimov Date: Fri, 16 Sep 2022 13:59:49 +0300 Subject: [PATCH 1/2] Added version checking in handshake during connection to debug server --- VSRAD.DebugServer/Logging/ClientLogger.cs | 27 +++++++ VSRAD.DebugServer/Server.cs | 74 ++++++++++++++++++- VSRAD.Package/Constants.cs | 1 + VSRAD.Package/Server/CommunicationChannel.cs | 76 +++++++++++++++++++- 4 files changed, 174 insertions(+), 4 deletions(-) diff --git a/VSRAD.DebugServer/Logging/ClientLogger.cs b/VSRAD.DebugServer/Logging/ClientLogger.cs index 42bbe5b7d..2f7066218 100644 --- a/VSRAD.DebugServer/Logging/ClientLogger.cs +++ b/VSRAD.DebugServer/Logging/ClientLogger.cs @@ -71,6 +71,33 @@ public void CommandProcessed() Console.WriteLine($"{Environment.NewLine}Time Elapsed: {_timer.ElapsedMilliseconds}ms"); } + public void ParseVersionError(String version) + { + Console.WriteLine($"{Environment.NewLine}Invalid Version on handshake attempt: {version}"); + } + + public void InvalidVersion(String receivedVersion, String minimalVersion) + { + Console.WriteLine($"{Environment.NewLine}Version mismatch. Client version: {receivedVersion}," + + $" expected version greater then {minimalVersion} "); + } + + public void ClientRejectedServerVersion(String serverVersion, String clientVersion) + { + Console.WriteLine($"{Environment.NewLine}Client rejected server version{Environment.NewLine}" + + $"client version: {clientVersion}, server version: {serverVersion}"); + } + + public void HandshakeFailed(EndPoint clientEndpoint) + { + Console.WriteLine($"{Environment.NewLine}Handshake in connection with {clientEndpoint} failed"); + } + + public void ConnectionTimeoutOnHandShake() + { + Console.WriteLine($"{Environment.NewLine}Connection timeout on handshake attempt"); + } + private void Print(string message) => Console.WriteLine("===" + Environment.NewLine + $"[Client #{_clientId}] {message}"); } diff --git a/VSRAD.DebugServer/Server.cs b/VSRAD.DebugServer/Server.cs index a87752eea..84405a4a6 100644 --- a/VSRAD.DebugServer/Server.cs +++ b/VSRAD.DebugServer/Server.cs @@ -4,6 +4,8 @@ using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; +using System.IO; +using System.Text; namespace VSRAD.DebugServer { @@ -12,6 +14,8 @@ public sealed class Server private readonly SemaphoreSlim _commandExecutionLock = new SemaphoreSlim(1, 1); private readonly TcpListener _listener; private readonly bool _verboseLogging; + private static Version _serverVersion = typeof(Server).Assembly.GetName().Version; + private static Version _minimalAcceptedClientVersion = new Version("2021.12.8"); const uint ENABLE_QUICK_EDIT = 0x0040; const int STD_INPUT_HANDLE = -10; @@ -31,6 +35,14 @@ public Server(IPAddress ip, int port, bool verboseLogging = false) _verboseLogging = verboseLogging; } + public enum HandShakeStatus + { + client_accepted, + client_not_accepted, + server_accepted, + server_not_accepted + } + public async Task LoopAsync() { /* disable Quick Edit cmd feature to prevent server hanging */ @@ -49,19 +61,77 @@ public async Task LoopAsync() while (true) { var client = await _listener.AcceptTcpClientAsync().ConfigureAwait(false); - ClientConnected(client, clientsCount); + TcpClientConnected(client, clientsCount); clientsCount++; } } - private void ClientConnected(TcpClient tcpClient, uint clientId) + private void TcpClientConnected(TcpClient tcpClient, uint clientId) { var networkClient = new NetworkClient(tcpClient, clientId); var clientLog = new ClientLogger(clientId, _verboseLogging); + if (!Task.Run(() => TryProcessServerHandshake(tcpClient, clientLog)).Result) + { + clientLog.HandshakeFailed(networkClient.EndPoint); + return; + } clientLog.ConnectionEstablished(networkClient.EndPoint); Task.Run(() => BeginClientLoopAsync(networkClient, clientLog)); } + private async Task TryProcessServerHandshake(TcpClient client, ClientLogger clientLog) + { + try + { + StreamWriter writer = new StreamWriter(client.GetStream(), Encoding.UTF8) { AutoFlush = true }; + StreamReader reader = new StreamReader(client.GetStream(), Encoding.UTF8); + + // Send server version to client + // + await writer.WriteLineAsync(_serverVersion.ToString()).ConfigureAwait(false); + + // Obtain client version + // + String clientResponse = await reader.ReadLineAsync().ConfigureAwait(false); + + Version clientVersion = null; + if (!Version.TryParse(clientResponse, out clientVersion)) + { + clientLog.ParseVersionError(clientResponse); + // Inform client that server declines client's version + // + await writer.WriteLineAsync(HandShakeStatus.server_not_accepted.ToString()).ConfigureAwait(false); + return false; + } + + if (clientVersion.CompareTo(_minimalAcceptedClientVersion) < 0) + { + clientLog.InvalidVersion(clientVersion.ToString(), _minimalAcceptedClientVersion.ToString()); + // Inform client that server declines client's version + // + await writer.WriteLineAsync(HandShakeStatus.server_not_accepted.ToString()).ConfigureAwait(false); + return false; + } + + // Inform client that server accepts client's version + // + await writer.WriteLineAsync(HandShakeStatus.server_accepted.ToString()).ConfigureAwait(false); + + // Check if client accepts server version + // + if (await reader.ReadLineAsync() != HandShakeStatus.client_accepted.ToString()) + { + clientLog.ClientRejectedServerVersion(_serverVersion.ToString(), clientVersion.ToString()); + return false; + } + } catch (Exception) + { + clientLog.ConnectionTimeoutOnHandShake(); + return false; + } + return true; + } + private async Task BeginClientLoopAsync(NetworkClient client, ClientLogger clientLog) { while (true) diff --git a/VSRAD.Package/Constants.cs b/VSRAD.Package/Constants.cs index 9f9ac2bfc..c5eaf140d 100644 --- a/VSRAD.Package/Constants.cs +++ b/VSRAD.Package/Constants.cs @@ -66,5 +66,6 @@ public static class Constants public const string ToolbarIconStripResourcePackUri = "pack://application:,,,/RadeonAsmDebugger;component/Resources/DebugVisualizerWindowCommand.png"; public const string CurrentStatementIconResourcePackUri = "pack://application:,,,/RadeonAsmDebugger;component/Resources/CurrentStatement.png"; + public static readonly Version MinimalRequiredServerVersion = new Version("2021.3.3"); }; } \ No newline at end of file diff --git a/VSRAD.Package/Server/CommunicationChannel.cs b/VSRAD.Package/Server/CommunicationChannel.cs index b0de4d8be..efe21769a 100644 --- a/VSRAD.Package/Server/CommunicationChannel.cs +++ b/VSRAD.Package/Server/CommunicationChannel.cs @@ -12,6 +12,9 @@ using VSRAD.Package.Options; using VSRAD.Package.ProjectSystem; using Task = System.Threading.Tasks.Task; +using System.IO; +using System.Text; +using System.Text.RegularExpressions; namespace VSRAD.Package.Server { @@ -37,6 +40,20 @@ public ConnectionRefusedException(ServerConnectionOptions connection) : { } } + public sealed class UnsupportedServerVersionException : System.IO.IOException + { + public UnsupportedServerVersionException(ServerConnectionOptions connection, Version serverVersion) : + base($"The debug server on host {connection} is out of date and missing critical features. Please update it to the {serverVersion} or above version.") + { } + } + + public sealed class UnsupportedDebuggerVersionException : System.IO.IOException + { + public UnsupportedDebuggerVersionException(ServerConnectionOptions connection, Version serverVersion) : + base($"This extension is out of date and missing critical features to work with debug server on host {connection}. Please update it to the {serverVersion} or above version.") + { } + } + public enum ClientState { Disconnected, @@ -44,6 +61,14 @@ public enum ClientState Connected } + public enum HandShakeStatus + { + client_accepted, + client_not_accepted, + server_accepted, + server_not_accepted + } + [Export(typeof(ICommunicationChannel))] [AppliesTo(Constants.RadOrVisualCProjectCapability)] public sealed class CommunicationChannel : ICommunicationChannel @@ -51,7 +76,7 @@ public sealed class CommunicationChannel : ICommunicationChannel public event Action ConnectionStateChanged; public ServerConnectionOptions ConnectionOptions => _project.Options.Profile?.General?.Connection ?? new ServerConnectionOptions("Remote address is not specified", 0); - + private Version _extensionVersion; private ClientState _state = ClientState.Disconnected; public ClientState ConnectionState { @@ -64,6 +89,7 @@ public ClientState ConnectionState } private static readonly TimeSpan _connectionTimeout = new TimeSpan(hours: 0, minutes: 0, seconds: 5); + private static readonly Regex _extensionVersionRegex = new Regex(@".*\/(?.*)\/RadeonAsmDebugger\.dll", RegexOptions.Compiled); private readonly OutputWindowWriter _outputWindowWriter; private readonly IProject _project; @@ -81,11 +107,56 @@ public CommunicationChannel(SVsServiceProvider provider, IProject project) _project = project; _project.RunWhenLoaded((options) => options.PropertyChanged += (s, e) => { if (e.PropertyName == nameof(options.ActiveProfile)) ForceDisconnect(); }); + var match = _extensionVersionRegex.Match(typeof(IProject).Assembly.CodeBase); + if (!match.Success) + throw new Exception("Error while getting current extension version."); + _extensionVersion = new Version(match.Groups["version"].Value); } public Task SendWithReplyAsync(ICommand command) where T : IResponse => SendWithReplyAsync(command, tryReconnect: true); + private async Task TryProcessClientHandshake(TcpClient client) + { + StreamWriter writer = new StreamWriter(client.GetStream(), Encoding.UTF8) { AutoFlush = true }; + StreamReader reader = new StreamReader(client.GetStream(), Encoding.UTF8); + + // Send client version to server + // + await writer.WriteLineAsync(_extensionVersion.ToString()).ConfigureAwait(false); + // Obtain server version + // + String serverResponse = await reader.ReadLineAsync().ConfigureAwait(false); + Version serverVersion = null; + if (!Version.TryParse(serverResponse, out serverVersion)) + { + // Inform server that client declines serve's version + // + await writer.WriteLineAsync(HandShakeStatus.client_not_accepted.ToString()).ConfigureAwait(false); + throw new UnsupportedServerVersionException(ConnectionOptions, Constants.MinimalRequiredServerVersion); + } + + if (serverVersion.CompareTo(Constants.MinimalRequiredServerVersion) < 0) + { + // Inform client that server declines client's version + // + await writer.WriteLineAsync(HandShakeStatus.client_not_accepted.ToString()).ConfigureAwait(false); + throw new UnsupportedServerVersionException(ConnectionOptions, Constants.MinimalRequiredServerVersion); + } + + // Inform client that server accepts client's version + // + await writer.WriteLineAsync(HandShakeStatus.client_accepted.ToString()).ConfigureAwait(false); + + // Check if client accepts server version + // + if (await reader.ReadLineAsync() != HandShakeStatus.server_accepted.ToString()) + { + throw new UnsupportedDebuggerVersionException(ConnectionOptions, serverVersion); + } + return true; + } + private async Task SendWithReplyAsync(ICommand command, bool tryReconnect) where T : IResponse { await _mutex.WaitAsync(); @@ -156,11 +227,12 @@ private async Task EstablishServerConnectionAsync() using (cts.Token.Register(() => client.Dispose())) { await client.ConnectAsync(ConnectionOptions.RemoteMachine, ConnectionOptions.Port); + await TryProcessClientHandshake(client); _connection = client; ConnectionState = ClientState.Connected; } } - catch (Exception) + catch (Exception e) when (!(e is UnsupportedDebuggerVersionException) && !(e is UnsupportedServerVersionException)) { ConnectionState = ClientState.Disconnected; throw new ConnectionRefusedException(ConnectionOptions); From 139ecabf1c65c67e03d6c5d2101ba7b514b8b9c8 Mon Sep 17 00:00:00 2001 From: Alexey Akimov Date: Thu, 22 Sep 2022 11:17:00 +0300 Subject: [PATCH 2/2] Added action level lock on the server side This will prevent interleaved execution of commands from different users --- VSRAD.DebugServer/Logging/ClientLogger.cs | 12 +++- VSRAD.DebugServer/NetworkClient.cs | 4 ++ VSRAD.DebugServer/Server.cs | 70 +++++++++++++------- VSRAD.Package/Server/ActionRunner.cs | 2 + VSRAD.Package/Server/CommunicationChannel.cs | 50 ++++++++++++-- 5 files changed, 110 insertions(+), 28 deletions(-) diff --git a/VSRAD.DebugServer/Logging/ClientLogger.cs b/VSRAD.DebugServer/Logging/ClientLogger.cs index 2f7066218..e1f12ed96 100644 --- a/VSRAD.DebugServer/Logging/ClientLogger.cs +++ b/VSRAD.DebugServer/Logging/ClientLogger.cs @@ -35,7 +35,7 @@ public void FatalClientException(Exception e) => Print("An exception has occurred while processing the command. Connection has been terminated." + Environment.NewLine + e.ToString()); public void CliendDisconnected() => - Print("client has been disconnected"); + Console.WriteLine($"{Environment.NewLine}client #{_clientId} has been disconnected"); public void ExecutionStarted() { @@ -98,6 +98,16 @@ public void ConnectionTimeoutOnHandShake() Console.WriteLine($"{Environment.NewLine}Connection timeout on handshake attempt"); } + public void LockAcquired() + { + Console.WriteLine($"{Environment.NewLine}client#{_clientId} acquired lock"); + } + + public void LockReleased() + { + Console.WriteLine($"{Environment.NewLine}client#{_clientId} released lock"); + } + private void Print(string message) => Console.WriteLine("===" + Environment.NewLine + $"[Client #{_clientId}] {message}"); } diff --git a/VSRAD.DebugServer/NetworkClient.cs b/VSRAD.DebugServer/NetworkClient.cs index 80db5a15d..57b9f4566 100644 --- a/VSRAD.DebugServer/NetworkClient.cs +++ b/VSRAD.DebugServer/NetworkClient.cs @@ -38,6 +38,10 @@ public async Task ReceiveCommandAsync() } } + public NetworkStream GetStream() + { + return _socket.GetStream(); + } public Task SendResponseAsync(IPC.Responses.IResponse response) => _socket.GetStream().WriteSerializedMessageAsync(response); diff --git a/VSRAD.DebugServer/Server.cs b/VSRAD.DebugServer/Server.cs index 84405a4a6..98db1dcd1 100644 --- a/VSRAD.DebugServer/Server.cs +++ b/VSRAD.DebugServer/Server.cs @@ -12,6 +12,7 @@ namespace VSRAD.DebugServer public sealed class Server { private readonly SemaphoreSlim _commandExecutionLock = new SemaphoreSlim(1, 1); + private readonly SemaphoreSlim _actionExecutionLock = new SemaphoreSlim(1, 1); private readonly TcpListener _listener; private readonly bool _verboseLogging; private static Version _serverVersion = typeof(Server).Assembly.GetName().Version; @@ -42,6 +43,11 @@ public enum HandShakeStatus server_accepted, server_not_accepted } + public enum LockStatus + { + lock_not_ackquired, + lock_acquired + } public async Task LoopAsync() { @@ -79,6 +85,27 @@ private void TcpClientConnected(TcpClient tcpClient, uint clientId) Task.Run(() => BeginClientLoopAsync(networkClient, clientLog)); } + private async Task AcquireLock(NetworkClient client, ClientLogger clientLog) + { + try + { + await _actionExecutionLock.WaitAsync(); + clientLog.LockAcquired(); + StreamWriter writer = new StreamWriter(client.GetStream(), Encoding.UTF8) { AutoFlush = true }; + await writer.WriteLineAsync(LockStatus.lock_acquired.ToString()); + } + catch (Exception e) + { + _actionExecutionLock.Release(); + clientLog.LockReleased(); + clientLog.FatalClientException(e); + client.Disconnect(); + clientLog.CliendDisconnected(); + return false; + } + return true; + } + private async Task TryProcessServerHandshake(TcpClient client, ClientLogger clientLog) { try @@ -127,6 +154,7 @@ private async Task TryProcessServerHandshake(TcpClient client, ClientLogge } catch (Exception) { clientLog.ConnectionTimeoutOnHandShake(); + return false; } return true; @@ -134,17 +162,18 @@ private async Task TryProcessServerHandshake(TcpClient client, ClientLogge private async Task BeginClientLoopAsync(NetworkClient client, ClientLogger clientLog) { - while (true) + // Client closed connection during acquire lock phase + // + if (!await AcquireLock(client, clientLog)) + return; + + try { - bool lockAcquired = false; - try + while (true) { var command = await client.ReceiveCommandAsync().ConfigureAwait(false); clientLog.CommandReceived(command); - await _commandExecutionLock.WaitAsync(); - lockAcquired = true; - var response = await Dispatcher.DispatchAsync(command, clientLog).ConfigureAwait(false); if (response != null) // commands like Deploy do not return a response { @@ -153,24 +182,19 @@ private async Task BeginClientLoopAsync(NetworkClient client, ClientLogger clien } clientLog.CommandProcessed(); } - catch (ConnectionFailedException) - { - client.Disconnect(); - clientLog.CliendDisconnected(); - break; - } - catch (Exception e) - { - client.Disconnect(); - clientLog.FatalClientException(e); - break; - } - finally - { - if (lockAcquired) - _commandExecutionLock.Release(); - } } + catch (ConnectionFailedException) + { + clientLog.CliendDisconnected(); + } + catch (Exception e) + { + clientLog.FatalClientException(e); + } + + client.Disconnect(); + _actionExecutionLock.Release(); + clientLog.LockReleased(); } } } diff --git a/VSRAD.Package/Server/ActionRunner.cs b/VSRAD.Package/Server/ActionRunner.cs index 06ae00a2c..e6d983ad7 100644 --- a/VSRAD.Package/Server/ActionRunner.cs +++ b/VSRAD.Package/Server/ActionRunner.cs @@ -21,6 +21,7 @@ public sealed class ActionRunner private readonly Dictionary _initialTimestamps = new Dictionary(); private readonly ActionEnvironment _environment; private readonly IProject _project; + private readonly VsStatusBarWriter _statusBar; public ActionRunner(ICommunicationChannel channel, SVsServiceProvider serviceProvider, ActionEnvironment environment, IProject project) { @@ -28,6 +29,7 @@ public ActionRunner(ICommunicationChannel channel, SVsServiceProvider servicePro _serviceProvider = serviceProvider; _environment = environment; _project = project; + _statusBar = new VsStatusBarWriter(serviceProvider); } public DateTime GetInitialFileTimestamp(string file) => diff --git a/VSRAD.Package/Server/CommunicationChannel.cs b/VSRAD.Package/Server/CommunicationChannel.cs index efe21769a..7c558bb24 100644 --- a/VSRAD.Package/Server/CommunicationChannel.cs +++ b/VSRAD.Package/Server/CommunicationChannel.cs @@ -15,6 +15,8 @@ using System.IO; using System.Text; using System.Text.RegularExpressions; +using VSRAD.Package.Utils; +using System.Windows; namespace VSRAD.Package.Server { @@ -54,6 +56,13 @@ public UnsupportedDebuggerVersionException(ServerConnectionOptions connection, V { } } + public sealed class UnaquiredLockException : System.IO.IOException + { + public UnaquiredLockException(ServerConnectionOptions connection) : + base($"Unable to acquire lock at {connection}") + { } + } + public enum ClientState { Disconnected, @@ -69,6 +78,12 @@ public enum HandShakeStatus server_not_accepted } + public enum LockStatus + { + lock_not_ackquired, + lock_acquired + } + [Export(typeof(ICommunicationChannel))] [AppliesTo(Constants.RadOrVisualCProjectCapability)] public sealed class CommunicationChannel : ICommunicationChannel @@ -88,10 +103,12 @@ public ClientState ConnectionState } } - private static readonly TimeSpan _connectionTimeout = new TimeSpan(hours: 0, minutes: 0, seconds: 5); + private static readonly TimeSpan _connectionTimeout = new TimeSpan(hours: 0, minutes: 0, seconds: 10); + private static readonly TimeSpan _lockTimeout = new TimeSpan(hours: 0, minutes: 0, seconds: 10); private static readonly Regex _extensionVersionRegex = new Regex(@".*\/(?.*)\/RadeonAsmDebugger\.dll", RegexOptions.Compiled); private readonly OutputWindowWriter _outputWindowWriter; + private readonly VsStatusBarWriter _statusBar; private readonly IProject _project; private TcpClient _connection; @@ -104,6 +121,7 @@ public CommunicationChannel(SVsServiceProvider provider, IProject project) { _outputWindowWriter = new OutputWindowWriter(provider, Constants.OutputPaneServerGuid, Constants.OutputPaneServerTitle); + _statusBar = new VsStatusBarWriter(provider); _project = project; _project.RunWhenLoaded((options) => options.PropertyChanged += (s, e) => { if (e.PropertyName == nameof(options.ActiveProfile)) ForceDisconnect(); }); @@ -114,7 +132,29 @@ public CommunicationChannel(SVsServiceProvider provider, IProject project) } public Task SendWithReplyAsync(ICommand command) where T : IResponse => - SendWithReplyAsync(command, tryReconnect: true); + SendWithReplyAsync(command, tryReconnect: false); + + private async Task TryAcquireServerLock(TcpClient client) { + StreamReader reader = new StreamReader(client.GetStream(), Encoding.UTF8); + await _statusBar.SetTextAsync("Acquiring lock on the server"); + using (var cts = new CancellationTokenSource(_lockTimeout)) + using (cts.Token.Register(() => client.Dispose())) + try + { + + if (await reader.ReadLineAsync() != LockStatus.lock_acquired.ToString()) + { + throw new UnaquiredLockException(ConnectionOptions); + } + await _statusBar.SetTextAsync("Acquired Lock"); + } + catch (Exception e) + { + MessageBox.Show("Unable to acquire lock, try again."); + return false; + } + return true; + } private async Task TryProcessClientHandshake(TcpClient client) { @@ -202,6 +242,7 @@ public async Task> GetRemoteEnvironmentAsync await EstablishServerConnectionAsync().ConfigureAwait(false); var environment = await SendWithReplyAsync(new ListEnvironmentVariables()); _remoteEnvironment = environment.Variables; + ForceDisconnect(); } return _remoteEnvironment; } @@ -223,14 +264,15 @@ private async Task EstablishServerConnectionAsync() var client = new TcpClient(); try { - using (var cts = new CancellationTokenSource(_connectionTimeout)) - using (cts.Token.Register(() => client.Dispose())) { + using (var cts = new CancellationTokenSource(_connectionTimeout)) + using (cts.Token.Register(() => client.Dispose())) await client.ConnectAsync(ConnectionOptions.RemoteMachine, ConnectionOptions.Port); await TryProcessClientHandshake(client); _connection = client; ConnectionState = ClientState.Connected; } + await TryAcquireServerLock(client); } catch (Exception e) when (!(e is UnsupportedDebuggerVersionException) && !(e is UnsupportedServerVersionException)) {