Skip to content

Commit

Permalink
Merge branch 'main' of github.com:genia-dev/vibraniumdome
Browse files Browse the repository at this point in the history
  • Loading branch information
shlomsh committed Mar 26, 2024
2 parents 0c40f5c + 801c873 commit eef5f6f
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 77 deletions.
3 changes: 2 additions & 1 deletion vibraniumdome-app/src/server/api/root.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { policyRouter, getPolicyByLLMAppApi, getBasePolicy } from "~/server/api/routers/policy";
import { membershipRouter } from "~/server/api/routers/membership";
import { apiTokenRouter } from "~/server/api/routers/apitoken";
import { apiTokenRouter, validateAPIToken } from "~/server/api/routers/apitoken";
import { createTRPCRouter } from "~/server/api/trpc";

/**
Expand All @@ -14,6 +14,7 @@ export const appRouter = createTRPCRouter({
apitoken: apiTokenRouter,
policies: getPolicyByLLMAppApi,
base_policy: getBasePolicy,
validateAPIToken: validateAPIToken,
});

// export type definition of API
Expand Down
11 changes: 11 additions & 0 deletions vibraniumdome-app/src/server/api/routers/apitoken.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ function generateApiKey(): string {
}


export const validateAPIToken = protectedProcedure
.query(async ({ ctx }) => {
return ctx.db.aPIToken.findFirst({
// @ts-ignore
where: { user: { id: ctx.session.user.id } },
select: {
user: true,
}
});
})

export const apiTokenRouter = createTRPCRouter({
createApiToken: protectedProcedure
.input(z.object({ name: z.string().min(1) }))
Expand Down
3 changes: 2 additions & 1 deletion vibraniumdome-app/src/server/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ function generateOpenSearchJWT() {
*/
export const authOptions: NextAuthOptions = {
callbacks: {
//@ts-ignore
async session({ session, user, token }) {
const currUser = await db.user.findUnique({
//@ts-ignore
where: { email: session.user.email.toLowerCase() },
});

if (!currUser.isActive) {
if (!currUser || !currUser.isActive) {
return
}

Expand Down
2 changes: 1 addition & 1 deletion vibraniumdome-shields/docker-entrypoint.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

exec PROMETHEUS_MULTIPROC_DIR=/tmp/ $POETRY_HOME/bin/poetry run gunicorn --bind 0.0.0.0:5001 --threads 4 --workers 4 --preload vibraniumdome_shields.main:app -k gthread
PROMETHEUS_MULTIPROC_DIR=/tmp/ exec $POETRY_HOME/bin/poetry run gunicorn --bind 0.0.0.0:5001 --threads 4 --workers 4 --preload vibraniumdome_shields.main:app -k gthread
11 changes: 9 additions & 2 deletions vibraniumdome-shields/otel_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import time
import os

from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter

vibranium_dome_api_key = os.getenv('VIBRANIUM_DOME_API_KEY')

if vibranium_dome_api_key is None:
raise ValueError("VIBRANIUM_DOME_API_KEY environment variable is not set")

headers = {"Authorization": f"Bearer {vibranium_dome_api_key}"}

trace.set_tracer_provider(TracerProvider())
otlp_exporter = OTLPSpanExporter(endpoint="http://localhost:5001/v1/traces")
otlp_exporter = OTLPSpanExporter(endpoint="http://localhost:5001/v1/traces", headers=headers)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))

Expand Down
139 changes: 70 additions & 69 deletions vibraniumdome-shields/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion vibraniumdome-shields/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ presidio-analyzer = "^2.2.351"
spacy = "^3.7.2"
opentelemetry-proto = "^1.21.0"
pyrate-limiter = "^3.1.0"
vibraniumdome-sdk = "^0.1.0"
openai = "0.28.0"
httpx = "^0.25.2"
prometheus-client = "^0.20.0"
vibraniumdome-sdk = "^0.4.0"


[tool.poetry.group.dev.dependencies]
Expand Down
23 changes: 21 additions & 2 deletions vibraniumdome-shields/vibraniumdome_shields/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time

from dotenv import load_dotenv
from flask import Flask, Response, jsonify, request
from flask import Flask, Response, jsonify, request, make_response
from marshmallow import Schema, fields
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceResponse

Expand All @@ -18,6 +18,8 @@
from vibraniumdome_shields.user_interface.cli_app import main
from vibraniumdome_shields.vector_db.vector_db_service import VectorDBService

from vibraniumdome_shields.utils import check_api_token

from prometheus_client import multiprocess
from prometheus_client import generate_latest, CollectorRegistry, CONTENT_TYPE_LATEST, Counter, Histogram

Expand Down Expand Up @@ -86,7 +88,7 @@ def vector_reload():
return jsonify({"response": "Done"}), 200


@app.route("/api/scan", methods=["POST"])
@app.route("/v1/scan", methods=["POST"])
def scan():
data = request.json
try:
Expand All @@ -100,6 +102,23 @@ def scan():

@app.route("/v1/traces", methods=["POST"])
def receive_traces():
try:
vibranium_dome_base_url = settings.get("VIBRANIUM_DOME_APP_BASE_URL", "http://localhost:3000")
auth_header = request.headers.get('Authorization')

if not auth_header:
return make_response('Unauthorized Access', 401)

token_type, vibranium_dome_api_key = auth_header.split()
if token_type.lower() != 'bearer':
return make_response('Unauthorized Access', 401)

if not check_api_token(vibranium_dome_base_url, vibranium_dome_api_key):
logger.warning("got an invalid API key")
return make_response('Unauthorized Access', 401)
except Exception:
return make_response('Unauthorized Access', 401)

number_of_requests.inc()
llm_interactions: list(LLMInteraction) = parser.parse_llm_call(request.data)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1000, thread_name_prefix="traces")
Expand Down
12 changes: 12 additions & 0 deletions vibraniumdome-shields/vibraniumdome_shields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import yaml
from pydantic import BaseModel
import requests

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,3 +63,14 @@ def pydantic_json_encoder(obj):
if isinstance(obj, BaseModel):
return obj.model_dump()
return obj

def check_api_token(vibranium_dome_base_url, vibranium_dome_api_key) -> bool:
full_url = f"{vibranium_dome_base_url}/api/trpc/validateAPIToken"
headers = {"Authorization": f"Bearer {vibranium_dome_api_key}"}
try:
response = requests.get(full_url, headers=headers)
logger.debug(response)
return response.ok
except Exception:
logger.exception("failed to check_api_token")
return False

0 comments on commit eef5f6f

Please sign in to comment.