Skip to content

Commit

Permalink
api-server: fix postgres label sorting (#957)
Browse files Browse the repository at this point in the history
* fix postgres label sorting

Signed-off-by: Teo Koon Peng <teokoonpeng@gmail.com>

* succesful test with both postgres and sqlite

Signed-off-by: Teo Koon Peng <teokoonpeng@gmail.com>

* fix lint errors

Signed-off-by: Teo Koon Peng <teokoonpeng@gmail.com>

---------

Signed-off-by: Teo Koon Peng <teokoonpeng@gmail.com>
koonpeng authored Jun 24, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 49a866a commit bb88e1c
Showing 21 changed files with 394 additions and 326 deletions.
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ isort = "==5.13.2"
pylint = "==3.1.0"
coverage = "~=5.5"
# api-server
api-server = {editable = true, path = "./packages/api-server"}
api-server = {editable = true, path = "./packages/api-server", extras = ["postgres"]}
httpx = "~=0.26.0"
datamodel-code-generator = "==0.25.4"
requests = "~=2.25"
420 changes: 240 additions & 180 deletions Pipfile.lock

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions packages/api-server/README.md
Original file line number Diff line number Diff line change
@@ -252,18 +252,24 @@ Restart the `api-server` and the changes to the databse should be reflected.
### Running unit tests

```bash
npm test
pnpm test
```

By default in-memory sqlite database is used for testing, to test on another database, set the `RMF_API_SERVER_TEST_DB_URL` environment variable.

```bash
RMF_API_SERVER_TEST_DB_URL=<db_url> pnpm test
```

### Collecting code coverage

```bash
npm run test:cov
pnpm run test:cov
```

Generate coverage report
```bash
npm run test:report
pnpm run test:report
```

## Live reload
6 changes: 5 additions & 1 deletion packages/api-server/api_server/dependencies.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,11 @@ def pagination_query(
) -> Pagination:
limit = limit or 100
offset = offset or 0
return Pagination(limit=limit, offset=offset, order_by=order_by)
return Pagination(
limit=limit,
offset=offset,
order_by=order_by.split(",") if order_by else [],
)


# hacky way to get the sio user
4 changes: 1 addition & 3 deletions packages/api-server/api_server/models/pagination.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Optional

from pydantic import BaseModel


class Pagination(BaseModel):
limit: int
offset: int
order_by: Optional[str]
order_by: list[str]
51 changes: 6 additions & 45 deletions packages/api-server/api_server/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import tortoise.functions as tfuncs
from tortoise.expressions import Q
from tortoise.queryset import MODEL, QuerySet

from api_server.models.pagination import Pagination
@@ -8,47 +6,10 @@
def add_pagination(
query: QuerySet[MODEL],
pagination: Pagination,
field_mappings: dict[str, str] | None = None,
group_by: str | None = None,
) -> QuerySet[MODEL]:
"""
Adds pagination and ordering to a query. If the order field starts with `label=`, it is
assumed to be a label and label sorting will used. In this case, the model must have
a reverse relation named "labels" and the `group_by` param is required.
:param field_mapping: A dict mapping the order fields to the fields used to build the
query. e.g. a url of `?order_by=order_field` and a field mapping of `{"order_field": "db_field"}`
will order the query result according to `db_field`.
:param group_by: Required when sorting by labels, must be the foreign key column of the label table.
"""
field_mappings = field_mappings or {}
annotations = {}
query = query.limit(pagination.limit).offset(pagination.offset)
if pagination.order_by is not None:
order_fields = []
order_values = pagination.order_by.split(",")
for v in order_values:
# perform the mapping after stripping the order prefix
order_prefix = ""
order_field = v
if v[0] in ["-", "+"]:
order_prefix = v[0]
order_field = v[1:]
order_field = field_mappings.get(order_field, order_field)

# add annotations required for sorting by labels
if order_field.startswith("label="):
f = order_field[6:]
annotations[f"label_sort_{f}"] = tfuncs.Max(
"labels__label_value",
_filter=Q(labels__label_name=f),
)
order_field = f"label_sort_{f}"

order_fields.append(order_prefix + order_field)

query = query.annotate(**annotations)
if group_by is not None:
query = query.group_by(group_by)
query = query.order_by(*order_fields)
return query
"""Adds pagination and ordering to a query"""
return (
query.limit(pagination.limit)
.offset(pagination.offset)
.order_by(*pagination.order_by)
)
85 changes: 80 additions & 5 deletions packages/api-server/api_server/repositories/tasks.py
Original file line number Diff line number Diff line change
@@ -2,10 +2,11 @@
from datetime import datetime
from typing import Dict, List, Optional, Sequence, Tuple

import tortoise.functions as tfuncs
from fastapi import Depends, HTTPException
from tortoise.exceptions import FieldError, IntegrityError
from tortoise.expressions import Expression, Q
from tortoise.query_utils import Prefetch
from tortoise.queryset import QuerySet
from tortoise.transactions import in_transaction

from api_server.authenticator import user_dep
@@ -18,14 +19,14 @@
TaskEventLog,
TaskRequest,
TaskState,
TaskStatus,
User,
)
from api_server.models import tortoise_models as ttm
from api_server.models.rmf_api.log_entry import Tier
from api_server.models.rmf_api.task_state import Category, Id, Phase
from api_server.models.tortoise_models import TaskRequest as DbTaskRequest
from api_server.models.tortoise_models import TaskState as DbTaskState
from api_server.query import add_pagination
from api_server.rmf_io import task_events


@@ -96,11 +97,85 @@ async def save_task_state(self, task_state: TaskState) -> None:
await self.save_task_labels(db_task_state, labels)

async def query_task_states(
self, query: QuerySet[DbTaskState], pagination: Optional[Pagination] = None
self,
task_id: list[str] | None = None,
category: list[str] | None = None,
assigned_to: list[str] | None = None,
start_time_between: tuple[datetime, datetime] | None = None,
finish_time_between: tuple[datetime, datetime] | None = None,
status: list[str] | None = None,
label: Labels | None = None,
pagination: Optional[Pagination] = None,
) -> List[TaskState]:
filters = {}
if task_id is not None:
filters["id___in"] = task_id
if category is not None:
filters["category__in"] = category
if assigned_to is not None:
filters["assigned_to__in"] = assigned_to
if start_time_between is not None:
filters["unix_millis_start_time__gte"] = start_time_between[0]
filters["unix_millis_start_time__lte"] = start_time_between[1]
if finish_time_between is not None:
filters["unix_millis_finish_time__gte"] = finish_time_between[0]
filters["unix_millis_finish_time__lte"] = finish_time_between[1]
if status is not None:
valid_values = [member.value for member in TaskStatus]
filters["status__in"] = []
for status_string in status:
if status_string not in valid_values:
continue
filters["status__in"].append(TaskStatus(status_string))
query = DbTaskState.filter(**filters)

need_group_by = False
label_filters = {}
if label is not None:
label_filters.update(
{
f"label_filter_{k}": tfuncs.Count(
"id_",
_filter=Q(labels__label_name=k, labels__label_value=v),
)
for k, v in label.root.items()
}
)

if len(label_filters) > 0:
filter_gt = {f"{f}__gt": 0 for f in label_filters}
query = query.annotate(**label_filters).filter(**filter_gt)
need_group_by = True

if pagination:
order_fields: list[str] = []
annotations: dict[str, Expression] = {}
# add annotations required for sorting by labels
for f in pagination.order_by:
order_prefix = f[0] if f[0] == "-" else ""
order_field = f[1:] if order_prefix == "-" else f
if order_field.startswith("label="):
f = order_field[6:]
annotations[f"label_sort_{f}"] = tfuncs.Max(
"labels__label_value",
_filter=Q(labels__label_name=f),
)
order_field = f"label_sort_{f}"

order_fields.append(order_prefix + order_field)

query = (
query.annotate(**annotations)
.limit(pagination.limit)
.offset(pagination.offset)
.order_by(*order_fields)
)
need_group_by = True

if need_group_by:
query = query.group_by("id_", "labels__state_id")

try:
if pagination:
query = add_pagination(query, pagination, group_by="labels__state_id")
# TODO: enforce with authz
results = await query.values_list("data")
return [TaskState(**r[0]) for r in results]
Original file line number Diff line number Diff line change
@@ -134,7 +134,7 @@ async def get_scheduled_tasks(
.offset(pagination.offset)
)
if pagination.order_by:
q.order_by(*pagination.order_by.split(","))
q.order_by(*pagination.order_by)
results = await q
await ttm.ScheduledTask.fetch_for_list(results)
return [ScheduledTask.model_validate(x) for x in results]
58 changes: 10 additions & 48 deletions packages/api-server/api_server/routes/tasks/tasks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from datetime import datetime
from typing import List, Optional, Tuple, cast

import tortoise.functions as tfuncs
from fastapi import Body, Depends, HTTPException, Path, Query
from reactivex import operators as rxops
from tortoise.expressions import Q

from api_server import models as mdl
from api_server.dependencies import (
@@ -15,7 +13,6 @@
start_time_between_query,
)
from api_server.fast_io import FastIORouter, SubscriptionRequest
from api_server.models.tortoise_models import TaskState as DbTaskState
from api_server.repositories import TaskRepository, task_repo_dep
from api_server.response import RawJSONResponse
from api_server.rmf_io import task_events, tasks_service
@@ -60,51 +57,16 @@ async def query_task_states(
),
pagination: mdl.Pagination = Depends(pagination_query),
):
filters = {}
if task_id is not None:
filters["id___in"] = task_id.split(",")
if category is not None:
filters["category__in"] = category.split(",")
if assigned_to is not None:
filters["assigned_to__in"] = assigned_to.split(",")
if start_time_between is not None:
filters["unix_millis_start_time__gte"] = start_time_between[0]
filters["unix_millis_start_time__lte"] = start_time_between[1]
if finish_time_between is not None:
filters["unix_millis_finish_time__gte"] = finish_time_between[0]
filters["unix_millis_finish_time__lte"] = finish_time_between[1]
if status is not None:
valid_values = [member.value for member in mdl.TaskStatus]
filters["status__in"] = []
for status_string in status.split(","):
if status_string not in valid_values:
continue
filters["status__in"].append(mdl.TaskStatus(status_string))
query = DbTaskState.filter(**filters)

label_filters = {}
if label is not None:
labels = mdl.Labels.from_strings(label.split(","))
label_filters.update(
{
f"label_filter_{k}": tfuncs.Count(
"id_", _filter=Q(labels__label_name=k, labels__label_value=v)
)
for k, v in labels.root.items()
}
)

if len(label_filters) > 0:
filter_gt = {f"{f}__gt": 0 for f in label_filters}
query = (
query.annotate(**label_filters)
.group_by(
"labels__state_id"
) # need to group by a related field to make tortoise-orm generate joins
.filter(**filter_gt)
)

return await task_repo.query_task_states(query, pagination)
return await task_repo.query_task_states(
task_id=task_id.split(",") if task_id else None,
category=category.split(",") if category else None,
assigned_to=assigned_to.split(",") if assigned_to else None,
start_time_between=start_time_between,
finish_time_between=finish_time_between,
status=status.split(",") if status else None,
label=mdl.Labels.from_strings(label.split(",")) if label else None,
pagination=pagination,
)


@router.get("/{task_id}/state", response_model=mdl.TaskState)
9 changes: 3 additions & 6 deletions packages/api-server/api_server/routes/tasks/test_tasks.py
Original file line number Diff line number Diff line change
@@ -36,15 +36,12 @@ def setUpClass(cls):
cls.task_logs = [make_task_log(task_id=f"test_{x}") for x in task_ids]
cls.clsSetupErr: str | None = None

if cls.client.portal is None:
cls.clsSetupErr = "missing client portal, is the client context entered?"
return

portal = cls.get_portal()
repo = TaskRepository(cls.admin_user)
for x in cls.task_states:
cls.client.portal.call(repo.save_task_state, x)
portal.call(repo.save_task_state, x)
for x in cls.task_logs:
cls.client.portal.call(repo.save_task_log, x)
portal.call(repo.save_task_log, x)

def setUp(self):
super().setUp()
4 changes: 2 additions & 2 deletions packages/api-server/api_server/routes/test_building_map.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from api_server.rmf_io import rmf_events
from api_server.test import AppFixture, make_building_map, try_until


class TestBuildingMapRoute(AppFixture):
def test_get_building_map(self):
building_map = make_building_map()
rmf_events.building_map.on_next(building_map)
portal = self.get_portal()
portal.call(building_map.save)

resp = try_until(
lambda: self.client.get("/building_map"), lambda x: x.status_code == 200
4 changes: 2 additions & 2 deletions packages/api-server/api_server/routes/test_dispensers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from typing import List
from uuid import uuid4

@@ -12,8 +11,9 @@ def setUpClass(cls):
super().setUpClass()
cls.dispenser_states = [make_dispenser_state(f"test_{uuid4()}")]

portal = cls.get_portal()
for x in cls.dispenser_states:
asyncio.run(x.save())
portal.call(x.save)

def test_get_dispensers(self):
resp = self.client.get("/dispensers")
6 changes: 3 additions & 3 deletions packages/api-server/api_server/routes/test_doors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from uuid import uuid4

from rmf_door_msgs.msg import DoorMode as RmfDoorMode
@@ -12,12 +11,13 @@ class TestDoorsRoute(AppFixture):
def setUpClass(cls):
super().setUpClass()
cls.building_map = make_building_map()
asyncio.run(cls.building_map.save())
portal = cls.get_portal()
portal.call(cls.building_map.save)

cls.door_states = [make_door_state(f"test_{uuid4()}")]

for x in cls.door_states:
asyncio.run(x.save())
portal.call(x.save)

def test_get_doors(self):
resp = self.client.get("/doors")
Loading

0 comments on commit bb88e1c

Please sign in to comment.