Skip to content

Commit

Permalink
Merge pull request #157 from crestalnetwork/feat/enso-skills-async
Browse files Browse the repository at this point in the history
Feat: Enso skills async
  • Loading branch information
hyacinthus authored Feb 3, 2025
2 parents 3feb535 + c9d3d9d commit 97364ab
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 74 deletions.
1 change: 1 addition & 0 deletions app/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ async def initialize_agent(aid):
agentkit.wallet if agentkit else None,
config.rpc_base_mainnet,
skill_store,
agent_store,
aid,
)
tools.append(s)
Expand Down
36 changes: 22 additions & 14 deletions skills/enso/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from skills.enso.route import EnsoRouteShortcut
from skills.enso.tokens import EnsoGetTokens
from skills.enso.wallet import (
EnsoBroadcastWalletApprove,
EnsoGetWalletApprovals,
EnsoGetWalletBalances,
EnsoWalletApprove,
)


Expand All @@ -20,8 +20,9 @@ def get_enso_skill(
api_token: str,
main_tokens: list[str],
wallet: Wallet,
rpc_nodes: dict[str, str],
store: SkillStoreABC,
rpc_node: str,
skill_store: SkillStoreABC,
agent_store: SkillStoreABC,
agent_id: str,
) -> EnsoBaseTool:
if not api_token:
Expand All @@ -31,23 +32,26 @@ def get_enso_skill(
return EnsoGetNetworks(
api_token=api_token,
main_tokens=main_tokens,
store=store,
skill_store=skill_store,
agent_store=agent_store,
agent_id=agent_id,
)

if name == "get_tokens":
return EnsoGetTokens(
api_token=api_token,
main_tokens=main_tokens,
store=store,
skill_store=skill_store,
agent_store=agent_store,
agent_id=agent_id,
)

if name == "get_prices":
return EnsoGetPrices(
api_token=api_token,
main_tokens=main_tokens,
store=store,
skill_store=skill_store,
agent_store=agent_store,
agent_id=agent_id,
)

Expand All @@ -58,7 +62,8 @@ def get_enso_skill(
api_token=api_token,
main_tokens=main_tokens,
wallet=wallet,
store=store,
skill_store=skill_store,
agent_store=agent_store,
agent_id=agent_id,
)

Expand All @@ -69,31 +74,34 @@ def get_enso_skill(
api_token=api_token,
main_tokens=main_tokens,
wallet=wallet,
store=store,
skill_store=skill_store,
agent_store=agent_store,
agent_id=agent_id,
)

if name == "wallet_approve":
if not wallet:
raise ValueError("Wallet is empty")
return EnsoBroadcastWalletApprove(
return EnsoWalletApprove(
api_token=api_token,
main_tokens=main_tokens,
wallet=wallet,
rpc_nodes=rpc_nodes,
store=store,
rpc_node=rpc_node,
skill_store=skill_store,
agent_store=agent_store,
agent_id=agent_id,
)

if name == "broadcast_route_shortcut":
if name == "route_shortcut":
if not wallet:
raise ValueError("Wallet is empty")
return EnsoRouteShortcut(
api_token=api_token,
main_tokens=main_tokens,
wallet=wallet,
rpc_nodes=rpc_nodes,
store=store,
rpc_node=rpc_node,
skill_store=skill_store,
agent_store=agent_store,
agent_id=agent_id,
)

Expand Down
12 changes: 8 additions & 4 deletions skills/enso/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from cdp import Wallet
from pydantic import BaseModel, Field

from abstracts.agent import AgentStoreABC
from abstracts.skill import IntentKitSkill, SkillStoreABC

base_url = "https://api.enso.finance"
Expand All @@ -15,11 +16,14 @@ class EnsoBaseTool(IntentKitSkill):
api_token: str = Field(description="API token")
main_tokens: list[str] = Field(description="Main supported tokens")
wallet: Wallet | None = Field(None, description="The wallet of the agent")
rpc_nodes: dict[str, str] | None = Field(
None, description="RPC nodes for different networks"
)
rpc_node: str | None = Field(None, description="RPC nodes for different networks")
name: str = Field(description="The name of the tool")
description: str = Field(description="A description of what the tool does")
args_schema: Type[BaseModel]
agent_id: str = Field(description="The ID of the agent")
store: SkillStoreABC = Field(description="The skill store for persisting data")
agent_store: AgentStoreABC = Field(
description="The agent store for persisting data"
)
skill_store: SkillStoreABC = Field(
description="The skill store for persisting data"
)
17 changes: 14 additions & 3 deletions skills/enso/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ class EnsoGetNetworks(EnsoBaseTool):
args_schema: Type[BaseModel] = EnsoGetNetworksInput

def _run(self) -> EnsoGetNetworksOutput:
"""Run the tool to get all supported networks.
Returns:
EnsoGetNetworksOutput: A structured output containing the result of the networks.
Raises:
Exception: If there's an error accessing the Enso API.
"""
raise NotImplementedError("Use _arun instead")

async def _arun(self) -> EnsoGetNetworksOutput:
"""
Function to request the list of supported networks and their chain id and name.
Expand All @@ -58,10 +69,10 @@ def _run(self) -> EnsoGetNetworksOutput:
"Authorization": f"Bearer {self.api_token}",
}

with httpx.Client() as client:
async with httpx.AsyncClient() as client:
try:
# Send the GET request
response = client.get(url, headers=headers)
response = await client.get(url, headers=headers)
response.raise_for_status()

# Parse the response JSON into the NetworkResponse model
Expand All @@ -74,7 +85,7 @@ def _run(self) -> EnsoGetNetworksOutput:
networks.append(network)
networks_memory[network.id] = network.model_dump(exclude_none=True)

self.store.save_agent_skill_data(
await self.skill_store.save_agent_skill_data(
self.agent_id,
"enso_get_networks",
"networks",
Expand Down
19 changes: 17 additions & 2 deletions skills/enso/prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ def _run(
self,
chainId: int = default_chain_id,
address: str = "0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
) -> EnsoGetPricesOutput:
"""Run the tool to get the token price from the API.
Returns:
EnsoGetPricesOutput: A structured output containing the result of token prices.
Raises:
Exception: If there's an error accessing the Enso API.
"""
raise NotImplementedError("Use _arun instead")

async def _arun(
self,
chainId: int = default_chain_id,
address: str = "0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
) -> EnsoGetPricesOutput:
"""
Asynchronous function to request the token price from the API.
Expand All @@ -62,9 +77,9 @@ def _run(
"Authorization": f"Bearer {self.api_token}",
}

with httpx.Client() as client:
async with httpx.AsyncClient() as client:
try:
response = client.get(url, headers=headers)
response = await client.get(url, headers=headers)
response.raise_for_status()
json_dict = response.json()

Expand Down
58 changes: 36 additions & 22 deletions skills/enso/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ class EnsoRouteShortcut(EnsoBaseTool):
"""

name: str = "enso_route_shortcut"
description: str = "This tool is used specifically for broadcasting a route transaction calldata to the network. It should only be used when the user explicitly requests to broadcast a route transaction with routeId."
description: str = (
"This tool is used specifically for broadcasting a route transaction calldata to the network. It should only be used when the user explicitly requests to broadcast a route transaction with routeId."
)
args_schema: Type[BaseModel] = EnsoRouteShortcutInput

def _run(
Expand All @@ -170,6 +172,25 @@ def _run(
"""
Run the tool to get swap route information.
Returns:
EnsoRouteShortcutOutput: The response containing route shortcut information.
Raises:
Exception: If there's an error accessing the Enso API.
"""
raise NotImplementedError("Use _arun instead")

async def _arun(
self,
amountIn: list[int],
tokenIn: list[str],
tokenOut: list[str],
chainId: int = default_chain_id,
broadcast_requested: bool | None = False,
) -> EnsoRouteShortcutOutput:
"""
Run the tool to get swap route information.
Args:
amountIn (list[int]): Amount of tokenIn to swap in wei, you should multiply user's requested value by token decimals.
tokenIn (list[str]): Ethereum address of the token to swap or enter into a position from (For ETH, use 0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee).
Expand All @@ -181,10 +202,10 @@ def _run(
EnsoRouteShortcutOutput: The response containing route shortcut information.
"""

with httpx.Client() as client:
async with httpx.AsyncClient() as client:
try:
network_name = None
networks = self.store.get_agent_skill_data(
networks = await self.skill_store.get_agent_skill_data(
self.agent_id, "enso_get_networks", "networks"
)

Expand All @@ -195,17 +216,15 @@ def _run(
else None
)
if network_name is None:
networks_list = (
EnsoGetNetworks(
api_token=self.api_token,
main_tokens=self.main_tokens,
store=self.store,
agent_id=self.agent_id,
)
.run(EnsoGetNetworksInput())
.res
)
for network in networks_list:
networks = await EnsoGetNetworks(
api_token=self.api_token,
main_tokens=self.main_tokens,
skill_store=self.skill_store,
agent_store=self.agent_store,
agent_id=self.agent_id,
).arun(EnsoGetNetworksInput())

for network in networks.res:
if network.id == chainId:
network_name = network.name

Expand All @@ -219,7 +238,7 @@ def _run(
"Authorization": f"Bearer {self.api_token}",
}

token_decimals = self.store.get_agent_skill_data(
token_decimals = await self.skill_store.get_agent_skill_data(
self.agent_id,
"enso_get_tokens",
"decimals",
Expand Down Expand Up @@ -252,7 +271,7 @@ def _run(

params["fromAddress"] = self.wallet.addresses[0].address_id

response = client.get(url, headers=headers, params=params)
response = await client.get(url, headers=headers, params=params)
response.raise_for_status() # Raise HTTPError for non-2xx responses
json_dict = response.json()

Expand All @@ -264,13 +283,8 @@ def _run(
)

if broadcast_requested:
if not self.rpc_nodes.get(str(chainId)):
raise ToolException(
f"rpc node not found for chainId: {chainId}"
)

contract = EvmContractWrapper(
self.rpc_nodes[str(chainId)], ABI_ROUTE, json_dict.get("tx")
self.rpc_node, ABI_ROUTE, json_dict.get("tx")
)

fn, fn_args = contract.fn_and_args
Expand Down
23 changes: 19 additions & 4 deletions skills/enso/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,21 @@ def _run(
self,
chainId: int = default_chain_id,
protocolSlug: str | None = None,
) -> EnsoGetTokensOutput:
"""Run the tool to get the tokens and APYs from the API.
Returns:
EnsoGetPricesOutput: A structured output containing the result of tokens and APYs.
Raises:
Exception: If there's an error accessing the Enso API.
"""
raise NotImplementedError("Use _arun instead")

async def _arun(
self,
chainId: int = default_chain_id,
protocolSlug: str | None = None,
) -> EnsoGetTokensOutput:
"""Run the tool to get Tokens and APY.
Args:
Expand All @@ -165,13 +180,13 @@ def _run(
params["page"] = 1
params["includeMetadata"] = "true"

with httpx.Client() as client:
async with httpx.AsyncClient() as client:
try:
response = client.get(url, headers=headers, params=params)
response = await client.get(url, headers=headers, params=params)
response.raise_for_status()
json_dict = response.json()

token_decimals = self.store.get_agent_skill_data(
token_decimals = await self.skill_store.get_agent_skill_data(
self.agent_id,
"enso_get_tokens",
"decimals",
Expand All @@ -196,7 +211,7 @@ def _run(
for u_token in token_response.underlyingTokens:
token_decimals[u_token.address] = u_token.decimals

self.store.save_agent_skill_data(
await self.skill_store.save_agent_skill_data(
self.agent_id,
"enso_get_tokens",
"decimals",
Expand Down
Loading

0 comments on commit 97364ab

Please sign in to comment.