Skip to content

Commit

Permalink
feat(api): adding aggregate model (#646)
Browse files Browse the repository at this point in the history
* feat(api): adding aggregate model

* fix: test

* feat(infra): adding SSH key to pull from infra-libs

* feat(infra): remove SSH keys, use https for infra-libs

* feat(infra): load model paths from 1P env

* fix: accidentally deleted variable

* fix: adding test step back
  • Loading branch information
lucianHymer authored Jul 29, 2024
1 parent aabd726 commit b9f24a3
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 5,564 deletions.
58 changes: 44 additions & 14 deletions api/passport/api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import asyncio
import json
from typing import Dict
from typing import Dict, List

import aiohttp
import api_logging as logging
from django.conf import settings
from django_ratelimit.exceptions import Ratelimited
from eth_utils.address import to_checksum_address
from ninja import Schema
from ninja_extra import NinjaExtraAPI
from ninja_extra.exceptions import APIException

import api_logging as logging
from registry.api.utils import aapi_key, check_rate_limit, is_valid_address
from registry.exceptions import InvalidAddressException
from scorer.settings.model_config import MODEL_AGGREGATION_KEYS

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,17 +89,11 @@ async def fetch(session, url, data):
return await response.json()


async def fetch_all(urls, address):
async def fetch_all(urls, payload):
async with aiohttp.ClientSession() as session:
tasks = []
for url in urls:
task = asyncio.ensure_future(
fetch(
session,
url,
{"address": address},
)
)
task = asyncio.ensure_future(fetch(session, url, payload))
tasks.append(task)
responses = await asyncio.gather(*tasks)
return responses
Expand Down Expand Up @@ -130,13 +126,17 @@ async def handle_get_analysis(
detail=f"Invalid model name(s): {', '.join(bad_models)}. Must be one of {', '.join(settings.MODEL_ENDPOINTS.keys())}"
)

urls = [settings.MODEL_ENDPOINTS[model] for model in models]

# The cache historically uses checksummed addresses, need to do this for consistency
checksummed_address = to_checksum_address(address)

try:
responses = await fetch_all(urls, checksummed_address)
# TODO How to handle this when multiple models allowed at once?
# Maybe prefetch all requested non-aggregate and pass them to the aggregate
# model which will skip checking those again?
if settings.AGGREGATE_MODEL_NAME in models:
responses = await get_aggregate_model_response(checksummed_address)
else:
responses = await get_model_responses(models, checksummed_address)

ret = PassportAnalysisResponse(
address=address,
Expand All @@ -148,7 +148,37 @@ async def handle_get_analysis(
)

return ret

except Exception:
log.error("Error retrieving Passport analysis", exc_info=True)
raise PassportAnalysisError()


async def get_aggregate_model_response(checksummed_address: str):
models = [model for model in MODEL_AGGREGATION_KEYS]

model_responses = await get_model_responses(models, checksummed_address)

payload = {
"address": checksummed_address,
"data": {},
}

for model, response in zip(models, model_responses):
data = response.get("data", {})
score = data.get("human_probability", 0)
num_transactions = data.get("n_transactions", 0)
model_key = MODEL_AGGREGATION_KEYS[model]

payload["data"][f"score_{model_key}"] = score
payload["data"][f"txs_{model_key}"] = num_transactions

url = settings.MODEL_ENDPOINTS[settings.AGGREGATE_MODEL_NAME]

return await fetch_all([url], payload)


async def get_model_responses(models: List[str], checksummed_address: str):
urls = [settings.MODEL_ENDPOINTS[model] for model in models]

payload = {"address": checksummed_address}
return await fetch_all(urls, payload)
2 changes: 1 addition & 1 deletion api/passport/test/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_checksummed_address_is_passed_on(self, mock_post):
Changing this would affect the current cached values
"""
self.client.get(
"/passport/analysis/0x06e3c221011767FE816D0B8f5B16253E43e4Af7d".lower(),
"/passport/analysis/0x06e3c221011767FE816D0B8f5B16253E43e4Af7d?model_list=ethereum_activity".lower(),
content_type="application/json",
**self.headers,
)
Expand Down
17 changes: 16 additions & 1 deletion api/scorer/settings/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@
OPTIMISM_MODEL_ENDPOINT = env(
"OPTIMISM_MODEL_ENDPOINT", default="http://localhost:80/zksync"
)
AGGREGATE_MODEL_ENDPOINT = env(
"AGGREGATE_MODEL_ENDPOINT", default="http://localhost:80/aggregate"
)

AGGREGATE_MODEL_NAME = "aggregate"


MODEL_AGGREGATION_KEYS = {
"zksync": "zk",
"polygon": "polygon",
"ethereum_activity": "eth",
"arbitrum": "arb",
"optimism": "op",
}

MODEL_ENDPOINTS = {
"ethereum_activity": ETHEREUM_MODEL_ENDPOINT,
Expand All @@ -24,6 +38,7 @@
"polygon": POLYGON_MODEL_ENDPOINT,
"arbitrum": ARBITRUM_MODEL_ENDPOINT,
"optimism": OPTIMISM_MODEL_ENDPOINT,
AGGREGATE_MODEL_NAME: AGGREGATE_MODEL_ENDPOINT,
}

MODEL_ENDPOINTS_DEFAULT = "ethereum_activity"
MODEL_ENDPOINTS_DEFAULT = "aggregate"
47 changes: 0 additions & 47 deletions infra/aws/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,6 @@ const redashMailPassword = pulumi.secret(
)
);

// TODO: we could load this from these endpoints data-science stackconst ethereumModelEndpoint = `${process.env["ETHEREUM_MODEL_ENDPOINT"]}`;
const ethereumModelEndpoint = op.read.parse(
`op://DevOps/passport-scorer-${stack}-env/ci/ETHEREUM_MODEL_ENDPOINT`
);
const nftModelEndpoint = op.read.parse(
`op://DevOps/passport-scorer-${stack}-env/ci/NFT_MODEL_ENDPOINT`
);
const zksyncModelEndpoint = op.read.parse(
`op://DevOps/passport-scorer-${stack}-env/ci/ZKSYNC_MODEL_ENDPOINT`
);
const polygonModelEndpoint = op.read.parse(
`op://DevOps/passport-scorer-${stack}-env/ci/POLYGON_MODEL_ENDPOINT`
);
const arbitrumModelEndpoint = op.read.parse(
`op://DevOps/passport-scorer-${stack}-env/ci/ARBITRUM_MODEL_ENDPOINT`
);
const optimismModelEndpoint = op.read.parse(
`op://DevOps/passport-scorer-${stack}-env/ci/OPTIMISM_MODEL_ENDPOINT`
);

const pagerDutyIntegrationEndpoint = op.read.parse(
`op://DevOps/passport-scorer-${stack}-env/ci/PAGERDUTY_INTEGRATION_ENDPOINT`
);
Expand Down Expand Up @@ -1634,33 +1614,6 @@ buildHttpLambdaFn(
buildHttpLambdaFn(
{
...lambdaSettings,
environment: [
...lambdaSettings.environment,
{
name: "ETHEREUM_MODEL_ENDPOINT",
value: ethereumModelEndpoint,
},
{
name: "NFT_MODEL_ENDPOINT",
value: nftModelEndpoint,
},
{
name: "ZKSYNC_MODEL_ENDPOINT",
value: zksyncModelEndpoint,
},
{
name: "POLYGON_MODEL_ENDPOINT",
value: polygonModelEndpoint,
},
{
name: "ARBITRUM_MODEL_ENDPOINT",
value: arbitrumModelEndpoint,
},
{
name: "OPTIMISM_MODEL_ENDPOINT",
value: optimismModelEndpoint,
},
].sort(secretsManager.sortByName),
name: "passport-analysis-GET-0",
memorySize: 256,
dockerCmd: ["aws_lambdas.passport.analysis_GET.handler"],
Expand Down
Loading

0 comments on commit b9f24a3

Please sign in to comment.