Skip to content

Commit

Permalink
update async threadpool use
Browse files Browse the repository at this point in the history
  • Loading branch information
shlomsh committed Dec 17, 2023
1 parent 13480ab commit 2d7a2c1
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 18 deletions.
17 changes: 10 additions & 7 deletions vibraniumdome-app/src/server/api/routers/policy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,22 @@ export const getPolicyByLLMAppApi = protectedProcedure
"name": "Default Policy",
"content": {
"shields_filter": "all",
"high_risk_threshold": 0.8,
"low_risk_threshold": 0.2,
"input_shields": [
{"type": "llm_shield", "metadata": {"model": "gpt-3.5-turbo", "model_vendor": "openai"}},
{"type": "vector_db_shield", "metadata": {}},
{"type": "regex_shield", "metadata": {}, "name": "policy number"},
{"type": "llm_shield", "metadata": {"model": "gpt-3.5-turbo", "model_vendor": "openai"}},
{"type": "transformer_shield", "metadata": {}},
{"type": "vector_db_shield", "metadata": {}},
{"type": "prompt_safety_shield", "metadata": {}},
{"type": "sensitive_shield", "metadata": {}},
{"type": "model_denial_of_service_shield", "metadata": {"threshold": 10, "interval_sec": 60, "limit_by": "llm.user"}},
],
"output_shields": [
{"type": "llm_shield", "metadata": {"model": "gpt-3.5-turbo", "model_vendor": "openai"}},
{"type": "regex_shield", "metadata": {}},
{"type": "transformer_shield", "metadata": {}},
{"type": "vector_db_shield", "metadata": {}},
{"type": "regex_output_shield", "metadata": {}, "name": "credit card"},
{"type": "refusal_shield", "metadata": {}},
{"type": "sensitive_shield", "metadata": {}},
{"type": "canary_token_disclosure_shield", "metadata": {"canary_tokens": []}},
{"type": "sensitive_output_shield", "metadata": {}},
],
},
}
Expand Down
24 changes: 19 additions & 5 deletions vibraniumdome-shields/vibraniumdome_shields/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import tempfile
import asyncio

from dotenv import load_dotenv
from flask import Flask, Response, jsonify, request
Expand All @@ -28,6 +29,10 @@
logger.setLevel(log_level)


_logger = logging.getLogger(__name__)
trace_thread_pool_executor = None


class ScanInputSchema(Schema):
llm_session = fields.String(required=True)

Expand All @@ -46,6 +51,13 @@ class ScanInputSchema(Schema):
interaction_service = LLMInteractionService()


def _get_thread_pool_executor():
global trace_thread_pool_executor
if trace_thread_pool_executor is None:
trace_thread_pool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4, thread_name_prefix="traces")
return trace_thread_pool_executor


@app.route("/api/health", methods=["GET"])
def api():
return jsonify({"status": "OK"}), 200
Expand Down Expand Up @@ -79,17 +91,19 @@ def receive_traces():

def process_traces(llm_interaction: LLMInteraction):
try:
# policy = policy_service.get_default_policy()
policy = policy_service.get_policy_by_name(llm_interaction._interaction.get("service_name", "default"))
llm_interaction._shields_result = captain_llm.deflect_shields(llm_interaction, policy)
interaction_service.save_llm_interaction(llm_interaction)
except Exception:
logger.exception("error while deflecting shields for interaction= %s with policy= %s", llm_interaction, policy)
_logger.exception("error while deflecting shields for interaction= %s with policy= %s", llm_interaction, policy)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# Create a ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor() as executor:
executor.submit(process_traces, llm_interaction)
# Submit the async task to the thread pool with input arguments
loop.run_in_executor(_get_thread_pool_executor(), process_traces, llm_interaction)

_logger.debug("done proccesing %s", llm_interaction._id)
return Response(ExportTraceServiceResponse().SerializeToString(), mimetype="application/octet-stream")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ def get_default_policy(self):
"high_risk_threshold": 0.8,
"low_risk_threshold": 0.2,
"input_shields": [
{"type": "llm_shield", "metadata": {"model": "gpt-3.5-turbo", "model_vendor": "openai"}},
{"type": "vector_db_shield", "metadata": {}},
{"type": "regex_shield", "metadata": {}, "name": "policy number"},
{"type": "llm_shield", "metadata": {"model": "gpt-3.5-turbo", "model_vendor": "openai"}},
{"type": "transformer_shield", "metadata": {}},
{"type": "vector_db_shield", "metadata": {}},
{"type": "model_denial_of_service_shield", "metadata": {"threshold": 10, "interval_sec": 60, "limit_by": "llm.user"}},
{"type": "sensitive_shield", "metadata": {}},
{"type": "prompt_safety_shield", "metadata": {}},
{"type": "sensitive_shield", "metadata": {}},
{"type": "model_denial_of_service_shield", "metadata": {"threshold": 10, "interval_sec": 60, "limit_by": "llm.user"}},
],
"output_shields": [
{"type": "regex_output_shield", "metadata": {}, "name": "credit card"},
{"type": "refusal_shield", "metadata": {}},
{"type": "sensitive_output_shield", "metadata": {}},
{"type": "canary_token_disclosure_shield", "metadata": {"canary_tokens": []}},
{"type": "sensitive_output_shield", "metadata": {}},
],
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class VibraniumShieldsFactory:

_vector_db_service: VectorDBService
_input_shields: dict
_output_shields: dict

def __init__(self, _vector_db_service: VectorDBService):
if not _vector_db_service:
Expand Down Expand Up @@ -94,7 +95,7 @@ def deflect_shield(tuple: [VibraniumShield, dict]) -> List[ShieldDeflectionResul
self._logger.exception("error while deflecting shield %s with scan_id=%s", shield.name, scan_id)

if execution_mode_async:
with concurrent.futures.ThreadPoolExecutor() as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=5, thread_name_prefix="CaptainLLM") as executor:
shields_res = executor.map(deflect_shield, shields)
results = dict(filter(lambda x: len(x[1]) > 0, shields_res))
else:
Expand Down

0 comments on commit 2d7a2c1

Please sign in to comment.