Skip to content

Commit 356c310

Browse files
authored
Fixes for numpy 2.0 (#928)
* Remove tests for float_ for numpy2.0 * Fix is_different for numpy=2.0 * Add tox test environments for numpy 2.0 * Add pull_request trigger to tests * Fix s3 tests (moto.mock_s3 -> moto.mock_aws) * Reproduce old np.array_equal behavior * Update tensorflow test configurations
1 parent cd90ee1 commit 356c310

File tree

5 files changed

+34
-56
lines changed

5 files changed

+34
-56
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ name: Tests
22

33
on:
44
- push
5+
- pull_request
56

67
jobs:
78
pytest:

sacred/config/custom_containers.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,19 @@ def type_changed(old_value, new_value):
300300
def is_different(old_value, new_value):
301301
"""Numpy aware comparison between two values."""
302302
if opt.has_numpy:
303-
return not opt.np.array_equal(old_value, new_value)
304-
else:
305-
return old_value != new_value
303+
# Reproduces np.array_equal from numpy<2
304+
# np.array_equal raises an exception when the arguments are scalar and
305+
# differ in type (e.g. int and str) in numpy>=2.0
306+
try:
307+
old_value = opt.np.asarray(old_value)
308+
new_value = opt.np.asarray(new_value)
309+
except:
310+
return False
311+
else:
312+
result = old_value == new_value
313+
if isinstance(result, bool):
314+
return result
315+
else:
316+
return result.all()
317+
318+
return old_value != new_value

tests/test_config/test_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
"uint16",
2323
"uint32",
2424
"uint64",
25-
"float_",
2625
"float16",
2726
"float32",
2827
"float64",
@@ -49,7 +48,6 @@ def test_normalize_or_die_for_numpy_datatypes(typename):
4948
"uint16",
5049
"uint32",
5150
"uint64",
52-
"float_",
5351
"float16",
5452
"float32",
5553
"float64",

tests/test_observers/test_s3_observer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _get_file_data(bucket_name, key):
7777
return s3.Object(bucket_name, key).get()["Body"].read()
7878

7979

80-
@moto.mock_s3
80+
@moto.mock_aws
8181
def test_fs_observer_started_event_creates_bucket(observer, sample_run):
8282
_id = observer.started_event(**sample_run)
8383
run_dir = s3_join(BASEDIR, str(_id))
@@ -102,7 +102,7 @@ def test_fs_observer_started_event_creates_bucket(observer, sample_run):
102102
}
103103

104104

105-
@moto.mock_s3
105+
@moto.mock_aws
106106
def test_fs_observer_started_event_increments_run_id(observer, sample_run):
107107
_id = observer.started_event(**sample_run)
108108
_id2 = observer.started_event(**sample_run)
@@ -119,15 +119,15 @@ def test_s3_observer_equality():
119119
assert obs_one != different_bucket
120120

121121

122-
@moto.mock_s3
122+
@moto.mock_aws
123123
def test_raises_error_on_duplicate_id_directory(observer, sample_run):
124124
observer.started_event(**sample_run)
125125
sample_run["_id"] = 1
126126
with pytest.raises(FileExistsError):
127127
observer.started_event(**sample_run)
128128

129129

130-
@moto.mock_s3
130+
@moto.mock_aws
131131
def test_completed_event_updates_run_json(observer, sample_run):
132132
observer.started_event(**sample_run)
133133
run = json.loads(
@@ -145,7 +145,7 @@ def test_completed_event_updates_run_json(observer, sample_run):
145145
assert run["status"] == "COMPLETED"
146146

147147

148-
@moto.mock_s3
148+
@moto.mock_aws
149149
def test_interrupted_event_updates_run_json(observer, sample_run):
150150
observer.started_event(**sample_run)
151151
run = json.loads(
@@ -163,7 +163,7 @@ def test_interrupted_event_updates_run_json(observer, sample_run):
163163
assert run["status"] == "SERVER_EXPLODED"
164164

165165

166-
@moto.mock_s3
166+
@moto.mock_aws
167167
def test_failed_event_updates_run_json(observer, sample_run):
168168
observer.started_event(**sample_run)
169169
run = json.loads(
@@ -181,7 +181,7 @@ def test_failed_event_updates_run_json(observer, sample_run):
181181
assert run["status"] == "FAILED"
182182

183183

184-
@moto.mock_s3
184+
@moto.mock_aws
185185
def test_queued_event_updates_run_json(observer, sample_run):
186186
del sample_run["start_time"]
187187
sample_run["queue_time"] = T2
@@ -194,7 +194,7 @@ def test_queued_event_updates_run_json(observer, sample_run):
194194
assert run["status"] == "QUEUED"
195195

196196

197-
@moto.mock_s3
197+
@moto.mock_aws
198198
def test_artifact_event_works(observer, sample_run, tmpfile):
199199
observer.started_event(**sample_run)
200200
observer.artifact_event("test_artifact.py", tmpfile.name)

tox.ini

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# and then run "tox" from this directory.
55

66
[tox]
7-
envlist = py{38,39,310,311}, setup, numpy-{120,121,123}, tensorflow-{26,27,28,29,210,211}
7+
envlist = py{38,39,310,311}, setup, numpy-{120,121,123,200}, tensorflow-{212,216}
88

99
[testenv]
1010
deps =
@@ -53,68 +53,34 @@ deps =
5353
commands =
5454
pytest tests/test_config {posargs}
5555

56-
[testenv:tensorflow-115]
56+
[testenv:numpy-200]
5757
basepython = python
5858
deps =
5959
-rdev-requirements.txt
60-
tensorflow~=1.15.0
60+
numpy~=2.0.0
6161
commands =
62-
pytest tests/test_stflow tests/test_optional.py \
63-
{posargs}
64-
65-
[testenv:tensorflow-26]
66-
basepython = python
67-
deps =
68-
-rdev-requirements.txt
69-
tensorflow~=2.6.0
70-
commands =
71-
pytest tests/test_stflow tests/test_optional.py \
72-
{posargs}
73-
74-
[testenv:tensorflow-27]
75-
basepython = python
76-
deps =
77-
-rdev-requirements.txt
78-
tensorflow~=2.7.0
79-
commands =
80-
pytest tests/test_stflow tests/test_optional.py \
81-
{posargs}
62+
pytest tests/test_config {posargs}
8263

83-
[testenv:tensorflow-28]
64+
[testenv:tensorflow-212]
8465
basepython = python
8566
deps =
8667
-rdev-requirements.txt
87-
tensorflow~=2.8.0
68+
numpy<2.0.0
69+
tensorflow~=2.12.0
8870
commands =
8971
pytest tests/test_stflow tests/test_optional.py \
9072
{posargs}
9173

92-
[testenv:tensorflow-29]
93-
basepython = python
94-
deps =
95-
-rdev-requirements.txt
96-
tensorflow~=2.9.0
97-
commands =
98-
pytest tests/test_stflow tests/test_optional.py \
99-
{posargs}
10074

101-
[testenv:tensorflow-210]
75+
[testenv:tensorflow-216]
10276
basepython = python
10377
deps =
10478
-rdev-requirements.txt
105-
tensorflow~=2.10.0
79+
tensorflow~=2.16.0
10680
commands =
10781
pytest tests/test_stflow tests/test_optional.py \
10882
{posargs}
10983

110-
[testenv:tensorflow-211]
111-
basepython = python
112-
deps =
113-
-rdev-requirements.txt
114-
tensorflow~=2.11.0
115-
commands =
116-
pytest tests/test_stflow tests/test_optional.py \
117-
{posargs}
11884

11985
[testenv:setup]
12086
basepython = python

0 commit comments

Comments
 (0)