Skip to content

Commit

Permalink
Added sql backend for backgroundjobs
Browse files Browse the repository at this point in the history
  • Loading branch information
bitbyt3r committed May 9, 2021
1 parent 1ac6234 commit 6c8b6c5
Show file tree
Hide file tree
Showing 13 changed files with 384 additions and 102 deletions.
3 changes: 2 additions & 1 deletion backend/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.
pytest
coverage
coverage
fakeredis
86 changes: 77 additions & 9 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,100 @@
import importlib
import fakeredis
import sqlite3
import pytest
import json
import sys
import os

import tuber
settings_override = {
'TESTING': True,
'SQLALCHEMY_DATABASE_URI': "sqlite:///:memory:"
}
tuber.app.config.update(settings_override)

def csrf(client):
for cookie in client.cookie_jar:
if cookie.name == "csrf_token":
return cookie.value
return ""

@pytest.fixture()
def tuber():
os.environ['REDIS_URL'] = ""
mod = importlib.import_module('tuber')
settings_override = {
'TESTING': True,
'SQLALCHEMY_DATABASE_URI': "sqlite:///:memory:"
}
mod.app.config.update(settings_override)
yield mod
for key in list(sys.modules.keys()):
if key.startswith("tuber"):
del sys.modules[key]

@pytest.fixture
def client_fresh():
def client_fresh(tuber):
"""Creates a client with a fresh database and no active sessions. Initial setup will not yet be completed.
"""
tuber.db.create_all()
with tuber.app.test_client() as client:
yield client
tuber.db.drop_all()
del sys.modules['tuber']

@pytest.fixture
def client(tuber):
"""Creates a test client with initial setup complete and the admin user logged in already.
Also patches the get/post/patch/delete functions to handle CSRF tokens for you.
"""
redis = fakeredis.FakeStrictRedis()
tuber.r = redis
tuber.api.r = redis
tuber.db.create_all()
with tuber.app.test_client() as client:
client.post('/api/initial_setup', json={"username": "admin", "email": "admin@magfest.org", "password": "admin"})
client.post("/api/login", json={"csrf_token": csrf(client), "username": "admin", "password": "admin"})
client.post("/api/events", json={"csrf_token": csrf(client), "name": "Tuber Event", "description": "It's a potato"})
_get = client.get
def get(*args, **kwargs):
if not 'query_string' in kwargs:
kwargs['query_string'] = {}
kwargs['query_string']['csrf_token'] = csrf(client)
rv = _get(*args, **kwargs)
return rv
_post = client.post
def post(*args, **kwargs):
if 'data' in kwargs:
kwargs['data']['csrf_token'] = csrf(client)
else:
if not 'json' in kwargs:
kwargs['json'] = {}
kwargs['json']['csrf_token'] = csrf(client)
rv = _post(*args, **kwargs)
return rv
_patch = client.patch
def patch(*args, **kwargs):
if 'data' in kwargs:
kwargs['data']['csrf_token'] = csrf(client)
else:
if not 'json' in kwargs:
kwargs['json'] = {}
kwargs['json']['csrf_token'] = csrf(client)
rv = _patch(*args, **kwargs)
return rv
_delete = client.delete
def delete(*args, **kwargs):
if 'data' in kwargs:
kwargs['data']['csrf_token'] = csrf(client)
else:
if not 'json' in kwargs:
kwargs['json'] = {}
kwargs['json']['csrf_token'] = csrf(client)
rv = _delete(*args, **kwargs)
return rv
client.get = get
client.post = post
client.patch = patch
client.delete = delete
yield client
tuber.db.drop_all()

@pytest.fixture
def client():
def client_noredis(tuber):
"""Creates a test client with initial setup complete and the admin user logged in already.
Also patches the get/post/patch/delete functions to handle CSRF tokens for you.
"""
Expand Down
71 changes: 71 additions & 0 deletions backend/tests/test_circuitbreaker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import time

def test_fast_request_redis(client):
"""Make sure that fast requests do not create a job when using redis."""
result = client.get("/api/fast")
assert result.status_code == 200

def test_slow_request_redis(client):
"""Make sure that slow requests do create a job when using redis."""
result = client.get("/api/slow")
assert result.status_code == 202

def test_fast_request_no_redis(client_noredis):
"""Make sure that fast requests do not create a job when not using redis."""
result = client_noredis.get("/api/fast")
assert result.status_code == 200

def test_slow_request_no_redis(client_noredis):
"""Make sure that slow requests do create a job when not using redis."""
result = client_noredis.get("/api/slow")
assert result.status_code == 202

def test_job_retrieval(client):
"""Make sure jobs can be retrieved after a long running job is started."""
result = client.get("/api/slow")
assert result.status_code == 202
assert "job" in result.json
job_id = result.json['job']
start_time = time.time()
while time.time() - start_time < 15:
result = client.get("/api/jobs", query_string={"job": job_id})
assert result.status_code == 200
assert "progress" in result.json
assert "complete" in result.json['progress']
assert "result" in result.json
if result.json['result']:
assert result.json['progress']['complete']
break
else:
assert not result.json['progress']['complete']
assert result.json['result']['status_code'] == 200
assert result.json['result']['data'] == "success"
assert result.json['result']['mimetype'] == "text/html"
assert result.json['result']['headers']
assert result.json['result']['execution_time']

def test_job_retrieval_noredis(client_noredis):
"""Make sure jobs can be retrieved after a long running job is started."""
client = client_noredis
result = client.get("/api/slow")
assert result.status_code == 202
assert "job" in result.json
job_id = result.json['job']
start_time = time.time()
while time.time() - start_time < 15:
result = client.get("/api/jobs", query_string={"job": job_id})
assert result.status_code == 200
assert "progress" in result.json
assert "complete" in result.json['progress']
assert "result" in result.json
if result.json['result']:
assert result.json['progress']['complete']
break
else:
assert not result.json['progress']['complete']
print(result.json)
assert result.json['result']['status_code'] == 200
assert result.json['result']['data'] == "success"
assert result.json['result']['mimetype'] == "text/html"
assert result.json['result']['headers']
assert result.json['result']['execution_time']
17 changes: 17 additions & 0 deletions backend/tuber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,24 @@
from flask import Flask, g
from flask_sqlalchemy import SQLAlchemy
from flask_talisman import Talisman
import redis
import tuber.config as config
import json
import sys
import os
import re
import alembic
from alembic.config import Config as AlembicConfig

db = None
r = None
app = Flask(__name__)
initialized = False
alembic_config = None

print("Importing...")
if not initialized:
print("Initializing...")
initialized = True

if config.flask_env == "production":
Expand Down Expand Up @@ -58,6 +63,18 @@

alembic_config = AlembicConfig(config.alembic_ini)

print(config.redis_url)
if config.redis_url:
m = re.search("redis://([a-z0-9\.]+)(:(\d+))?(/(\d+))?", config.redis_url)
redis_host = m.group(1)
redis_port = 6379
if m.group(3):
redis_port = int(m.group(3))
redis_db = 0
if m.group(5):
redis_db = int(m.group(5))
r = redis.Redis(host=redis_host, port=redis_port, db=redis_db)

db = SQLAlchemy(app)
import tuber.csrf
import tuber.models
Expand Down
76 changes: 2 additions & 74 deletions backend/tuber/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,8 @@
from marshmallow_sqlalchemy import ModelSchema
import inspect
import json
import time
import uuid
import sqlalchemy

if config.enable_circuitbreaker:
from multiprocessing.pool import ThreadPool
pool = ThreadPool(processes=config.circuitbreaker_threads)

all_permissions = []

def check_matches(matches, row, env):
Expand Down Expand Up @@ -163,6 +157,7 @@ def register_crud(name, schema, methods=["GET", "POST", "PATCH", "DELETE"], perm
from .emails import *
from .badges import *
from .shifts import *
from .backgroundjobs import *

def indent(string, level=4):
lines = string.split("\n")
Expand Down Expand Up @@ -203,71 +198,4 @@ def underline(string, char="-"):
.. sourcecode:: json
{}
""".format(name, description, sample_json)

@app.route("/api/slow", methods=["GET"])
def slow_call():
#if check_permission("circuitbreaker.test"):
for i in range(10):
time.sleep(1)
g.progress(i*0.1)
return "That took some time", 200
#return "Permission Denied", 403

@app.route("/api/fast", methods=["GET"])
def fast_call():
if check_permission("circuitbreaker.test"):
time.sleep(0.1)
return "Super speedy", 200
return "Permission Denied", 403

# This wraps all view functions with a circuit breaker that allows soft timeouts.
# Basically, if a view function takes more than the timeout to render a page then
# the client gets a 202 and receives a claim ticket while the view function keeps
# running in the background. The result gets pushed to either redis or the sql db
# and the client can retrieve it (or the status while it's pending) later.
#
# The main design goal is for this to be transparent to people writing flask view
# functions, so make sure the environment matches whether this wrapper is enabled
# or not.
#
# View functions that often take a long time can be made aware of this behavior
# and can push status/progress updates into the job while they work.
def job_wrapper(func):
def wrapped(*args, **kwargs):
def yo_dawg(request_context, before_request_funcs):
with app.test_request_context(**request_context):
for before_request_func in before_request_funcs:
before_request_func()
def progress(amount):
print(f"Progress Callback {amount}")
g.progress = progress
return func(*args, **kwargs)
request_context = {
"path": request.path,
"base_url": request.base_url,
"query_string": request.query_string,
"method": request.method,
"headers": dict(request.headers),
"data": g.raw_data,
}
start_time = time.time()
jobid = str(uuid.uuid4())
def store_result(ret):
if time.time() - start_time > 0.9 * config.circuitbreaker_timeout:
print(f"Storing result as {jobid}", ret)
else:
print("Not storing result", ret)
if hasattr(ret, "data"):
print(ret.data)
result = pool.apply_async(yo_dawg, (request_context, app.before_request_funcs[None]), callback=store_result)
result.wait(timeout=config.circuitbreaker_timeout)
if result.ready():
return result.get()
else:
return jsonify(job=jobid), 202
return wrapped

if config.enable_circuitbreaker:
for key, val in app.view_functions.items():
app.view_functions[key] = job_wrapper(val)
""".format(name, description, sample_json)
Loading

0 comments on commit 6c8b6c5

Please sign in to comment.