Skip to content

Commit

Permalink
2680 internal batch endpoint (#657)
Browse files Browse the repository at this point in the history
* wip: s3 initiated ecs task

* feat: trigger ecs task from upload to s3

* chore: cleanup naming and comments

* chore: remove comment

* feat: get list of address process them then upload results to s3

* feat: update progress

* feat: bulk process status endpoint

* chore: test batch processing

* chore: use internal router

* fix: unable to use setting as default value

* fix: patch

* fix: pytest mocks for s3_client

* fix: mock in api instead of source

* feat: create bucket folders and only initiate task on address-list

* fix: mocking of s3_client

* fix: correct command

* feat: analysis documentation

* feat: checksum address and parse to json

* feat: log duration, return errors in csv

* chore: increas batch size and unique bucket names

* feat: better error handling and additional data in response

* feat: unique naming for buckets

* fix: detailed response

* fix: just polygon

* fix: mocking of model responses

* fix: polygon to zksync
  • Loading branch information
tim-schultz authored Aug 14, 2024
1 parent 64208b4 commit 9f689e0
Show file tree
Hide file tree
Showing 17 changed files with 943 additions and 116 deletions.
38 changes: 33 additions & 5 deletions api/aws_lambdas/passport/tests/test_passport_analysis_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,39 @@

mock_model_responses = {
"ethereum_activity": {
"data": {"human_probability": 75, "n_transactions": 10},
"data": {
"human_probability": 75,
"n_transactions": 10,
"first_funder": "funder",
"first_funder_amount": 1000,
},
"metadata": {"model_name": "ethereum_activity", "version": "1.0"},
},
"nft": {
"data": {"human_probability": 85},
"data": {
"human_probability": 85,
"n_transactions": 10,
"first_funder": "funder",
"first_funder_amount": 1000,
},
"metadata": {"model_name": "social_media", "version": "2.0"},
},
"zksync": {
"data": {"human_probability": 95, "n_transactions": 5},
"data": {
"human_probability": 95,
"n_transactions": 10,
"first_funder": "funder",
"first_funder_amount": 1000,
},
"metadata": {"model_name": "transaction_history", "version": "1.5"},
},
"aggregate": {
"data": {"human_probability": 90},
"data": {
"human_probability": 90,
"n_transactions": 10,
"first_funder": "funder",
"first_funder_amount": 1000,
},
"metadata": {"model_name": "aggregate", "version": "2.5"},
},
}
Expand All @@ -43,7 +63,15 @@ def mock_post_response(session, url, data):
for model, endpoint in settings.MODEL_ENDPOINTS.items():
if endpoint in url:
response_data = mock_model_responses.get(
model, {"data": {"human_probability": 0}}
model,
{
"data": {
"human_probability": 0,
"n_transactions": 10,
"first_funder": "funder",
"first_funder_amount": 1000,
}
},
)
break
else:
Expand Down
60 changes: 50 additions & 10 deletions api/passport/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import json
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import aiohttp
from django.conf import settings
Expand Down Expand Up @@ -32,6 +32,13 @@ class ScoreModel(Schema):
score: int


class DetailedScoreModel(Schema):
score: int
n_transactions: Optional[int]
first_funder: Optional[str]
first_funder_amount: Optional[int]


@api.exception_handler(Ratelimited)
def service_unavailable(request, _):
return api.create_response(
Expand All @@ -42,7 +49,7 @@ def service_unavailable(request, _):


class PassportAnalysisDetails(Schema):
models: Dict[str, ScoreModel]
models: Dict[str, ScoreModel | DetailedScoreModel]


class PassportAnalysisResponse(Schema):
Expand Down Expand Up @@ -85,8 +92,20 @@ async def get_analysis(

async def fetch(session, url, data):
headers = {"Content-Type": "application/json"}
async with session.post(url, data=json.dumps(data), headers=headers) as response:
return await response.json()
try:
async with session.post(
url, data=json.dumps(data), headers=headers
) as response:
return await response.json()
except Exception as e:
log.error(f"Error fetching {url}", exc_info=True)
return {
"data": {
"human_probability": -1,
"n_transactions": -1,
"error": "Error fetching model response",
}
}


async def fetch_all(urls, payload):
Expand All @@ -100,8 +119,14 @@ async def fetch_all(urls, payload):


async def handle_get_analysis(
address: str, model_list: str = None
address: str,
model_list: str = None,
only_one_model=None,
additional_data=False,
) -> PassportAnalysisResponse:
only_one_model = (
only_one_model if only_one_model is not None else settings.ONLY_ONE_MODEL
)
# Set default in case nothing was selected by the user
if not model_list or model_list.strip() == "":
model_list = settings.MODEL_ENDPOINTS_DEFAULT
Expand All @@ -111,7 +136,7 @@ async def handle_get_analysis(
if not is_valid_address(address):
raise InvalidAddressException()

if settings.ONLY_ONE_MODEL and len(models) > 1:
if only_one_model and len(models) > 1:
raise BadModelNameError(
detail="Currently, only one model name can be provided at a time"
)
Expand Down Expand Up @@ -150,10 +175,25 @@ async def handle_get_analysis(
details=PassportAnalysisDetails(models={}),
)

for model, response in model_responses:
ret.details.models[model] = ScoreModel(
score=response.get("data", {}).get("human_probability", 0)
)
if additional_data:
for model, response in model_responses:
data = response.get("data", {})
score = data.get("human_probability", 0)
num_transactions = data.get("n_transactions", 0)
first_funder = data.get("first_funder", "")
first_funder_amount = data.get("first_funder_amount", 0)

ret.details.models[model] = DetailedScoreModel(
score=score,
n_transactions=num_transactions,
first_funder=first_funder,
first_funder_amount=first_funder_amount,
)
else:
for model, response in model_responses:
data = response.get("data", {})
score = data.get("human_probability", 0)
ret.details.models[model] = ScoreModel(score=score)

return ret
except Exception:
Expand Down
61 changes: 60 additions & 1 deletion api/passport/test/test_analysis.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
from asgiref.sync import async_to_sync
from django.conf import settings
from django.contrib.auth import get_user_model
from django.test import Client, TestCase, override_settings
from web3 import Web3

from account.models import Account, AccountAPIKey
from aws_lambdas.passport.tests.test_passport_analysis_lambda import mock_post_response
from passport.api import (
PassportAnalysisDetails,
PassportAnalysisResponse,
ScoreModel,
fetch_all,
handle_get_analysis,
)

pytestmark = pytest.mark.django_db

Expand All @@ -18,6 +26,16 @@
User = get_user_model()


def assert_passport_analysis_structure(
actual: PassportAnalysisResponse, expected: PassportAnalysisResponse
):
assert actual.address == expected.address
assert set(actual.details.models.keys()) == set(expected.details.models.keys())
for model_name in expected.details.models:
assert isinstance(actual.details.models[model_name], ScoreModel)
assert isinstance(actual.details.models[model_name].score, (int, float))


@pytest.mark.django_db
class TestPassportAnalysis(TestCase):
def setUp(self):
Expand Down Expand Up @@ -166,3 +184,44 @@ def test_ignore_duplicate_model(self, mock_fetch):
self.assertEqual(analysis_response.status_code, 200)

assert mock_fetch.call_count == 1

@override_settings(ONLY_ONE_MODEL=False)
@patch("passport.api.fetch", side_effect=mock_post_response)
def test_handle_get_analysis_returns_additional_data(self, mock_fetch):
"""Test handle_get_analysis returns additional data when requested."""
analysis = async_to_sync(handle_get_analysis)(
"0x06e3c221011767FE816D0B8f5B16253E43e4Af7D", "zksync", False, True
)

assert analysis.details.models["zksync"].score == 95
assert analysis.details.models["zksync"].n_transactions == 10
assert analysis.details.models["zksync"].first_funder == "funder"
assert analysis.details.models["zksync"].first_funder_amount == 1000

@override_settings(ONLY_ONE_MODEL=False)
@patch("passport.api.fetch", side_effect=mock_post_response)
def test_handle_get_analysis_does_not_return_additional_data(self, mock_fetch):
"""Test handle_get_analysis does not return additional data when not requested."""
analysis = async_to_sync(handle_get_analysis)(
"0x06e3c221011767FE816D0B8f5B16253E43e4Af7D", "nft,zksync", False
)
expected = PassportAnalysisResponse(
address="0x06e3c221011767FE816D0B8f5B16253E43e4Af7D",
details=PassportAnalysisDetails(
models={"nft": ScoreModel(score=85), "zksync": ScoreModel(score=0)}
),
)

assert_passport_analysis_structure(analysis, expected)

# Check that additional data is not present
for model_name in ["nft", "zksync"]:
assert not hasattr(analysis.details.models[model_name], "n_transactions")
assert not hasattr(analysis.details.models[model_name], "first_funder")
assert not hasattr(
analysis.details.models[model_name], "first_funder_amount"
)

# Check specific scores
assert analysis.details.models["nft"].score == 85
assert analysis.details.models["zksync"].score == 95
7 changes: 6 additions & 1 deletion api/registry/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import boto3
from asgiref.sync import async_to_sync
from django import forms
from django.conf import settings
from django.contrib import admin, messages
from django.shortcuts import redirect, render
from django.urls import path
Expand Down Expand Up @@ -34,7 +35,11 @@
def get_s3_client():
global _s3_client
if not _s3_client:
_s3_client = boto3.client("s3")
_s3_client = boto3.client(
"s3",
aws_access_key_id=settings.S3_DATA_AWS_SECRET_KEY_ID,
aws_secret_access_key=settings.S3_DATA_AWS_SECRET_ACCESS_KEY,
)
return _s3_client


Expand Down
74 changes: 69 additions & 5 deletions api/registry/api/v2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from datetime import datetime
from typing import List, Optional

from django.db.models import Q
from ninja import Router, Schema
from ninja.security import APIKeyHeader

import api_logging as logging

# --- Deduplication Modules
from account.models import Community
from django.db.models import Q
from ninja import Router
from registry.admin import get_s3_client
from registry.api import common, v1
from registry.api.schema import (
CursorPaginatedHistoricalScoreResponse,
Expand All @@ -29,20 +32,23 @@
InvalidLimitException,
api_get_object_or_404,
)
from registry.models import Score
from registry.models import BatchModelScoringRequest, BatchRequestStatus, Score
from registry.utils import (
decode_cursor,
encode_cursor,
get_cursor_query_condition,
reverse_lazy_with_query,
)
from scorer import settings
from scorer.settings import (
BULK_MODEL_SCORE_REQUESTS_RESULTS_FOLDER,
BULK_SCORE_REQUESTS_BUCKET_NAME,
)

log = logging.getLogger(__name__)

router = Router()

analytics_router = Router()


@router.get(
"/signing-message",
Expand Down Expand Up @@ -320,3 +326,61 @@ def get_score_history(
)
def get_score(request, address: str, scorer_id: int) -> DetailedScoreResponse:
return v1.get_score(request, address, scorer_id)


internal_router = Router()


class DataScienceApiKey(APIKeyHeader):
param_name = "AUTHORIZATION"

def authenticate(self, request, key):
if key == settings.DATA_SCIENCE_API_KEY:
return key
return None


data_science_auth = DataScienceApiKey()


class BatchResponse(Schema):
created_at: str
s3_url: Optional[str]
status: BatchRequestStatus
percentage_complete: int


@internal_router.get(
"/analysis",
auth=data_science_auth,
response={
200: list[BatchResponse],
400: ErrorMessageResponse,
500: ErrorMessageResponse,
},
summary="Retrieve batch scoring status and result",
description="Retrieve batch scoring status and result",
)
def get_batch_analysis_stats(request, limit: int = 10) -> list[BatchResponse]:
requests = BatchModelScoringRequest.objects.order_by("-created_at")[:limit]
return [
BatchResponse(
created_at=req.created_at.isoformat(),
s3_url=(
get_s3_client().generate_presigned_url(
"get_object",
Params={
"Bucket": BULK_SCORE_REQUESTS_BUCKET_NAME,
"Key": f"{BULK_MODEL_SCORE_REQUESTS_RESULTS_FOLDER}/{req.s3_filename}",
},
# 24 hrs
ExpiresIn=60 * 60 * 24,
)
if req.status == BatchRequestStatus.DONE
else None
),
status=req.status,
percentage_complete=req.progress,
)
for req in requests
]
Loading

0 comments on commit 9f689e0

Please sign in to comment.