Skip to content

Commit

Permalink
feat: add drift tools to langchain (#212)
Browse files Browse the repository at this point in the history
# Pull Request Description

This PR is the langchain implementation of #207 

## Changes Made
This PR adds the following changes:
<!-- List the key changes made in this PR -->
- This PR adds files that implement the drift actions in a way
compatible with langchain
  
## Implementation Details
<!-- Provide technical details about the implementation -->
- Just a quick conversion of the drift actions to langchain tool classes

## Transaction executed by agent and prompt used
<!-- If applicable, provide example usage, transactions, or screenshots
-->
Example transaction: 
<img width="998" alt="Screenshot 2025-01-15 at 17 43 42"
src="https://github.com/user-attachments/assets/25f12c26-0f1a-470a-a566-028a54adf995"
/>
<img width="998" alt="Screenshot 2025-01-15 at 17 43 27"
src="https://github.com/user-attachments/assets/b07c6089-f5fc-4498-9d5a-14c5698c21a9"
/>
<img width="998" alt="Screenshot 2025-01-15 at 17 43 02"
src="https://github.com/user-attachments/assets/69067241-bb22-429b-9021-024c526ec25f"
/>

## Additional Notes
<!-- Any additional information that reviewers should know -->

## Checklist
- [x] I have tested these changes locally
- [ ] I have updated the documentation
- [ ] I have added a transaction link
- [x] I have added the prompt used to test it
  • Loading branch information
thearyanag authored Jan 15, 2025
2 parents 3a33894 + ccbdc27 commit 975dd79
Show file tree
Hide file tree
Showing 17 changed files with 621 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/langchain/drift/create_user_account.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaCreateDriftUserAccountTool extends Tool {
name = "create_drift_user_account";
description = `Create a new user account with a deposit on Drift protocol.
Inputs (JSON string):
- amount: number, amount of the token to deposit (required)
- symbol: string, symbol of the token to deposit (required)`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const res = await this.solanaKit.createDriftUserAccount(
parsedInput.amount,
parsedInput.symbol,
);

return JSON.stringify({
status: "success",
message: `User account created with ${parsedInput.amount} ${parsedInput.symbol} successfully deposited`,
account: res.account,
signature: res.txSignature,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "CREATE_DRIFT_USER_ACCOUNT_ERROR",
});
}
}
}
42 changes: 42 additions & 0 deletions src/langchain/drift/create_vault.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaCreateDriftVaultTool extends Tool {
name = "create_drift_vault";
description = `Create a new drift vault delegating the agents address as the owner.
Inputs (JSON string):
- name: string, unique vault name (min 5 chars)
- marketName: string, market name in TOKEN-SPOT format
- redeemPeriod: number, days to wait before funds can be redeemed (min 1)
- maxTokens: number, maximum tokens vault can accommodate (min 100)
- minDepositAmount: number, minimum deposit amount
- managementFee: number, fee percentage for managing funds (max 20)
- profitShare: number, profit sharing percentage (max 90, default 5)
- hurdleRate: number, optional hurdle rate
- permissioned: boolean, whether vault has whitelist`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.createDriftVault(parsedInput);

return JSON.stringify({
status: "success",
message: "Drift vault created successfully",
vaultName: parsedInput.name,
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "CREATE_DRIFT_VAULT_ERROR",
});
}
}
}
37 changes: 37 additions & 0 deletions src/langchain/drift/deposit_into_vault.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaDepositIntoDriftVaultTool extends Tool {
name = "deposit_into_drift_vault";
description = `Deposit funds into an existing drift vault.
Inputs (JSON string):
- vaultAddress: string, address of the vault (required)
- amount: number, amount to deposit (required)`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.depositIntoDriftVault(
parsedInput.amount,
parsedInput.vaultAddress,
);

return JSON.stringify({
status: "success",
message: "Funds deposited successfully",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DEPOSIT_INTO_VAULT_ERROR",
});
}
}
}
39 changes: 39 additions & 0 deletions src/langchain/drift/deposit_to_user_account.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaDepositToDriftUserAccountTool extends Tool {
name = "deposit_to_drift_user_account";
description = `Deposit funds into your drift user account.
Inputs (JSON string):
- amount: number, amount to deposit (required)
- symbol: string, token symbol (required)
- repay: boolean, whether to repay borrowed funds (optional, default: false)`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.depositToDriftUserAccount(
parsedInput.amount,
parsedInput.symbol,
parsedInput.repay,
);

return JSON.stringify({
status: "success",
message: "Funds deposited successfully",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DEPOSIT_TO_DRIFT_ACCOUNT_ERROR",
});
}
}
}
32 changes: 32 additions & 0 deletions src/langchain/drift/derive_vault_address.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaDeriveVaultAddressTool extends Tool {
name = "derive_drift_vault_address";
description = `Derive a drift vault address from the vault's name.
Inputs (JSON string):
- name: string, name of the vault to derive the address of (required)`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const address = await this.solanaKit.deriveDriftVaultAddress(input);

return JSON.stringify({
status: "success",
message: "Vault address derived successfully",
address,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DERIVE_VAULT_ADDRESS_ERROR",
});
}
}
}
38 changes: 38 additions & 0 deletions src/langchain/drift/does_user_have_drift_account.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaCheckDriftAccountTool extends Tool {
name = "does_user_have_drift_account";
description = `Check if a user has a Drift account.
Inputs: No inputs required - checks the current user's account`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(_input: string): Promise<string> {
try {
const res = await this.solanaKit.doesUserHaveDriftAccount();

if (!res.hasAccount) {
return JSON.stringify({
status: "error",
message: "You do not have a Drift account",
});
}

return JSON.stringify({
status: "success",
message: "Nice! You have a Drift account",
account: res.account,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "CHECK_DRIFT_ACCOUNT_ERROR",
});
}
}
}
29 changes: 29 additions & 0 deletions src/langchain/drift/drift_user_account_info.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaDriftUserAccountInfoTool extends Tool {
name = "drift_user_account_info";
description = `Get information about your drift account.
Inputs: No inputs required - retrieves current user's account info`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(_input: string): Promise<string> {
try {
const accountInfo = await this.solanaKit.driftUserAccountInfo();
return JSON.stringify({
status: "success",
data: accountInfo,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DRIFT_ACCOUNT_INFO_ERROR",
});
}
}
}
15 changes: 15 additions & 0 deletions src/langchain/drift/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export * from "./create_user_account";
export * from "./create_vault";
export * from "./deposit_into_vault";
export * from "./deposit_to_user_account";
export * from "./derive_vault_address";
export * from "./does_user_have_drift_account";
export * from "./drift_user_account_info";
export * from "./request_withdrawal";
export * from "./trade_delegated_vault";
export * from "./trade_perp_account";
export * from "./update_drift_vault_delegate";
export * from "./update_vault";
export * from "./vault_info";
export * from "./withdraw_from_account";
export * from "./withdraw_from_vault";
37 changes: 37 additions & 0 deletions src/langchain/drift/request_withdrawal.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaRequestDriftWithdrawalTool extends Tool {
name = "request_withdrawal_from_drift_vault";
description = `Request a withdrawal from an existing drift vault.
Inputs (JSON string):
- vaultAddress: string, vault address (required)
- amount: number, amount of shares to withdraw (required)`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.requestWithdrawalFromDriftVault(
parsedInput.amount,
parsedInput.vaultAddress,
);

return JSON.stringify({
status: "success",
message: "Withdrawal request successful",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "REQUEST_DRIFT_WITHDRAWAL_ERROR",
});
}
}
}
49 changes: 49 additions & 0 deletions src/langchain/drift/trade_delegated_vault.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaTradeDelegatedDriftVaultTool extends Tool {
name = "trade_delegated_drift_vault";
description = `Carry out trades in a Drift vault.
Inputs (JSON string):
- vaultAddress: string, address of the Drift vault
- amount: number, amount to trade
- symbol: string, symbol of the token to trade
- action: "long" | "short", trade direction
- type: "market" | "limit", order type
- price: number, optional limit price`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.tradeUsingDelegatedDriftVault(
parsedInput.vaultAddress,
parsedInput.amount,
parsedInput.symbol,
parsedInput.action,
parsedInput.type,
parsedInput.price,
);

return JSON.stringify({
status: "success",
message:
parsedInput.type === "limit"
? "Order placed successfully"
: "Trade successful",
transactionId: tx,
...parsedInput,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "TRADE_DRIFT_VAULT_ERROR",
});
}
}
}
Loading

0 comments on commit 975dd79

Please sign in to comment.