-
Notifications
You must be signed in to change notification settings - Fork 444
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add drift tools to langchain (#212)
# Pull Request Description This PR is the langchain implementation of #207 ## Changes Made This PR adds the following changes: <!-- List the key changes made in this PR --> - This PR adds files that implement the drift actions in a way compatible with langchain ## Implementation Details <!-- Provide technical details about the implementation --> - Just a quick conversion of the drift actions to langchain tool classes ## Transaction executed by agent and prompt used <!-- If applicable, provide example usage, transactions, or screenshots --> Example transaction: <img width="998" alt="Screenshot 2025-01-15 at 17 43 42" src="https://github.com/user-attachments/assets/25f12c26-0f1a-470a-a566-028a54adf995" /> <img width="998" alt="Screenshot 2025-01-15 at 17 43 27" src="https://github.com/user-attachments/assets/b07c6089-f5fc-4498-9d5a-14c5698c21a9" /> <img width="998" alt="Screenshot 2025-01-15 at 17 43 02" src="https://github.com/user-attachments/assets/69067241-bb22-429b-9021-024c526ec25f" /> ## Additional Notes <!-- Any additional information that reviewers should know --> ## Checklist - [x] I have tested these changes locally - [ ] I have updated the documentation - [ ] I have added a transaction link - [x] I have added the prompt used to test it
- Loading branch information
Showing
17 changed files
with
621 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import { Tool } from "langchain/tools"; | ||
import { SolanaAgentKit } from "../../agent"; | ||
|
||
export class SolanaCreateDriftUserAccountTool extends Tool { | ||
name = "create_drift_user_account"; | ||
description = `Create a new user account with a deposit on Drift protocol. | ||
Inputs (JSON string): | ||
- amount: number, amount of the token to deposit (required) | ||
- symbol: string, symbol of the token to deposit (required)`; | ||
|
||
constructor(private solanaKit: SolanaAgentKit) { | ||
super(); | ||
} | ||
|
||
protected async _call(input: string): Promise<string> { | ||
try { | ||
const parsedInput = JSON.parse(input); | ||
const res = await this.solanaKit.createDriftUserAccount( | ||
parsedInput.amount, | ||
parsedInput.symbol, | ||
); | ||
|
||
return JSON.stringify({ | ||
status: "success", | ||
message: `User account created with ${parsedInput.amount} ${parsedInput.symbol} successfully deposited`, | ||
account: res.account, | ||
signature: res.txSignature, | ||
}); | ||
} catch (error: any) { | ||
return JSON.stringify({ | ||
status: "error", | ||
message: error.message, | ||
code: error.code || "CREATE_DRIFT_USER_ACCOUNT_ERROR", | ||
}); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import { Tool } from "langchain/tools"; | ||
import { SolanaAgentKit } from "../../agent"; | ||
|
||
export class SolanaCreateDriftVaultTool extends Tool { | ||
name = "create_drift_vault"; | ||
description = `Create a new drift vault delegating the agents address as the owner. | ||
Inputs (JSON string): | ||
- name: string, unique vault name (min 5 chars) | ||
- marketName: string, market name in TOKEN-SPOT format | ||
- redeemPeriod: number, days to wait before funds can be redeemed (min 1) | ||
- maxTokens: number, maximum tokens vault can accommodate (min 100) | ||
- minDepositAmount: number, minimum deposit amount | ||
- managementFee: number, fee percentage for managing funds (max 20) | ||
- profitShare: number, profit sharing percentage (max 90, default 5) | ||
- hurdleRate: number, optional hurdle rate | ||
- permissioned: boolean, whether vault has whitelist`; | ||
|
||
constructor(private solanaKit: SolanaAgentKit) { | ||
super(); | ||
} | ||
|
||
protected async _call(input: string): Promise<string> { | ||
try { | ||
const parsedInput = JSON.parse(input); | ||
const tx = await this.solanaKit.createDriftVault(parsedInput); | ||
|
||
return JSON.stringify({ | ||
status: "success", | ||
message: "Drift vault created successfully", | ||
vaultName: parsedInput.name, | ||
signature: tx, | ||
}); | ||
} catch (error: any) { | ||
return JSON.stringify({ | ||
status: "error", | ||
message: error.message, | ||
code: error.code || "CREATE_DRIFT_VAULT_ERROR", | ||
}); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import { Tool } from "langchain/tools"; | ||
import { SolanaAgentKit } from "../../agent"; | ||
|
||
export class SolanaDepositIntoDriftVaultTool extends Tool { | ||
name = "deposit_into_drift_vault"; | ||
description = `Deposit funds into an existing drift vault. | ||
Inputs (JSON string): | ||
- vaultAddress: string, address of the vault (required) | ||
- amount: number, amount to deposit (required)`; | ||
|
||
constructor(private solanaKit: SolanaAgentKit) { | ||
super(); | ||
} | ||
|
||
protected async _call(input: string): Promise<string> { | ||
try { | ||
const parsedInput = JSON.parse(input); | ||
const tx = await this.solanaKit.depositIntoDriftVault( | ||
parsedInput.amount, | ||
parsedInput.vaultAddress, | ||
); | ||
|
||
return JSON.stringify({ | ||
status: "success", | ||
message: "Funds deposited successfully", | ||
signature: tx, | ||
}); | ||
} catch (error: any) { | ||
return JSON.stringify({ | ||
status: "error", | ||
message: error.message, | ||
code: error.code || "DEPOSIT_INTO_VAULT_ERROR", | ||
}); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import { Tool } from "langchain/tools"; | ||
import { SolanaAgentKit } from "../../agent"; | ||
|
||
export class SolanaDepositToDriftUserAccountTool extends Tool { | ||
name = "deposit_to_drift_user_account"; | ||
description = `Deposit funds into your drift user account. | ||
Inputs (JSON string): | ||
- amount: number, amount to deposit (required) | ||
- symbol: string, token symbol (required) | ||
- repay: boolean, whether to repay borrowed funds (optional, default: false)`; | ||
|
||
constructor(private solanaKit: SolanaAgentKit) { | ||
super(); | ||
} | ||
|
||
protected async _call(input: string): Promise<string> { | ||
try { | ||
const parsedInput = JSON.parse(input); | ||
const tx = await this.solanaKit.depositToDriftUserAccount( | ||
parsedInput.amount, | ||
parsedInput.symbol, | ||
parsedInput.repay, | ||
); | ||
|
||
return JSON.stringify({ | ||
status: "success", | ||
message: "Funds deposited successfully", | ||
signature: tx, | ||
}); | ||
} catch (error: any) { | ||
return JSON.stringify({ | ||
status: "error", | ||
message: error.message, | ||
code: error.code || "DEPOSIT_TO_DRIFT_ACCOUNT_ERROR", | ||
}); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import { Tool } from "langchain/tools"; | ||
import { SolanaAgentKit } from "../../agent"; | ||
|
||
export class SolanaDeriveVaultAddressTool extends Tool { | ||
name = "derive_drift_vault_address"; | ||
description = `Derive a drift vault address from the vault's name. | ||
Inputs (JSON string): | ||
- name: string, name of the vault to derive the address of (required)`; | ||
|
||
constructor(private solanaKit: SolanaAgentKit) { | ||
super(); | ||
} | ||
|
||
protected async _call(input: string): Promise<string> { | ||
try { | ||
const address = await this.solanaKit.deriveDriftVaultAddress(input); | ||
|
||
return JSON.stringify({ | ||
status: "success", | ||
message: "Vault address derived successfully", | ||
address, | ||
}); | ||
} catch (error: any) { | ||
return JSON.stringify({ | ||
status: "error", | ||
message: error.message, | ||
code: error.code || "DERIVE_VAULT_ADDRESS_ERROR", | ||
}); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import { Tool } from "langchain/tools"; | ||
import { SolanaAgentKit } from "../../agent"; | ||
|
||
export class SolanaCheckDriftAccountTool extends Tool { | ||
name = "does_user_have_drift_account"; | ||
description = `Check if a user has a Drift account. | ||
Inputs: No inputs required - checks the current user's account`; | ||
|
||
constructor(private solanaKit: SolanaAgentKit) { | ||
super(); | ||
} | ||
|
||
protected async _call(_input: string): Promise<string> { | ||
try { | ||
const res = await this.solanaKit.doesUserHaveDriftAccount(); | ||
|
||
if (!res.hasAccount) { | ||
return JSON.stringify({ | ||
status: "error", | ||
message: "You do not have a Drift account", | ||
}); | ||
} | ||
|
||
return JSON.stringify({ | ||
status: "success", | ||
message: "Nice! You have a Drift account", | ||
account: res.account, | ||
}); | ||
} catch (error: any) { | ||
return JSON.stringify({ | ||
status: "error", | ||
message: error.message, | ||
code: error.code || "CHECK_DRIFT_ACCOUNT_ERROR", | ||
}); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import { Tool } from "langchain/tools"; | ||
import { SolanaAgentKit } from "../../agent"; | ||
|
||
export class SolanaDriftUserAccountInfoTool extends Tool { | ||
name = "drift_user_account_info"; | ||
description = `Get information about your drift account. | ||
Inputs: No inputs required - retrieves current user's account info`; | ||
|
||
constructor(private solanaKit: SolanaAgentKit) { | ||
super(); | ||
} | ||
|
||
protected async _call(_input: string): Promise<string> { | ||
try { | ||
const accountInfo = await this.solanaKit.driftUserAccountInfo(); | ||
return JSON.stringify({ | ||
status: "success", | ||
data: accountInfo, | ||
}); | ||
} catch (error: any) { | ||
return JSON.stringify({ | ||
status: "error", | ||
message: error.message, | ||
code: error.code || "DRIFT_ACCOUNT_INFO_ERROR", | ||
}); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
export * from "./create_user_account"; | ||
export * from "./create_vault"; | ||
export * from "./deposit_into_vault"; | ||
export * from "./deposit_to_user_account"; | ||
export * from "./derive_vault_address"; | ||
export * from "./does_user_have_drift_account"; | ||
export * from "./drift_user_account_info"; | ||
export * from "./request_withdrawal"; | ||
export * from "./trade_delegated_vault"; | ||
export * from "./trade_perp_account"; | ||
export * from "./update_drift_vault_delegate"; | ||
export * from "./update_vault"; | ||
export * from "./vault_info"; | ||
export * from "./withdraw_from_account"; | ||
export * from "./withdraw_from_vault"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import { Tool } from "langchain/tools"; | ||
import { SolanaAgentKit } from "../../agent"; | ||
|
||
export class SolanaRequestDriftWithdrawalTool extends Tool { | ||
name = "request_withdrawal_from_drift_vault"; | ||
description = `Request a withdrawal from an existing drift vault. | ||
Inputs (JSON string): | ||
- vaultAddress: string, vault address (required) | ||
- amount: number, amount of shares to withdraw (required)`; | ||
|
||
constructor(private solanaKit: SolanaAgentKit) { | ||
super(); | ||
} | ||
|
||
protected async _call(input: string): Promise<string> { | ||
try { | ||
const parsedInput = JSON.parse(input); | ||
const tx = await this.solanaKit.requestWithdrawalFromDriftVault( | ||
parsedInput.amount, | ||
parsedInput.vaultAddress, | ||
); | ||
|
||
return JSON.stringify({ | ||
status: "success", | ||
message: "Withdrawal request successful", | ||
signature: tx, | ||
}); | ||
} catch (error: any) { | ||
return JSON.stringify({ | ||
status: "error", | ||
message: error.message, | ||
code: error.code || "REQUEST_DRIFT_WITHDRAWAL_ERROR", | ||
}); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import { Tool } from "langchain/tools"; | ||
import { SolanaAgentKit } from "../../agent"; | ||
|
||
export class SolanaTradeDelegatedDriftVaultTool extends Tool { | ||
name = "trade_delegated_drift_vault"; | ||
description = `Carry out trades in a Drift vault. | ||
Inputs (JSON string): | ||
- vaultAddress: string, address of the Drift vault | ||
- amount: number, amount to trade | ||
- symbol: string, symbol of the token to trade | ||
- action: "long" | "short", trade direction | ||
- type: "market" | "limit", order type | ||
- price: number, optional limit price`; | ||
|
||
constructor(private solanaKit: SolanaAgentKit) { | ||
super(); | ||
} | ||
|
||
protected async _call(input: string): Promise<string> { | ||
try { | ||
const parsedInput = JSON.parse(input); | ||
const tx = await this.solanaKit.tradeUsingDelegatedDriftVault( | ||
parsedInput.vaultAddress, | ||
parsedInput.amount, | ||
parsedInput.symbol, | ||
parsedInput.action, | ||
parsedInput.type, | ||
parsedInput.price, | ||
); | ||
|
||
return JSON.stringify({ | ||
status: "success", | ||
message: | ||
parsedInput.type === "limit" | ||
? "Order placed successfully" | ||
: "Trade successful", | ||
transactionId: tx, | ||
...parsedInput, | ||
}); | ||
} catch (error: any) { | ||
return JSON.stringify({ | ||
status: "error", | ||
message: error.message, | ||
code: error.code || "TRADE_DRIFT_VAULT_ERROR", | ||
}); | ||
} | ||
} | ||
} |
Oops, something went wrong.