Skip to content

Commit

Permalink
🔧 split filter_needs func by needs type
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell committed Sep 4, 2024
1 parent 0948928 commit 14683a1
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 43 deletions.
17 changes: 6 additions & 11 deletions sphinx_needs/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from sphinx.builders import Builder

from sphinx_needs.config import NeedsSphinxConfig
from sphinx_needs.data import NeedsInfoType, SphinxNeedsData
from sphinx_needs.data import SphinxNeedsData
from sphinx_needs.directives.need import post_process_needs_data
from sphinx_needs.filter_common import filter_needs_view
from sphinx_needs.logging import get_logger, log_warning
from sphinx_needs.needsfile import NeedsList

Expand Down Expand Up @@ -80,12 +81,10 @@ def finish(self) -> None:
# This is needed as needs could have been removed from documentation and if this is the case,
# removed needs would stay in needs_list, if list gets not cleaned.
needs_list.wipe_version(version)
#
from sphinx_needs.filter_common import filter_needs

filter_string = needs_config.builder_filter
filtered_needs: list[NeedsInfoType] = filter_needs(
data.get_needs_view().values(),
filtered_needs = filter_needs_view(
data.get_needs_view(),
needs_config,
filter_string,
append_warning="(from need_builder_filter)",
Expand Down Expand Up @@ -177,16 +176,12 @@ def finish(self) -> None:
post_process_needs_data(self.app)

data = SphinxNeedsData(self.env)
needs = (
data.get_needs_view().values()
) # We need a list of needs for later filter checks
version = getattr(self.env.config, "version", "unset")
needs_config = NeedsSphinxConfig(self.env.config)
filter_string = needs_config.builder_filter
from sphinx_needs.filter_common import filter_needs

filtered_needs = filter_needs(
needs,
filtered_needs = filter_needs_view(
data.get_needs_view(),
needs_config,
filter_string,
append_warning="(from need_builder_filter)",
Expand Down
10 changes: 8 additions & 2 deletions sphinx_needs/directives/needbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from sphinx_needs.config import NeedsSphinxConfig
from sphinx_needs.data import NeedsBarType, SphinxNeedsData
from sphinx_needs.filter_common import FilterBase, expand_needs_view, filter_needs
from sphinx_needs.filter_common import (
FilterBase,
expand_needs_view,
filter_needs_parts,
)
from sphinx_needs.logging import get_logger, log_warning
from sphinx_needs.utils import (
add_doc,
Expand Down Expand Up @@ -302,7 +306,9 @@ def process_needbar(
line_number.append(float(element))
else:
result = len(
filter_needs(need_list, needs_config, element, location=node)
filter_needs_parts(
need_list, needs_config, element, location=node
)
)
line_number.append(float(result))
local_data_number.append(line_number)
Expand Down
6 changes: 3 additions & 3 deletions sphinx_needs/directives/needextend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
NeedsMutable,
SphinxNeedsData,
)
from sphinx_needs.filter_common import filter_needs
from sphinx_needs.filter_common import filter_needs_mutable
from sphinx_needs.logging import get_logger, log_warning
from sphinx_needs.utils import add_doc

Expand Down Expand Up @@ -133,8 +133,8 @@ def extend_needs_data(
else:
# a filter string
try:
found_needs = filter_needs(
all_needs.values(),
found_needs = filter_needs_mutable(
all_needs,
needs_config,
need_filter,
location=(
Expand Down
4 changes: 2 additions & 2 deletions sphinx_needs/directives/needpie.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sphinx_needs.data import NeedsPieType, SphinxNeedsData
from sphinx_needs.debug import measure_time
from sphinx_needs.directives.utils import no_needs_found_paragraph
from sphinx_needs.filter_common import FilterBase, expand_needs_view, filter_needs
from sphinx_needs.filter_common import FilterBase, expand_needs_view, filter_needs_parts
from sphinx_needs.logging import get_logger, log_warning
from sphinx_needs.utils import (
add_doc,
Expand Down Expand Up @@ -165,7 +165,7 @@ def process_needpie(
sizes.append(abs(float(line)))
else:
result = len(
filter_needs(need_list, needs_config, line, location=node)
filter_needs_parts(need_list, needs_config, line, location=node)
)
sizes.append(result)
elif current_needpie["filter_func"] and not content:
Expand Down
6 changes: 2 additions & 4 deletions sphinx_needs/directives/needuml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sphinx_needs.debug import measure_time
from sphinx_needs.diagrams_common import calculate_link
from sphinx_needs.directives.needflow._plantuml import make_entity_name
from sphinx_needs.filter_common import filter_needs
from sphinx_needs.filter_common import filter_needs_view
from sphinx_needs.utils import add_doc

if TYPE_CHECKING:
Expand Down Expand Up @@ -423,9 +423,7 @@ def filter(self, filter_string: str) -> list[NeedsInfoType]:
"""
needs_config = NeedsSphinxConfig(self.app.config)

return filter_needs(
list(self.needs.values()), needs_config, filter_string=filter_string
)
return filter_needs_view(self.needs, needs_config, filter_string=filter_string)

def imports(self, *args: str) -> str:
if not self.parent_need_id:
Expand Down
64 changes: 61 additions & 3 deletions sphinx_needs/filter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sphinx_needs.data import (
NeedsFilteredBaseType,
NeedsInfoType,
NeedsMutable,
NeedsPartsView,
NeedsView,
SphinxNeedsData,
Expand Down Expand Up @@ -181,7 +182,7 @@ def process_filters(
if status_filter_passed and tags_filter_passed and type_filter_passed:
found_needs_by_options.append(need_info)
# Get need by filter string
found_needs_by_string = filter_needs(
found_needs_by_string = filter_needs_parts(
all_needs_incl_parts,
needs_config,
filter_data["filter"],
Expand All @@ -194,7 +195,7 @@ def process_filters(
else:
# There is no other config as the one for filter string.
# So we only need this result.
found_needs = filter_needs(
found_needs = filter_needs_parts(
all_needs_incl_parts,
needs_config,
filter_data["filter"],
Expand Down Expand Up @@ -297,8 +298,65 @@ def intersection_of_need_results(list_a: list[T], list_b: list[T]) -> list[T]:
return [a for a in list_a if a in list_b]


def filter_needs_mutable(
needs: NeedsMutable,
config: NeedsSphinxConfig,
filter_string: None | str = "",
current_need: NeedsInfoType | None = None,
*,
location: tuple[str, int | None] | nodes.Node | None = None,
append_warning: str = "",
) -> list[NeedsInfoType]:
return _filter_needs(
needs.values(),
config,
filter_string,
current_need,
location=location,
append_warning=append_warning,
)


def filter_needs_view(
needs: NeedsView,
config: NeedsSphinxConfig,
filter_string: None | str = "",
current_need: NeedsInfoType | None = None,
*,
location: tuple[str, int | None] | nodes.Node | None = None,
append_warning: str = "",
) -> list[NeedsInfoType]:
return _filter_needs(
needs.values(),
config,
filter_string,
current_need,
location=location,
append_warning=append_warning,
)


def filter_needs_parts(
needs: NeedsPartsView,
config: NeedsSphinxConfig,
filter_string: None | str = "",
current_need: NeedsInfoType | None = None,
*,
location: tuple[str, int | None] | nodes.Node | None = None,
append_warning: str = "",
) -> list[NeedsInfoType]:
return _filter_needs(
needs,
config,
filter_string,
current_need,
location=location,
append_warning=append_warning,
)


@measure_time("filtering")
def filter_needs(
def _filter_needs(
needs: Iterable[NeedsInfoType],
config: NeedsSphinxConfig,
filter_string: None | str = "",
Expand Down
9 changes: 6 additions & 3 deletions sphinx_needs/functions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from sphinx_needs.api.exceptions import NeedsInvalidFilter
from sphinx_needs.config import NeedsSphinxConfig
from sphinx_needs.data import NeedsInfoType, NeedsView
from sphinx_needs.filter_common import filter_needs, filter_single_need
from sphinx_needs.filter_common import (
filter_needs_view,
filter_single_need,
)
from sphinx_needs.logging import log_warning
from sphinx_needs.utils import logger

Expand Down Expand Up @@ -166,8 +169,8 @@ def copy(
need = needs[need_id]

if filter:
result = filter_needs(
needs.values(),
result = filter_needs_view(

Check warning on line 172 in sphinx_needs/functions/common.py

View check run for this annotation

Codecov / codecov/patch

sphinx_needs/functions/common.py#L172

Added line #L172 was not covered by tests
needs,
NeedsSphinxConfig(app.config),
filter,
need,
Expand Down
11 changes: 7 additions & 4 deletions sphinx_needs/roles/need_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from sphinx_needs.api.exceptions import NeedsInvalidFilter
from sphinx_needs.config import NeedsSphinxConfig
from sphinx_needs.data import SphinxNeedsData
from sphinx_needs.filter_common import expand_needs_view, filter_needs
from sphinx_needs.filter_common import (
expand_needs_view,
filter_needs_parts,
)
from sphinx_needs.logging import get_logger

log = get_logger(__name__)
Expand All @@ -39,7 +42,7 @@ def process_need_count(
need_list = expand_needs_view(needs_view) # adds parts to need_list
amount = str(
len(
filter_needs(
filter_needs_parts(
need_list,
needs_config,
filters[0],
Expand All @@ -50,12 +53,12 @@ def process_need_count(
elif len(filters) == 2:
need_list = expand_needs_view(needs_view) # adds parts to need_list
amount_1 = len(
filter_needs(
filter_needs_parts(
need_list, needs_config, filters[0], location=node_need_count
)
)
amount_2 = len(
filter_needs(
filter_needs_parts(
need_list, needs_config, filters[1], location=node_need_count
)
)
Expand Down
21 changes: 10 additions & 11 deletions sphinx_needs/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from sphinx.util import logging

from sphinx_needs.config import NEEDS_CONFIG, NeedsSphinxConfig
from sphinx_needs.data import NeedsInfoType, SphinxNeedsData
from sphinx_needs.filter_common import filter_needs
from sphinx_needs.data import NeedsView, SphinxNeedsData
from sphinx_needs.filter_common import filter_needs_view
from sphinx_needs.logging import get_logger, log_warning

logger = get_logger(__name__)
Expand All @@ -33,9 +33,9 @@ def process_warnings(app: Sphinx, exception: Exception | None) -> None:
return

env = app.env
needs = SphinxNeedsData(env).get_needs_view()
needs_view = SphinxNeedsData(env).get_needs_view()
# If no needs were defined, we do not need to do anything
if not needs:
if not needs_view:
return

# Check if warnings already got executed.
Expand All @@ -47,10 +47,9 @@ def process_warnings(app: Sphinx, exception: Exception | None) -> None:
env.needs_warnings_executed = True # type: ignore[attr-defined]

# Exclude external needs for warnings check
checked_needs: dict[str, NeedsInfoType] = {}
for need_id, need in needs.items():
if not need["is_external"]:
checked_needs[need_id] = need
needs_view = NeedsView(
{id: need for id, need in needs_view.items() if not need["is_external"]}
)

needs_config = NeedsSphinxConfig(app.config)
warnings_always_warn = needs_config.warnings_always_warn
Expand All @@ -61,16 +60,16 @@ def process_warnings(app: Sphinx, exception: Exception | None) -> None:
for warning_name, warning_filter in NEEDS_CONFIG.warnings.items():
if isinstance(warning_filter, str):
# filter string used
result = filter_needs(
checked_needs.values(),
result = filter_needs_view(
needs_view,
needs_config,
warning_filter,
append_warning=f"(from warning filter {warning_name!r})",
)
elif callable(warning_filter):
# custom defined filter code used from conf.py
result = []
for need in checked_needs.values():
for need in needs_view.values():
if warning_filter(need, logger):
result.append(need)
else:
Expand Down

0 comments on commit 14683a1

Please sign in to comment.