Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post #134

Merged
merged 1 commit into from
Nov 3, 2024
Merged

Post #134

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api/birdxplorer_api/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TopicId,
TwitterTimestamp,
UserEnrollment,
UserId,
)
from birdxplorer_common.storage import Storage

Expand Down Expand Up @@ -255,6 +256,7 @@ def get_posts(
request: Request,
post_ids: Union[List[PostId], None] = Query(default=None),
note_ids: Union[List[NoteId], None] = Query(default=None),
user_ids: Union[List[UserId], None] = Query(default=None),
created_at_from: Union[None, TwitterTimestamp, str] = Query(
default=None, **V1DataPostsDocs.params["created_at_from"]
),
Expand All @@ -275,6 +277,7 @@ def get_posts(
storage.get_posts(
post_ids=post_ids,
note_ids=note_ids,
user_ids=user_ids,
start=created_at_from,
end=created_at_to,
search_text=search_text,
Expand All @@ -288,6 +291,7 @@ def get_posts(
total_count = storage.get_number_of_posts(
post_ids=post_ids,
note_ids=note_ids,
user_ids=user_ids,
start=created_at_from,
end=created_at_to,
search_text=search_text,
Expand Down
6 changes: 5 additions & 1 deletion api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def _get_number_of_notes(
def _get_posts(
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
user_ids: Union[List[str], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
Expand All @@ -403,6 +404,8 @@ def _get_posts(
note.note_id in note_ids and note.post_id == post.post_id for note in note_samples
):
continue
if user_ids is not None and post.x_user_id not in user_ids:
continue
if start is not None and post.created_at < start:
continue
if end is not None and post.created_at >= end:
Expand All @@ -426,12 +429,13 @@ def _get_posts(
def _get_number_of_posts(
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
user_ids: Union[List[str], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
search_url: Union[HttpUrl, None] = None,
) -> int:
return len(list(_get_posts(post_ids, note_ids, start, end, search_text, search_url)))
return len(list(_get_posts(post_ids, note_ids, user_ids, start, end, search_text, search_url)))

mock.get_number_of_posts.side_effect = _get_number_of_posts

Expand Down
6 changes: 6 additions & 0 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def get_posts(
self,
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
user_ids: Union[List[UserId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
Expand All @@ -452,6 +453,8 @@ def get_posts(
query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter(
NoteRecord.note_id.in_(note_ids)
)
if user_ids is not None:
query = query.filter(PostRecord.user_id.in_(user_ids))
if start is not None:
query = query.filter(PostRecord.created_at >= start)
if end is not None:
Expand All @@ -474,6 +477,7 @@ def get_number_of_posts(
self,
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
user_ids: Union[List[UserId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
Expand All @@ -487,6 +491,8 @@ def get_number_of_posts(
query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter(
NoteRecord.note_id.in_(note_ids)
)
if user_ids is not None:
query = query.filter(PostRecord.user_id.in_(user_ids))
if start is not None:
query = query.filter(PostRecord.created_at >= start)
if end is not None:
Expand Down
Loading