Skip to content

Commit

Permalink
Merge pull request #4236 from freelawproject/feat-enhance-command-to-…
Browse files Browse the repository at this point in the history
…create-manifest-files

feat(db_manifest): Enhance command to create aws manifest files
  • Loading branch information
mlissner authored Jul 23, 2024
2 parents 9ff801c + 1df7c4e commit 57ebc40
Showing 1 changed file with 106 additions and 53 deletions.
159 changes: 106 additions & 53 deletions cl/corpus_importer/management/commands/make_aws_manifest_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,63 @@
s3_client = boto3.client("s3")


def get_total_number_of_records(type: str, use_replica: bool = False) -> int:
def get_total_number_of_records(type: str, options: dict[str, Any]) -> int:
"""
Retrieves the total number of records for a specific data type.
Args:
type (str): The type of data to count. Must be one of the valid values
from the `SEARCH_TYPES` class.
use_replica (bool, optional): Whether to use the replica database
connection (default: False).
options (dict[str, Any]): A dictionary containing options for filtering
the results.
- 'use_replica' (bool, optional): Whether to use the replica database
connection (default: False).
- 'random_sample_percentage' (float, optional): The percentage of
records to include in a random sample.
Returns:
int: The total number of records matching the specified data type.
"""
match type:
case SEARCH_TYPES.RECAP_DOCUMENT:
query = """
SELECT count(*) AS exact_count
FROM search_recapdocument
base_query = (
"SELECT count(*) AS exact_count FROM search_recapdocument"
)
filter_clause = """
WHERE is_available=True AND page_count>0 AND ocr_status!=1
"""
case SEARCH_TYPES.OPINION:
query = """
SELECT count(*) AS exact_count
FROM search_opinion
WHERE extracted_by_ocr != true
"""
base_query = "SELECT count(*) AS exact_count FROM search_opinion"
filter_clause = "WHERE extracted_by_ocr != true"
case SEARCH_TYPES.ORAL_ARGUMENT:
query = """
SELECT count(*) AS exact_count
FROM audio_audio
WHERE
local_path_mp3 != '' AND
base_query = "SELECT count(*) AS exact_count FROM audio_audio"
filter_clause = """WHERE local_path_mp3 != '' AND
download_url != 'https://www.cadc.uscourts.gov/recordings/recordings.nsf/' AND
position('Unavailable' in download_url) = 0 AND
duration > 30
"""

if options["random_sample_percentage"]:
percentage = options["random_sample_percentage"]
base_query = f"{base_query} TABLESAMPLE SYSTEM ({percentage})"

query = (
f"{base_query}\n"
if options["all_records"]
else f"{base_query}\n {filter_clause}\n"
)
with connections[
"replica" if use_replica else "default"
"replica" if options["use_replica"] else "default"
].cursor() as cursor:
cursor.execute(query, [])
result = cursor.fetchone()

return int(result[0])


def get_custom_query(type: str, last_pk: str) -> tuple[str, list[Any]]:
def get_custom_query(
type: str, last_pk: str, options: dict[str, Any]
) -> tuple[str, list[Any]]:
"""
Generates a custom SQL query based on the provided type and optional last
pk.
Expand All @@ -69,57 +79,59 @@ def get_custom_query(type: str, last_pk: str) -> tuple[str, list[Any]]:
type (str): Type of data to retrieve.
last_pk (int, optional): Last primary key retrieved in a previous
query. Defaults to None.
options (dict[str, Any]): A dictionary containing options for filtering
the results.
- 'random_sample_percentage' (float, optional): The percentage of
records to include in a random sample.
Returns:
tuple[str, list[Any]]: A tuple containing the constructed SQL
query(str) and a list of parameters (list[Any]) to be used with
the query.
"""
params = []

random_sample = options["random_sample_percentage"]
match type:
case SEARCH_TYPES.RECAP_DOCUMENT:
base_query = "SELECT id from search_recapdocument"
filter_clause = (
"WHERE is_available=True AND page_count>0 AND ocr_status!=1"
if not last_pk
else (
"WHERE id > %s AND is_available = True AND page_count > 0"
" AND ocr_status != 1"
)
)
case SEARCH_TYPES.OPINION:
base_query = "SELECT id from search_opinion"
filter_clause = (
"WHERE extracted_by_ocr != true"
if not last_pk
else "WHERE id > %s AND extracted_by_ocr != true"
)
filter_clause = "WHERE extracted_by_ocr != true"
case SEARCH_TYPES.ORAL_ARGUMENT:
base_query = "SELECT id from audio_audio"
no_argument_where_clause = """
filter_clause = """
WHERE local_path_mp3 != '' AND
download_url != 'https://www.cadc.uscourts.gov/recordings/recordings.nsf/' AND
position('Unavailable' in download_url) = 0 AND
duration > 30
"""
where_clause_with_argument = """
WHERE id > %s AND
local_path_mp3 != '' AND
download_url != 'https://www.cadc.uscourts.gov/recordings/recordings.nsf/' AND
position('Unavailable' in download_url) = 0 AND
duration > 30
"""
filter_clause = (
no_argument_where_clause
if not last_pk
else where_clause_with_argument
)

if last_pk:
if random_sample:
base_query = f"{base_query} TABLESAMPLE SYSTEM ({random_sample})"

if options["all_records"]:
filter_clause = ""

# Using a WHERE clause with `id > last_pk` and a LIMIT clause for batch
# retrieval is not suitable for random sampling. The following logic
# removes these clauses when retrieving a random sample to ensure all rows
# have an equal chance of being selected.
if last_pk and not random_sample:
filter_clause = (
f"WHERE id > %s"
if not filter_clause
else f"{filter_clause} AND id > %s"
)
params.append(last_pk)

query = f"{base_query}\n {filter_clause}\n ORDER BY id\n LIMIT %s"
query = (
f"{base_query}\n {filter_clause}"
if random_sample
else f"{base_query}\n {filter_clause}\n ORDER BY id\n LIMIT %s"
)

return query, params

Expand Down Expand Up @@ -170,6 +182,27 @@ def add_arguments(self, parser: CommandParser):
default=False,
help="Use this flag to run the queries in the replica db",
)
parser.add_argument(
"--file-name",
type=str,
default=None,
help="Custom name for the output files. If not provided, a default "
"name will be used.",
)
parser.add_argument(
"--random-sample-percentage",
type=float,
default=None,
help="Specifies the proportion of the table to be sampled (between "
"0.0 and 100.0). Use this flag to retrieve a random set of records.",
)
parser.add_argument(
"--all-records",
action="store_true",
default=False,
help="Use this flag to retrieve all records from the table without"
" applying any filters.",
)

def handle(self, *args, **options):
r = get_redis_interface("CACHE")
Expand All @@ -188,7 +221,7 @@ def handle(self, *args, **options):
)
if not total_number_of_records:
total_number_of_records = get_total_number_of_records(
record_type, options["use_replica"]
record_type, options
)
r.hset(
f"{record_type}_import_status",
Expand All @@ -200,12 +233,17 @@ def handle(self, *args, **options):
r.hget(f"{record_type}_import_status", "next_iteration_counter")
or 0
)
file_name = (
options["file_name"]
if options["file_name"]
else f"{record_type}_filelist"
)
while True:
query, params = get_custom_query(
options["record_type"],
last_pk,
options["record_type"], last_pk, options
)
params.append(options["query_batch_size"])
if not options["random_sample_percentage"]:
params.append(options["query_batch_size"])

with connections[
"replica" if options["use_replica"] else "default"
Expand All @@ -226,22 +264,37 @@ def handle(self, *args, **options):
extrasaction="ignore",
)
for row in batched(rows, options["lambda_record_size"]):
query_dict = {
"bucket": bucket_name,
"file_name": (
if options["random_sample_percentage"]:
# Create an underscore-separated file name that lambda
# can split and use as part of batch processing.
ids = [str(r[0]) for r in row]
content = "_".join(ids)
else:
content = (
f"{row[0][0]}_{row[-1][0]}"
if len(row) > 1
else f"{row[0][0]}"
),
)
query_dict = {
"bucket": bucket_name,
"file_name": content,
}
writer.writerow(query_dict)

s3_client.put_object(
Key=f"{record_type}_filelist_{counter}.csv",
Key=f"{file_name}_{counter}.csv",
Bucket=bucket_name,
Body=csvfile.getvalue().encode("utf-8"),
)

if options["random_sample_percentage"]:
# Due to the non-deterministic nature of random sampling,
# storing data to recover the query for future executions
# wouldn't be meaningful. Random queries are unlikely to
# produce the same results on subsequent runs.
logger.info(f"Finished processing {record_count} records")
break

counter += 1
last_pk = rows[-1][0]
records_processed = int(
Expand Down

0 comments on commit 57ebc40

Please sign in to comment.