Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

default session saving completed & tested #156

Merged
merged 17 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Core Modules/WalletConnectSharp.Crypto/KeyChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ private async Task<Dictionary<string, string>> GetKeyChain()

private async Task SaveKeyChain()
{
await Storage.SetItem(StorageKey, this._keyChain);
// We need to copy the contents, otherwise Dispose()
// may clear the reference stored inside InMemoryStorage
await Storage.SetItem(StorageKey, new Dictionary<string, string>(this._keyChain));
}

public void Dispose()
Expand Down
14 changes: 14 additions & 0 deletions Core Modules/WalletConnectSharp.Network/JsonRpcProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,19 @@ protected void RegisterEventListeners()
_hasRegisteredEventListeners = true;
}

protected void UnregisterEventListeners()
{
if (!_hasRegisteredEventListeners) return;

WCLogger.Log(
$"[JsonRpcProvider] Unregistering event listeners on connection object with context {_connection.ToString()} inside {Context}");
_connection.PayloadReceived -= OnPayload;
_connection.Closed -= OnConnectionDisconnected;
_connection.ErrorReceived -= OnConnectionError;

_hasRegisteredEventListeners = false;
}

private void OnConnectionError(object sender, Exception e)
{
this.ErrorReceived?.Invoke(this, e);
Expand Down Expand Up @@ -313,6 +326,7 @@ protected virtual void Dispose(bool disposing)

if (disposing)
{
UnregisterEventListeners();
_connection?.Dispose();
}

Expand Down
71 changes: 53 additions & 18 deletions Core Modules/WalletConnectSharp.Storage/FileSystemStorage.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections.Concurrent;
using System.Text;
using Newtonsoft.Json;
using WalletConnectSharp.Common.Logging;
Expand Down Expand Up @@ -41,7 +42,11 @@ public FileSystemStorage(string filePath = null)
/// <returns></returns>
public override async Task Init()
{
if (Initialized)
return;

_semaphoreSlim = new SemaphoreSlim(1, 1);

await Task.WhenAll(
Load(), base.Init()
);
Expand Down Expand Up @@ -89,38 +94,68 @@ private async Task Save()
Directory.CreateDirectory(path);
}

var json = JsonConvert.SerializeObject(Entries,
string json;
json = JsonConvert.SerializeObject(Entries,
new JsonSerializerSettings() { TypeNameHandling = TypeNameHandling.All });

await _semaphoreSlim.WaitAsync();
await File.WriteAllTextAsync(FilePath, json, Encoding.UTF8);
_semaphoreSlim.Release();
try
{
if (!Disposed)
await _semaphoreSlim.WaitAsync();
int count = 5;
IOException lastException;
do
{
try
{
await File.WriteAllTextAsync(FilePath, json, Encoding.UTF8);
return;
}
catch (IOException e)
{
WCLogger.LogError($"Got error saving storage file: retries left {count}");
await Task.Delay(100);
count--;
lastException = e;
}
} while (count > 0);

throw lastException;
}
finally
{
if (!Disposed)
_semaphoreSlim.Release();
}
}

private async Task Load()
{
if (!File.Exists(FilePath))
return;

await _semaphoreSlim.WaitAsync();
var json = await File.ReadAllTextAsync(FilePath, Encoding.UTF8);
_semaphoreSlim.Release();
string json;
try
{
await _semaphoreSlim.WaitAsync();
json = await File.ReadAllTextAsync(FilePath, Encoding.UTF8);
}
finally
{
_semaphoreSlim.Release();
}

// Hard fail here if the storage file is bad, unless it's serialized as a Dictionary (for backwards compatibility)
var jsonSerializerSettings = new JsonSerializerSettings { TypeNameHandling = TypeNameHandling.Auto };
try
{
Entries = JsonConvert.DeserializeObject<Dictionary<string, object>>(json,
new JsonSerializerSettings() { TypeNameHandling = TypeNameHandling.Auto });
Entries = JsonConvert.DeserializeObject<ConcurrentDictionary<string, object>>(json,
jsonSerializerSettings);
}
catch (JsonSerializationException e)
catch (JsonSerializationException)
{
// Move the file to a .unsupported file
// and start fresh
WCLogger.LogError(e);
WCLogger.LogError("Cannot load JSON file, moving data to .unsupported file to force continue");
if (File.Exists(FilePath + ".unsupported"))
File.Move(FilePath + ".unsupported", FilePath + "." + Guid.NewGuid() + ".unsupported");
File.Move(FilePath, FilePath + ".unsupported");
Entries = new Dictionary<string, object>();
var dict = JsonConvert.DeserializeObject<Dictionary<string, object>>(json, jsonSerializerSettings);
Entries = new ConcurrentDictionary<string, object>(dict);
}
}

Expand Down
18 changes: 13 additions & 5 deletions Core Modules/WalletConnectSharp.Storage/InMemoryStorage.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
using System.Collections.Concurrent;
using WalletConnectSharp.Common.Model.Errors;
using WalletConnectSharp.Storage.Interfaces;

namespace WalletConnectSharp.Storage
{
public class InMemoryStorage : IKeyValueStorage
{
protected Dictionary<string, object> Entries = new Dictionary<string, object>();
private bool _initialized = false;
protected ConcurrentDictionary<string, object> Entries = new ConcurrentDictionary<string, object>();
protected bool Initialized = false;
protected bool Disposed;

public virtual Task Init()
{
_initialized = true;
if (Initialized)
return Task.CompletedTask;

Initialized = true;
return Task.CompletedTask;
}

Expand All @@ -24,6 +28,7 @@ public virtual Task<string[]> GetKeys()
public virtual async Task<T[]> GetEntriesOfType<T>()
{
IsInitialized();
// GetEntries is thread-safe
return (await GetEntries()).OfType<T>().ToArray();
}

Expand All @@ -43,13 +48,15 @@ public virtual Task SetItem<T>(string key, T value)
{
IsInitialized();
Entries[key] = value;

return Task.CompletedTask;
}

public virtual Task RemoveItem(string key)
{
IsInitialized();
Entries.Remove(key);
Entries.Remove(key, out _);

return Task.CompletedTask;
}

Expand All @@ -63,12 +70,13 @@ public virtual Task Clear()
{
IsInitialized();
Entries.Clear();

return Task.CompletedTask;
}

protected void IsInitialized()
{
if (!_initialized)
if (!Initialized)
{
throw WalletConnectException.FromType(ErrorType.NOT_INITIALIZED, "Storage");
}
Expand Down
49 changes: 26 additions & 23 deletions Tests/WalletConnectSharp.Auth.Tests/AuthClientTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using WalletConnectSharp.Storage;
using WalletConnectSharp.Tests.Common;
using Xunit;
using Xunit.Abstractions;
using ErrorResponse = WalletConnectSharp.Auth.Models.ErrorResponse;

namespace WalletConnectSharp.Auth.Tests
Expand All @@ -26,6 +27,7 @@ public class AuthClientTests : IClassFixture<CryptoWalletFixture>, IAsyncLifetim
};

private readonly CryptoWalletFixture _cryptoWalletFixture;
private readonly ITestOutputHelper _testOutputHelper;

private IAuthClient PeerA;
public IAuthClient PeerB;
Expand Down Expand Up @@ -54,13 +56,14 @@ public string WalletAddress
}
}

public AuthClientTests(CryptoWalletFixture cryptoFixture)
public AuthClientTests(CryptoWalletFixture cryptoFixture, ITestOutputHelper testOutputHelper)
{
this._cryptoWalletFixture = cryptoFixture;
_testOutputHelper = testOutputHelper;
}

[Fact, Trait("Category", "unit")]
public async void TestInit()
public async Task TestInit()
{
Assert.NotNull(PeerA);
Assert.NotNull(PeerB);
Expand All @@ -77,7 +80,7 @@ public async void TestInit()
}

[Fact, Trait("Category", "unit")]
public async void TestPairs()
public async Task TestPairs()
{
var ogPairSize = PeerA.Core.Pairing.Pairings.Length;

Expand Down Expand Up @@ -110,7 +113,7 @@ public async void TestPairs()
}

[Fact, Trait("Category", "unit")]
public async void TestKnownPairings()
public async Task TestKnownPairings()
{
var ogSizeA = PeerA.Core.Pairing.Pairings.Length;
var history = await PeerA.AuthHistory();
Expand All @@ -121,7 +124,7 @@ public async void TestKnownPairings()
var ogHistorySizeB = historyB.Keys.Length;

List<TopicMessage> responses = new List<TopicMessage>();
TaskCompletionSource<TopicMessage> responseTask = new TaskCompletionSource<TopicMessage>();
TaskCompletionSource<TopicMessage> knownPairingTask = new TaskCompletionSource<TopicMessage>();

async void OnPeerBOnAuthRequested(object sender, AuthRequest request)
{
Expand All @@ -145,9 +148,9 @@ void OnPeerAOnAuthResponded(object sender, AuthResponse args)
var sessionTopic = args.Topic;
var cacao = args.Response.Result;
var signature = cacao.Signature;
Console.WriteLine($"{sessionTopic}: {signature}");
_testOutputHelper.WriteLine($"{sessionTopic}: {signature}");
responses.Add(args);
responseTask.SetResult(args);
knownPairingTask.SetResult(args);
}

PeerA.AuthResponded += OnPeerAOnAuthResponded;
Expand All @@ -156,9 +159,9 @@ void OnPeerAOnAuthError(object sender, AuthErrorResponse args)
{
var sessionTopic = args.Topic;
var error = args.Error;
Console.WriteLine($"{sessionTopic}: {error}");
_testOutputHelper.WriteLine($"{sessionTopic}: {error}");
responses.Add(args);
responseTask.SetResult(args);
knownPairingTask.SetResult(args);
}

PeerA.AuthError += OnPeerAOnAuthError;
Expand All @@ -167,18 +170,18 @@ void OnPeerAOnAuthError(object sender, AuthErrorResponse args)

await PeerB.Core.Pairing.Pair(requestData.Uri);

await responseTask.Task;

await knownPairingTask.Task;
// Reset
responseTask = new TaskCompletionSource<TopicMessage>();
knownPairingTask = new TaskCompletionSource<TopicMessage>();

// Get last pairing, that is the one we just made
var knownPairing = PeerA.Core.Pairing.Pairings[^1];

var requestData2 = await PeerA.Request(DefaultRequestParams, knownPairing.Topic);

await responseTask.Task;

await knownPairingTask.Task;
Assert.Null(requestData2.Uri);

Assert.Equal(ogSizeA + 1, PeerA.Core.Pairing.Pairings.Length);
Expand All @@ -195,7 +198,7 @@ void OnPeerAOnAuthError(object sender, AuthErrorResponse args)
}

[Fact, Trait("Category", "unit")]
public async void HandlesAuthRequests()
public async Task HandlesAuthRequests()
{
var ogSize = PeerB.Requests.Length;

Expand All @@ -218,7 +221,7 @@ public async void HandlesAuthRequests()
}

[Fact, Trait("Category", "unit")]
public async void TestErrorResponses()
public async Task TestErrorResponses()
{
var ogPSize = PeerA.Core.Pairing.Pairings.Length;

Expand Down Expand Up @@ -263,7 +266,7 @@ void OnPeerAOnAuthResponded(object sender, AuthResponse response)
}

[Fact, Trait("Category", "unit")]
public async void HandlesSuccessfulResponse()
public async Task HandlesSuccessfulResponse()
{
var ogPSize = PeerA.Core.Pairing.Pairings.Length;

Expand Down Expand Up @@ -313,7 +316,7 @@ void OnPeerAOnAuthResponded(object sender, AuthResponse response) =>
}

[Fact, Trait("Category", "unit")]
public async void TestCustomRequestExpiry()
public async Task TestCustomRequestExpiry()
{
var uri = "";
var expiry = 1000;
Expand Down Expand Up @@ -360,7 +363,7 @@ await PeerB.Respond(
}

[Fact, Trait("Category", "unit")]
public async void TestGetPendingPairings()
public async Task TestGetPendingPairings()
{
var ogCount = PeerB.PendingRequests.Count;

Expand All @@ -386,7 +389,7 @@ public async void TestGetPendingPairings()
}

[Fact, Trait("Category", "unit")]
public async void TestGetPairings()
public async Task TestGetPairings()
{
var peerAOgSize = PeerA.Core.Pairing.Pairings.Length;
var peerBOgSize = PeerB.Core.Pairing.Pairings.Length;
Expand Down Expand Up @@ -414,7 +417,7 @@ public async void TestGetPairings()
}

[Fact, Trait("Category", "unit")]
public async void TestPing()
public async Task TestPing()
{
TaskCompletionSource<bool> receivedAuthRequest = new TaskCompletionSource<bool>();
TaskCompletionSource<bool> receivedClientPing = new TaskCompletionSource<bool>();
Expand Down Expand Up @@ -453,7 +456,7 @@ public async void TestPing()
}

[Fact, Trait("Category", "unit")]
public async void TestDisconnectedPairing()
public async Task TestDisconnectedPairing()
{
var peerAOgSize = PeerA.Core.Pairing.Pairings.Length;
var peerBOgSize = PeerB.Core.Pairing.Pairings.Length;
Expand Down Expand Up @@ -493,7 +496,7 @@ public async void TestDisconnectedPairing()
}

[Fact, Trait("Category", "unit")]
public async void TestReceivesMetadata()
public async Task TestReceivesMetadata()
{
var receivedMetadataName = "";
var ogPairingSize = PeerA.Core.Pairing.Pairings.Length;
Expand Down
Loading
Loading