Skip to content

Commit

Permalink
Merge pull request #221 from ranking-agent/lookup_cache
Browse files Browse the repository at this point in the history
Add caching for lookups
  • Loading branch information
uhbrar authored Dec 11, 2023
2 parents 04b5629 + 7da7617 commit 8d2cc3b
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 16 deletions.
2 changes: 1 addition & 1 deletion openapi-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ servers:
# url: http://127.0.0.1:5000
termsOfService: http://robokop.renci.org:7055/tos?service_long=ARAGORN&provider_long=RENCI
title: ARAGORN
version: 2.5.2
version: 2.6.0
tags:
- name: translator
- name: ARA
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ uvloop==0.17.0
opentelemetry-sdk==1.16.0
opentelemetry-instrumentation-fastapi==0.37b0
opentelemetry-exporter-jaeger==1.16.0
opentelemetry-instrumentation-httpx==0.37b0
opentelemetry-instrumentation-httpx==0.37b0
fakeredis<=2.10.2
15 changes: 13 additions & 2 deletions src/aragorn_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,23 @@ class ClearCacheRequest(BaseModel):
pswd: str


@ARAGORN_APP.post("/clear_cache", status_code=200, include_in_schema=False)
@ARAGORN_APP.post("/clear_creative_cache", status_code=200, include_in_schema=False)
def clear_redis_cache(request: ClearCacheRequest) -> dict:
"""Clear the redis cache."""
if request.pswd == cache_password:
cache = ResultsCache()
cache.clear_cache()
cache.clear_creative_cache()
return {"status": "success"}
else:
raise HTTPException(status_code=401, detail="Invalid Password")


@ARAGORN_APP.post("/clear_lookup_cache", status_code=200, include_in_schema=False)
def clear_redis_cache(request: ClearCacheRequest) -> dict:
"""Clear the redis cache."""
if request.pswd == cache_password:
cache = ResultsCache()
cache.clear_lookup_cache()
return {"status": "success"}
else:
raise HTTPException(status_code=401, detail="Invalid Password")
Expand Down
51 changes: 43 additions & 8 deletions src/results_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,30 @@

CACHE_HOST = os.environ.get("CACHE_HOST", "localhost")
CACHE_PORT = os.environ.get("CACHE_PORT", "6379")
CACHE_DB = os.environ.get("CACHE_DB", "0")
CREATIVE_CACHE_DB = os.environ.get("CREATIVE_CACHE_DB", "0")
LOOKUP_CACHE_DB = os.environ.get("LOOKUP_CACHE_DB", "1")
CACHE_PASSWORD = os.environ.get("CACHE_PASSWORD", "")

class ResultsCache:
def __init__(self, redis_host=CACHE_HOST, redis_port=CACHE_PORT, redis_db=CACHE_DB, redis_password=CACHE_PASSWORD):
def __init__(
self,
redis_host=CACHE_HOST,
redis_port=CACHE_PORT,
creative_redis_db=CREATIVE_CACHE_DB,
lookup_redis_db=LOOKUP_CACHE_DB,
redis_password=CACHE_PASSWORD,
):
"""Connect to cache."""
self.redis = redis.StrictRedis(
self.creative_redis = redis.StrictRedis(
host=redis_host,
port=redis_port,
db=redis_db,
db=creative_redis_db,
password=redis_password,
)
self.lookup_redis = redis.StrictRedis(
host=redis_host,
port=redis_port,
db=lookup_redis_db,
password=redis_password,
)

Expand All @@ -25,7 +39,7 @@ def get_query_key(self, input_id, predicate, qualifiers, source_input, caller, w

def get_result(self, input_id, predicate, qualifiers, source_input, caller, workflow):
key = self.get_query_key(input_id, predicate, qualifiers, source_input, caller, workflow)
result = self.redis.get(key)
result = self.creative_redis.get(key)
if result is not None:
result = json.loads(gzip.decompress(result))
return result
Expand All @@ -34,7 +48,28 @@ def get_result(self, input_id, predicate, qualifiers, source_input, caller, work
def set_result(self, input_id, predicate, qualifiers, source_input, caller, workflow, final_answer):
key = self.get_query_key(input_id, predicate, qualifiers, source_input, caller, workflow)

self.redis.set(key, gzip.compress(json.dumps(final_answer).encode()))
self.creative_redis.set(key, gzip.compress(json.dumps(final_answer).encode()))

def get_lookup_query_key(self, workflow, query_graph):
keydict = {'workflow': workflow, 'query_graph': query_graph}
return json.dumps(keydict, sort_keys=True)

def get_lookup_result(self, workflow, query_graph):
key = self.get_lookup_query_key(workflow, query_graph)
result = self.lookup_redis.get(key)
if result is not None:
result = json.loads(gzip.decompress(result))
return result


def set_lookup_result(self, workflow, query_graph, final_answer):
key = self.get_lookup_query_key(workflow, query_graph)

self.lookup_redis.set(key, gzip.compress(json.dumps(final_answer).encode()))


def clear_cache(self):
self.redis.flushdb()
def clear_creative_cache(self):
self.creative_redis.flushdb()

def clear_lookup_cache(self):
self.lookup_redis.flushdb()
13 changes: 13 additions & 0 deletions src/service_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,14 @@ async def entry(message, guid, coalesce_type, caller) -> (dict, int):
# We told the world what we can do!
# Workflow will be a list of the functions, and the parameters if there are any

try:
query_graph = message["message"]["query_graph"]
except KeyError:
return f"No query graph", 422
results_cache = ResultsCache()
override_cache = (message.get("parameters") or {}).get("override_cache")
override_cache = override_cache if type(override_cache) is bool else False
results = None
if infer:
# We're going to cache infer queries, and we need to do that even if we're overriding the cache
# because we need these values to post to the cache at the end.
Expand All @@ -182,6 +187,12 @@ async def entry(message, guid, coalesce_type, caller) -> (dict, int):
return results, 200
else:
logger.info(f"{guid}: Results cache miss")
else:
if not override_cache:
results = results_cache.get_lookup_result(workflow_def, query_graph)
if results is not None:
logger.info(f"{guid}: Returning results cache lookup")
return results, 200

workflow = []

Expand All @@ -199,6 +210,8 @@ async def entry(message, guid, coalesce_type, caller) -> (dict, int):

if infer:
results_cache.set_result(input_id, predicate, qualifiers, source_input, caller, workflow_def, final_answer)
else:
results_cache.set_lookup_result(workflow_def, query_graph, final_answer)

# return the answer
return final_answer, status_code
Expand Down
10 changes: 10 additions & 0 deletions tests/helpers/redisMock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import fakeredis
import gzip
import json

def redisMock(host=None, port=None, db=None, password=None):
# Here's where I got documentation for how to do async fakeredis:
# https://github.com/cunla/fakeredis-py/issues/66#issuecomment-1316045893
redis = fakeredis.FakeStrictRedis()
# set up mock function
return redis
8 changes: 6 additions & 2 deletions tests/test_aragorn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import pytest
from fastapi.testclient import TestClient
import redis
from src.server import APP
import os
import json
from datetime import datetime as dt, timedelta
from time import sleep
from unittest.mock import patch
from src.process_db import init_db
from tests.helpers.redisMock import redisMock

client = TestClient(APP)

Expand Down Expand Up @@ -46,7 +48,8 @@ def xtest_async(mock_callback):
assert mock_callback.called


def test_aragorn_wf():
def test_aragorn_wf(monkeypatch):
monkeypatch.setattr(redis, "StrictRedis", redisMock)
init_db()
workflow_A1("aragorn")

Expand Down Expand Up @@ -296,7 +299,8 @@ def x_test_standup_2():

assert found

def test_null_results():
def test_null_results(monkeypatch):
monkeypatch.setattr(redis, "StrictRedis", redisMock)
init_db()
#make sure that aragorn can handle cases where results is null (as opposed to missing)
query= {
Expand Down
8 changes: 6 additions & 2 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import pytest
from fastapi.testclient import TestClient
import redis
from src.server import APP as APP
from src import operations
import os
import json
from unittest.mock import patch
from random import shuffle
from src.process_db import init_db
from tests.helpers.redisMock import redisMock

client = TestClient(APP)
jsondir = 'InputJson_1.2'

def test_bad_ops():
def test_bad_ops(monkeypatch):
monkeypatch.setattr(redis, "StrictRedis", redisMock)
# get the location of the test file
dir_path: str = os.path.dirname(os.path.realpath(__file__))
test_filename = os.path.join(dir_path, jsondir, 'workflow_422.json')
Expand All @@ -25,8 +28,9 @@ def test_bad_ops():
# was the request successful
assert(response.status_code == 422)

def test_lookup_only():
def test_lookup_only(monkeypatch):
"""This has a workflow with a single op (lookup). So the result should not have scores"""
monkeypatch.setattr(redis, "StrictRedis", redisMock)
init_db()
dir_path: str = os.path.dirname(os.path.realpath(__file__))
test_filename = os.path.join(dir_path, jsondir, 'workflow_200.json')
Expand Down

0 comments on commit 8d2cc3b

Please sign in to comment.