Skip to content

Commit

Permalink
update tensorflow and fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gmgeorg committed Apr 24, 2024
1 parent c0974be commit 2f085e1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
6 changes: 4 additions & 2 deletions pypress/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def test_use_in_model_works():
feats, y = _test_data(n_samples=1000)

model = tf.keras.Sequential()
model.add(layers.PredictiveStateSimplex(5, input_dim=feats.shape[1]))
model.add(tf.keras.layers.Input(shape=(feats.shape[1],)))
model.add(layers.PredictiveStateSimplex(5))
model.add(layers.PredictiveStateMeans(1, "linear"))
model.compile(loss="mse", optimizer=tf.keras.optimizers.Nadam(learning_rate=0.01))

Expand All @@ -65,7 +66,8 @@ def test_press_in_model_works():
feats, y = _test_data(n_samples=1000)

model = tf.keras.Sequential()
model.add(layers.PRESS(units=1, n_states=5, input_dim=feats.shape[1]))
model.add(tf.keras.layers.Input(shape=(feats.shape[1],)))
model.add(layers.PRESS(units=1, n_states=5))
model.compile(loss="mse", optimizer=tf.keras.optimizers.Nadam(learning_rate=0.01))

model.fit(feats, y, epochs=4)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
tensorflow~=2.11.0
tensorflow>=2.11.0
numpy>=1.11.5
pandas>=1.0.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
author_email="im@gmge.org",
description="Predictive State Smoothing (PRESS) in Python (keras)",
packages=find_packages(),
install_requires=["numpy >= 1.11.0", "tensorflow ~= 2.11.0", "pandas >= 1.0.0"],
install_requires=["numpy >= 1.11.0", "tensorflow >= 2.11.0", "pandas >= 1.0.0"],
)

0 comments on commit 2f085e1

Please sign in to comment.