Skip to content

Commit f2a8c5f

Browse files
authored
🔑 Key Rotator (#633)
1 parent df6d321 commit f2a8c5f

File tree

4 files changed

+77
-11
lines changed

4 files changed

+77
-11
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from collections import Counter
2+
3+
import pytest
4+
5+
from reworkd_platform.web.api.agent.api_utils import rotate_keys
6+
7+
ITERATIONS = 10000
8+
9+
10+
def test_rotate_keys():
11+
pk = "primary_key"
12+
sk = "secondary_key"
13+
14+
results = []
15+
for _ in range(ITERATIONS):
16+
key = rotate_keys(pk, sk)
17+
assert key in [pk, sk]
18+
results.append(key)
19+
20+
counter = Counter(results)
21+
assert 0.65 < counter[pk] / ITERATIONS < 0.75
22+
23+
24+
@pytest.mark.parametrize(
25+
"model",
26+
[
27+
"gpt-4",
28+
"gpt-4-0314",
29+
"gpt-4-turbo",
30+
],
31+
)
32+
def test_rotate_keys_gpt_4(model):
33+
pk = "gpt-3-primary_key"
34+
sk = "gpt-4-primary_key"
35+
36+
results = []
37+
for _ in range(ITERATIONS):
38+
key = rotate_keys(pk, sk, model=model)
39+
assert key in [pk, sk]
40+
results.append(key)
41+
42+
counter = Counter(results)
43+
assert 0.65 < counter[sk] / ITERATIONS < 0.75
44+
45+
46+
def test_rotate_keys_no_secondary():
47+
pk = "primary_key"
48+
assert rotate_keys(pk, None) == pk
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import random
2+
from typing import Optional
3+
4+
PRIMARY_KEY_RATE = 0.7
5+
SECONDARY_KEY_RATE = 1 - PRIMARY_KEY_RATE
6+
WEIGHTS = [PRIMARY_KEY_RATE, SECONDARY_KEY_RATE]
7+
8+
9+
def rotate_keys(
10+
primary_key: str, secondary_key: Optional[str], model: str = "gpt-3.5-turbo"
11+
) -> str:
12+
if not secondary_key:
13+
return primary_key
14+
15+
keys = [primary_key, secondary_key]
16+
if "gpt-4" in model:
17+
keys.reverse()
18+
19+
return random.choices(keys, WEIGHTS)[0]

platform/reworkd_platform/web/api/agent/model_settings.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from random import randint
21
from typing import Optional
32

43
import openai
54
from langchain.chat_models import ChatOpenAI
65
from pydantic import BaseModel
76

87
from reworkd_platform.settings import settings
8+
from reworkd_platform.web.api.agent.api_utils import rotate_keys
99

1010

1111
class ModelSettings(BaseModel):
@@ -16,13 +16,6 @@ class ModelSettings(BaseModel):
1616
language: Optional[str] = "English"
1717

1818

19-
def get_server_side_key() -> str:
20-
keys = [
21-
key.strip() for key in (settings.openai_api_key or "").split(",") if key.strip()
22-
]
23-
return keys[randint(0, len(keys) - 1)] if keys else ""
24-
25-
2619
GPT_35_TURBO = "gpt-3.5-turbo"
2720

2821
openai.api_base = settings.openai_api_base
@@ -33,7 +26,10 @@ def create_model(
3326
) -> ChatOpenAI:
3427
return ChatOpenAI(
3528
client=None, # Meta private value but mypy will complain its missing
36-
openai_api_key=get_server_side_key(),
29+
openai_api_key=rotate_keys(
30+
primary_key=settings.openai_api_key,
31+
secondary_key=settings.secondary_openai_api_key,
32+
),
3733
temperature=model_settings.customTemperature
3834
if model_settings and model_settings.customTemperature is not None
3935
else 0.9,

platform/reworkd_platform/web/api/agent/tools/image.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
44

55
from reworkd_platform.settings import settings
6+
from reworkd_platform.web.api.agent.api_utils import rotate_keys
67
from reworkd_platform.web.api.agent.model_settings import (
78
ModelSettings,
8-
get_server_side_key,
99
)
1010
from reworkd_platform.web.api.agent.tools.stream_mock import stream_string
1111
from reworkd_platform.web.api.agent.tools.tool import Tool
@@ -27,7 +27,10 @@ async def get_replicate_image(input_str: str) -> str:
2727

2828
# Use AI to generate an Image based on a prompt
2929
async def get_open_ai_image(input_str: str) -> str:
30-
api_key = get_server_side_key()
30+
api_key = rotate_keys(
31+
primary_key=settings.openai_api_key,
32+
secondary_key=settings.secondary_openai_api_key,
33+
)
3134

3235
response = openai.Image.create(
3336
api_key=api_key,

0 commit comments

Comments
 (0)