Skip to content

Commit

Permalink
[OPIK-595] Automation rule evaluator cache (#1042)
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagohora authored Jan 15, 2025
1 parent 2a5bca0 commit 270a8f2
Show file tree
Hide file tree
Showing 12 changed files with 331 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import jakarta.validation.constraints.NotNull;
import lombok.Builder;

import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record AutomationRuleEvaluatorUpdate(
@NotNull String name,
@NotNull AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode code,
@NotNull LlmAsJudgeCode code,
@NotNull Float samplingRate) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ int updateBaseRule(@Bind("id") UUID id,
@Bind("projectId") UUID projectId,
@Bind("workspaceId") String workspaceId,
@Bind("name") String name,
@Bind("samplingRate") float samplingRate,
@Bind("lastUpdatedBy") String lastUpdatedBy);
@Bind("samplingRate") float samplingRate);

@SqlUpdate("""
DELETE FROM automation_rules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import com.comet.opik.api.AutomationRuleEvaluatorUpdate;
import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.infrastructure.cache.CacheEvict;
import com.comet.opik.infrastructure.cache.Cacheable;
import com.google.inject.ImplementedBy;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
Expand Down Expand Up @@ -61,6 +63,7 @@ class AutomationRuleEvaluatorServiceImpl implements AutomationRuleEvaluatorServi
private final @NonNull TransactionTemplate template;

@Override
@CacheEvict(name = "automation_rule_evaluators_find_by_type", key = "$projectId +'-'+ $workspaceId +'-'+ $inputRuleEvaluator.type")
public <E, T extends AutomationRuleEvaluator<E>> T save(@NonNull T inputRuleEvaluator,
@NonNull UUID projectId,
@NonNull String workspaceId,
Expand Down Expand Up @@ -108,6 +111,7 @@ public <E, T extends AutomationRuleEvaluator<E>> T save(@NonNull T inputRuleEval
}

@Override
@CacheEvict(name = "automation_rule_evaluators_find_by_type", key = "$projectId +'-'+ $workspaceId +'-*'", keyUsesPatternMatching = true)
public void update(@NonNull UUID id, @NonNull UUID projectId, @NonNull String workspaceId,
@NonNull String userName, @NonNull AutomationRuleEvaluatorUpdate evaluatorUpdate) {

Expand All @@ -118,7 +122,7 @@ public void update(@NonNull UUID id, @NonNull UUID projectId, @NonNull String wo

try {
int resultBase = dao.updateBaseRule(id, projectId, workspaceId, evaluatorUpdate.name(),
evaluatorUpdate.samplingRate(), userName);
evaluatorUpdate.samplingRate());

var modelUpdate = LlmAsJudgeAutomationRuleEvaluatorModel.builder()
.code(AutomationModelEvaluatorMapper.INSTANCE.map(evaluatorUpdate.code()))
Expand Down Expand Up @@ -166,6 +170,7 @@ public <E, T extends AutomationRuleEvaluator<E>> T findById(@NonNull UUID id, @N
}

@Override
@CacheEvict(name = "automation_rule_evaluators_find_by_type", key = "$projectId +'-'+ $workspaceId +'-*'", keyUsesPatternMatching = true)
public void delete(@NonNull Set<UUID> ids, @NonNull UUID projectId, @NonNull String workspaceId) {
if (ids.isEmpty()) {
log.info("Delete AutomationRuleEvaluator: ids list is empty, returning");
Expand Down Expand Up @@ -223,6 +228,7 @@ public AutomationRuleEvaluatorPage find(@NonNull UUID projectId,
}

@Override
@Cacheable(name = "automation_rule_evaluators_find_by_type", key = "$projectId +'-'+ $workspaceId +'-'+ $type", returnType = AutomationRuleEvaluator.class, wrapperType = List.class)
public List<AutomationRuleEvaluatorLlmAsJudge> findAll(@NonNull UUID projectId, @NonNull String workspaceId,
@NonNull AutomationRuleEvaluatorType type) {
log.debug("Finding AutomationRuleEvaluators with type '{}' in projectId '{}' and workspaceId '{}'", type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,13 @@
* @return SpEL expression evaluated to generate the cache key.
* */
String key();

/**
* @return whether the key is a pattern or not. Default is false.
*
* @see <a href="https://redis.io/commands/KEYS">Redis KEYS command documentation</a>
*
* This is useful when you want to evict multiple keys that match a pattern.
* */
boolean keyUsesPatternMatching() default false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,50 +50,54 @@ public Object invoke(MethodInvocation invocation) throws Throwable {
var cacheable = method.getAnnotation(Cacheable.class);
if (cacheable != null) {
return runCacheAwareAction(invocation, isReactive, cacheable.name(), cacheable.key(),
(key, name) -> processCacheableMethod(invocation, isReactive, key, name, cacheable));
(key, group) -> processCacheableMethod(invocation, isReactive, key, group, cacheable));
}

var cachePut = method.getAnnotation(CachePut.class);
if (cachePut != null) {
return runCacheAwareAction(invocation, isReactive, cachePut.name(), cachePut.key(),
(key, name) -> processCachePutMethod(invocation, isReactive, key, name));
(key, group) -> processCachePutMethod(invocation, isReactive, key, group));
}

var cacheEvict = method.getAnnotation(CacheEvict.class);
if (cacheEvict != null) {
return runCacheAwareAction(invocation, isReactive, cacheEvict.name(), cacheEvict.key(),
(key, name) -> processCacheEvictMethod(invocation, isReactive, key));
(key, group) -> processCacheEvictMethod(invocation, isReactive, key, cacheEvict));
}

return invocation.proceed();
}

private Object runCacheAwareAction(MethodInvocation invocation, boolean isReactive, String name, String keyAgs,
private Object runCacheAwareAction(MethodInvocation invocation, boolean isReactive, String group, String keyAgs,
BiFunction<String, String, Object> action) throws Throwable {

String key;

try {
key = getKeyName(name, keyAgs, invocation);
key = getKeyName(group, keyAgs, invocation);
} catch (Exception e) {
// If there is an error evaluating the key, proceed without caching
log.error("Error evaluating key expression: {}", keyAgs, e);
log.warn("Cache will be skipped due to error evaluating key expression");
return invocation.proceed();
}

if (isReactive) {
return action.apply(key, name);
return action.apply(key, group);
}

return ((Mono<?>) action.apply(key, name)).block();
return ((Mono<?>) action.apply(key, group)).block();
}

private Mono<Object> processCacheEvictMethod(MethodInvocation invocation, boolean isReactive, String key) {
private Mono<Object> processCacheEvictMethod(MethodInvocation invocation, boolean isReactive, String key,
CacheEvict cacheEvict) {
if (isReactive) {
try {
return ((Mono<?>) invocation.proceed())
.flatMap(value -> cacheManager.get().evict(key).thenReturn(value))
.switchIfEmpty(cacheManager.get().evict(key).then(Mono.empty()))
.flatMap(value -> cacheManager.get().evict(key, cacheEvict.keyUsesPatternMatching())
.thenReturn(value))
.switchIfEmpty(
cacheManager.get().evict(key, cacheEvict.keyUsesPatternMatching()).then(Mono.empty()))
.map(Function.identity());
} catch (Throwable e) {
return Mono.error(e);
Expand All @@ -102,42 +106,42 @@ private Mono<Object> processCacheEvictMethod(MethodInvocation invocation, boolea
try {
var value = invocation.proceed();
if (value == null) {
return cacheManager.get().evict(key).then(Mono.empty());
return cacheManager.get().evict(key, cacheEvict.keyUsesPatternMatching()).then(Mono.empty());
}
return cacheManager.get().evict(key).thenReturn(value);
return cacheManager.get().evict(key, cacheEvict.keyUsesPatternMatching()).thenReturn(value);
} catch (Throwable e) {
return Mono.error(e);
}
}
}

private Mono<Object> processCachePutMethod(MethodInvocation invocation, boolean isReactive, String key,
String name) {
String group) {
if (isReactive) {
try {
return ((Mono<?>) invocation.proceed()).flatMap(value -> cachePut(value, key, name));
return ((Mono<?>) invocation.proceed()).flatMap(value -> cachePut(value, key, group));
} catch (Throwable e) {
return Mono.error(e);
}
} else {
try {
var value = invocation.proceed();
return cachePut(value, key, name).thenReturn(value);
return cachePut(value, key, group).thenReturn(value);
} catch (Throwable e) {
return Mono.error(e);
}
}
}

private Object processCacheableMethod(MethodInvocation invocation, boolean isReactive, String key,
String name, Cacheable cacheable) {
String group, Cacheable cacheable) {

if (isReactive) {

if (invocation.getMethod().getReturnType().isAssignableFrom(Mono.class)) {
return handleMono(invocation, key, name, cacheable);
return handleMono(invocation, key, group, cacheable);
} else {
return handleFlux(invocation, key, name, cacheable);
return handleFlux(invocation, key, group, cacheable);
}
} else {

Expand All @@ -146,16 +150,16 @@ private Object processCacheableMethod(MethodInvocation invocation, boolean isRea
cacheable.returnType());

return cacheManager.get().get(key, typeReference)
.switchIfEmpty(processSyncCacheMiss(invocation, key, name));
.switchIfEmpty(processSyncCacheMiss(invocation, key, group));
}

return cacheManager.get().get(key, invocation.getMethod().getReturnType())
.map(Object.class::cast)
.switchIfEmpty(processSyncCacheMiss(invocation, key, name));
.switchIfEmpty(processSyncCacheMiss(invocation, key, group));
}
}

private Flux<Object> handleFlux(MethodInvocation invocation, String key, String name, Cacheable cacheable) {
private Flux<Object> handleFlux(MethodInvocation invocation, String key, String group, Cacheable cacheable) {
if (cacheable.wrapperType() != Object.class) {
TypeReference typeReference = TypeReferenceUtils.forTypes(cacheable.wrapperType(),
cacheable.returnType());
Expand All @@ -168,7 +172,7 @@ public Type getType() {
}
};

return getFromCacheOrCallMethod(invocation, key, name, collectionType);
return getFromCacheOrCallMethod(invocation, key, group, collectionType);
}

TypeReference<List<?>> collectionType = new TypeReference<>() {
Expand All @@ -178,61 +182,61 @@ public Type getType() {
}
};

return getFromCacheOrCallMethod(invocation, key, name, collectionType);
return getFromCacheOrCallMethod(invocation, key, group, collectionType);
}

private Flux<Object> getFromCacheOrCallMethod(MethodInvocation invocation, String key, String name,
private Flux<Object> getFromCacheOrCallMethod(MethodInvocation invocation, String key, String group,
TypeReference<List<?>> collectionType) {
return cacheManager.get()
.get(key, collectionType)
.map(Collection.class::cast)
.flatMapMany(Flux::fromIterable)
.switchIfEmpty(processFluxCacheMiss(invocation, key, name));
.switchIfEmpty(processFluxCacheMiss(invocation, key, group));
}

private Mono<Object> handleMono(MethodInvocation invocation, String key, String name, Cacheable cacheable) {
private Mono<Object> handleMono(MethodInvocation invocation, String key, String group, Cacheable cacheable) {
if (cacheable.wrapperType() != Object.class) {
TypeReference typeReference = TypeReferenceUtils.forTypes(cacheable.wrapperType(),
cacheable.returnType());

return cacheManager.get().get(key, typeReference)
.switchIfEmpty(processCacheMiss(invocation, key, name));
.switchIfEmpty(processCacheMiss(invocation, key, group));
}

return cacheManager.get().get(key, cacheable.returnType())
.map(Object.class::cast)
.switchIfEmpty(processCacheMiss(invocation, key, name));
.switchIfEmpty(processCacheMiss(invocation, key, group));
}

private Mono<Object> processSyncCacheMiss(MethodInvocation invocation, String key, String name) {
private Mono<Object> processSyncCacheMiss(MethodInvocation invocation, String key, String group) {
return Mono.defer(() -> {
try {
return Mono.just(invocation.proceed());
} catch (Throwable e) {
return Mono.error(e);
}
}).flatMap(value -> cachePut(value, key, name));
}).flatMap(value -> cachePut(value, key, group));
}

private Mono<Object> processCacheMiss(MethodInvocation invocation, String key, String name) {
private Mono<Object> processCacheMiss(MethodInvocation invocation, String key, String group) {
return Mono.defer(() -> {
try {
return ((Mono<?>) invocation.proceed())
.flatMap(value -> cachePut(value, key, name));
.flatMap(value -> cachePut(value, key, group));
} catch (Throwable e) {
return Mono.error(e);
}
});
}

private Flux<Object> processFluxCacheMiss(MethodInvocation invocation, String key, String name) {
private Flux<Object> processFluxCacheMiss(MethodInvocation invocation, String key, String group) {
return Flux.defer(() -> {
try {
Flux<Object> flux = (Flux<Object>) invocation.proceed();

var cacheable = flux.cache()
.collectList()
.flatMap(value -> cachePut(value, key, name));
.flatMap(value -> cachePut(value, key, group));

return flux
.doOnSubscribe(subscription -> Schedulers.boundedElastic().schedule(() -> {
Expand All @@ -246,8 +250,8 @@ private Flux<Object> processFluxCacheMiss(MethodInvocation invocation, String ke
});
}

private Mono<Object> cachePut(Object value, String key, String name) {
Duration ttlDuration = cacheConfiguration.getCaches().getOrDefault(name,
private Mono<Object> cachePut(Object value, String key, String group) {
Duration ttlDuration = cacheConfiguration.getCaches().getOrDefault(group,
cacheConfiguration.getDefaultDuration());
return cacheManager.get().put(key, value, ttlDuration)
.thenReturn(value)
Expand All @@ -270,17 +274,12 @@ private String getKeyName(String name, String key, MethodInvocation invocation)
params.put("$" + parameters[i].getName(), value != null ? value : ""); // Null safety
}

try {
String evaluatedKey = Objects.requireNonNull(MVEL.evalToString(key, params),
"Key expression cannot return be null");
if (evaluatedKey.isEmpty() || evaluatedKey.equals("null")) {
throw new IllegalArgumentException("Key expression cannot return an empty string");
}
return "%s:-%s".formatted(name, evaluatedKey);
} catch (Exception e) {
log.error("Error evaluating key expression: {}", key, e);
throw new IllegalArgumentException("Error evaluating key expression: " + key);
String evaluatedKey = Objects.requireNonNull(MVEL.evalToString(key, params),
"Key expression cannot return be null");
if (evaluatedKey.isEmpty() || evaluatedKey.equals("null")) {
throw new IllegalArgumentException("Key expression cannot return an empty string");
}
return "%s:-%s".formatted(name, evaluatedKey);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

public interface CacheManager {

Mono<Boolean> evict(@NonNull String key);
Mono<Boolean> evict(@NonNull String key, boolean usePatternMatching);
Mono<Boolean> put(@NonNull String key, @NonNull Object value, @NonNull Duration ttlDuration);
<T> Mono<T> get(@NonNull String key, @NonNull Class<T> clazz);
<T> Mono<T> get(@NonNull String key, @NonNull TypeReference<T> clazz);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ class RedisCacheManager implements CacheManager {

private final @NonNull RedissonReactiveClient redisClient;

public Mono<Boolean> evict(@NonNull String key) {
public Mono<Boolean> evict(@NonNull String key, boolean usePatternMatching) {
if (usePatternMatching) {
return redisClient.getKeys().deleteByPattern(key)
.map(count -> count > 0);
}
return redisClient.getBucket(key).delete();
}

Expand Down
Loading

0 comments on commit 270a8f2

Please sign in to comment.