From ff6da13fdf09c6b7771783bdab92ee6cfaa03e5f Mon Sep 17 00:00:00 2001 From: TosinIbikunle Date: Tue, 14 Oct 2025 22:34:10 +0100 Subject: [PATCH 1/2] feat: implement ruff linter and autoformatter --- .pre-commit-config.yaml | 19 +++----- ami/base/models.py | 2 +- ami/base/serializers.py | 4 +- ami/exports/format_types.py | 4 +- ami/exports/utils.py | 2 +- .../management/commands/update_stale_jobs.py | 4 +- ami/jobs/models.py | 8 ++-- ami/main/api/serializers.py | 12 ++--- ami/main/charts.py | 22 +++++----- ami/main/management/commands/import_taxa.py | 18 ++++---- .../commands/import_trapdata_project.py | 16 +++---- ami/main/management/commands/update_taxa.py | 4 +- ami/main/models.py | 29 ++++++------ ami/main/models_future/filters.py | 3 +- ami/main/models_future/projects.py | 2 +- ami/main/tests.py | 8 +++- ami/ml/models/algorithm.py | 2 +- ami/ml/models/pipeline.py | 25 +++++------ ami/ml/views.py | 2 +- ami/tasks.py | 2 +- ami/tests/fixtures/images.py | 2 +- ami/tests/fixtures/main.py | 44 +++++++++---------- ami/tests/fixtures/storage.py | 16 +++---- ami/users/tests/test_forms.py | 1 + ami/utils/s3.py | 4 +- config/asgi.py | 1 + config/settings/base.py | 6 ++- config/wsgi.py | 1 + processing_services/example/api/algorithms.py | 4 +- processing_services/example/api/pipelines.py | 14 ++++-- pyproject.toml | 29 ++++++++++++ requirements/base.txt | 6 +-- setup.cfg | 6 +-- 33 files changed, 177 insertions(+), 145 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index acde97cd9..58211c6dd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,20 +29,12 @@ repos: - id: pyupgrade args: [--py310-plus] - - repo: https://github.com/psf/black - rev: 23.3.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.13.1 hooks: - - id: black - - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - - - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 - hooks: - - id: flake8 + - id: ruff + args: [--fix] + - id: ruff-format - repo: https://github.com/Riverside-Healthcare/djLint rev: v1.31.1 @@ -50,7 +42,6 @@ repos: - id: djlint-reformat-django - id: djlint-django - # sets up .pre-commit-ci.yaml to ensure pre-commit dependencies stay up to date ci: autoupdate_schedule: weekly diff --git a/ami/base/models.py b/ami/base/models.py index 2f245b745..68927dea4 100644 --- a/ami/base/models.py +++ b/ami/base/models.py @@ -153,7 +153,7 @@ def get_project(self): def __str__(self) -> str: """All django models should have this method.""" if hasattr(self, "name"): - name = getattr(self, "name") or "Untitled" + name = self.name or "Untitled" return f"#{self.pk} {name}" else: return f"{self.__class__.__name__} #{self.pk}" diff --git a/ami/base/serializers.py b/ami/base/serializers.py index e39ede52b..129e1e6ed 100644 --- a/ami/base/serializers.py +++ b/ami/base/serializers.py @@ -13,7 +13,9 @@ logger = logging.getLogger(__name__) -def reverse_with_params(viewname: str, args=None, kwargs=None, request=None, params: dict = {}, **extra) -> str: +def reverse_with_params(viewname: str, args=None, kwargs=None, request=None, params: dict = None, **extra) -> str: + if params is None: + params = {} query_string = urllib.parse.urlencode(params) base_url = reverse(viewname, request=request, args=args, kwargs=kwargs, **extra) url = urllib.parse.urlunsplit(("", "", base_url, query_string, "")) diff --git a/ami/exports/format_types.py b/ami/exports/format_types.py index 087e50e8f..f145739a5 100644 --- a/ami/exports/format_types.py +++ b/ami/exports/format_types.py @@ -63,7 +63,7 @@ def export(self): first = True f.write("[") records_exported = 0 - for i, batch in enumerate(get_data_in_batches(self.queryset, self.get_serializer_class())): + for _i, batch in enumerate(get_data_in_batches(self.queryset, self.get_serializer_class())): json_data = json.dumps(batch, cls=DjangoJSONEncoder) json_data = json_data[1:-1] # remove [ and ] from json string f.write(",\n" if not first else "") @@ -153,7 +153,7 @@ def export(self): writer = csv.DictWriter(csvfile, fieldnames=field_names) writer.writeheader() - for i, batch in enumerate(get_data_in_batches(self.queryset, self.serializer_class)): + for _i, batch in enumerate(get_data_in_batches(self.queryset, self.serializer_class)): writer.writerows(batch) records_exported += len(batch) self.update_job_progress(records_exported) diff --git a/ami/exports/utils.py b/ami/exports/utils.py index e59454219..7f9fb601a 100644 --- a/ami/exports/utils.py +++ b/ami/exports/utils.py @@ -80,7 +80,7 @@ def get_data_in_batches(QuerySet: models.QuerySet, Serializer: type[serializers. batch = [] fake_request = generate_fake_request() - for i, item in enumerate(items): + for _i, item in enumerate(items): try: serializer = Serializer( item, diff --git a/ami/jobs/management/commands/update_stale_jobs.py b/ami/jobs/management/commands/update_stale_jobs.py index da3a53e3d..d862ca045 100644 --- a/ami/jobs/management/commands/update_stale_jobs.py +++ b/ami/jobs/management/commands/update_stale_jobs.py @@ -7,9 +7,7 @@ class Command(BaseCommand): - help = ( - "Update the status of all jobs that are not in a final state " "and have not been updated in the last X hours." - ) + help = "Update the status of all jobs that are not in a final state and have not been updated in the last X hours." # Add argument for the number of hours to consider a job stale def add_arguments(self, parser): diff --git a/ami/jobs/models.py b/ami/jobs/models.py index f7b85283b..8b3dbfa89 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -412,17 +412,17 @@ def run(cls, job: "Job"): for i, chunk in enumerate(chunks): request_sent = time.time() - job.logger.info(f"Processing image batch {i+1} of {len(chunks)}") + job.logger.info(f"Processing image batch {i + 1} of {len(chunks)}") try: results = job.pipeline.process_images( images=chunk, job_id=job.pk, project_id=job.project.pk, ) - job.logger.info(f"Processed image batch {i+1} in {time.time() - request_sent:.2f}s") + job.logger.info(f"Processed image batch {i + 1} in {time.time() - request_sent:.2f}s") except Exception as e: # Log error about image batch and continue - job.logger.error(f"Failed to process image batch {i+1}: {e}") + job.logger.error(f"Failed to process image batch {i + 1}: {e}") request_failed_images.extend([img.pk for img in chunk]) else: total_captures += len(results.source_images) @@ -433,7 +433,7 @@ def run(cls, job: "Job"): # @TODO add callback to report errors while saving results marking the job as failed save_results_task: AsyncResult = job.pipeline.save_results_async(results=results, job_id=job.pk) save_tasks.append((i + 1, save_results_task)) - job.logger.info(f"Saving results for batch {i+1} in sub-task {save_results_task.id}") + job.logger.info(f"Saving results for batch {i + 1} in sub-task {save_results_task.id}") job.progress.update_stage( "process", diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index b07ce8657..7dff02974 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -1299,12 +1299,12 @@ def get_determination_details(self, obj: Occurrence): else: prediction = ClassificationNestedSerializer(obj.best_prediction, context=context).data - return dict( - taxon=taxon, - identification=identification, - prediction=prediction, - score=obj.determination_score, - ) + return { + "taxon": taxon, + "identification": identification, + "prediction": prediction, + "score": obj.determination_score, + } class OccurrenceSerializer(OccurrenceListSerializer): diff --git a/ami/main/charts.py b/ami/main/charts.py index 2d88ed573..6ec7d28f2 100644 --- a/ami/main/charts.py +++ b/ami/main/charts.py @@ -91,7 +91,7 @@ def captures_per_day(project_pk: int): ) if captures_per_date: - days, counts = list(zip(*captures_per_date)) + days, counts = list(zip(*captures_per_date, strict=False)) days = [day for day in days if day] # tickvals_per_month = [f"{d:%b}" for d in days] tickvals = [f"{days[0]:%b %d}", f"{days[-1]:%b %d}"] @@ -119,7 +119,7 @@ def captures_per_month(project_pk: int): ) if captures_per_month: - months, counts = list(zip(*captures_per_month)) + months, counts = list(zip(*captures_per_month, strict=False)) # tickvals_per_month = [f"{d:%b}" for d in days] tickvals = [f"{months[0]}", f"{months[-1]}"] # labels = [f"{d}" for d in months] @@ -146,7 +146,7 @@ def events_per_week(project_pk: int): ) if captures_per_week: - weeks, counts = list(zip(*captures_per_week)) + weeks, counts = list(zip(*captures_per_week, strict=False)) # tickvals_per_month = [f"{d:%b}" for d in days] tickvals = [f"{weeks[0]}", f"{weeks[-1]}"] labels = [f"{d}" for d in weeks] @@ -172,7 +172,7 @@ def events_per_month(project_pk: int): ) if captures_per_month: - months, counts = list(zip(*captures_per_month)) + months, counts = list(zip(*captures_per_month, strict=False)) # tickvals_per_month = [f"{d:%b}" for d in days] tickvals = [f"{months[0]}", f"{months[-1]}"] # labels = [f"{d}" for d in months] @@ -257,7 +257,7 @@ def occurrences_accumulated(project_pk: int): occurrences_exist = Occurrence.objects.filter(project=project_pk).exists() if occurrences_exist: - days, counts = list(zip(*occurrences_per_day)) + days, counts = list(zip(*occurrences_per_day, strict=False)) # Accumulate the counts counts = list(itertools.accumulate(counts)) # tickvals = [f"{d:%b %d}" for d in days] @@ -288,7 +288,9 @@ def event_detections_per_hour(event_pk: int): # hours, counts = list(zip(*detections_per_hour)) if detections_per_hour: hours, counts = list( - zip(*[(d["source_image__timestamp__hour"], d["num_detections"]) for d in detections_per_hour]) + zip( + *[(d["source_image__timestamp__hour"], d["num_detections"]) for d in detections_per_hour], strict=False + ) ) hours, counts = shift_to_nighttime(list(hours), list(counts)) # @TODO show a tick for every hour even if there are no detections @@ -317,7 +319,7 @@ def event_top_taxa(event_pk: int, top_n: int = 10): ) if top_taxa: - taxa, counts = list(zip(*[(t["name"], t["num_detections"]) for t in reversed(top_taxa)])) + taxa, counts = list(zip(*[(t["name"], t["num_detections"]) for t in reversed(top_taxa)], strict=False)) taxa = [t or "Unknown" for t in taxa] counts = [c or 0 for c in counts] else: @@ -340,7 +342,7 @@ def project_top_taxa(project_pk: int, top_n: int = 10): ) if top_taxa: - taxa, counts = list(zip(*[(t.name, t.occurrence_count) for t in reversed(top_taxa)])) + taxa, counts = list(zip(*[(t.name, t.occurrence_count) for t in reversed(top_taxa)], strict=False)) else: taxa, counts = [], [] @@ -363,7 +365,7 @@ def unique_species_per_month(project_pk: int): ) # Create a dictionary mapping month numbers to species counts - month_to_count = {month: count for month, count in unique_species_per_month} + month_to_count = dict(unique_species_per_month) # Create lists for all 12 months, using 0 for months with no data all_months = list(range(1, 13)) # 1-12 for January-December @@ -393,7 +395,7 @@ def average_occurrences_per_month(project_pk: int): ) # Create a dictionary mapping month numbers to occurrence counts - month_to_count = {month: count for month, count in occurrences_per_month} + month_to_count = dict(occurrences_per_month) # Create lists for all 12 months, using 0 for months with no data all_months = list(range(1, 13)) # 1-12 for January-December diff --git a/ami/main/management/commands/import_taxa.py b/ami/main/management/commands/import_taxa.py index 73eb1cb08..0b25918be 100644 --- a/ami/main/management/commands/import_taxa.py +++ b/ami/main/management/commands/import_taxa.py @@ -15,7 +15,7 @@ from ...models import TaxaList, Taxon, TaxonRank -RANK_CHOICES = [rank for rank in TaxonRank] +RANK_CHOICES = list(TaxonRank) logger = logging.getLogger(__name__) # Set level @@ -34,7 +34,7 @@ def read_csv(fname: str) -> list[dict]: reader = csv.DictReader(open(fname)) - taxa = [row for row in reader] + taxa = list(reader) return taxa @@ -79,7 +79,7 @@ def fix_generic_names(taxon_data: dict) -> dict: fixed_taxon_data = taxon_data.copy() generic_names = ["sp.", "sp", "spp", "spp.", "cf.", "cf", "aff.", "aff"] fallback_name_keys = ["bold_taxon_bin", "inat_taxon_id", "gbif_taxon_key"] - for key, value in taxon_data.items(): + for _key, value in taxon_data.items(): if value and value.lower() in generic_names: # set name to first fallback name that exists fallback_name = None @@ -236,7 +236,7 @@ def handle(self, *args, **options): taxalist, created = TaxaList.objects.get_or_create(name=list_name) if created: - self.stdout.write(self.style.SUCCESS('Successfully created taxa list "%s"' % taxalist)) + self.stdout.write(self.style.SUCCESS(f'Successfully created taxa list "{taxalist}"')) if options["purge"]: self.stdout.write(self.style.WARNING("Purging all taxa from the database in 5 seconds...")) @@ -319,11 +319,11 @@ def create_taxon(self, taxon_data: dict, root_taxon_parent: Taxon) -> tuple[set[ # If the taxon already exists, use it and maybe update it taxon, created = Taxon.objects.get_or_create( name=name, - defaults=dict( - rank=rank, - gbif_taxon_key=gbif_taxon_key, - parent=parent_taxon, - ), + defaults={ + "rank": rank, + "gbif_taxon_key": gbif_taxon_key, + "parent": parent_taxon, + }, ) taxa_in_row.append(taxon) diff --git a/ami/main/management/commands/import_trapdata_project.py b/ami/main/management/commands/import_trapdata_project.py index fa43829a1..b8dee4238 100644 --- a/ami/main/management/commands/import_trapdata_project.py +++ b/ami/main/management/commands/import_trapdata_project.py @@ -62,7 +62,7 @@ def handle(self, *args, **options): project, created = Project.objects.get_or_create(name="Default Project") if created: - self.stdout.write(self.style.SUCCESS('Successfully created project "%s"' % project)) + self.stdout.write(self.style.SUCCESS(f'Successfully created project "{project}"')) algorithm, created = Algorithm.objects.get_or_create(name="Latest Model", version="1.0") for occurrence in occurrences: deployment, created = Deployment.objects.get_or_create( @@ -70,14 +70,14 @@ def handle(self, *args, **options): project=project, ) if created: - self.stdout.write(self.style.SUCCESS('Successfully created deployment "%s"' % deployment)) + self.stdout.write(self.style.SUCCESS(f'Successfully created deployment "{deployment}"')) event, created = Event.objects.get_or_create( start=parse_date(occurrence["event"]["day"]), deployment=deployment, ) if created: - self.stdout.write(self.style.SUCCESS('Successfully created event "%s"' % event)) + self.stdout.write(self.style.SUCCESS(f'Successfully created event "{event}"')) best_taxon, created = Taxon.objects.get_or_create(name=occurrence["label"]) occ = Occurrence.objects.create( @@ -86,7 +86,7 @@ def handle(self, *args, **options): project=project, determination=best_taxon, ) - self.stdout.write(self.style.SUCCESS('Successfully created occurrence "%s"' % occ)) + self.stdout.write(self.style.SUCCESS(f'Successfully created occurrence "{occ}"')) for example in occurrence["examples"]: try: @@ -100,9 +100,9 @@ def handle(self, *args, **options): size=example["source_image_filesize"], ) if created: - self.stdout.write(self.style.SUCCESS('Successfully created image "%s"' % image)) + self.stdout.write(self.style.SUCCESS(f'Successfully created image "{image}"')) except KeyError as e: - self.stdout.write(self.style.ERROR('Error creating image "%s"' % e)) + self.stdout.write(self.style.ERROR(f'Error creating image "{e}"')) image = None if image: @@ -114,7 +114,7 @@ def handle(self, *args, **options): bbox=example["bbox"], ) if created: - self.stdout.write(self.style.SUCCESS('Successfully created detection "%s"' % detection)) + self.stdout.write(self.style.SUCCESS(f'Successfully created detection "{detection}"')) else: detection = None @@ -132,7 +132,7 @@ def handle(self, *args, **options): ) if created: self.stdout.write( - self.style.SUCCESS('Successfully created classification "%s"' % classification) + self.style.SUCCESS(f'Successfully created classification "{classification}"') ) # Update event start and end times based on the first and last detections diff --git a/ami/main/management/commands/update_taxa.py b/ami/main/management/commands/update_taxa.py index 4e1d75c33..58ae2e77d 100644 --- a/ami/main/management/commands/update_taxa.py +++ b/ami/main/management/commands/update_taxa.py @@ -18,7 +18,7 @@ def read_csv(fname: str) -> list[dict[str, Any]]: with open(fname) as f: reader = csv.DictReader(f) - taxa = [row for row in reader] + taxa = list(reader) return taxa @@ -248,6 +248,6 @@ def handle(self, *args, **options): if not_found: self.stdout.write(self.style.WARNING(f"Could not find {len(not_found)} taxa")) for i, data in enumerate(not_found[:5]): # Show only first 5 - self.stdout.write(f" {i+1}. {data}") + self.stdout.write(f" {i + 1}. {data}") if len(not_found) > 5: self.stdout.write(f" ... and {len(not_found) - 5} more") diff --git a/ami/main/models.py b/ami/main/models.py index e3b2ab8ce..f65b7737b 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -1164,7 +1164,8 @@ def audit_event_lengths(deployment: Deployment): logger.warning(f"Found {events_over_24_hours} event(s) over 24 hours in deployment {deployment}. ") events_starting_before_noon = Event.objects.filter( - deployment=deployment, start__hour__lt=12 # Before hour 12 + deployment=deployment, + start__hour__lt=12, # Before hour 12 ).count() if events_starting_before_noon: logger.warning( @@ -1189,7 +1190,7 @@ def group_images_into_events( ) if dupes.count(): values = "\n".join( - [f'{d.strftime("%Y-%m-%d %H:%M:%S")} x{c}' for d, c in dupes.values_list("timestamp", "count")] + [f"{d.strftime('%Y-%m-%d %H:%M:%S')} x{c}" for d, c in dupes.values_list("timestamp", "count")] ) logger.warning( f"Found {len(values)} images with the same timestamp in deployment '{deployment}'. " @@ -1240,7 +1241,7 @@ def group_images_into_events( ) logger.info( - f"Done grouping {len(image_timestamps)} captures into {len(events)} events " f"for deployment {deployment}" + f"Done grouping {len(image_timestamps)} captures into {len(events)} events for deployment {deployment}" ) if delete_empty: @@ -1912,9 +1913,7 @@ def set_dimensions_for_collection( width, height = image.get_dimensions() if width and height: - logger.info( - f"Setting dimensions for {event.captures.count()} images in event {event.pk} to " f"{width}x{height}" - ) + logger.info(f"Setting dimensions for {event.captures.count()} images in event {event.pk} to {width}x{height}") if replace_existing: captures = event.captures.all() else: @@ -2106,9 +2105,7 @@ def save(self, *args, **kwargs): Identification.objects.filter( occurrence=self.occurrence, user=self.user, - ).exclude( - pk=self.pk - ).update(withdrawn=True) + ).exclude(pk=self.pk).update(withdrawn=True) super().save(*args, **kwargs) @@ -2260,7 +2257,7 @@ def predictions(self, sort=True) -> typing.Iterable[tuple[str, float]]: if not self.category_map: raise ValueError("Classification must have a category map to get predictions.") scores = self.scores or [] - preds = zip(self.category_map.labels, scores) + preds = zip(self.category_map.labels, scores, strict=False) if sort: return sorted(preds, key=lambda x: x[1], reverse=True) else: @@ -2278,7 +2275,7 @@ def predictions_with_taxa(self, sort=True) -> typing.Iterable[tuple["Taxon", flo scores = self.scores or [] category_data_with_taxa = self.category_map.with_taxa() taxa_sorted_by_index = [cat["taxon"] for cat in sorted(category_data_with_taxa, key=lambda cat: cat["index"])] - preds = zip(taxa_sorted_by_index, scores) + preds = zip(taxa_sorted_by_index, scores, strict=False) if sort: return sorted(preds, key=lambda x: x[1], reverse=True) else: @@ -2972,7 +2969,7 @@ def add_genus_parents(self): Create a genus if it doesn't exist based on the scientific name of the species. This will replace any parents of a species that are not of the GENUS rank. """ - Taxon: "Taxon" = self.model # type: ignore + Taxon: Taxon = self.model # type: ignore species = self.get_queryset().filter(rank=TaxonRank.SPECIES) # , parent=None) updated = [] for taxon in species: @@ -3010,9 +3007,11 @@ def update_display_names(self, queryset: models.QuerySet | None = None): self.bulk_update(taxa, ["display_name"]) # Method that returns taxa nested in a tree structure - def tree(self, root: typing.Optional["Taxon"] = None, filter_ranks: list[TaxonRank] = []) -> dict: + def tree(self, root: typing.Optional["Taxon"] = None, filter_ranks: list[TaxonRank] = None) -> dict: """Build a recursive tree of taxa.""" + if filter_ranks is None: + filter_ranks = [] root = root or self.root() # Fetch all taxa @@ -3795,7 +3794,7 @@ def sample_common_combined( def sample_interval( self, minute_interval: int = 10, - exclude_events: list[int] = [], + exclude_events: list[int] = None, deployment_id: int | None = None, # Deprecated hour_start: int | None = None, hour_end: int | None = None, @@ -3809,6 +3808,8 @@ def sample_interval( ): """Create a sample of source images based on a time interval""" + if exclude_events is None: + exclude_events = [] qs = self.get_queryset() qs = self._filter_sample( qs=qs, diff --git a/ami/main/models_future/filters.py b/ami/main/models_future/filters.py index 6689065c2..ba00d0aa9 100644 --- a/ami/main/models_future/filters.py +++ b/ami/main/models_future/filters.py @@ -8,9 +8,10 @@ from django.db.models import Q if TYPE_CHECKING: - from ami.main.models import Project, Taxon from rest_framework.request import Request + from ami.main.models import Project, Taxon + from ami.utils.requests import get_apply_default_filters_flag, get_default_classification_threshold diff --git a/ami/main/models_future/projects.py b/ami/main/models_future/projects.py index 19b2156e8..7ce4b608b 100644 --- a/ami/main/models_future/projects.py +++ b/ami/main/models_future/projects.py @@ -49,7 +49,7 @@ class ProjectSettingsMixin(models.Model): related_name="exclude_taxa_default_projects", blank=True, help_text=( - "Taxa that are excluded by default in the occurrence filters and metrics. " "For example, 'Not a Moth'." + "Taxa that are excluded by default in the occurrence filters and metrics. For example, 'Not a Moth'." ), ) diff --git a/ami/main/tests.py b/ami/main/tests.py index 3bd1f77a5..f0733fb70 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -345,7 +345,9 @@ def test_event_calculated_fields_batch(self): update_calculated_fields_for_events(last_updated=datetime.datetime(3000, 1, 1, 0, 0, 0)) - for event, last_updated in zip(self.deployment.events.all().order_by("pk"), last_updated_timestamps): + for event, last_updated in zip( + self.deployment.events.all().order_by("pk"), last_updated_timestamps, strict=False + ): self.assertEqual(event.captures_count, event.get_captures_count()) self.assertEqual(event.detections_count, event.get_detections_count()) self.assertEqual(event.occurrences_count, event.get_occurrences_count()) @@ -356,7 +358,9 @@ def test_event_calculated_fields_batch(self): update_calculated_fields_for_events(last_updated=datetime.datetime(3000, 1, 1, 0, 0, 0)) - for event, last_updated in zip(self.deployment.events.all().order_by("pk"), last_updated_timestamps): + for event, last_updated in zip( + self.deployment.events.all().order_by("pk"), last_updated_timestamps, strict=False + ): self.assertEqual(event.captures_count, event.get_captures_count()) self.assertEqual(event.detections_count, event.get_detections_count()) self.assertEqual(event.occurrences_count, event.get_occurrences_count()) diff --git a/ami/ml/models/algorithm.py b/ami/ml/models/algorithm.py index 48b2e2336..dfd3d6545 100644 --- a/ami/ml/models/algorithm.py +++ b/ami/ml/models/algorithm.py @@ -75,7 +75,7 @@ class AlgorithmCategoryMap(BaseModel): max_length=255, blank=True, null=True, - help_text=("A URI to the category map file. " "Could be a public web URL or object store path."), + help_text=("A URI to the category map file. Could be a public web URL or object store path."), ) algorithms: models.QuerySet[Algorithm] diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index a7470ec6a..6cdb02c7e 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ami.ml.models import ProcessingService, ProjectPipelineConfig from ami.jobs.models import Job + from ami.ml.models import ProcessingService, ProjectPipelineConfig import collections import dataclasses @@ -215,7 +215,7 @@ def process_images( if pipeline_config.get("reprocess_existing_detections", True): reprocess_existing_detections = True - for source_image, url in zip(images, urls): + for source_image, url in zip(images, urls, strict=False): if url: source_image_request = SourceImageRequest( id=str(source_image.pk), @@ -404,9 +404,9 @@ def get_or_create_detection( serialized_bbox = list(detection_resp.bbox.dict().values()) detection_repr = f"Detection {detection_resp.source_image_id} {serialized_bbox}" - assert str(detection_resp.source_image_id) == str( - source_image.pk - ), f"Detection belongs to a different source image: {detection_repr}" + assert str(detection_resp.source_image_id) == str(source_image.pk), ( + f"Detection belongs to a different source image: {detection_repr}" + ) existing_detection = Detection.objects.filter( source_image=source_image, @@ -595,9 +595,9 @@ def create_classification( :return: A tuple of the Classification object and a boolean indicating whether it was created """ - assert ( - classification_resp.algorithm - ), f"No classification algorithm was specified for classification {classification_resp}" + assert classification_resp.algorithm, ( + f"No classification algorithm was specified for classification {classification_resp}" + ) logger.debug(f"Processing classification {classification_resp}") try: @@ -705,7 +705,7 @@ def create_classifications( existing_classifications: list[Classification] = [] new_classifications: list[Classification] = [] - for detection, detection_resp in zip(detections, detection_responses): + for detection, detection_resp in zip(detections, detection_responses, strict=False): for classification_resp in detection_resp.classifications: classification, created = create_classification( detection=detection, @@ -1017,10 +1017,10 @@ def get_config(self, project_id: int | None = None) -> PipelineRequestConfigPara if project_pipeline_config.config: config.update(project_pipeline_config.config) logger.debug( - f"Using ProjectPipelineConfig for Pipeline {self} and Project #{project_id}:" f"config: {config}" + f"Using ProjectPipelineConfig for Pipeline {self} and Project #{project_id}:config: {config}" ) except self.project_pipeline_configs.model.DoesNotExist as e: - logger.warning(f"No project-pipeline config for Pipeline {self} " f"and Project #{project_id}: {e}") + logger.warning(f"No project-pipeline config for Pipeline {self} and Project #{project_id}: {e}") return config def collect_images( @@ -1056,8 +1056,7 @@ def choose_processing_service_for_pipeline( # get all processing services that are associated with the provided pipeline project processing_services = self.processing_services.filter(projects=project_id) task_logger.info( - f"Searching processing services:" - f"{[processing_service.name for processing_service in processing_services]}" + f"Searching processing services:{[processing_service.name for processing_service in processing_services]}" ) # check the status of all processing services diff --git a/ami/ml/views.py b/ami/ml/views.py index 0e0bcf2f9..375f20312 100644 --- a/ami/ml/views.py +++ b/ami/ml/views.py @@ -50,7 +50,7 @@ class AlgorithmViewSet(DefaultViewSet, ProjectMixin): search_fields = ["name"] def get_queryset(self) -> QuerySet["Algorithm"]: - qs: QuerySet["Algorithm"] = super().get_queryset() + qs: QuerySet[Algorithm] = super().get_queryset() qs = qs.with_category_count() # type: ignore[union-attr] # Custom queryset method return qs diff --git a/ami/tasks.py b/ami/tasks.py index 47abfbd09..bb4cf9c9b 100644 --- a/ami/tasks.py +++ b/ami/tasks.py @@ -94,7 +94,7 @@ def regroup_events(deployment_id: int) -> None: if deployment: logger.info(f"Grouping captures for {deployment}") events = group_images_into_events(deployment) - logger.info(f"{deployment } now has {len(events)} events") + logger.info(f"{deployment} now has {len(events)} events") else: logger.error(f"Deployment with id {deployment_id} not found") diff --git a/ami/tests/fixtures/images.py b/ami/tests/fixtures/images.py index 3a1984516..85498ea44 100644 --- a/ami/tests/fixtures/images.py +++ b/ami/tests/fixtures/images.py @@ -248,7 +248,7 @@ def generate_multiple_series( max_moths: int = 8, ) -> None: all_series_data = [] - for i in range(num_series): + for _i in range(num_series): num_moths = random.randint(min_moths, max_moths) series_data = generate_moth_series(frames_per_series, width, height, num_moths) all_series_data.extend(series_data) diff --git a/ami/tests/fixtures/main.py b/ami/tests/fixtures/main.py index 398a27605..196020f46 100644 --- a/ami/tests/fixtures/main.py +++ b/ami/tests/fixtures/main.py @@ -54,9 +54,9 @@ def create_processing_service(project: Project, name: str = "Test Processing Ser processing_service.save() if created: - logger.info(f'Successfully created processing service with {processing_service_to_add["endpoint_url"]}.') + logger.info(f"Successfully created processing service with {processing_service_to_add['endpoint_url']}.") else: - logger.info(f'Using existing processing service with {processing_service_to_add["endpoint_url"]}.') + logger.info(f"Using existing processing service with {processing_service_to_add['endpoint_url']}.") for project_data in processing_service_to_add["projects"]: try: @@ -64,7 +64,7 @@ def create_processing_service(project: Project, name: str = "Test Processing Ser processing_service.projects.add(project) processing_service.save() except Exception: - logger.error(f'Could not find project {project_data["name"]}.') + logger.error(f"Could not find project {project_data['name']}.") processing_service.get_status() processing_service.create_pipelines() @@ -82,16 +82,16 @@ def create_deployment( deployment, _ = Deployment.objects.get_or_create( project=project, name=name, - defaults=dict( - description=f"Created at {timezone.now()}", - data_source=data_source, - data_source_subdir="/", - data_source_regex=".*\\.jpg", - latitude=45.0, - longitude=-123.0, - research_site=project.sites.first(), - device=project.devices.first(), - ), + defaults={ + "description": f"Created at {timezone.now()}", + "data_source": data_source, + "data_source_subdir": "/", + "data_source_regex": ".*\\.jpg", + "latitude": 45.0, + "longitude": -123.0, + "research_site": project.sites.first(), + "device": project.devices.first(), + }, ) return deployment @@ -191,11 +191,11 @@ def create_captures_from_files( source_images = SourceImage.objects.filter(deployment=deployment).order_by("timestamp") source_images = [img for img in source_images if any(img.path.endswith(frame.filename) for frame in frame_data)] - assert len(source_images) == len( - frame_data - ), f"There are {len(source_images)} source images and {len(frame_data)} frame data items." + assert len(source_images) == len(frame_data), ( + f"There are {len(source_images)} source images and {len(frame_data)} frame data items." + ) frame_data = sorted(frame_data, key=lambda x: x.timestamp) - frames_with_images = list(zip(source_images, frame_data)) + frames_with_images = list(zip(source_images, frame_data, strict=False)) for source_image, frame in frames_with_images: assert source_image.timestamp == frame.timestamp assert source_image.path.endswith(frame.filename) @@ -216,10 +216,10 @@ def create_taxa(project: Project) -> TaxaList: species_taxa = [] taxon, _ = Taxon.objects.get_or_create( name=species, - defaults=dict( - parent=genus_taxon, - rank=TaxonRank.SPECIES.name, - ), + defaults={ + "parent": genus_taxon, + "rank": TaxonRank.SPECIES.name, + }, ) species_taxa.append(taxon) taxon.projects.add(project) @@ -272,7 +272,7 @@ def create_detections( source_image: SourceImage, bboxes: list[tuple[float, float, float, float]], ): - for i, bbox in enumerate(bboxes): + for _i, bbox in enumerate(bboxes): detection = Detection.objects.create( source_image=source_image, timestamp=source_image.timestamp, diff --git a/ami/tests/fixtures/storage.py b/ami/tests/fixtures/storage.py index d53c782ff..f6a87006b 100644 --- a/ami/tests/fixtures/storage.py +++ b/ami/tests/fixtures/storage.py @@ -26,14 +26,14 @@ def create_storage_source(project: Project, name: str, prefix: str = S3_TEST_CON data_source, _created = S3StorageSource.objects.get_or_create( project=project, name=name, - defaults=dict( - bucket=S3_TEST_CONFIG.bucket_name, - prefix=prefix, - endpoint_url=S3_TEST_CONFIG.endpoint_url, - access_key=S3_TEST_CONFIG.access_key_id, - secret_key=S3_TEST_CONFIG.secret_access_key, - public_base_url=S3_TEST_CONFIG.public_base_url, - ), + defaults={ + "bucket": S3_TEST_CONFIG.bucket_name, + "prefix": prefix, + "endpoint_url": S3_TEST_CONFIG.endpoint_url, + "access_key": S3_TEST_CONFIG.access_key_id, + "secret_key": S3_TEST_CONFIG.secret_access_key, + "public_base_url": S3_TEST_CONFIG.public_base_url, + }, ) return data_source diff --git a/ami/users/tests/test_forms.py b/ami/users/tests/test_forms.py index 23566fcf8..24f9f41c5 100644 --- a/ami/users/tests/test_forms.py +++ b/ami/users/tests/test_forms.py @@ -1,6 +1,7 @@ """ Module for all Form Tests. """ + from django.utils.translation import gettext_lazy as _ from ami.users.forms import UserAdminCreationForm diff --git a/ami/utils/s3.py b/ami/utils/s3.py index ce157b213..8cd9d118a 100644 --- a/ami/utils/s3.py +++ b/ami/utils/s3.py @@ -300,7 +300,7 @@ def list_files_paginated( regex = _compile_regex_filter(regex_filter) num_files_checked = 0 - for i, page in enumerate(page_iterator): + for _i, page in enumerate(page_iterator): if "Contents" in page: for obj in page["Contents"]: num_files_checked += 1 @@ -685,7 +685,7 @@ def test(): deployments = list_deployments(config, project) print("\tDeployments:", deployments) - for deployment in deployments: + for _deployment in deployments: # print("\t\tFile Count:", count_files(deployment)) for file, _ in list_files(config, limit=1): diff --git a/config/asgi.py b/config/asgi.py index eb6ad19e7..79067a391 100644 --- a/config/asgi.py +++ b/config/asgi.py @@ -7,6 +7,7 @@ https://docs.djangoproject.com/en/dev/howto/deployment/asgi/ """ + import os import sys from pathlib import Path diff --git a/config/settings/base.py b/config/settings/base.py index aec01b715..f3e1e2b87 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -430,9 +430,11 @@ # Default processing service settings # If not set, we will not create a default processing service DEFAULT_PROCESSING_SERVICE_NAME = env( - "DEFAULT_PROCESSING_SERVICE_NAME", default="Default Processing Service" # type: ignore[no-untyped-call] + "DEFAULT_PROCESSING_SERVICE_NAME", + default="Default Processing Service", # type: ignore[no-untyped-call] ) DEFAULT_PROCESSING_SERVICE_ENDPOINT = env( - "DEFAULT_PROCESSING_SERVICE_ENDPOINT", default=None # type: ignore[no-untyped-call] + "DEFAULT_PROCESSING_SERVICE_ENDPOINT", + default=None, # type: ignore[no-untyped-call] ) DEFAULT_PIPELINES_ENABLED = env.list("DEFAULT_PIPELINES_ENABLED", default=None) # type: ignore[no-untyped-call] diff --git a/config/wsgi.py b/config/wsgi.py index 98e8217a9..c58ee86fc 100644 --- a/config/wsgi.py +++ b/config/wsgi.py @@ -13,6 +13,7 @@ framework. """ + import os import sys from pathlib import Path diff --git a/processing_services/example/api/algorithms.py b/processing_services/example/api/algorithms.py index 8a80038dd..6e0f8f5be 100644 --- a/processing_services/example/api/algorithms.py +++ b/processing_services/example/api/algorithms.py @@ -210,7 +210,7 @@ def run(self, detections: list[Detection]) -> list[Detection]: end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - for detection, preds in zip(detections, results): + for detection, preds in zip(detections, results, strict=False): labels = [pred["label"] for pred in preds] scores = [pred["score"] for pred in preds] max_score_index = scores.index(max(scores)) @@ -264,7 +264,7 @@ def get_category_map(self) -> AlgorithmCategoryMapResponse: # Create labels and data labels = [id2label[str(i)] for i in indices] - data = [{"label": label, "index": idx} for idx, label in zip(indices, labels)] + data = [{"label": label, "index": idx} for idx, label in zip(indices, labels, strict=False)] # Build description description_text = ( diff --git a/processing_services/example/api/pipelines.py b/processing_services/example/api/pipelines.py index 02b31d0d9..69b061593 100644 --- a/processing_services/example/api/pipelines.py +++ b/processing_services/example/api/pipelines.py @@ -57,10 +57,16 @@ class Pipeline: def __init__( self, source_images: list[SourceImage], - request_config: PipelineRequestConfigParameters | dict = {}, - existing_detections: list[Detection] = [], - custom_batch_sizes: list[int] = [], + request_config: "PipelineRequestConfigParameters | dict | None" = None, + existing_detections: list[Detection] | None = None, + custom_batch_sizes: list[int] | None = None, ): + if custom_batch_sizes is None: + custom_batch_sizes = [] + if existing_detections is None: + existing_detections = [] + if request_config is None: + request_config = {} self.source_images = source_images self.request_config = request_config if isinstance(request_config, dict) else request_config.model_dump() self.existing_detections = existing_detections @@ -81,7 +87,7 @@ def get_stages(self) -> list[Algorithm]: def compile(self): logger.info("Compiling algorithms....") for stage_idx, stage in enumerate(self.stages): - logger.info(f"[{stage_idx+1}/{len(self.stages)}] Compiling {stage.algorithm_config_response.name}...") + logger.info(f"[{stage_idx + 1}/{len(self.stages)}] Compiling {stage.algorithm_config_response.name}...") stage.compile() def run(self) -> PipelineResultsResponse: diff --git a/pyproject.toml b/pyproject.toml index e050b2755..a74868c6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,3 +103,32 @@ indent_size = 2 [tool.djlint.js] indent_size = 2 + +# ==== Ruff Implementation ==== +[tool.ruff] +line-length = 119 +target-version = "py311" +extend-exclude = ["migrations", "static/CACHE", "docs", "node_modules"] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "UP", # pyupgrade + "DJ", # flake8-django + "B", # flake8-bugbear + "C4", # flake8-comprehensions +] +ignore = ["E501","B904","DJ001","DJ012","B007","B019","F823"] # Ruff is Black-compatible by default + +[tool.ruff.lint.isort] +known-first-party = ["ami", "config"] + +[tool.ruff.format] +# Use Black-compatible formatting style +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" diff --git a/requirements/base.txt b/requirements/base.txt index dd9de69d5..a6658fb20 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -51,7 +51,7 @@ django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail Werkzeug[watchdog]==2.3.6 # https://github.com/pallets/werkzeug ipdb==0.13.13 # https://github.com/gotcha/ipdb -psycopg[binary]==3.1.9 # https://github.com/psycopg/psycopg +psycopg[binary] # https://github.com/psycopg/psycopg watchfiles==0.19.0 # https://github.com/samuelcolvin/watchfiles # Testing @@ -71,10 +71,8 @@ sphinx-autobuild==2021.3.14 # https://github.com/GaretJax/sphinx-autobuild # Code quality # ------------------------------------------------------------------------------ -flake8==6.0.0 # https://github.com/PyCQA/flake8 -flake8-isort==6.0.0 # https://github.com/gforcada/flake8-isort +ruff==0.13.1 # https://github.com/astral-sh/ruff coverage==7.2.7 # https://github.com/nedbat/coveragepy -black==23.3.0 # https://github.com/psf/black djlint==1.31.1 # https://github.com/Riverside-Healthcare/djLint pylint-django==2.5.3 # https://github.com/PyCQA/pylint-django pylint-celery==0.3 # https://github.com/PyCQA/pylint-celery diff --git a/setup.cfg b/setup.cfg index 829064213..f361a07f6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,9 +1,5 @@ -# flake8 and pycodestyle don't support pyproject.toml -# https://github.com/PyCQA/flake8/issues/234 +# pycodestyle does not support pyproject.toml # https://github.com/PyCQA/pycodestyle/issues/813 -[flake8] -max-line-length = 119 -exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv,.venv [pycodestyle] max-line-length = 119 From eb126523b9463371c9b458fe1412627f2c99de24 Mon Sep 17 00:00:00 2001 From: TosinIbikunle Date: Tue, 14 Oct 2025 22:34:51 +0100 Subject: [PATCH 2/2] vscode settings for ruff linter --- .vscode/settings.json | 83 +++++++++++++++++++++---------------------- README.md | 3 ++ pyproject.toml | 14 -------- 3 files changed, 43 insertions(+), 57 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index cf74d0079..5d3969f3b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,47 +1,44 @@ { - "editor.formatOnSave": true, - "diffEditor.codeLens": true, - "eslint.enable": true, + "editor.formatOnSave": true, + "diffEditor.codeLens": true, + "eslint.enable": true, + "editor.codeActionsOnSave": { + "source.organizeImports": "explicit", + "source.sortImports": "explicit", + "source.fixAll.markdownlint": "explicit", + "source.fixAll": "explicit" + }, + + "typescript.format.enable": true, + "prettier.requireConfig": true, + + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", "editor.codeActionsOnSave": { - "source.organizeImports": "explicit", - "source.sortImports": "explicit", - "source.fixAll.markdownlint": "explicit", - "source.fixAll": "explicit" + "source.fixAll": "explicit" }, - "isort.args": [ - "--profile", - "black" - ], - "typescript.format.enable": true, - "prettier.requireConfig": true, - "[python]": { - "editor.defaultFormatter": "ms-python.black-formatter", - "editor.tabSize": 4, - "editor.rulers": [ - 119 - ] - }, - "black-formatter.args": ["--line-length", "119"], - "flake8.args": [ - "--max-line-length", "119" - ], - "[javascript]": { - "editor.defaultFormatter": "esbenp.prettier-vscode" - }, - "[javascriptreact]": { - "editor.defaultFormatter": "esbenp.prettier-vscode" - }, - "[typescript]": { - "editor.defaultFormatter": "esbenp.prettier-vscode" - }, - "[typescriptreact]": { - "editor.defaultFormatter": "esbenp.prettier-vscode" - }, - "[json]": { - "editor.defaultFormatter": "esbenp.prettier-vscode" - }, - "files.eol": "\n", - "files.insertFinalNewline": true, - "files.trimTrailingWhitespace": true, - "python.analysis.typeCheckingMode": "basic" + "editor.tabSize": 4, + "editor.rulers": [119] + }, + + "[javascript]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[javascriptreact]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[typescript]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[typescriptreact]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + "[json]": { + "editor.defaultFormatter": "esbenp.prettier-vscode" + }, + + "files.eol": "\n", + "files.insertFinalNewline": true, + "files.trimTrailingWhitespace": true, + "python.analysis.typeCheckingMode": "basic" } diff --git a/README.md b/README.md index 7d8a26eff..8483fc65d 100644 --- a/README.md +++ b/README.md @@ -252,3 +252,6 @@ The local environment uses a local PostgreSQL database in a Docker container. ### Load fixtures with test data docker compose run --rm django python manage.py migrate +### Linting + +This project uses Ruff for linting and formatting — it’s super fast and replaces tools like Black, isort, and Flake8 for a cleaner, all-in-one setup. diff --git a/pyproject.toml b/pyproject.toml index a74868c6f..25f6310ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,22 +14,8 @@ omit = ["*/migrations/*", "*/tests/*"] plugins = ["django_coverage_plugin"] -# ==== black ==== -[tool.black] -line-length = 119 -target-version = ['py311'] -# ==== isort ==== -[tool.isort] -profile = "black" -line_length = 119 -known_first_party = [ - "ami", - "config", -] -skip = ["venv/"] -skip_glob = ["**/migrations/*.py"] # ==== mypy ====