Skip to content
This repository has been archived by the owner on Apr 14, 2024. It is now read-only.

Commit

Permalink
Add position filtering by specialization and skills
Browse files Browse the repository at this point in the history
  • Loading branch information
OlegYurchik committed Mar 22, 2024
1 parent 24cdf06 commit f4c4139
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 34 deletions.
4 changes: 2 additions & 2 deletions autotests/clients/rest/projects/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def get_positions(
specialization_ids: list[uuid.UUID] | Type[Empty] = Empty,
skill_ids: list[uuid.UUID] | Type[Empty] = Empty,
joined_user_id: uuid.UUID | Type[Empty] = Empty,
project_query_text: str | Type[Empty] = Empty,
query: str | Type[Empty] = Empty,
project_startline_ge: datetime | Type[Empty] = Empty,
project_startline_le: datetime | Type[Empty] = Empty,
project_deadline_ge: datetime | Type[Empty] = Empty,
Expand All @@ -156,7 +156,7 @@ async def get_positions(
"specialization_ids": specialization_ids,
"skill_ids": skill_ids,
"joined_user_id": joined_user_id,
"project_query_text": project_query_text,
"query": query,
"project_startline_ge": project_startline_ge,
"project_startline_le": project_startline_le,
"project_deadline_ge": project_deadline_ge,
Expand Down
6 changes: 3 additions & 3 deletions autotests/rest/projects/test_positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
pytest.lazy_fixture("random_projects_rest_client"),
))
@pytest.mark.parametrize(
("project_id", "specialization_ids", "skill_ids", "joined_user_id", "project_query_text",
("project_id", "specialization_ids", "skill_ids", "joined_user_id", "query",
"project_startline_ge", "project_startline_le", "project_deadline_ge", "project_deadline_le",
"project_statuses", "page", "per_page"),
(
Expand Down Expand Up @@ -48,7 +48,7 @@ async def test_get_positions(
specialization_ids: list[uuid.UUID] | Type[Empty],
skill_ids: list[uuid.UUID] | Type[Empty],
joined_user_id: uuid.UUID | Type[Empty],
project_query_text: str | Type[Empty],
query: str | Type[Empty],
project_startline_ge: datetime | Type[Empty],
project_startline_le: datetime | Type[Empty],
project_deadline_ge: datetime | Type[Empty],
Expand All @@ -62,7 +62,7 @@ async def test_get_positions(
specialization_ids=specialization_ids,
skill_ids=skill_ids,
joined_user_id=joined_user_id,
project_query_text=project_query_text,
query=query,
project_startline_ge=project_startline_ge,
project_startline_le=project_startline_le,
project_deadline_ge=project_deadline_ge,
Expand Down
2 changes: 1 addition & 1 deletion sapphire/projects/api/rest/positions/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def get_positions(
"specialization_ids": filters.specialization_ids,
"skill_ids": filters.skill_ids,
"joined_user_id": filters.joined_user_id,
"project_query_text": filters.project_query_text,
"query": filters.query,
"project_startline_ge": filters.project_startline_ge,
"project_startline_le": filters.project_startline_le,
"project_deadline_ge": filters.project_deadline_ge,
Expand Down
2 changes: 1 addition & 1 deletion sapphire/projects/api/rest/positions/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class PositionListFiltersRequest(BaseModel):
specialization_ids: list[uuid.UUID] | Type[Empty] = Field(fastapi.Query(Empty))
skill_ids: list[uuid.UUID] | Type[Empty] = Field(fastapi.Query(Empty))
joined_user_id: uuid.UUID | Type[Empty] = Empty
project_query_text: str | Type[Empty] = Empty
query: str | Type[Empty] = Empty
project_startline_ge: AwareDatetime | Type[Empty] = Empty
project_startline_le: AwareDatetime | Type[Empty] = Empty
project_deadline_ge: AwareDatetime | Type[Empty] = Empty
Expand Down
2 changes: 1 addition & 1 deletion sapphire/projects/api/rest/projects/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def get_projects(
async with database_service.transaction() as session:
params = {
"session": session,
"query_text": filters.query_text,
"query": filters.query,
"owner_id": filters.owner_id,
"user_id": filters.user_id,
"startline_le": filters.startline_le,
Expand Down
2 changes: 1 addition & 1 deletion sapphire/projects/api/rest/projects/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ProjectListResponse(PaginatedResponse):


class ProjectListFiltersRequest(BaseModel):
query_text: str | Type[Empty] = Empty
query: str | Type[Empty] = Empty
owner_id: uuid.UUID | Type[Empty] = Empty
user_id: uuid.UUID | Type[Empty] = Empty
startline_ge: AwareDatetime | Type[Empty] = Empty
Expand Down
55 changes: 34 additions & 21 deletions sapphire/projects/database/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
Project,
ProjectHistory,
ProjectStatusEnum,
Skill,
Specialization,
Review,
User,
)
Expand Down Expand Up @@ -137,7 +139,7 @@ async def _get_positions_filters(
specialization_ids: list[uuid.UUID] | Type[Empty] = Empty,
skill_ids: list[uuid.UUID] | Type[Empty] = Empty,
joined_user_id: uuid.UUID | Type[Empty] = Empty,
project_query_text: str | Type[Empty] = Empty,
query: str | Type[Empty] = Empty,
project_startline_ge: datetime | Type[Empty] = Empty,
project_startline_le: datetime | Type[Empty] = Empty,
project_deadline_ge: datetime | Type[Empty] = Empty,
Expand Down Expand Up @@ -167,10 +169,21 @@ async def _get_positions_filters(
))
))

if project_query_text is not Empty:
project_filters.append(or_(
Project.name.icontains(project_query_text),
Project.description.icontains(project_query_text),
if query is not Empty:
filters.append(or_(
Position.project_id.in_(select(Project.id).where(or_(
Project.name.icontains(query),
Project.description.icontains(query),
))),
Position.specialization_id.in_(select(Specialization.id).where(or_(
Specialization.name.icontains(query),
Specialization.name_en.icontains(query),
))),
Position.id.in_(select(PositionSkill.position_id).where(
PositionSkill.skill_id.in_(select(Skill.id).where(
Skill.name.icontains(query),
)),
)),
))
if project_startline_ge is not Empty:
project_filters.append(Project.startline >= project_startline_ge)
Expand Down Expand Up @@ -210,7 +223,7 @@ async def get_positions_count(
specialization_ids: list[uuid.UUID] | Type[Empty] = Empty,
skill_ids: list[uuid.UUID] | Type[Empty] = Empty,
joined_user_id: uuid.UUID | Type[Empty] = Empty,
project_query_text: str | Type[Empty] = Empty,
query: str | Type[Empty] = Empty,
project_startline_ge: datetime | Type[Empty] = Empty,
project_startline_le: datetime | Type[Empty] = Empty,
project_deadline_ge: datetime | Type[Empty] = Empty,
Expand All @@ -222,7 +235,7 @@ async def get_positions_count(
specialization_ids=specialization_ids,
skill_ids=skill_ids,
joined_user_id=joined_user_id,
project_query_text=project_query_text,
query=query,
project_startline_ge=project_startline_ge,
project_startline_le=project_startline_le,
project_deadline_ge=project_deadline_ge,
Expand All @@ -241,7 +254,7 @@ async def get_positions(
specialization_ids: list[uuid.UUID] | Type[Empty] = Empty,
skill_ids: list[uuid.UUID] | Type[Empty] = Empty,
joined_user_id: uuid.UUID | Type[Empty] = Empty,
project_query_text: str | Type[Empty] = Empty,
query: str | Type[Empty] = Empty,
project_startline_ge: datetime | Type[Empty] = Empty,
project_startline_le: datetime | Type[Empty] = Empty,
project_deadline_ge: datetime | Type[Empty] = Empty,
Expand All @@ -255,7 +268,7 @@ async def get_positions(
specialization_ids=specialization_ids,
skill_ids=skill_ids,
joined_user_id=joined_user_id,
project_query_text=project_query_text,
query=query,
project_startline_ge=project_startline_ge,
project_startline_le=project_startline_le,
project_deadline_ge=project_deadline_ge,
Expand Down Expand Up @@ -486,7 +499,7 @@ async def update_participant_status(

async def _get_projects_filters(
self,
query_text: str | Type[Empty] = Empty,
query: str | Type[Empty] = Empty,
owner_id: uuid.UUID | Type[Empty] = Empty,
user_id: uuid.UUID | Type[Empty] = Empty,
startline_le: datetime | Type[Empty] = Empty,
Expand All @@ -502,11 +515,11 @@ async def _get_projects_filters(
position_filters = []
participant_filters = []

if query_text is not Empty:
if query is not Empty:
filters.append(
or_(
Project.name.icontains(query_text),
Project.description.icontains(query_text),
Project.name.icontains(query),
Project.description.icontains(query),
)
)
if owner_id is not Empty:
Expand Down Expand Up @@ -568,7 +581,7 @@ async def _get_projects_filters(
async def get_projects_count(
self,
session: AsyncSession,
query_text: str | Type[Empty] = Empty,
query: str | Type[Empty] = Empty,
owner_id: uuid.UUID | Type[Empty] = Empty,
user_id: uuid.UUID | Type[Empty] = Empty,
startline_le: datetime | Type[Empty] = Empty,
Expand All @@ -580,9 +593,9 @@ async def get_projects_count(
position_specialization_ids: list[uuid.UUID] | Type[Empty] = Empty,
participant_user_ids: list[uuid.UUID] | Type[Empty] = Empty,
) -> int:
query = select(func.count(Project.id)) # pylint: disable=not-callable
statement = select(func.count(Project.id)) # pylint: disable=not-callable
filters = await self._get_projects_filters(
query_text=query_text,
query=query,
owner_id=owner_id,
user_id=user_id,
startline_le=startline_le,
Expand All @@ -596,13 +609,13 @@ async def get_projects_count(
)
query = query.where(*filters)

result = await session.scalar(query)
result = await session.scalar(statement)
return result

async def get_projects(
self,
session: AsyncSession,
query_text: str | Type[Empty] = Empty,
query: str | Type[Empty] = Empty,
owner_id: uuid.UUID | Type[Empty] = Empty,
user_id: uuid.UUID | Type[Empty] = Empty,
startline_le: datetime | Type[Empty] = Empty,
Expand All @@ -616,9 +629,9 @@ async def get_projects(
page: int = 1,
per_page: int = 10,
) -> list[Project]:
query = select(Project).order_by(Project.created_at.desc())
statement = select(Project).order_by(Project.created_at.desc())
filters = await self._get_projects_filters(
query_text=query_text,
query=query,
owner_id=owner_id,
user_id=user_id,
startline_le=startline_le,
Expand All @@ -635,7 +648,7 @@ async def get_projects(
offset = (page - 1) * per_page
query = query.limit(per_page).offset(offset)

result = await session.execute(query)
result = await session.execute(statement)

return list(result.unique().scalars().all())

Expand Down
4 changes: 2 additions & 2 deletions sapphire/users/cache/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import uuid
import secrets

from sapphire.common.cache.service import BaseCacheService

Expand All @@ -7,7 +7,7 @@

class Service(BaseCacheService):
async def set_state(self) -> str:
state = str(uuid.uuid4())
state = secrets.token_hex(32)
key = f"users:auth:oauth2:habr:state:{state}"
await self.redis.set(key, state, ex=120)
return state
Expand Down
4 changes: 2 additions & 2 deletions tests/projects/database/test_projects_database_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def test_get_projects_with_all_query_params(service: Service):
startline_le = datetime.now(tz=timezone.utc) + timedelta(days=30)
deadline_ge = datetime.now(tz=timezone.utc) - timedelta(days=30)
deadline_le = datetime.now(tz=timezone.utc) + timedelta(days=30)
query_text = "query_text"
query = "query"
position_skill_ids = [uuid.uuid4(), uuid.uuid4()]
position_specialization_ids = [uuid.uuid4(), uuid.uuid4()]
expected_projects = [Project(id=project_id, name="test", owner_id=owner_id)]
Expand All @@ -124,7 +124,7 @@ async def test_get_projects_with_all_query_params(service: Service):

projects = await service.get_projects(
session=session,
query_text=query_text,
query=query,
owner_id=owner_id,
deadline_ge=deadline_ge,
deadline_le=deadline_le,
Expand Down

0 comments on commit f4c4139

Please sign in to comment.