Skip to content

Commit 2df1630

Browse files
annavikmohamedelabbas1996mihow
authored
Make it possible to add tags to taxa (#892)
* Add support for Taxa Tags (#830) * feat: added support for tag creation for a given taxon * chore: moved Tag<->Taxon many2many to the Taxon model * feat: added the Tag admin model * feat: added custom assign_tags action to the TaxonViewSet to assign tags for a given taxon and return the assigned tags filtered by project * fixed formatting * merged migrations * feat: added tags to the Taxon admin model * feat: added global tags * added db migration * feat: show global tags if there is an active project * feat: added or based taxon filtering by tag_id * chore: renamed TaxonTagFilterBackend * fix: fixed filtering tags by project_id in the TaxonViewSet list * Add frontend support for taxa tags (#828) * feat: setup UI for taxon tags * feat: prepare UI controls for tag filtering * feat: hook up UI tags with backend * fix: update filter key from tag -> tag_id * fixed Taxon List tags column name * feat: added tags inverse filter * feat: return global tags with project tags * chore: reset tag migrations, add default tags * fix: remove invalid field in taxa list query * chore: add type hints for reverse relationships --------- Co-authored-by: Anna Viklund <annamariaviklund@gmail.com> Co-authored-by: Michael Bunsen <notbot@gmail.com> * chore: remove unused import * fix: pass project ID when assigning tags * chore: skip adding initial tags to the database * fix: update migration deps * fix: revert caption update to fix build * feat: put tags feature behind feature flag * feat: disable tag filters if no tags for project * chore: fix migration ordering * feat: store feature flags with project * fix: return feature flags as a dictionary in API response * style: simplify tags form --------- Co-authored-by: Mohamed Elabbas <hack1996man@gmail.com> Co-authored-by: Michael Bunsen <notbot@gmail.com>
1 parent 2455357 commit 2df1630

File tree

21 files changed

+564
-23
lines changed

21 files changed

+564
-23
lines changed

ami/main/admin.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Site,
2727
SourceImage,
2828
SourceImageCollection,
29+
Tag,
2930
TaxaList,
3031
Taxon,
3132
)
@@ -77,7 +78,18 @@ def save_related(self, request, form, formsets, change):
7778
inlines = [ProjectPipelineConfigInline]
7879

7980
fieldsets = (
80-
(None, {"fields": ("name", "description", "priority", "active")}),
81+
(
82+
None,
83+
{
84+
"fields": (
85+
"name",
86+
"description",
87+
"priority",
88+
"active",
89+
"feature_flags",
90+
)
91+
},
92+
),
8193
(
8294
"Ownership & Access",
8395
{
@@ -455,6 +467,7 @@ class TaxonAdmin(admin.ModelAdmin[Taxon]):
455467
"rank",
456468
"parent",
457469
"parent_names",
470+
"tag_list",
458471
"list_names",
459472
"created_at",
460473
"updated_at",
@@ -475,10 +488,10 @@ def get_queryset(self, request):
475488

476489
return qs.annotate(occurrence_count=models.Count("occurrences")).order_by("-occurrence_count")
477490

478-
@admin.display(
479-
description="Occurrences",
480-
ordering="occurrence_count",
481-
)
491+
@admin.display(description="Tags")
492+
def tag_list(self, obj) -> str:
493+
return ", ".join([tag.name for tag in obj.tags.all()])
494+
482495
def occurrence_count(self, obj) -> int:
483496
return obj.occurrence_count
484497

@@ -596,3 +609,10 @@ def populate_collection_async(self, request: HttpRequest, queryset: QuerySet[Sou
596609

597610
# Hide images many-to-many field from form. This would list all source images in the database.
598611
exclude = ("images",)
612+
613+
614+
@admin.register(Tag)
615+
class TagAdmin(admin.ModelAdmin):
616+
list_display = ("id", "name", "project")
617+
list_filter = ("project",)
618+
search_fields = ("name",)

ami/main/api/serializers.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ami.base.fields import DateStringField
1010
from ami.base.serializers import DefaultSerializer, MinimalNestedModelSerializer, get_current_user, reverse_with_params
1111
from ami.jobs.models import Job
12-
from ami.main.models import create_source_image_from_upload
12+
from ami.main.models import Tag, create_source_image_from_upload
1313
from ami.ml.models import Algorithm
1414
from ami.ml.serializers import AlgorithmSerializer
1515
from ami.users.models import User
@@ -274,14 +274,21 @@ class Meta:
274274

275275
class ProjectSerializer(DefaultSerializer):
276276
deployments = DeploymentNestedSerializerWithLocationAndCounts(many=True, read_only=True)
277+
feature_flags = serializers.SerializerMethodField()
277278
owner = UserNestedSerializer(read_only=True)
278279

280+
def get_feature_flags(self, obj):
281+
if obj.feature_flags:
282+
return obj.feature_flags.dict()
283+
return {}
284+
279285
class Meta:
280286
model = Project
281287
fields = ProjectListSerializer.Meta.fields + [
282288
"deployments",
283289
"summary_data", # @TODO move to a 2nd request, it's too slow
284290
"owner",
291+
"feature_flags",
285292
]
286293

287294

@@ -497,11 +504,32 @@ class Meta:
497504
]
498505

499506

507+
class TagSerializer(DefaultSerializer):
508+
project = ProjectNestedSerializer(read_only=True)
509+
project_id = serializers.PrimaryKeyRelatedField(queryset=Project.objects.all(), source="project", write_only=True)
510+
taxa_ids = serializers.PrimaryKeyRelatedField(
511+
queryset=Taxon.objects.all(), many=True, source="taxa", write_only=True, required=False
512+
)
513+
taxa = serializers.SerializerMethodField()
514+
515+
class Meta:
516+
model = Tag
517+
fields = ["id", "name", "project", "project_id", "taxa_ids", "taxa"]
518+
519+
def get_taxa(self, obj):
520+
return [{"id": taxon.id, "name": taxon.name} for taxon in obj.taxa.all()]
521+
522+
500523
class TaxonListSerializer(DefaultSerializer):
501524
# latest_detection = DetectionNestedSerializer(read_only=True)
502525
occurrences = serializers.SerializerMethodField()
503526
parents = TaxonNestedSerializer(read_only=True)
504527
parent_id = serializers.PrimaryKeyRelatedField(queryset=Taxon.objects.all(), source="parent")
528+
tags = serializers.SerializerMethodField()
529+
530+
def get_tags(self, obj):
531+
tag_list = getattr(obj, "prefetched_tags", [])
532+
return TagSerializer(tag_list, many=True, context=self.context).data
505533

506534
class Meta:
507535
model = Taxon
@@ -514,6 +542,7 @@ class Meta:
514542
"details",
515543
"occurrences_count",
516544
"occurrences",
545+
"tags",
517546
"last_detected",
518547
"best_determination_score",
519548
"created_at",
@@ -718,6 +747,12 @@ class TaxonSerializer(DefaultSerializer):
718747
parent = TaxonNoParentNestedSerializer(read_only=True)
719748
parent_id = serializers.PrimaryKeyRelatedField(queryset=Taxon.objects.all(), source="parent", write_only=True)
720749
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")
750+
tags = serializers.SerializerMethodField()
751+
752+
def get_tags(self, obj):
753+
# Use prefetched tags
754+
tag_list = getattr(obj, "prefetched_tags", [])
755+
return TagSerializer(tag_list, many=True, context=self.context).data
721756

722757
class Meta:
723758
model = Taxon
@@ -733,6 +768,7 @@ class Meta:
733768
"events_count",
734769
"occurrences",
735770
"gbif_taxon_key",
771+
"tags",
736772
"last_detected",
737773
"best_determination_score",
738774
]

ami/main/api/views.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from ami.base.serializers import FilterParamsSerializer, SingleParamSerializer
4444
from ami.base.views import ProjectMixin
45+
from ami.main.api.serializers import TagSerializer
4546
from ami.utils.requests import get_active_classification_threshold, project_id_doc_param
4647
from ami.utils.storages import ConnectionTestResult
4748

@@ -61,6 +62,7 @@
6162
SourceImage,
6263
SourceImageCollection,
6364
SourceImageUpload,
65+
Tag,
6466
TaxaList,
6567
Taxon,
6668
User,
@@ -1158,6 +1160,29 @@ def filter_queryset(self, request, queryset, view):
11581160
TaxonBestScoreFilter = ThresholdFilter.create("best_determination_score")
11591161

11601162

1163+
class TaxonTagFilter(filters.BaseFilterBackend):
1164+
"""FilterBackend that allows OR-based filtering of taxa by tag ID."""
1165+
1166+
def filter_queryset(self, request, queryset, view):
1167+
tag_ids = request.query_params.getlist("tag_id")
1168+
if tag_ids:
1169+
queryset = queryset.filter(tags__id__in=tag_ids).distinct()
1170+
return queryset
1171+
1172+
1173+
class TagInverseFilter(filters.BaseFilterBackend):
1174+
"""
1175+
Exclude taxa that have any of the specified tag IDs using `not_tag_id`.
1176+
Example: /api/v2/taxa/?not_tag_id=1&not_tag_id=2
1177+
"""
1178+
1179+
def filter_queryset(self, request, queryset, view):
1180+
not_tag_ids = request.query_params.getlist("not_tag_id")
1181+
if not_tag_ids:
1182+
queryset = queryset.exclude(tags__id__in=not_tag_ids)
1183+
return queryset.distinct()
1184+
1185+
11611186
class TaxonViewSet(DefaultViewSet, ProjectMixin):
11621187
"""
11631188
API endpoint that allows taxa to be viewed or edited.
@@ -1170,6 +1195,8 @@ class TaxonViewSet(DefaultViewSet, ProjectMixin):
11701195
TaxonCollectionFilter,
11711196
TaxonTaxaListFilter,
11721197
TaxonBestScoreFilter,
1198+
TaxonTagFilter,
1199+
TagInverseFilter,
11731200
]
11741201
filterset_fields = [
11751202
"name",
@@ -1294,6 +1321,7 @@ def get_queryset(self) -> QuerySet:
12941321
"""
12951322
qs = super().get_queryset()
12961323
project = self.get_active_project()
1324+
qs = self.attach_tags_by_project(qs, project)
12971325

12981326
if project:
12991327
# Allow showing detail views for unobserved taxa
@@ -1377,6 +1405,43 @@ def get_taxa_observed(self, qs: QuerySet, project: Project, include_unobserved=F
13771405
)
13781406
return qs
13791407

1408+
def attach_tags_by_project(self, qs: QuerySet, project: Project) -> QuerySet:
1409+
"""
1410+
Prefetch and override the `.tags` attribute on each Taxon
1411+
with only the tags belonging to the given project.
1412+
"""
1413+
# Include all tags if no project is passed
1414+
if project is None:
1415+
tag_qs = Tag.objects.all()
1416+
else:
1417+
# Prefetch only the tags that belong to the project or are global
1418+
tag_qs = Tag.objects.filter(models.Q(project=project) | models.Q(project__isnull=True))
1419+
1420+
tag_prefetch = Prefetch("tags", queryset=tag_qs, to_attr="prefetched_tags")
1421+
1422+
return qs.prefetch_related(tag_prefetch)
1423+
1424+
@action(detail=True, methods=["post"])
1425+
def assign_tags(self, request, pk=None):
1426+
"""
1427+
Assign tags to a taxon
1428+
"""
1429+
taxon = self.get_object()
1430+
tag_ids = request.data.get("tag_ids")
1431+
logger.info(f"Tag IDs: {tag_ids}")
1432+
if not isinstance(tag_ids, list):
1433+
return Response({"detail": "tag_ids must be a list of IDs."}, status=status.HTTP_400_BAD_REQUEST)
1434+
1435+
tags = Tag.objects.filter(id__in=tag_ids)
1436+
logger.info(f"Tags: {tags}, len: {len(tags)}")
1437+
taxon.tags.set(tags) # replaces all tags for this taxon
1438+
taxon.save()
1439+
logger.info(f"Tags after assingment : {len(taxon.tags.all())}")
1440+
return Response(
1441+
{"taxon_id": taxon.id, "assigned_tag_ids": [tag.pk for tag in tags]},
1442+
status=status.HTTP_200_OK,
1443+
)
1444+
13801445
@extend_schema(parameters=[project_id_doc_param])
13811446
def list(self, request, *args, **kwargs):
13821447
return super().list(request, *args, **kwargs)
@@ -1395,6 +1460,20 @@ def get_queryset(self):
13951460
serializer_class = TaxaListSerializer
13961461

13971462

1463+
class TagViewSet(DefaultViewSet, ProjectMixin):
1464+
queryset = Tag.objects.all()
1465+
serializer_class = TagSerializer
1466+
filterset_fields = ["taxa"]
1467+
1468+
def get_queryset(self):
1469+
qs = super().get_queryset()
1470+
project = self.get_active_project()
1471+
if project:
1472+
# Filter by project, but also include global tags
1473+
return qs.filter(models.Q(project=project) | models.Q(project__isnull=True))
1474+
return qs
1475+
1476+
13981477
class ClassificationViewSet(DefaultViewSet, ProjectMixin):
13991478
"""
14001479
API endpoint for viewing and adding classification results from a model.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Generated by Django 4.2.10 on 2025-05-15 21:23
2+
3+
from django.db import migrations, models
4+
import django.db.models.deletion
5+
6+
7+
class Migration(migrations.Migration):
8+
dependencies = [
9+
("main", "0060_alter_sourceimagecollection_method"),
10+
]
11+
12+
operations = [
13+
migrations.CreateModel(
14+
name="Tag",
15+
fields=[
16+
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
17+
("created_at", models.DateTimeField(auto_now_add=True)),
18+
("updated_at", models.DateTimeField(auto_now=True)),
19+
("name", models.CharField(max_length=255)),
20+
(
21+
"project",
22+
models.ForeignKey(
23+
blank=True,
24+
null=True,
25+
on_delete=django.db.models.deletion.CASCADE,
26+
related_name="tags",
27+
to="main.project",
28+
),
29+
),
30+
],
31+
options={
32+
"unique_together": {("name", "project")},
33+
},
34+
),
35+
migrations.AddField(
36+
model_name="taxon",
37+
name="tags",
38+
field=models.ManyToManyField(blank=True, related_name="taxa", to="main.tag"),
39+
),
40+
]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Generated by Django 4.2.10 on 2025-07-28 19:05
2+
3+
import ami.main.models
4+
from django.db import migrations
5+
import django_pydantic_field.fields
6+
7+
8+
class Migration(migrations.Migration):
9+
dependencies = [
10+
("main", "0061_tag_taxon_tags"),
11+
]
12+
13+
operations = [
14+
migrations.AddField(
15+
model_name="project",
16+
name="feature_flags",
17+
field=django_pydantic_field.fields.PydanticSchemaField(
18+
blank=True, config=None, default={"tags": False}, schema=ami.main.models.ProjectFeatureFlags
19+
),
20+
),
21+
]

0 commit comments

Comments
 (0)