diff --git a/app/api/common/models.py b/app/api/common/models.py index e43846a..ffe36e1 100644 --- a/app/api/common/models.py +++ b/app/api/common/models.py @@ -223,7 +223,11 @@ def get_by_near_intents_id(cls, near_intents_id: str): def to_spec(self) -> ChainSpec: return ChainSpec(coin=self.coin, chain_id=self.chain_id) - def __eq__(self, other: Chain) -> bool: + def __eq__(self, other) -> bool: + if other is None: + return False + if not isinstance(other, Chain): + return False return self.coin == other.coin and self.chain_id == other.chain_id def __str__(self): diff --git a/app/api/common/test_models.py b/app/api/common/test_models.py new file mode 100644 index 0000000..82ce736 --- /dev/null +++ b/app/api/common/test_models.py @@ -0,0 +1,67 @@ +import pytest + +from app.api.common.models import Chain, Coin + + +@pytest.mark.parametrize( + "comparison_chain", + [Chain.BITCOIN, Chain.ETHEREUM, Chain.SOLANA, Chain.CARDANO], +) +def test_chain_get_returns_none_comparison(comparison_chain): + chain = Chain.get("INVALID_COIN", "invalid_chain_id") + assert chain is None + assert (chain == comparison_chain) is False + + +@pytest.mark.parametrize( + "chain,other_object", + [ + (Chain.BITCOIN, "bitcoin"), + (Chain.ETHEREUM, 1), + (Chain.SOLANA, Coin.BTC), + (Chain.BITCOIN, {"coin": "BTC"}), + (Chain.ETHEREUM, []), + (Chain.SOLANA, None), + ], +) +def test_chain_equals_non_chain_object_returns_false(chain, other_object): + assert (chain == other_object) is False + + +@pytest.mark.parametrize( + "coin,chain_id,expected_chain,different_chain", + [ + ("BTC", "bitcoin_mainnet", Chain.BITCOIN, Chain.ETHEREUM), + ("ETH", "0x1", Chain.ETHEREUM, Chain.SOLANA), + ("SOL", "0x65", Chain.SOLANA, Chain.BITCOIN), + ("ADA", "cardano_mainnet", Chain.CARDANO, Chain.FILECOIN), + ], +) +def test_chain_get_valid_chain_comparison( + coin, chain_id, expected_chain, different_chain +): + chain = Chain.get(coin, chain_id) + assert chain is not None + assert chain == expected_chain + # Verify it's not equal to a different chain + assert (chain == different_chain) is False + assert chain != different_chain + + +def test_chain_equality_with_case_insensitive_chain_id(): + chain1 = Chain.get("BTC", "BITCOIN_MAINNET") + chain2 = Chain.get("btc", "bitcoin_mainnet") + assert chain1 == chain2 + assert chain1 == Chain.BITCOIN + assert chain2 == Chain.BITCOIN + + +@pytest.mark.parametrize("chain", list(Chain)) +def test_all_chains(chain): + assert (chain is None) is False + none = None + assert (none == chain) is False + + assert chain == chain + assert chain is chain + assert chain is not None diff --git a/app/api/oauth/zebpay.py b/app/api/oauth/zebpay.py index 73af213..327449a 100644 --- a/app/api/oauth/zebpay.py +++ b/app/api/oauth/zebpay.py @@ -1,8 +1,9 @@ +from urllib.parse import parse_qs, urlencode + import httpx from fastapi import APIRouter, HTTPException, Request from fastapi.responses import JSONResponse, RedirectResponse from starlette.datastructures import URL -from urllib.parse import parse_qs, urlencode from app.api.common.models import Tags from app.api.oauth.models import Environment diff --git a/app/api/swap/providers/jupiter/client.py b/app/api/swap/providers/jupiter/client.py index 3b9118e..8730ec6 100644 --- a/app/api/swap/providers/jupiter/client.py +++ b/app/api/swap/providers/jupiter/client.py @@ -158,7 +158,9 @@ async def _get_order( self._handle_error_response(response) # _handle_error_response is expected to always raise SwapError, # but add an explicit raise to make the control flow clear. - raise SwapError(message="Unhandled Jupiter API error", kind=SwapErrorKind.UNKNOWN) + raise SwapError( + message="Unhandled Jupiter API error", kind=SwapErrorKind.UNKNOWN + ) async def get_indicative_routes( self, diff --git a/pyproject.toml b/pyproject.toml index 1bc8f26..1173677 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "gate3" -version = "0.11.0" +version = "0.11.1" description = "Gate API for web3 applications at Brave" authors = [ ]