Skip to content

Commit

Permalink
feat: extended the export command to include the scores for all models (
Browse files Browse the repository at this point in the history
#645)

* feat: extended the export command to include the scores for all models

* feat: update test cases for scorer_dump_data_model_score

* feat: update the cronjob exporting the model scroes, to include the scores for all existing models and use parquet file format
  • Loading branch information
nutrina authored Jul 29, 2024
1 parent b9f24a3 commit 9012b2c
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 39 deletions.
76 changes: 54 additions & 22 deletions api/data_model/management/commands/scorer_dump_data_model_score.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,47 @@
import json
import traceback
from contextlib import contextmanager
from logging import getLogger
from urllib.parse import urlparse


import pyarrow as pa
import pyarrow.parquet as pq
from django.core.management.base import BaseCommand
from django.core.serializers.json import DjangoJSONEncoder

from data_model.models import Cache
from scorer.export_utils import (
export_data_for_model,
get_pa_schema,
upload_to_s3,
)
from data_model.models import Cache
from contextlib import contextmanager
from django.core.serializers.json import DjangoJSONEncoder
from logging import getLogger

log = getLogger(__name__)


def get_writer(output_file):
def get_parquet_writer(output_file):
@contextmanager
def writer_context_manager(model):
schema = get_pa_schema(model)
try:
with pq.ParquetWriter(output_file, schema) as writer:

class WriterWrappe:
def __init__(self, writer):
self.writer = writer

def write_batch(self, data):
batch = pa.RecordBatch.from_pylist(data, schema=schema)
self.writer.write_batch(batch)

yield WriterWrappe(writer)
finally:
pass

return writer_context_manager


def get_jsonl_writer(output_file):
@contextmanager
def eth_stamp_writer_context_manager(queryset):
try:
Expand All @@ -31,17 +56,10 @@ def write_batch(self, data):
try:
value = d["value"]
address = d["key_1"].lower()
model = d["key_0"].lower()
self.file.write(
json.dumps(
{
"address": address,
"data": {
"score": str(
value["data"]["human_probability"]
)
},
"updated_at": d["updated_at"],
},
d,
cls=DjangoJSONEncoder,
)
+ "\n"
Expand Down Expand Up @@ -88,19 +106,30 @@ def add_arguments(self, parser):
)

parser.add_argument("--filename", type=str, help="The output filename")

parser.add_argument(
"--s3-extra-args",
type=str,
help="""JSON object, that contains extra args for the files uploaded to S3.
This will be passed in as the `ExtraArgs` parameter to boto3's upload_file method.""",
)
parser.add_argument(
"--format",
type=str,
choices=["jsonl", "parquet"],
help="The output format",
default="jsonl",
)

def handle(self, *args, **options):
batch_size = options["batch_size"]
s3_uri = options["s3_uri"]
filename = options["filename"]
data_model_name = options["data_model"]
format = options["format"]
data_model_names = (
[n.strip() for n in options["data_model"].split(",")]
if options["data_model"]
else None
)

extra_args = (
json.loads(options["s3_extra_args"]) if options["s3_extra_args"] else None
Expand All @@ -113,16 +142,19 @@ def handle(self, *args, **options):
parsed_uri = urlparse(s3_uri)
s3_bucket_name = parsed_uri.netloc
s3_folder = parsed_uri.path.strip("/")
query = Cache.objects.all()
if data_model_names:
query = query.filter(key_0__in=data_model_names)

try:
export_data_for_model(
Cache.objects.filter(
key_0=data_model_name
), # This will only filter the scores for eth_stamp_model (v1)
query,
"id",
batch_size,
get_writer(filename),
jsonfields_as_str=False,
get_parquet_writer(filename)
if format == "parquet"
else get_jsonl_writer(filename),
jsonfields_as_str=(format == "parquet"),
)

self.stdout.write(
Expand Down
Loading

0 comments on commit 9012b2c

Please sign in to comment.