Skip to content

Commit

Permalink
🔧 Introduce NeedsPartsView type
Browse files Browse the repository at this point in the history
Make it clearer where we are using an list of the "expanded" needs+parts.

As discussed in #1264, there are a number of different representations of the needs,
and so this makes it clearer which one a variable is.
  • Loading branch information
chrisjsewell committed Sep 5, 2024
1 parent c862e9d commit 1ba22a9
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 53 deletions.
20 changes: 19 additions & 1 deletion sphinx_needs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Final, Literal, Mapping, NewType, TypedDict
from typing import (
TYPE_CHECKING,
Any,
Dict,
Final,
Literal,
Mapping,
NewType,
Sequence,
TypedDict,
)

from sphinx.util.logging import getLogger

Expand Down Expand Up @@ -685,6 +695,14 @@ class NeedsUmlType(NeedsBaseDataType):
(e.g. back links have been computed etc)
"""

NeedsPartsView = NewType("NeedsPartsView", Sequence[NeedsInfoType])
"""A read-only view of a sequence of needs and parts,
after resolution (e.g. back links have been computed etc)
The parts are created by creating a copy of the need for each item in ``parts``,
and then overwriting a subset of fields with the values from the part.
"""


class SphinxNeedsData:
"""Centralised access to sphinx-needs data, stored within the Sphinx environment."""
Expand Down
29 changes: 11 additions & 18 deletions sphinx_needs/directives/needpie.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import hashlib
from typing import Any, Iterable, Sequence
from typing import Iterable, Sequence

from docutils import nodes
from docutils.parsers.rst import directives
Expand Down Expand Up @@ -171,7 +171,7 @@ def process_needpie(
elif current_needpie["filter_func"] and not content:
# check and get filter_func
try:
filter_func_sig = check_and_get_external_filter_func(
ff_result = check_and_get_external_filter_func(
current_needpie.get("filter_func")
)
except NeedsInvalidFilter as e:
Expand All @@ -185,30 +185,23 @@ def process_needpie(
continue

# execute filter_func code
if filter_func_sig:
# Provides only a copy of needs to avoid data manipulations.
context: dict[str, Any] = {
"needs": need_list,
"results": [],
}
args = filter_func_sig.args.split(",") if filter_func_sig.args else []
for index, arg in enumerate(args):
# All rgs are strings, but we must transform them to requested type, e.g. 1 -> int, "1" -> str
context[f"arg{index + 1}"] = arg

filter_func_sig.func(**context)

sizes = context["results"]
if ff_result:
args = ff_result.args.split(",") if ff_result.args else []
args_context = {f"arg{index+1}": arg for index, arg in enumerate(args)}

sizes = []
ff_result.func(needs=need_list, results=sizes, **args_context)

# check items in sizes
if not isinstance(sizes, list):
logger.error(
f"The returned values from the given filter_func {filter_func_sig.sig!r} is not valid."
f"The returned values from the given filter_func {ff_result.sig!r} is not valid."
" It must be a list."
)
for item in sizes:
if not isinstance(item, int) and not isinstance(item, float):
logger.error(
f"The returned values from the given filter_func {filter_func_sig.sig!r} is not valid. "
f"The returned values from the given filter_func {ff_result.sig!r} is not valid. "
"It must be a list with items of type int/float."
)

Expand Down
55 changes: 27 additions & 28 deletions sphinx_needs/filter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sphinx_needs.data import (
NeedsFilteredBaseType,
NeedsInfoType,
NeedsPartsView,
NeedsView,
SphinxNeedsData,
)
Expand Down Expand Up @@ -134,9 +135,7 @@ def process_filters(

# Check if external filter code is defined
try:
filter_func_sig = check_and_get_external_filter_func(
filter_data.get("filter_func")
)
ff_result = check_and_get_external_filter_func(filter_data.get("filter_func"))
except NeedsInvalidFilter as e:
log_warning(
log,
Expand All @@ -151,7 +150,7 @@ def process_filters(
if not filter_code and filter_data["filter_code"]:
filter_code = "\n".join(filter_data["filter_code"])

if (not filter_code or filter_code.isspace()) and not filter_func_sig:
if (not filter_code or filter_code.isspace()) and not ff_result:
if bool(filter_data["status"] or filter_data["tags"] or filter_data["types"]):
found_needs_by_options: list[NeedsInfoType] = []
for need_info in all_needs_incl_parts:
Expand Down Expand Up @@ -202,35 +201,36 @@ def process_filters(
location=location,
)
else:
# Provides only a copy of needs to avoid data manipulations.
context: dict[str, Any] = {
"needs": all_needs_incl_parts,
"results": [],
}
# The filter results may be dirty, as it may continue manipulated needs.
found_dirty_needs: list[NeedsInfoType] = []

if filter_code: # code from content
# TODO better context type
context: dict[str, list[NeedsInfoType]] = {
"needs": all_needs_incl_parts, # type: ignore[dict-item]
"results": [],
}
exec(filter_code, context)
elif filter_func_sig: # code from external file
found_dirty_needs = context["results"]
elif ff_result: # code from external file
args = []
if filter_func_sig.args:
args = filter_func_sig.args.split(",")
for index, arg in enumerate(args):
# All args are strings, but we must transform them to requested type, e.g. 1 -> int, "1" -> str
context[f"arg{index+1}"] = arg
if ff_result.args:
args = ff_result.args.split(",")
args_context = {f"arg{index+1}": arg for index, arg in enumerate(args)}

# Decorate function to allow time measurments
filter_func = measure_time_func(
filter_func_sig.func, category="filter_func", source="user"
ff_result.func, category="filter_func", source="user"
)
filter_func(
needs=all_needs_incl_parts, results=found_dirty_needs, **args_context
)
filter_func(**context)
else:
log_warning(
log, "Something went wrong running filter", None, location=location
)
return []

# The filter results may be dirty, as it may continue manipulated needs.
found_dirty_needs: list[NeedsInfoType] = context["results"]
found_needs = []

# Check if config allow unsafe filters
Expand Down Expand Up @@ -277,15 +277,17 @@ def process_filters(
return found_needs


def expand_needs_view(needs_view: NeedsView) -> list[NeedsInfoType]:
"""Turns a needs view into a list of needs, expanding all need["parts"] to be items of the list."""
all_needs_incl_parts: list[NeedsInfoType] = []
def expand_needs_view(needs_view: NeedsView) -> NeedsPartsView:
"""Turns a needs view into a sequence of needs,
expanding all ``need["parts"]`` to be items of the list.
"""
all_needs_incl_parts = []
for need in needs_view.values():
all_needs_incl_parts.append(need)
for need_part in iter_need_parts(need):
all_needs_incl_parts.append(need_part)

return all_needs_incl_parts
return NeedsPartsView(all_needs_incl_parts)


T = TypeVar("T")
Expand All @@ -295,19 +297,16 @@ def intersection_of_need_results(list_a: list[T], list_b: list[T]) -> list[T]:
return [a for a in list_a if a in list_b]


V = TypeVar("V", bound=NeedsInfoType)


@measure_time("filtering")
def filter_needs(
needs: Iterable[V],
needs: Iterable[NeedsInfoType],
config: NeedsSphinxConfig,
filter_string: None | str = "",
current_need: NeedsInfoType | None = None,
*,
location: tuple[str, int | None] | nodes.Node | None = None,
append_warning: str = "",
) -> list[V]:
) -> list[NeedsInfoType]:
"""
Filters given needs based on a given filter string.
Returns all needs, which pass the given filter.
Expand Down
22 changes: 16 additions & 6 deletions sphinx_needs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
from dataclasses import dataclass
from functools import lru_cache, reduce, wraps
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar
from urllib.parse import urlparse

from docutils import nodes
Expand All @@ -16,7 +16,7 @@

from sphinx_needs.api.exceptions import NeedsInvalidFilter
from sphinx_needs.config import LinkOptionsType, NeedsSphinxConfig
from sphinx_needs.data import NeedsInfoType, NeedsView, SphinxNeedsData
from sphinx_needs.data import NeedsInfoType, NeedsPartsView, NeedsView, SphinxNeedsData
from sphinx_needs.defaults import NEEDS_PROFILING
from sphinx_needs.logging import get_logger, log_warning

Expand Down Expand Up @@ -310,19 +310,29 @@ def check_and_calc_base_url_rel_path(external_url: str, fromdocname: str) -> str
return ref_uri


class FilterFunc(Protocol):
def __call__(
self,
*,
needs: NeedsPartsView,
results: list[Any],
**kwargs: str,
) -> None: ...


@dataclass
class FilterFunc:
class FilterFuncResult:
"""Dataclass for filter function."""

sig: str
func: Callable[..., Any]
func: FilterFunc
args: str


@lru_cache(maxsize=32)
def check_and_get_external_filter_func(
filter_func_ref: str | None,
) -> FilterFunc | None:
) -> FilterFuncResult | None:
"""Check and import filter function from external python file."""
if not filter_func_ref:
return None
Expand All @@ -348,7 +358,7 @@ def check_and_get_external_filter_func(
except Exception:
raise NeedsInvalidFilter(f"module does not have function: {filter_function}")

return FilterFunc(filter_func_ref, filter_func, filter_args)
return FilterFuncResult(filter_func_ref, filter_func, filter_args)


def jinja_parse(context: dict[str, Any], jinja_string: str) -> str:
Expand Down

0 comments on commit 1ba22a9

Please sign in to comment.