Skip to content

Commit

Permalink
Working!
Browse files Browse the repository at this point in the history
  • Loading branch information
peytondmurray committed Jan 3, 2025
1 parent f8bf29f commit 54f7907
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ async def api_list_environments(
query=query,
ordering_metadata=OrderingMetadata(
order_names=["namespace", "name"],
# column_names=['namespace.name', 'name'],
column_names=[orm.Namespace.name, orm.Environment.name],
column_names=["namespace.name", "name"],
column_objects=[orm.Namespace.name, orm.Environment.name],
),
cursor=cursor,
order_by=paginated_args.sort_by,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
import pydantic
from fastapi import HTTPException
from sqlalchemy import asc, desc, tuple_
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.orm import Query as SqlQuery
from sqlalchemy.sql.expression import ColumnClause

from conda_store_server._internal.orm import Base


class Cursor(pydantic.BaseModel):
last_id: int | None = 0
Expand All @@ -23,15 +26,47 @@ class Cursor(pydantic.BaseModel):
last_value: dict[str, str] | None = {}

def dump(self) -> str:
"""Dump the cursor as a b64-encoded string.
Returns
-------
str
base64-encoded string containing the information needed
to retrieve the page of data following the location of the cursor
"""
return base64.b64encode(self.model_dump_json().encode("utf8"))

@classmethod
def load(cls, data: str | None = None) -> Cursor | None:
"""Create a Cursor from a b64-encoded string.
Parameters
----------
data : str | None
base64-encoded string containing information about the cursor
Returns
-------
Cursor | None
Cursor representation of the b64-encoded string
"""
if data is None:
return None
return cls.from_json(base64.b64decode(data).decode("utf8"))

def get_last_values(self, order_names: list[str]) -> list[Any]:
"""Get a list of the values corresponding to the order_names.
Parameters
----------
order_names : list[str]
List of names of values stored in the cursor
Returns
-------
list[Any]
The last values pointed to by the cursor for the given order_names
"""
if order_names:
return [self.last_value[name] for name in order_names]
else:
Expand All @@ -43,7 +78,6 @@ def paginate(
ordering_metadata: OrderingMetadata,
cursor: Cursor | None = None,
order_by: list[str] | None = None,
# valid_order_by: dict[str, str] | None = None,
order: str = "asc",
limit: int = 10,
) -> tuple[SqlQuery, Cursor]:
Expand All @@ -61,8 +95,6 @@ def paginate(
----------
query : SqlQuery
Query containing database results to paginate
valid_order_by : dict[str, str] | None
Mapping between valid names to order by and the column names on the orm object they apply to
cursor : Cursor | None
Cursor object containing information about the last item on the previous page.
If None, the first page is returned.
Expand Down Expand Up @@ -106,10 +138,9 @@ def paginate(
)
)

breakpoint()
query = query.order_by(
*[order_func(col) for col in columns], order_func(queried_type.id)
)
order_by_args = [order_func(col) for col in columns] + [order_func(queried_type.id)]

query = query.order_by(*order_by_args)
data = query.limit(limit).all()
count = query.count()

Expand Down Expand Up @@ -159,11 +190,13 @@ def __init__(
self,
order_names: list[str] | None = None,
column_names: list[str] | None = None,
column_objects: list[InstrumentedAttribute] | None = None,
):
self.order_names = order_names
self.column_names = column_names
self.column_objects = column_objects

def validate(self, model: Any):
def validate(self, model: Base):
if len(self.order_names) != len(self.column_names):
raise ValueError(
"Each name of a valid ordering available to the order_by query parameter"
Expand Down Expand Up @@ -196,14 +229,19 @@ def get_requested_columns(
if order_by:
for order_name in order_by:
idx = self.order_names.index(order_name)
# columns.append(text(self.column_names[idx]))
columns.append(self.column_names[idx])
columns.append(self.column_objects[idx])

return columns

def __str__(self) -> str:
return f"OrderingMetadata<order_names={self.order_names}, column_names={self.column_names}>"

def __repr__(self) -> str:
return str(self)

def get_attr_values(
self,
obj: Any,
obj: Base,
order_by: list[str] | None = None,
) -> dict[str, Any]:
"""Using the order_by values, get the corresponding attribute values on obj.
Expand All @@ -223,7 +261,6 @@ def get_attr_values(
A mapping between the `order_by` values and the attribute values on `obj`
"""
breakpoint()
values = {}
for order_name in order_by:
idx = self.order_names.index(order_name)
Expand All @@ -233,7 +270,31 @@ def get_attr_values(
return values


def get_nested_attribute(obj: Any, attr: str) -> Any:
def get_nested_attribute(obj: Base, attr: str) -> str | int | float:
"""Get a nested attribute from the given sqlalchemy model.
Parameters
----------
obj : Base
A sqlalchemy model for which a (possibly nested) attribute is to be
retrieved
attr : str
String attribute; nested attributes should be separated with `.`
Returns
-------
str | int | float
Value of the attribute; strictly this can be any column type supported
by sqlalchemy, but for conda-store this is a str, an int, or a float
Examples
--------
>>> env = db.query(orm.Environment).join(orm.Namespace).first()
>>> get_nested_attribute(env, 'namespace.name')
'namespace1'
>>> get_nested_attribute(env, 'name')
'my_environment'
"""
attribute, *rest = attr.split(".")
while len(rest) > 0:
obj = getattr(obj, attribute)
Expand Down
19 changes: 15 additions & 4 deletions conda-store-server/tests/_internal/server/views/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,8 +1118,8 @@ def test_api_list_environments_by_namespace(
model = schema.APIListEnvironment.model_validate(response.json())
assert model.status == schema.APIStatus.OK

env_names = [env.namespace.name for env in model.data]
assert sorted(env_names, reverse=order == "desc") == env_names
namespace_names = [env.namespace.name for env in model.data]
assert sorted(namespace_names, reverse=order == "desc") == namespace_names


@pytest.mark.parametrize(
Expand All @@ -1145,5 +1145,16 @@ def test_api_list_environments_by_namespace_name(
model = schema.APIListEnvironment.model_validate(response.json())
assert model.status == schema.APIStatus.OK

env_names = [env.namespace.name for env in model.data]
assert sorted(env_names, reverse=order == "desc") == env_names
# Get the namespace and environment names from the returned environments
namespace_names = [env.namespace.name for env in model.data]
env_names = [env.name for env in model.data]

# Check that they are identical to what we get if we sort them with python
# by both the namespace name and then environment name
sorted_envs = sorted(
model.data,
reverse=order == "desc",
key=lambda env: (env.namespace.name, env.name),
)
assert [env.name for env in sorted_envs] == env_names
assert [env.namespace.name for env in sorted_envs] == namespace_names
19 changes: 8 additions & 11 deletions conda-store-server/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import typing
import uuid
from collections import defaultdict

import pytest
import yaml
Expand Down Expand Up @@ -169,24 +170,24 @@ def seed_conda_store(db, conda_store):

@pytest.fixture
def seed_conda_store_big(db, conda_store):
default = {}
namespace1 = {}
namespace2 = {}
"""Seed the conda-store db with 150 randomly named envs in 5 random namespaces."""
namespace_names = [str(uuid.uuid4()) for _ in range(5)]
namespaces = defaultdict(dict)
for i in range(50):
name = "".join(random.choices(string.ascii_letters, k=10))
default[name] = schema.CondaSpecification(
namespaces[random.choice(namespace_names)][name] = schema.CondaSpecification(
name=name, channels=["defaults"], dependencies=["numpy"]
)

name = "".join(random.choices(string.ascii_letters, k=11))
namespace1[name] = schema.CondaSpecification(
namespaces[random.choice(namespace_names)][name] = schema.CondaSpecification(
name=name,
channels=["defaults"],
dependencies=["flask"],
)

name = "".join(random.choices(string.ascii_letters, k=12))
namespace2[name] = schema.CondaSpecification(
namespaces[random.choice(namespace_names)][name] = schema.CondaSpecification(
name=name,
channels=["defaults"],
dependencies=["flask"],
Expand All @@ -195,11 +196,7 @@ def seed_conda_store_big(db, conda_store):
_seed_conda_store(
db,
conda_store,
{
"default": default,
"namespace1": namespace1,
"namespace2": namespace2,
},
namespaces,
)

# for testing purposes make build 4 complete
Expand Down

0 comments on commit 54f7907

Please sign in to comment.