Skip to content

Commit

Permalink
implemented feedback system
Browse files Browse the repository at this point in the history
  • Loading branch information
nerfZael committed Jul 2, 2024
1 parent fd61f5c commit b340e60
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 21 deletions.
26 changes: 22 additions & 4 deletions autotx/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, app_id: str):
self.client = get_db_client("public")
self.app_id = app_id

def start(self, prompt: str, address: str, chain_id: int, app_user_id: str) -> models.Task:
def start(self, prompt: str, address: str, chain_id: int, app_user_id: str, previous_task_id: str | None = None) -> models.Task:
client = get_db_client("public")

created_at = datetime.utcnow()
Expand All @@ -57,6 +57,8 @@ def start(self, prompt: str, address: str, chain_id: int, app_user_id: str) -> m
"messages": json.dumps([]),
"logs": json.dumps([]),
"intents": json.dumps([]),
"previous_task_id": previous_task_id,
"feedback": None
}
).execute()

Expand All @@ -72,6 +74,8 @@ def start(self, prompt: str, address: str, chain_id: int, app_user_id: str) -> m
messages=[],
logs=[],
intents=[],
previous_task_id=previous_task_id,
feedback=None
)

def stop(self, task_id: str) -> None:
Expand All @@ -95,9 +99,19 @@ def update(self, task: models.Task) -> None:
"messages": json.dumps(task.messages),
"error": task.error,
"logs": dump_pydantic_list(task.logs if task.logs else []),
"intents": dump_pydantic_list(task.intents)
"intents": dump_pydantic_list(task.intents),
"previous_task_id": task.previous_task_id
}
).eq("id", task.id).eq("app_id", self.app_id).execute()

def update_feedback(self, task_id: str, feedback: str) -> None:
client = get_db_client("public")

client.table("tasks").update(
{
"feedback": feedback
}
).eq("id", task_id).eq("app_id", self.app_id).execute()

def get(self, task_id: str) -> models.Task | None:
client = get_db_client("public")
Expand All @@ -124,7 +138,9 @@ def get(self, task_id: str) -> models.Task | None:
error=task_data["error"],
messages=json.loads(task_data["messages"]),
logs=[models.TaskLog(**log) for log in json.loads(task_data["logs"])] if task_data["logs"] else None,
intents=[load_intent(intent) for intent in json.loads(task_data["intents"])]
intents=[load_intent(intent) for intent in json.loads(task_data["intents"])],
previous_task_id=task_data["previous_task_id"],
feedback=task_data["feedback"]
)

def get_all(self) -> list[models.Task]:
Expand All @@ -147,7 +163,9 @@ def get_all(self) -> list[models.Task]:
error=task_data["error"],
messages=json.loads(task_data["messages"]),
logs=[models.TaskLog(**log) for log in json.loads(task_data["logs"])] if task_data["logs"] else None,
intents=[load_intent(intent) for intent in json.loads(task_data["intents"])]
intents=[load_intent(intent) for intent in json.loads(task_data["intents"])],
previous_task_id=task_data["previous_task_id"],
feedback=task_data["feedback"]
)
)

Expand Down
2 changes: 2 additions & 0 deletions autotx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class Task(BaseModel):
messages: List[str]
logs: List[TaskLog] | None
intents: List[Intent]
previous_task_id: str | None
feedback: str | None

class TaskError(BaseModel):
id: str
Expand Down
88 changes: 72 additions & 16 deletions autotx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from autotx import models, setup, task_logs
from autotx import db
from autotx.AutoTx import AutoTx, Config as AutoTxConfig
from autotx.intents import Intent
from autotx.smart_accounts.smart_account import SmartAccount
from autotx.transactions import Transaction
Expand Down Expand Up @@ -124,29 +125,24 @@ def add_task_log(log: models.TaskLog, task_id: str, tasks: db.TasksRepository) -
task.logs.append(log)
tasks.update(task)

@app_router.post("/api/v1/tasks", response_model=models.Task)
async def create_task(task: models.TaskCreate, background_tasks: BackgroundTasks, authorization: Annotated[str | None, Header()] = None) -> models.Task:
from autotx.AutoTx import AutoTx, Config as AutoTxConfig

app = authorize(authorization)
app_user = db.get_app_user(app.id, task.user_id)
if not app_user:
raise HTTPException(status_code=400, detail="User not found")

tasks = db.TasksRepository(app.id)
def get_previous_tasks(task: models.Task, tasks: db.TasksRepository) -> List[models.Task]:
previous_tasks = []
current_task = task
while current_task is not None:
previous_tasks.append(current_task)
if current_task.previous_task_id is None:
break
current_task = tasks.get(current_task.previous_task_id)
return previous_tasks

global autotx_params
if not autotx_params.is_dev and (not task.address or not task.chain_id):
raise HTTPException(status_code=400, detail="Address and Chain ID are required for non-dev mode")

prompt = task.prompt

def run_task(prompt: str, task: models.TaskCreate, app: models.App, app_user: models.AppUser, tasks: db.TasksRepository, background_tasks: BackgroundTasks, previous_task_id: str | None = None) -> models.Task:
app_config = AppConfig(subsidized_chain_id=task.chain_id)

wallet = SafeSmartAccount(app_config.rpc_url, app_config.network_info, smart_account_addr=task.address)
api_wallet = ApiSmartAccount(app_config.web3, wallet, tasks)

created_task: models.Task = tasks.start(prompt, api_wallet.address.hex, app_config.network_info.chain_id.value, app_user.id)
created_task: models.Task = tasks.start(prompt, api_wallet.address.hex, app_config.network_info.chain_id.value, app_user.id, previous_task_id)
task_id = created_task.id
api_wallet.task_id = task_id

Expand Down Expand Up @@ -205,6 +201,66 @@ async def run_task() -> None:
stop_task_for_error(tasks, task_id, error, f"An error caused AutoTx to stop ({task_id})")
raise e

@app_router.post("/api/v1/tasks", response_model=models.Task)
async def create_task(task: models.TaskCreate, background_tasks: BackgroundTasks, authorization: Annotated[str | None, Header()] = None) -> models.Task:
app = authorize(authorization)
app_user = db.get_app_user(app.id, task.user_id)
if not app_user:
raise HTTPException(status_code=400, detail="User not found")

tasks = db.TasksRepository(app.id)

global autotx_params
if not autotx_params.is_dev and (not task.address or not task.chain_id):
raise HTTPException(status_code=400, detail="Address and Chain ID are required for non-dev mode")

prompt = task.prompt

task = run_task(prompt, task, app, app_user, tasks, background_tasks)

return task

class FeedbackParams(BaseModel):
feedback: str
user_id: str

@app_router.post("/api/v1/tasks/{task_id}/feedback", response_model=models.Task)
def provide_feedback(task_id: str, model: FeedbackParams, background_tasks: BackgroundTasks, authorization: Annotated[str | None, Header()] = None) -> 'models.Task':
(app, app_user) = authorize_app_and_user(authorization, model.user_id)

tasks = db.TasksRepository(app.id)

task = get_task_or_404(task_id, tasks)

if task.running:
raise HTTPException(status_code=400, detail="Task is still running")

global autotx_params
if not autotx_params.is_dev and (not task.address or not task.chain_id):
raise HTTPException(status_code=400, detail="Address and Chain ID are required for non-dev mode")

# Get all previous tasks
previous_tasks = get_previous_tasks(task, tasks)

prompt = "History:\n"
for previous_task in previous_tasks[::-1]:
if previous_task.previous_task_id is None:
prompt += "The user first said:\n"+ previous_task.prompt + "\n\n"
prompt += "Then the agents generated the following transactions:\n"
for intent in previous_task.intents:
prompt += intent.summary + "\n"
prompt += "\n"
if previous_task.feedback:
prompt += "The user then said:\n" + previous_task.feedback + "\n\n"

prompt += "Now the user provided feedback:\n" + model.feedback

tasks.update_feedback(task_id, model.feedback)

task = run_task(prompt, models.TaskCreate(prompt=prompt, address=task.address, chain_id=task.chain_id, user_id=app_user.user_id), app, app_user, tasks, background_tasks, task_id)

return task

@app_router.post("/api/v1/connect", response_model=models.AppUser)
async def connect(model: models.ConnectionCreate, authorization: Annotated[str | None, Header()] = None) -> models.AppUser:
app = authorize(authorization)
Expand Down
6 changes: 5 additions & 1 deletion autotx/utils/ethereum/lifi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,14 @@ async def get_quote_to_amount(
attempt_count += 1
await asyncio.sleep(0.5)
continue
raise e
except Exception as e:
if "No available quotes for the requested transfer" in str(e) or "Unable to find quote to match expected output" in str(e):
if attempt_count < 5:
attempt_count += 1
await asyncio.sleep(0.5)
continue
raise e

@classmethod
async def get_quote_from_amount(
Expand Down Expand Up @@ -122,9 +124,11 @@ async def get_quote_from_amount(
attempt_count += 1
await asyncio.sleep(0.5)
continue
raise e
except Exception as e:
if "No available quotes for the requested transfer" in str(e) or "Unable to find quote to match expected output" in str(e):
if attempt_count < 5:
attempt_count += 1
await asyncio.sleep(0.5)
continue
continue
raise e
9 changes: 9 additions & 0 deletions supabase/migrations/20240702110412_feedback.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
alter table "public"."tasks" add column "feedback" text;

alter table "public"."tasks" add column "previous_task_id" uuid;

alter table "public"."tasks" add constraint "public_tasks_previous_task_id_fkey" FOREIGN KEY (previous_task_id) REFERENCES tasks(id) not valid;

alter table "public"."tasks" validate constraint "public_tasks_previous_task_id_fkey";


0 comments on commit b340e60

Please sign in to comment.