Skip to content

Commit

Permalink
Use Whoosh for search
Browse files Browse the repository at this point in the history
  • Loading branch information
wsanchez committed Nov 8, 2023
1 parent 112b60a commit 6182d86
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 50 deletions.
8 changes: 7 additions & 1 deletion src/transmissions/model/_transmission.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,25 @@
from datetime import datetime as DateTime
from datetime import timedelta as TimeDelta
from pathlib import Path
from typing import ClassVar, TypeAlias

from attrs import field, frozen


__all__ = ()


KeyType = TypeAlias


@frozen(kw_only=True, order=True)
class Transmission:
"""
Radio transmission
"""

Key: ClassVar[TypeAlias] = KeyType

startTime: DateTime
eventID: str
station: str
Expand All @@ -53,7 +59,7 @@ def endTime(self) -> DateTime | None:
return self.startTime + self.duration

@property
def key(self) -> tuple[str, str, str, DateTime]:
def key(self) -> Key:
return (self.eventID, self.system, self.channel, self.startTime)

def __str__(self) -> str:
Expand Down
13 changes: 3 additions & 10 deletions src/transmissions/run/_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,19 +275,12 @@ async def app(store: TXDataStore) -> None:
transmissionsByKey = {t.key: t for t in await store.transmissions()}

if search:
results = set()
index = TransmissionsIndex()
await index.connect()
await index.add(transmissionsByKey.values())
async for result in index.search(search):
key = (
result["eventID"],
result["system"],
result["channel"],
result["startTime"],
)
results.add(transmissionsByKey[key])
transmissions: Iterable[Transmission] = results
transmissions: Iterable[Transmission] = [
transmissionsByKey[key] async for key in index.search(search)
]
else:
transmissions = transmissionsByKey.values()

Expand Down
12 changes: 9 additions & 3 deletions src/transmissions/search/_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import AsyncIterable, Iterable
from enum import Enum, auto
from typing import Any, ClassVar
from typing import ClassVar

from attrs import mutable
from whoosh.fields import DATETIME, ID, NUMERIC, TEXT, Schema
Expand Down Expand Up @@ -91,7 +91,7 @@ async def clear(self) -> None:
# way to await on completion of the indexing thread here.
writer.commit(mergetype=CLEAR)

async def search(self, queryText: str) -> AsyncIterable[dict[str, Any]]:
async def search(self, queryText: str) -> AsyncIterable[Transmission.Key]:
"""
Perform search.
"""
Expand All @@ -102,4 +102,10 @@ async def search(self, queryText: str) -> AsyncIterable[dict[str, Any]]:

with self._index.searcher() as searcher:
for result in searcher.search(query, limit=None):
yield result.fields()
result.fields()
yield (
result["eventID"],
result["system"],
result["channel"],
result["startTime"],
)
28 changes: 4 additions & 24 deletions src/transmissions/tui/_transmissionlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def control(self) -> Widget:
dateTimeDisplayFormat = reactive("ddd YY/MM/DD HH:mm:ss")
timeZone = reactive("US/Pacific")
searchQuery = reactive("")
displayKeys: reactive[frozenset[str] | None] = reactive(None)

def __init__(self, id: str) -> None:
self._tableData: TransmissionTableData = ()
Expand Down Expand Up @@ -126,27 +127,6 @@ def dateTimeFromDisplayText(self, displayText: str) -> DateTime:
arrow = makeArrow(displayText, self.dateTimeDisplayFormat)
return arrow.datetime

def filterTable(self, row: TransmissionTableRowCells, key: str) -> bool:
query = self.searchQuery
if not query:
return True

transcription = row[8]

if not transcription:
return False

for term in query.split():
if term:
term = term.lower()
if transcription.lower().find(term) != -1:
continue
break
else:
return True

return False

def updateTable(self) -> None:
self.log("Updating table")
columns = []
Expand All @@ -173,7 +153,7 @@ def updateTable(self) -> None:
table = self.query_one(DataTable)
table.clear()
for row, key in self._tableData:
if self.filterTable(row, key):
if self.displayKeys is None or key in self.displayKeys:
table.add_row(*[row[column] for column in columns], key=key)

# def sortKey(startTime: str) -> Any:
Expand Down Expand Up @@ -237,8 +217,8 @@ def watch_transmissions(
except Exception as e:
self.log(f"Unable to update transmissions: {e}")

def watch_searchQuery(self, searchQuery: str) -> None:
self.log(f"Received search query: {searchQuery}")
def watch_displayKeys(self, displayKeys: frozenset[str]) -> None:
self.log(f"Received display keys: {displayKeys}")
try:
self.updateTable()
except Exception as e:
Expand Down
46 changes: 34 additions & 12 deletions src/transmissions/tui/_transmissionsscreen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from textual.screen import Screen

from transmissions.model import Transmission
from transmissions.search import TransmissionsIndex

from ._body import Body
from ._footer import Footer
Expand All @@ -18,15 +19,8 @@
__all__ = ()


def transmissionKey(transmission: Transmission) -> str:
return ":".join(
(
transmission.eventID,
transmission.system,
transmission.channel,
str(transmission.startTime),
)
)
def transmissionTableKey(key: Transmission.Key) -> str:
return ":".join(str(i) for i in key)


def transmissionAsTuple(
Expand Down Expand Up @@ -66,7 +60,7 @@ class TransmissionsScreen(Screen):

def __init__(self, transmissions: tuple[Transmission, ...]) -> None:
self.transmissionsByKey = {
transmissionKey(transmission): transmission
transmissionTableKey(transmission.key): transmission
for transmission in transmissions
}

Expand All @@ -81,6 +75,14 @@ async def on_mount(self) -> None:
for key, transmission in self.transmissionsByKey.items()
)

try:
transmissionsIndex = TransmissionsIndex()
await transmissionsIndex.connect()
await transmissionsIndex.add(self.transmissionsByKey.values())
self._transcriptionsIndex = transmissionsIndex
except Exception as e:
self.log(f"Unable to index transmissions: {e}")

def compose(self) -> ComposeResult:
yield Header("Radio Transmissions", id="Header")
yield Footer(
Expand All @@ -106,11 +108,31 @@ def handleTransmissionSelected(
)

@on(SearchField.QueryUpdated)
def handleSearchQueryUpdated(
async def handleSearchQueryUpdated(
self, message: SearchField.QueryUpdated
) -> None:
self.log(f"Search query: {message.query}")

searchQuery = message.query.strip()

transmissionList = cast(
TransmissionList, self.query_one("TransmissionList")
)
transmissionList.searchQuery = message.query
transmissionList.searchQuery = searchQuery

if searchQuery:
try:
keys = frozenset(
{
transmissionTableKey(result)
async for result in self._transcriptionsIndex.search(
message.query
)
}
)
self.log(f"{keys}")
transmissionList.displayKeys = keys
except Exception as e:
self.log(f"Unable to perform search: {e}")
else:
transmissionList.displayKeys = None

0 comments on commit 6182d86

Please sign in to comment.