Skip to content

Commit

Permalink
change the traces route not to wait to shields processing (async hand…
Browse files Browse the repository at this point in the history
…ling)
  • Loading branch information
cmpxchg16 committed Dec 18, 2023
1 parent 23a62d6 commit 06645e9
Showing 1 changed file with 4 additions and 20 deletions.
24 changes: 4 additions & 20 deletions vibraniumdome-shields/vibraniumdome_shields/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import os
import tempfile
import asyncio

from dotenv import load_dotenv
from flask import Flask, Response, jsonify, request
Expand All @@ -29,10 +28,6 @@
logger.setLevel(log_level)


_logger = logging.getLogger(__name__)
trace_thread_pool_executor = None


class ScanInputSchema(Schema):
llm_session = fields.String(required=True)

Expand All @@ -51,13 +46,6 @@ class ScanInputSchema(Schema):
interaction_service = LLMInteractionService()


def _get_thread_pool_executor():
global trace_thread_pool_executor
if trace_thread_pool_executor is None:
trace_thread_pool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4, thread_name_prefix="traces")
return trace_thread_pool_executor


@app.route("/api/health", methods=["GET"])
def api():
return jsonify({"status": "OK"}), 200
Expand Down Expand Up @@ -88,22 +76,18 @@ def scan():
@app.route("/v1/traces", methods=["POST"])
def receive_traces():
llm_interaction: LLMInteraction = parser.parse_llm_call(request.data)

executor = concurrent.futures.ThreadPoolExecutor()
def process_traces(llm_interaction: LLMInteraction):
try:
# policy = policy_service.get_default_policy()
policy = policy_service.get_policy_by_name(llm_interaction._interaction.get("service_name", "default"))
llm_interaction._shields_result = captain_llm.deflect_shields(llm_interaction, policy)
interaction_service.save_llm_interaction(llm_interaction)
except Exception:
_logger.exception("error while deflecting shields for interaction= %s with policy= %s", llm_interaction, policy)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
logger.exception("error while deflecting shields for interaction= %s with policy= %s", llm_interaction, policy)

# Submit the async task to the thread pool with input arguments
loop.run_in_executor(_get_thread_pool_executor(), process_traces, llm_interaction)
executor.submit(process_traces, llm_interaction)

_logger.debug("done proccesing %s", llm_interaction._id)
return Response(ExportTraceServiceResponse().SerializeToString(), mimetype="application/octet-stream")


Expand Down

0 comments on commit 06645e9

Please sign in to comment.