Skip to content

Commit

Permalink
feat: adding the ability to update ttl with update (#472)
Browse files Browse the repository at this point in the history
  • Loading branch information
slorello89 authored Jul 29, 2024
1 parent 4890559 commit 893834a
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 61 deletions.
16 changes: 14 additions & 2 deletions src/Redis.OM/RedisCommands.cs
Original file line number Diff line number Diff line change
Expand Up @@ -774,8 +774,9 @@ public static async Task<IDictionary<string, RedisReply>> HGetAllAsync(this IRed
/// <param name="key">The key.</param>
/// <param name="value">The value.</param>
/// <param name="storageType">The storage type of the value.</param>
/// <param name="ttl">The ttl for the key.</param>
/// <typeparam name="T">The type of the value.</typeparam>
internal static void UnlinkAndSet<T>(this IRedisConnection connection, string key, T value, StorageType storageType)
internal static void UnlinkAndSet<T>(this IRedisConnection connection, string key, T value, StorageType storageType, TimeSpan? ttl)
{
_ = value ?? throw new ArgumentNullException(nameof(value));
if (storageType == StorageType.Json)
Expand All @@ -791,6 +792,11 @@ internal static void UnlinkAndSet<T>(this IRedisConnection connection, string ke
{
args.Add(pair.Key);
args.Add(pair.Value);
if (ttl is not null)
{
args.Add("EXPIRE");
args.Add(ttl.Value.TotalMilliseconds.ToString(CultureInfo.InvariantCulture));
}
}

connection.CreateAndEval(nameof(Scripts.UnlinkAndSetHash), new[] { key }, args.ToArray());
Expand All @@ -804,9 +810,10 @@ internal static void UnlinkAndSet<T>(this IRedisConnection connection, string ke
/// <param name="key">The key.</param>
/// <param name="value">The value.</param>
/// <param name="storageType">The storage type of the value.</param>
/// <param name="ttl">The time to live for the key.</param>
/// <typeparam name="T">The type of the value.</typeparam>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
internal static async Task UnlinkAndSetAsync<T>(this IRedisConnection connection, string key, T value, StorageType storageType)
internal static async Task UnlinkAndSetAsync<T>(this IRedisConnection connection, string key, T value, StorageType storageType, TimeSpan? ttl)
{
_ = value ?? throw new ArgumentNullException(nameof(value));
if (storageType == StorageType.Json)
Expand All @@ -822,6 +829,11 @@ internal static async Task UnlinkAndSetAsync<T>(this IRedisConnection connection
{
args.Add(pair.Key);
args.Add(pair.Value);
if (ttl is not null)
{
args.Add("EXPIRE");
args.Add(ttl.Value.TotalMilliseconds.ToString(CultureInfo.InvariantCulture));
}
}

await connection.CreateAndEvalAsync(nameof(Scripts.UnlinkAndSetHash), new[] { key }, args.ToArray());
Expand Down
38 changes: 33 additions & 5 deletions src/Redis.OM/Scripts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ internal class Scripts
if index>=0 then
redis.call('JSON.ARRPOP', key, ARGV[i+1], index)
end
elseif 'EXPIRE' == ARGV[i] then
redis.call('PEXPIRE', key, tonumber(ARGV[i+1]))
else
if 'DEL' == ARGV[i] then
redis.call('JSON.DEL',key,ARGV[i+1])
Expand All @@ -38,6 +40,7 @@ internal class Scripts
local num_fields_to_set = ARGV[1]
local end_index = num_fields_to_set*2+1
local args = {}
local expire_time = -1
for i=2, end_index, 2 do
args[i-1] = ARGV[i]
args[i] = ARGV[i+1]
Expand All @@ -49,9 +52,19 @@ internal class Scripts
local second_op
args = {}
for i = end_index+1, num_args, 1 do
args[i-end_index] = ARGV[i]
if ARGV[i] == 'EXPIRE' then
expire_time = tonumber(ARGV[i+1])
else
args[i-end_index] = ARGV[i]
end
end
if table.getn(args) > 0 then
redis.call('HDEL',key,unpack(args))
end
redis.call('HDEL',key,unpack(args))
end
if expire_time > -1 then
redis.call('PEXPIRE', key, expire_time)
end
";

Expand All @@ -69,19 +82,34 @@ local second_op
local num_fields = ARGV[1]
local end_index = num_fields * 2 + 1
local args = {}
local expire_time = -1
for i = 2, end_index, 2 do
args[i-1] = ARGV[i]
args[i] = ARGV[i+1]
if ARGV[i] == 'EXPIRE' then
expire_time = tonumber(ARGV[i+1])
else
args[i-1] = ARGV[i]
args[i] = ARGV[i+1]
end
end
redis.call('HSET',KEYS[1],unpack(args))
if expire_time > -1 then
redis.call('PEXPIRE', KEYS[1], expire_time)
end
return 0
";

/// <summary>
/// Unlinks a JSON object and sets the key again with a fresh new JSON object.
/// </summary>
internal const string UnlinkAndSendJson = @"
local expiry = tonumber(redis.call('PTTL', KEYS[1]))
local num_args = table.getn(ARGV)
local expiry = -1
if num_args > 1 and 'EXPIRE' == ARGV[2] then
expiry = tonumber(ARGV[3])
else
expiry = tonumber(redis.call('PTTL', KEYS[1]))
end
redis.call('UNLINK', KEYS[1])
redis.call('JSON.SET', KEYS[1], '.', ARGV[1])
if expiry > 0 then
Expand Down
23 changes: 23 additions & 0 deletions src/Redis.OM/Searching/IRedisCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,29 @@ public interface IRedisCollection<T> : IOrderedQueryable<T>, IAsyncEnumerable<T>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
ValueTask UpdateAsync(IEnumerable<T> items);

/// <summary>
/// Updates the provided item in Redis. Document must have a property marked with the <see cref="RedisIdFieldAttribute"/>.
/// </summary>
/// <param name="item">The item to update.</param>
/// <param name="ttl">The updated ttl for the record.</param>
void Update(T item, TimeSpan ttl);

/// <summary>
/// Updates the provided item in Redis. Document must have a property marked with the <see cref="RedisIdFieldAttribute"/>.
/// </summary>
/// <param name="item">The item to update.</param>
/// <param name="ttl">The updated ttl for the record.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
Task UpdateAsync(T item, TimeSpan ttl);

/// <summary>
/// Updates the provided items in Redis. Document must have a property marked with the <see cref="RedisIdFieldAttribute"/>.
/// </summary>
/// <param name="items">The items to update.</param>
/// <param name="ttl">The updated ttl for the record.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
ValueTask UpdateAsync(IEnumerable<T> items, TimeSpan ttl);

/// <summary>
/// Deletes the item from Redis.
/// </summary>
Expand Down
179 changes: 125 additions & 54 deletions src/Redis.OM/Searching/RedisCollection.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand Down Expand Up @@ -144,69 +145,37 @@ public bool Any(Expression<Func<T, bool>> expression)
/// <inheritdoc />
public void Update(T item)
{
var key = item.GetKey();
IList<IObjectDiff>? diff;
var diffConstructed = StateManager.TryDetectDifferencesSingle(key, item, out diff);
if (diffConstructed)
{
if (diff!.Any())
{
var args = new List<string>();
var scriptName = diff!.First().Script;
foreach (var update in diff!)
{
args.AddRange(update.SerializeScriptArgs());
}

_connection.CreateAndEval(scriptName, new[] { key }, args.ToArray());
}
}
else
{
_connection.UnlinkAndSet(key, item, StateManager.DocumentAttribute.StorageType);
}

SaveToStateManager(key, item);
SendUpdate(item);
}

/// <inheritdoc />
public async Task UpdateAsync(T item)
public Task UpdateAsync(T item)
{
var key = item.GetKey();
IList<IObjectDiff>? diff;
var diffConstructed = StateManager.TryDetectDifferencesSingle(key, item, out diff);
if (diffConstructed)
{
if (diff!.Any())
{
var args = new List<string>();
var scriptName = diff!.First().Script;
foreach (var update in diff!)
{
args.AddRange(update.SerializeScriptArgs());
}
return SendUpdateAsync(item);
}

await _connection.CreateAndEvalAsync(scriptName, new[] { key }, args.ToArray());
}
}
else
{
await _connection.UnlinkAndSetAsync(key, item, StateManager.DocumentAttribute.StorageType);
}
/// <inheritdoc />
public ValueTask UpdateAsync(IEnumerable<T> items)
{
return SendUpdateAsync(items);
}

SaveToStateManager(key, item);
/// <inheritdoc />
public void Update(T item, TimeSpan ttl)
{
SendUpdate(item, ttl);
}

/// <inheritdoc />
public async ValueTask UpdateAsync(IEnumerable<T> items)
public Task UpdateAsync(T item, TimeSpan ttl)
{
var tasks = items.Select(UpdateAsyncNoSave);
return SendUpdateAsync(item, ttl);
}

await Task.WhenAll(tasks);
foreach (var kvp in tasks.Select(x => x.Result))
{
SaveToStateManager(kvp.Key, kvp.Value);
}
/// <inheritdoc />
public ValueTask UpdateAsync(IEnumerable<T> items, TimeSpan ttl)
{
return SendUpdateAsync(items, ttl);
}

/// <inheritdoc />
Expand Down Expand Up @@ -774,7 +743,7 @@ private static MethodInfo GetMethodInfo<T1, T2>(Func<T1, T2> f, T1 unused)
return _connection.GetAsync<T>(key).AsTask();
}

private async Task<KeyValuePair<string, T>> UpdateAsyncNoSave(T item)
private async Task<KeyValuePair<string, T>> UpdateAsyncNoSave(T item, TimeSpan? ttl)
{
var key = item.GetKey();
IList<IObjectDiff>? diff;
Expand All @@ -790,12 +759,22 @@ private async Task<KeyValuePair<string, T>> UpdateAsyncNoSave(T item)
args.AddRange(update.SerializeScriptArgs());
}

if (ttl is not null)
{
args.Add("EXPIRE");
args.Add(ttl.Value.TotalMilliseconds.ToString(CultureInfo.InvariantCulture));
}

await _connection.CreateAndEvalAsync(scriptName, new[] { key }, args.ToArray());
}
else if (ttl is not null)
{
await _connection.ExecuteAsync("PEXPIRE", key, ttl.Value.TotalMilliseconds);
}
}
else
{
await _connection.UnlinkAndSetAsync(key, item, StateManager.DocumentAttribute.StorageType);
await _connection.UnlinkAndSetAsync(key, item, StateManager.DocumentAttribute.StorageType, ttl);
}

return new KeyValuePair<string, T>(key, item);
Expand Down Expand Up @@ -831,5 +810,97 @@ private void SaveToStateManager(string key, object value)
}
}
}

private void SendUpdate(T item, TimeSpan? ttl = null)
{
var key = item.GetKey();
IList<IObjectDiff>? diff;
var diffConstructed = StateManager.TryDetectDifferencesSingle(key, item, out diff);
if (diffConstructed)
{
if (diff!.Any())
{
var args = new List<string>();
var scriptName = diff!.First().Script;
foreach (var update in diff!)
{
args.AddRange(update.SerializeScriptArgs());
}

if (ttl is not null)
{
args.Add("EXPIRE");
args.Add(ttl.Value.TotalMilliseconds.ToString(CultureInfo.InvariantCulture));
}

_connection.CreateAndEval(scriptName, new[] { key }, args.ToArray());
}
else if (ttl is not null)
{
_connection.Execute("PEXPIRE", key, ttl.Value.TotalMilliseconds);
}
}
else
{
_connection.UnlinkAndSet(key, item, StateManager.DocumentAttribute.StorageType, ttl);
}

SaveToStateManager(key, item);
}

private Task SendUpdateAsync(T item, TimeSpan? ttl = null)
{
var key = item.GetKey();
IList<IObjectDiff>? diff;
var diffConstructed = StateManager.TryDetectDifferencesSingle(key, item, out diff);
Task? task = null;
if (diffConstructed)
{
if (diff!.Any())
{
var args = new List<string>();
var scriptName = diff!.First().Script;
foreach (var update in diff!)
{
args.AddRange(update.SerializeScriptArgs());
}

if (ttl is not null)
{
args.Add("EXPIRE");
args.Add(ttl.Value.TotalMilliseconds.ToString(CultureInfo.InvariantCulture));
}

task = _connection.CreateAndEvalAsync(scriptName, new[] { key }, args.ToArray());
}
else if (ttl is not null)
{
task = _connection.ExecuteAsync("PEXPIRE", key, ttl.Value.TotalMilliseconds);
}
}
else
{
task = _connection.UnlinkAndSetAsync(key, item, StateManager.DocumentAttribute.StorageType, ttl);
}

SaveToStateManager(key, item);
if (task is null)
{
return Task.CompletedTask;
}

return task;
}

private async ValueTask SendUpdateAsync(IEnumerable<T> items, TimeSpan? ttl = null)
{
var tasks = items.Select(x => UpdateAsyncNoSave(x, ttl)).ToArray();

await Task.WhenAll(tasks);
foreach (var kvp in tasks.Select(x => x.Result))
{
SaveToStateManager(kvp.Key, kvp.Value);
}
}
}
}
Loading

0 comments on commit 893834a

Please sign in to comment.