Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/analysis_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def on_new_task(job: ACPJob):
AnalysisResult(
vault_info=VaultInfo(
chain=analysis_request.chain,
vault_address=price_history.vault_address,
vault_name=price_history.vault_name,
address=price_history.address,
name=price_history.name,
protocol="Morpho Meta Vault",
last_updated_timestamp=int(time.time()),
),
Expand Down
28 changes: 14 additions & 14 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test_analyze_yield_with_daily_share_price_success(self) -> None:

# Create SharePriceHistory object
share_price_history = SharePriceHistory(
vault_name="Test Vault",
vault_address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
price_history=list(zip(timestamps, prices)),
)

Expand All @@ -42,8 +42,8 @@ def test_analyze_yield_with_daily_share_price_insufficient_data(self) -> None:
"""Test analysis with insufficient data."""
# Test with empty price history
empty_history = SharePriceHistory(
vault_name="Test Vault",
vault_address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
price_history=[],
)
with pytest.raises(
Expand All @@ -53,8 +53,8 @@ def test_analyze_yield_with_daily_share_price_insufficient_data(self) -> None:

# Test with single price
single_price_history = SharePriceHistory(
vault_name="Test Vault",
vault_address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
price_history=[(1640995200, 1.0)],
)
with pytest.raises(
Expand All @@ -69,8 +69,8 @@ def test_analyze_yield_with_daily_share_price_decreasing_trend(self) -> None:
# No reverse here

share_price_history = SharePriceHistory(
vault_name="Test Vault",
vault_address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
price_history=list(zip(timestamps, prices)),
)

Expand All @@ -96,8 +96,8 @@ def test_analyze_yield_with_daily_share_price_volatile_data(self) -> None:
# No reverse here

share_price_history = SharePriceHistory(
vault_name="Test Vault",
vault_address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
price_history=list(zip(timestamps, prices)),
)

Expand All @@ -114,8 +114,8 @@ def test_analyze_yield_with_daily_share_price_custom_risk_free_rate(self) -> Non
# No reverse here

share_price_history = SharePriceHistory(
vault_name="Test Vault",
vault_address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
price_history=list(zip(timestamps, prices)),
)

Expand All @@ -137,8 +137,8 @@ def test_analyze_yield_with_daily_share_price_90_days_data(self) -> None:
# No reverse here

share_price_history = SharePriceHistory(
vault_name="Test Vault",
vault_address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
price_history=list(zip(timestamps, prices)),
)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def test_format_price_history_response_valid_data(self) -> None:

assert len(result) == 1
assert isinstance(result[0], SharePriceHistory)
assert result[0].vault_name == "Test Vault"
assert result[0].vault_address == "0x1234567890abcdef1234567890abcdef12345678"
assert result[0].name == "Test Vault"
assert result[0].address == "0x1234567890abcdef1234567890abcdef12345678"
assert len(result[0].price_history) == 2

def test_format_price_history_response_no_data(self) -> None:
Expand Down Expand Up @@ -104,6 +104,6 @@ def test_get_daily_share_price_history_from_subgraph(

assert len(result) == 1
assert isinstance(result[0], SharePriceHistory)
assert result[0].vault_name == "Test Vault"
assert result[0].vault_address == "0x1234567890abcdef1234567890abcdef12345678"
assert result[0].name == "Test Vault"
assert result[0].address == "0x1234567890abcdef1234567890abcdef12345678"
mock_send_query.assert_called_once()
57 changes: 31 additions & 26 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AnalysisResult,
AuditStatus,
Chain,
Contract,
PerformanceAnalysis,
RegistrationRequest,
RegistrationResponse,
Expand Down Expand Up @@ -74,8 +75,8 @@ def test_vault_info_creation(self) -> None:
"""Test VaultInfo model creation."""
vault_info = VaultInfo(
chain=Chain.BASE,
vault_address="0x1234567890abcdef1234567890abcdef12345678",
vault_name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
protocol="Test",
max_deposit_amount=1000000.0,
last_updated_timestamp=1640995200,
Expand All @@ -85,8 +86,8 @@ def test_vault_info_creation(self) -> None:
)

assert vault_info.chain == Chain.BASE
assert vault_info.vault_address == "0x1234567890abcdef1234567890abcdef12345678"
assert vault_info.vault_name == "Test Vault"
assert vault_info.address == "0x1234567890abcdef1234567890abcdef12345678"
assert vault_info.name == "Test Vault"
assert vault_info.protocol == "Test"
assert vault_info.max_deposit_amount == 1000000.0
assert vault_info.entry_cost_bps == 0.0 # Default value
Expand All @@ -97,8 +98,8 @@ def test_vault_info_serialization(self) -> None:
"""Test VaultInfo model serialization."""
vault_info = VaultInfo(
chain=Chain.BASE,
vault_address="0x1234567890abcdef1234567890abcdef12345678",
vault_name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
protocol="Test",
max_deposit_amount=1000000.0,
last_updated_timestamp=1640995200,
Expand All @@ -108,8 +109,8 @@ def test_vault_info_serialization(self) -> None:
)
obj = vault_info.model_dump(mode="json")
assert obj["chain"] == "base"
assert obj["vault_address"] == "0x1234567890abcdef1234567890abcdef12345678"
assert obj["vault_name"] == "Test Vault"
assert obj["address"] == "0x1234567890abcdef1234567890abcdef12345678"
assert obj["name"] == "Test Vault"
assert obj["protocol"] == "Test"
assert obj["max_deposit_amount"] == 1000000.0
assert obj["last_updated_timestamp"] == 1640995200
Expand Down Expand Up @@ -140,8 +141,8 @@ def test_vault_performance_analysis_creation(self) -> None:
"""Test AnalysisResult model creation."""
vault_info = VaultInfo(
chain=Chain.BASE,
vault_address="0x1234567890abcdef1234567890abcdef12345678",
vault_name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
protocol="Test",
max_deposit_amount=1000000.0,
last_updated_timestamp=1640995200,
Expand Down Expand Up @@ -170,8 +171,8 @@ def test_analysis_response_creation(self) -> None:
"""Test AnalysisResponse model creation."""
vault_info = VaultInfo(
chain=Chain.BASE,
vault_address="0x1234567890abcdef1234567890abcdef12345678",
vault_name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
protocol="Test",
max_deposit_amount=1000000.0,
last_updated_timestamp=1640995200,
Expand Down Expand Up @@ -201,41 +202,45 @@ def test_analysis_response_creation(self) -> None:
def test_share_price_history_creation(self) -> None:
"""Test SharePriceHistory model creation."""
price_history = SharePriceHistory(
vault_name="Test Vault",
vault_address="0x1234567890abcdef1234567890abcdef12345678",
name="Test Vault",
address="0x1234567890abcdef1234567890abcdef12345678",
price_history=[(1640995200, 1.05), (1640908800, 1.04)],
)

assert price_history.vault_name == "Test Vault"
assert (
price_history.vault_address == "0x1234567890abcdef1234567890abcdef12345678"
)
assert price_history.name == "Test Vault"
assert price_history.address == "0x1234567890abcdef1234567890abcdef12345678"
assert len(price_history.price_history) == 2
assert price_history.price_history[0] == (1640995200, 1.05)

def test_registration_request_creation(self) -> None:
"""Test RegistrationRequest model creation."""
request = RegistrationRequest(
chain=Chain.BASE,
vault_address="0x1234567890abcdef1234567890abcdef12345678",
vault=Contract(
chain=Chain.BASE,
address="0x1234567890abcdef1234567890abcdef12345678",
)
)

assert request.chain == Chain.BASE
assert request.vault_address == "0x1234567890abcdef1234567890abcdef12345678"
assert request.vault.chain == Chain.BASE
assert request.vault.address == "0x1234567890abcdef1234567890abcdef12345678"

def test_registration_request_validation(self) -> None:
"""Test RegistrationRequest model validation."""
request = {
"chain": "base",
"vault_address": "0x1234567890abcdef1234567890abcdef12345678",
"vault": {
"chain": "base",
"address": "0x1234567890abcdef1234567890abcdef12345678",
}
}
RegistrationRequest.model_validate(request)

def test_registration_request_json_validation(self) -> None:
"""Test RegistrationRequest model JSON validation."""
request = {
"chain": "base",
"vault_address": "0x1234567890abcdef1234567890abcdef12345678",
"vault": {
"chain": "base",
"address": "0x1234567890abcdef1234567890abcdef12345678",
}
}
RegistrationRequest.model_validate_json(json.dumps(request))

Expand Down
20 changes: 10 additions & 10 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from yield_analysis_sdk.exceptions import ValidationError
from yield_analysis_sdk.type import Chain
from yield_analysis_sdk.validators import (
AddressValidatorMixin,
ChainMixin,
UnderlyingTokenValidatorMixin,
VaultAddressValidatorMixin,
normalize_address,
validate_address_value,
validate_chain_value,
Expand Down Expand Up @@ -116,23 +116,23 @@ def test_validate_address_value(self) -> None:
with pytest.raises(ValidationError, match="Invalid address format"):
validate_address_value("0x1234567890abcdef")

def test_vault_address_validator_mixin(self) -> None:
"""Test the VaultAddressValidatorMixin."""
def test_address_validator_mixin(self) -> None:
"""Test the AddressValidatorMixin."""

class TestModel(VaultAddressValidatorMixin, BaseModel):
vault_address: str
class TestModel(AddressValidatorMixin, BaseModel):
address: str

# Test with valid address
model = TestModel(vault_address="0x1234567890abcdef1234567890abcdef12345678")
assert model.vault_address == "0x1234567890abcdef1234567890abcdef12345678"
model = TestModel(address="0x1234567890abcdef1234567890abcdef12345678")
assert model.address == "0x1234567890abcdef1234567890abcdef12345678"

# Test with address without 0x prefix
model = TestModel(vault_address="1234567890abcdef1234567890abcdef12345678")
assert model.vault_address == "0x1234567890abcdef1234567890abcdef12345678"
model = TestModel(address="1234567890abcdef1234567890abcdef12345678")
assert model.address == "0x1234567890abcdef1234567890abcdef12345678"

# Test with invalid address
with pytest.raises(ValidationError, match="Invalid address format"):
TestModel(vault_address="0x1234567890abcdef")
TestModel(address="0x1234567890abcdef")

def test_token_address_validator_mixin(self) -> None:
"""Test the UnderlyingTokenValidatorMixin."""
Expand Down
8 changes: 4 additions & 4 deletions yield_analysis_sdk/subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def _format_price_history_response(

if vault_address not in history_by_vault:
history_by_vault[vault_address] = {
"vault_name": vault_name,
"vault_address": vault_address,
"name": vault_name,
"address": vault_address,
"price_history": [],
}

Expand All @@ -103,8 +103,8 @@ def _format_price_history_response(
result = []
for vault_data in history_by_vault.values():
share_price_history = SharePriceHistory(
vault_name=vault_data["vault_name"],
vault_address=vault_data["vault_address"],
name=vault_data["name"],
address=vault_data["address"],
price_history=vault_data["price_history"],
)
result.append(share_price_history)
Expand Down
26 changes: 16 additions & 10 deletions yield_analysis_sdk/type.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, HttpUrl

from .validators import (
AddressValidatorMixin,
ChainMixin,
UnderlyingTokenValidatorMixin,
VaultAddressValidatorMixin,
)


Expand Down Expand Up @@ -75,9 +75,15 @@ class AuditStatus(Enum):
UNKNOWN = "unknown"


class RegistrationRequest(VaultAddressValidatorMixin, ChainMixin, BaseModel):
class Contract(AddressValidatorMixin, ChainMixin, BaseModel):
address: str
chain: Chain
vault_address: str


class RegistrationRequest(BaseModel):
vault: Contract
contracts: Optional[List[Contract]] = None
github_repo_url: Optional[HttpUrl] = None


class RegistrationResponse(BaseModel):
Expand All @@ -95,11 +101,11 @@ class AnalysisRequest(BaseModel):
strategies: List[Strategy]


class VaultInfo(VaultAddressValidatorMixin, ChainMixin, BaseModel):
class VaultInfo(AddressValidatorMixin, ChainMixin, BaseModel):
# Basic Vault Information
chain: Chain
vault_address: str
vault_name: str
address: str
name: str
protocol: str = Field(
..., description="The protocol/platform this vault belongs to"
)
Expand Down Expand Up @@ -160,7 +166,7 @@ class AnalysisResponse(BaseModel):
analyses: List[AnalysisResult] = Field(..., description="List of vault analyses")


class SharePriceHistory(VaultAddressValidatorMixin, BaseModel):
vault_name: str
vault_address: str
class SharePriceHistory(AddressValidatorMixin, BaseModel):
name: str
address: str
price_history: List[Tuple[int, float]]
12 changes: 6 additions & 6 deletions yield_analysis_sdk/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ def validate_chain(cls, v: Any) -> "Chain":
return Chain.OTHER


class VaultAddressValidatorMixin:
"""Mixin class that provides vault address validation functionality."""
class AddressValidatorMixin:
"""Mixin class that provides address validation functionality."""

@field_validator("vault_address", mode="before")
@field_validator("address", mode="before")
@classmethod
def validate_vault_address(cls, v: Any) -> str:
"""Validate vault address format and normalize it."""
def validate_address(cls, v: Any) -> str:
"""Validate address format and normalize it."""
if isinstance(v, str):
return normalize_address(v)
elif v is None:
raise ValidationError("Vault address cannot be None")
raise ValidationError("Address cannot be None")
else:
return str(v)

Expand Down