diff --git a/cl/corpus_importer/management/commands/make_aws_manifest_files.py b/cl/corpus_importer/management/commands/make_aws_manifest_files.py index 479708c0d0..ebc4690312 100644 --- a/cl/corpus_importer/management/commands/make_aws_manifest_files.py +++ b/cl/corpus_importer/management/commands/make_aws_manifest_files.py @@ -14,45 +14,53 @@ s3_client = boto3.client("s3") -def get_total_number_of_records(type: str, use_replica: bool = False) -> int: +def get_total_number_of_records(type: str, options: dict[str, Any]) -> int: """ Retrieves the total number of records for a specific data type. Args: type (str): The type of data to count. Must be one of the valid values from the `SEARCH_TYPES` class. - use_replica (bool, optional): Whether to use the replica database - connection (default: False). + options (dict[str, Any]): A dictionary containing options for filtering + the results. + - 'use_replica' (bool, optional): Whether to use the replica database + connection (default: False). + - 'random_sample_percentage' (float, optional): The percentage of + records to include in a random sample. Returns: int: The total number of records matching the specified data type. """ match type: case SEARCH_TYPES.RECAP_DOCUMENT: - query = """ - SELECT count(*) AS exact_count - FROM search_recapdocument + base_query = ( + "SELECT count(*) AS exact_count FROM search_recapdocument" + ) + filter_clause = """ WHERE is_available=True AND page_count>0 AND ocr_status!=1 """ case SEARCH_TYPES.OPINION: - query = """ - SELECT count(*) AS exact_count - FROM search_opinion - WHERE extracted_by_ocr != true - """ + base_query = "SELECT count(*) AS exact_count FROM search_opinion" + filter_clause = "WHERE extracted_by_ocr != true" case SEARCH_TYPES.ORAL_ARGUMENT: - query = """ - SELECT count(*) AS exact_count - FROM audio_audio - WHERE - local_path_mp3 != '' AND + base_query = "SELECT count(*) AS exact_count FROM audio_audio" + filter_clause = """WHERE local_path_mp3 != '' AND download_url != 'https://www.cadc.uscourts.gov/recordings/recordings.nsf/' AND position('Unavailable' in download_url) = 0 AND duration > 30 """ + if options["random_sample_percentage"]: + percentage = options["random_sample_percentage"] + base_query = f"{base_query} TABLESAMPLE SYSTEM ({percentage})" + + query = ( + f"{base_query}\n" + if options["all_records"] + else f"{base_query}\n {filter_clause}\n" + ) with connections[ - "replica" if use_replica else "default" + "replica" if options["use_replica"] else "default" ].cursor() as cursor: cursor.execute(query, []) result = cursor.fetchone() @@ -60,7 +68,9 @@ def get_total_number_of_records(type: str, use_replica: bool = False) -> int: return int(result[0]) -def get_custom_query(type: str, last_pk: str) -> tuple[str, list[Any]]: +def get_custom_query( + type: str, last_pk: str, options: dict[str, Any] +) -> tuple[str, list[Any]]: """ Generates a custom SQL query based on the provided type and optional last pk. @@ -69,6 +79,10 @@ def get_custom_query(type: str, last_pk: str) -> tuple[str, list[Any]]: type (str): Type of data to retrieve. last_pk (int, optional): Last primary key retrieved in a previous query. Defaults to None. + options (dict[str, Any]): A dictionary containing options for filtering + the results. + - 'random_sample_percentage' (float, optional): The percentage of + records to include in a random sample. Returns: tuple[str, list[Any]]: A tuple containing the constructed SQL @@ -76,50 +90,48 @@ def get_custom_query(type: str, last_pk: str) -> tuple[str, list[Any]]: the query. """ params = [] - + random_sample = options["random_sample_percentage"] match type: case SEARCH_TYPES.RECAP_DOCUMENT: base_query = "SELECT id from search_recapdocument" filter_clause = ( "WHERE is_available=True AND page_count>0 AND ocr_status!=1" - if not last_pk - else ( - "WHERE id > %s AND is_available = True AND page_count > 0" - " AND ocr_status != 1" - ) ) case SEARCH_TYPES.OPINION: base_query = "SELECT id from search_opinion" - filter_clause = ( - "WHERE extracted_by_ocr != true" - if not last_pk - else "WHERE id > %s AND extracted_by_ocr != true" - ) + filter_clause = "WHERE extracted_by_ocr != true" case SEARCH_TYPES.ORAL_ARGUMENT: base_query = "SELECT id from audio_audio" - no_argument_where_clause = """ + filter_clause = """ WHERE local_path_mp3 != '' AND download_url != 'https://www.cadc.uscourts.gov/recordings/recordings.nsf/' AND position('Unavailable' in download_url) = 0 AND duration > 30 """ - where_clause_with_argument = """ - WHERE id > %s AND - local_path_mp3 != '' AND - download_url != 'https://www.cadc.uscourts.gov/recordings/recordings.nsf/' AND - position('Unavailable' in download_url) = 0 AND - duration > 30 - """ - filter_clause = ( - no_argument_where_clause - if not last_pk - else where_clause_with_argument - ) - if last_pk: + if random_sample: + base_query = f"{base_query} TABLESAMPLE SYSTEM ({random_sample})" + + if options["all_records"]: + filter_clause = "" + + # Using a WHERE clause with `id > last_pk` and a LIMIT clause for batch + # retrieval is not suitable for random sampling. The following logic + # removes these clauses when retrieving a random sample to ensure all rows + # have an equal chance of being selected. + if last_pk and not random_sample: + filter_clause = ( + f"WHERE id > %s" + if not filter_clause + else f"{filter_clause} AND id > %s" + ) params.append(last_pk) - query = f"{base_query}\n {filter_clause}\n ORDER BY id\n LIMIT %s" + query = ( + f"{base_query}\n {filter_clause}" + if random_sample + else f"{base_query}\n {filter_clause}\n ORDER BY id\n LIMIT %s" + ) return query, params @@ -170,6 +182,27 @@ def add_arguments(self, parser: CommandParser): default=False, help="Use this flag to run the queries in the replica db", ) + parser.add_argument( + "--file-name", + type=str, + default=None, + help="Custom name for the output files. If not provided, a default " + "name will be used.", + ) + parser.add_argument( + "--random-sample-percentage", + type=float, + default=None, + help="Specifies the proportion of the table to be sampled (between " + "0.0 and 100.0). Use this flag to retrieve a random set of records.", + ) + parser.add_argument( + "--all-records", + action="store_true", + default=False, + help="Use this flag to retrieve all records from the table without" + " applying any filters.", + ) def handle(self, *args, **options): r = get_redis_interface("CACHE") @@ -188,7 +221,7 @@ def handle(self, *args, **options): ) if not total_number_of_records: total_number_of_records = get_total_number_of_records( - record_type, options["use_replica"] + record_type, options ) r.hset( f"{record_type}_import_status", @@ -200,12 +233,17 @@ def handle(self, *args, **options): r.hget(f"{record_type}_import_status", "next_iteration_counter") or 0 ) + file_name = ( + options["file_name"] + if options["file_name"] + else f"{record_type}_filelist" + ) while True: query, params = get_custom_query( - options["record_type"], - last_pk, + options["record_type"], last_pk, options ) - params.append(options["query_batch_size"]) + if not options["random_sample_percentage"]: + params.append(options["query_batch_size"]) with connections[ "replica" if options["use_replica"] else "default" @@ -226,22 +264,37 @@ def handle(self, *args, **options): extrasaction="ignore", ) for row in batched(rows, options["lambda_record_size"]): - query_dict = { - "bucket": bucket_name, - "file_name": ( + if options["random_sample_percentage"]: + # Create an underscore-separated file name that lambda + # can split and use as part of batch processing. + ids = [str(r[0]) for r in row] + content = "_".join(ids) + else: + content = ( f"{row[0][0]}_{row[-1][0]}" if len(row) > 1 else f"{row[0][0]}" - ), + ) + query_dict = { + "bucket": bucket_name, + "file_name": content, } writer.writerow(query_dict) s3_client.put_object( - Key=f"{record_type}_filelist_{counter}.csv", + Key=f"{file_name}_{counter}.csv", Bucket=bucket_name, Body=csvfile.getvalue().encode("utf-8"), ) + if options["random_sample_percentage"]: + # Due to the non-deterministic nature of random sampling, + # storing data to recover the query for future executions + # wouldn't be meaningful. Random queries are unlikely to + # produce the same results on subsequent runs. + logger.info(f"Finished processing {record_count} records") + break + counter += 1 last_pk = rows[-1][0] records_processed = int(