Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with connection validation #108

Merged
merged 6 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,9 @@ target-version = ['py39']
[[tool.mypy.overrides]]
module = "graphviz.*"
ignore_missing_imports = true

[tool.pytest.ini_options]
markers = [
"invalid_schema_examples",
"invalid_pydantic_examples"
]
18 changes: 8 additions & 10 deletions src/qref/schema_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
ConfigDict,
Field,
StringConstraints,
field_validator,
model_validator,
)
from pydantic.json_schema import GenerateJsonSchema
from typing_extensions import Self

NAME_PATTERN = "[A-Za-z_][A-Za-z0-9_]*"
OPTIONALLY_NAMESPACED_NAME_PATTERN = rf"^({NAME_PATTERN}\.)?{NAME_PATTERN}$"
Expand Down Expand Up @@ -115,18 +116,15 @@ class RoutineV1(BaseModel):
def __init__(self, **data: Any):
super().__init__(**{k: v for k, v in data.items() if v != [] and v != {}})

@field_validator("connections", mode="after")
@classmethod
def _validate_connections(cls, v, values) -> list[ConnectionV1]:
children_port_names = [
f"{child.name}.{port.name}" for child in values.data.get("children") for port in child.ports
]
parent_port_names = [port.name for port in values.data["ports"]]
@model_validator(mode="after")
def _validate_connections(self) -> Self:
children_port_names = [f"{child.name}.{port.name}" for child in self.children for port in child.ports]
parent_port_names = [port.name for port in self.ports]
available_port_names = set(children_port_names + parent_port_names)

missed_ports = [
port
for connection in v
for connection in self.connections
for port in (connection.source, connection.target)
if port not in available_port_names
]
Expand All @@ -135,7 +133,7 @@ def _validate_connections(cls, v, values) -> list[ConnectionV1]:
"The following ports appear in a connection but are not "
"among routine's port or their children's ports: {missed_ports}."
)
return v
return self


class SchemaV1(BaseModel):
Expand Down
37 changes: 35 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License.

"""Common fixtures for QREF tests."""

from functools import lru_cache
from pathlib import Path

import pytest
import yaml

VALID_PROGRAMS_ROOT_PATH = Path(__file__).parent / "qref/data/valid_programs"
DATA_ROOT_PATH = Path(__file__).parent / "qref/data"
VALID_PROGRAMS_ROOT_PATH = DATA_ROOT_PATH / "valid_programs"
INVALID_YAML_PROGRAMS_PATH = DATA_ROOT_PATH / "invalid_yaml_programs.yaml"
INVALID_PYDANTIC_PROGRAMS_PATH = DATA_ROOT_PATH / "invalid_pydantic_programs.yaml"


def _load_valid_examples():
Expand All @@ -32,3 +35,33 @@ def _load_valid_examples():
@pytest.fixture(params=_load_valid_examples())
def valid_program(request):
return request.param


@lru_cache(maxsize=None)
def _load_yaml(path):
with open(path) as f:
return yaml.safe_load(f)


def pytest_generate_tests(metafunc):
marker_names = [marker.name for marker in metafunc.definition.iter_markers()]
if "invalid_schema_examples" in marker_names:
data = _load_yaml(INVALID_YAML_PROGRAMS_PATH)
metafunc.parametrize(
"input, error_path, error_message",
[
pytest.param(
example["input"],
example["error_path"],
example["error_message"],
id=example["description"],
)
for example in data
],
)
elif "invalid_pydantic_examples" in marker_names:
data = [
example["input"]
for example in (_load_yaml(INVALID_YAML_PROGRAMS_PATH) + _load_yaml(INVALID_PYDANTIC_PROGRAMS_PATH))
]
metafunc.parametrize("input", data)
38 changes: 36 additions & 2 deletions tests/qref/data/invalid_pydantic_programs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,39 @@
- source: foo.out_0
target: bar.in_1
description: "Connection contains non-existent port name"
error_path: "$.program.connections[0].source"
error_message: "'foo.foo.out_0' does not match '^(([A-Za-z_][A-Za-z0-9_]*)|([A-Za-z_][A-Za-z0-9_]*\\\\.[A-Za-z_][A-Za-z0-9_]*))$'"
- input:
version: v1
program:
name: root
children:
- name: foo
ports:
- name: in_0
direction: inpt # Warning: intentional typo here!
size: 1
- name: out_0
direction: output
size: 1
- name: bar
ports:
- name: in_0
direction: input
size: 1
- name: out_0
direction: output
size: 1
ports:
- name: in_0
direction: input
size: 1
- name: out_0
direction: output
size: 1
connections:
- source: in_0
target: foo.in_0
- source: foo.out_0
target: bar.in_0
- source: bar.oout_0 # Warning: intentional typo here!
target: out_0
description: "Validation error in child and in connections."
31 changes: 4 additions & 27 deletions tests/qref/test_schema_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import pydantic
import pytest
import yaml # type: ignore[import-untyped]
from jsonschema import ValidationError, validate

from qref import SchemaV1, generate_program_schema
Expand All @@ -26,27 +22,7 @@ def validate_with_v1(data):
validate(data, generate_program_schema(version="v1"))


def load_invalid_examples(add_pydantic=False):
with open(Path(__file__).parent / "data/invalid_yaml_programs.yaml") as f:
data = yaml.safe_load(f)

if add_pydantic:
with open(Path(__file__).parent / "data/invalid_pydantic_programs.yaml") as f:
additional_data = yaml.safe_load(f)
data += additional_data

return [
pytest.param(
example["input"],
example["error_path"],
example["error_message"],
id=example["description"],
)
for example in data
]


@pytest.mark.parametrize("input, error_path, error_message", load_invalid_examples())
@pytest.mark.invalid_schema_examples
def test_invalid_program_fails_to_validate_with_schema_v1(input, error_path, error_message):
with pytest.raises(ValidationError) as err_info:
validate_with_v1(input)
Expand All @@ -59,10 +35,11 @@ def test_valid_program_successfully_validates_with_schema_v1(valid_program):
validate_with_v1(valid_program)


@pytest.mark.parametrize("input", [input for input, *_ in load_invalid_examples(add_pydantic=True)])
@pytest.mark.invalid_pydantic_examples
def test_invalid_program_fails_to_validate_with_pydantic_model_v1(input):
with pytest.raises(pydantic.ValidationError):
with pytest.raises(pydantic.ValidationError) as e:
SchemaV1.model_validate(input)
print(e.value)


def test_valid_program_succesfully_validate_with_pydantic_model_v1(valid_program):
Expand Down
Loading