Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple security #15

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions api/flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from run import Predictor
from utils.image_utils import load_image_array_from_bytes, load_image_tensor_from_bytes
from utils.logging_utils import get_logger_name
from api.simple_security import SimpleSecurity, session_key_required

# Reading environment files
try:
Expand All @@ -45,6 +46,10 @@

app = Flask(__name__)

enable_security = os.getenv("SECURITY_ENABLED", False)
api_key_user_json_string = os.getenv("API_KEY_USER_JSON_STRING")
SimpleSecurity(app, enable_security, api_key_user_json_string)

predictor = None
gen_page = None

Expand Down Expand Up @@ -233,6 +238,7 @@ def check_exception_callback(future: Future):

@app.route("/predict", methods=["POST"])
@exception_predict_counter.count_exceptions()
@session_key_required
def predict() -> tuple[Response, int]:
"""
Run the prediction on a submitted image
Expand Down Expand Up @@ -299,6 +305,7 @@ def predict() -> tuple[Response, int]:


@app.route("/prometheus", methods=["GET"])
@session_key_required
def metrics() -> bytes:
"""
Return the Prometheus metrics for the running flask application
Expand Down
65 changes: 65 additions & 0 deletions api/simple_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import functools
import uuid

import flask
from flask import request, Response, jsonify, Flask
import json


class SimpleSecurity:
def __init__(self, app: Flask, enabled: bool = False, key_user_json: str = None):
app.extensions["security"] = self
self.enabled = enabled
if enabled:
self.register_login_resource(app)
try:
self.api_key_user = json.loads(key_user_json)
self.session_key_user = {}
except Exception as e:
raise ValueError("When security is enabled, key_user_json should be a valid json string. ", e)

def is_known_session_key(self, session_key: str):
return session_key in self.session_key_user.keys()

def register_login_resource(self, app):
@app.route("/login", methods=["POST"])
def login():
if "Authorization" in request.headers.keys():
api_key = request.headers["Authorization"]
session_key = self.login(api_key)

if session_key is not None:
response = Response(status=204)
response.headers["X_AUTH_TOKEN"] = session_key

return response

return Response(status=401)

def login(self, api_key: str) -> str | None:
if self.enabled and api_key in self.api_key_user:
session_key = str(uuid.uuid4())
self.session_key_user[session_key] = self.api_key_user[api_key]
return session_key

return None


def session_key_required(func):
@functools.wraps(func)
def decorator(*args, **kwargs) -> Response:
security_ = flask.current_app.extensions["security"]
if security_.enabled:
if "Authorization" in request.headers.keys():
session_key = request.headers["Authorization"]
if security_.is_known_session_key(session_key):
return func(*args, **kwargs)

response = jsonify({"message": "Expected a valid session key in the Authorization header"})
response.status_code = 401
return response
else:
print("security disabled")
return func(*args, **kwargs)

return decorator
2 changes: 1 addition & 1 deletion api/start_flask_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ if [[ $( builtin cd "$( dirname ${BASH_SOURCE[0]} )/.."; pwd ) != $( pwd ) ]]; t
fi

LAYPA_MAX_QUEUE_SIZE=128 \
LAYPA_MODEL_BASE_PATH="/home/stefan/Documents/models/" \
LAYPA_MODEL_BASE_PATH="/home/martijnm/workspace/images/laypa-models" \
LAYPA_OUTPUT_BASE_PATH="/tmp/" \
FLASK_DEBUG=true FLASK_APP=api.flask_app.py flask run
14 changes: 14 additions & 0 deletions api/start_flask_local_with_security.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

if [[ $( builtin cd "$( dirname ${BASH_SOURCE[0]} )/.."; pwd ) != $( pwd ) ]]; then
DIR_OF_SCRIPT=$( builtin cd "$( dirname ${BASH_SOURCE[0]} )/.."; pwd )
echo "Change to laypa base folder ($DIR_OF_SCRIPT)"
cd $DIR_OF_SCRIPT
fi

LAYPA_MAX_QUEUE_SIZE=128 \
LAYPA_MODEL_BASE_PATH="/home/martijnm/workspace/images/laypa-models" \
LAYPA_OUTPUT_BASE_PATH="/tmp/" \
SECURITY_ENABLED="True" \
API_KEY_USER_JSON_STRING='{"1234": "test user"}' \
FLASK_DEBUG=true FLASK_APP=api.flask_app.py flask run