Skip to content

Commit

Permalink
Pin tf<2.16 and resolve in a future pr
Browse files Browse the repository at this point in the history
  • Loading branch information
David Rubinstein committed Mar 12, 2024
1 parent ede09c9 commit ef94d1a
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 8 deletions.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ dependencies = [
"numpy>=1.18",
"onnxruntime; platform_system == 'Windows' and python_version < '3.11'",
"pretty_midi>=0.2.9",
"resampy>=0.2.2",
"resampy>=0.2.2,<0.4.3",
"scikit-learn",
"scipy>=1.4.1",
"tensorflow>=2.4.1; platform_system != 'Darwin' and python_version >= '3.11'",
"tensorflow>=2.4.1,<2.16; platform_system != 'Darwin' and python_version >= '3.11'",
"tensorflow-macos>=2.4.1; platform_system == 'Darwin' and python_version >= '3.11'",
"tflite-runtime; platform_system == 'Linux' and python_version < '3.11'",
"typing_extensions",
Expand Down Expand Up @@ -56,8 +56,8 @@ test = [
"pytest-mock",
]
tf = [
"tensorflow>=2.4.1; platform_system != 'Darwin'",
"tensorflow-macos>=2.4.1; platform_system == 'Darwin' and python_version > '3.7'",
"tensorflow>=2.4.1,<2.16; platform_system != 'Darwin'",
"tensorflow-macos>=2.4.1,<2.16; platform_system == 'Darwin' and python_version > '3.7'",
]
coreml = ["coremltools"]
onnx = ["onnxruntime"]
Expand Down
Binary file modified tests/resources/vocadito_10/model_output.npz
Binary file not shown.
Binary file modified tests/resources/vocadito_10/note_events.npz
Binary file not shown.
5 changes: 2 additions & 3 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,11 @@ def test_predict() -> None:
assert all(note_pitch_max)
assert isinstance(note_events, list)

expected_model_output = np.load("tests/resources/vocadito_10/model_output.npz")
expected_model_output = np.load("tests/resources/vocadito_10/model_output.npz", allow_pickle=True)["arr_0"].item()
for k in expected_model_output.keys():
np.testing.assert_allclose(expected_model_output[k], model_output[k], atol=1e-4, rtol=0)

expected_note_events = np.load("tests/resources/vocadito_10/note_events.npz", allow_pickle=True)
expected_note_events = expected_note_events.get("arr_0")
expected_note_events = np.load("tests/resources/vocadito_10/note_events.npz", allow_pickle=True)["arr_0"]

assert len(expected_note_events) == len(note_events)
for expected, calculated in zip(expected_note_events, note_events):
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ requires =
[testenv]
deps = -e .[test]
commands =
pytest tests --ignore tests/test_nn.py -s {posargs}
pytest tests --ignore tests/test_nn.py {posargs}
setenv =
SOURCE = {toxinidir}/basic_pitch
TEST_SOURCE = {toxinidir}/tests
Expand Down

0 comments on commit ef94d1a

Please sign in to comment.