diff --git a/src/bootlace/table/base.py b/src/bootlace/table/base.py index 8da15d2..49a7cfc 100644 --- a/src/bootlace/table/base.py +++ b/src/bootlace/table/base.py @@ -30,10 +30,19 @@ def __tag__(self) -> tags.html_tag: @attrs.define class ColumnBase(ABC): heading: Heading = attrs.field(converter=maybe(Heading)) # type: ignore - attribute: str | None = None + _attribute: str | None = None + + def __set_name__(self, owner: type, name: str) -> None: + self._attribute = self._attribute or name + + @property + def attribute(self) -> str: + if self._attribute is None: + raise ValueError("column must be named in Table or attribute= parameter must be provided") + return self._attribute @abstractmethod - def cell(self, name: str, value: Any) -> tags.html_tag: + def cell(self, value: Any) -> tags.html_tag: raise NotImplementedError("Subclasses must implement this method") @@ -106,8 +115,8 @@ def render(self, items: list[Any]) -> tags.html_tag: for item in items: id = getattr(item, "id", None) tr = tags.tr(id=f"item-{id}" if id else None) - for column_name, column in self.columns.items(): - td = column.cell(column_name, item) + for column in self.columns.values(): + td = column.cell(item) tr.add(td) tbody.add(tr) table.add(tbody) diff --git a/src/bootlace/table/columns.py b/src/bootlace/table/columns.py index 45e885b..797ae04 100644 --- a/src/bootlace/table/columns.py +++ b/src/bootlace/table/columns.py @@ -12,17 +12,17 @@ @attrs.define class Column(ColumnBase): - def cell(self, name: str, value: Any) -> tags.html_tag: - return tags.td(getattr(value, name)) + def cell(self, value: Any) -> tags.html_tag: + return tags.td(getattr(value, self.attribute)) @attrs.define class EditColumn(ColumnBase): endpoint: str = attrs.field(default=".edit") - def cell(self, name: str, value: Any) -> tags.html_tag: + def cell(self, value: Any) -> tags.html_tag: id = getattr(value, "id", None) - return tags.td(tags.a(getattr(value, name), href=url_for(self.endpoint, id=id))) + return tags.td(tags.a(getattr(value, self.attribute), href=url_for(self.endpoint, id=id))) @attrs.define @@ -31,8 +31,8 @@ class CheckColumn(ColumnBase): yes: Icon = attrs.field(default=Icon("check", width=16, height=16)) no: Icon = attrs.field(default=Icon("x", width=16, height=16)) - def cell(self, name: str, value: Any) -> tags.html_tag: - if getattr(value, name): + def cell(self, value: Any) -> tags.html_tag: + if getattr(value, self.attribute): return tags.td(as_tag(self.yes)) return tags.td(as_tag(self.no)) @@ -40,5 +40,5 @@ def cell(self, name: str, value: Any) -> tags.html_tag: @attrs.define class Datetime(ColumnBase): - def cell(self, name: str, value: Any) -> tags.html_tag: - return tags.td(getattr(value, name).isoformat()) + def cell(self, value: Any) -> tags.html_tag: + return tags.td(getattr(value, self.attribute).isoformat()) diff --git a/tests/table/test_columns.py b/tests/table/test_columns.py index 87a2ff4..1537106 100644 --- a/tests/table/test_columns.py +++ b/tests/table/test_columns.py @@ -24,6 +24,14 @@ class Item: when: dt.datetime = dt.datetime(2021, 1, 1, 12, 18, 5) +def test_unnamed_column() -> None: + + col = EditColumn(heading="Edit", endpoint="index") + + with pytest.raises(ValueError): + col.attribute + + @pytest.mark.usefixtures("homepage") def test_edit_column(app: Flask) -> None: @@ -34,7 +42,7 @@ def test_edit_column(app: Flask) -> None: assert str(th) == "Edit" with app.test_request_context("/"): - td = col.cell("editor", Item()) + td = col.cell(Item()) expected = '