Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change to keras3 for multi backend #4

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/tkat_ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Run tests on multiple backends

on:
push:
branches: [ main, beta ]
pull_request:
branches: [ main ]


jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install poetry
poetry install
- name: Run tests
run: |
poetry run python run_tests.py
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

![TKAT representation](images/model_representation.jpeg)

This folder includes the original code implemented for the [paper](https://arxiv.org/abs/2406.02486) of the same name.
This folder includes the original code implemented for the [paper](https://arxiv.org/abs/2406.02486) of the same name. The model is made in keras3 and is supporting all backend (jax, tensorflow, pytorch).

It is inspired on the Temporal Fusion Transformer by [google-research](https://github.com/google-research/google-research/tree/master/tft) and the [Temporal Kolmogorov Arnold Network](https://github.com/remigenet/TKAN).

Expand Down
27 changes: 22 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,31 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "tkat"
version = "0.1.1"
version = "0.2.0"
description = "Temporal KAN Transformer"
authors = [ "Rémi Genet", "Hugo Inzirillo"]
readme = "README.md"
packages = [{include = "tkat"}]

[tool.poetry.dependencies]
python = ">=3.10,<4.0"
numpy = ">=1.2,<2"
tensorflow = ">=2.8,<3"
tkan = ">=0.3.0,<0.4.0"
python = ">=3.9,<3.12"
keras = ">=3.0.0,<4.0"
keras_efficient_kan = "^0.1.4"
tkan = "^0.4.1"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-xdist = "^3.3.0"
tensorflow = "^2.15.0"
torch = "^2.0.0"
jax = "^0.4.13"
jaxlib = "^0.4.13"

[tool.pytest.ini_options]
addopts = "-v"
testpaths = ["tests"]
filterwarnings = [
"ignore:Can't initialize NVML:UserWarning",
"ignore:jax.xla_computation is deprecated:DeprecationWarning",
"ignore::DeprecationWarning:jax._src.dtypes"
]
26 changes: 26 additions & 0 deletions run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import subprocess

def run_test(backend):
env = os.environ.copy()
env['KERAS_BACKEND'] = backend
result = subprocess.run(['pytest', f'tests/test_{backend}.py'], env=env, capture_output=True, text=True)
print(f"\n--- {backend.upper()} Backend Test Results ---")
print(result.stdout)
if result.stderr:
print("Errors:")
print(result.stderr)
return result.returncode

if __name__ == "__main__":
backends = ['tensorflow', 'torch', 'jax']
exit_codes = []

for backend in backends:
exit_codes.append(run_test(backend))

if any(exit_codes):
exit(1)
else:
print("\nAll tests passed successfully!")
exit(0)
Empty file added tests/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os
BACKEND = 'jax'
os.environ['KERAS_BACKEND'] = BACKEND

import pytest
import keras
from keras import ops
from keras import backend
from keras import random
from tkat import TKAT # Assuming you've defined TKAT in a separate file

def generate_random_tensor(shape):
return random.normal(shape=shape, dtype=backend.floatx())

def test_tkat_basic():
assert keras.backend.backend() == BACKEND
batch_size, sequence_length, n_ahead = 32, 10, 5
num_unknow_features, num_know_features = 3, 2
num_embedding, num_hidden, num_heads = 8, 16, 4

tkat_model = TKAT(
sequence_length=sequence_length,
num_unknow_features=num_unknow_features,
num_know_features=num_know_features,
num_embedding=num_embedding,
num_hidden=num_hidden,
num_heads=num_heads,
n_ahead=n_ahead
)

input_shape = (batch_size, sequence_length + n_ahead, num_unknow_features + num_know_features)
input_data = generate_random_tensor(input_shape)
output = tkat_model(input_data)

expected_output_shape = (batch_size, n_ahead)
assert output.shape == expected_output_shape, f"Expected shape {expected_output_shape}, but got {output.shape}"

def test_tkat_variable_selection():
assert keras.backend.backend() == BACKEND
batch_size, sequence_length, n_ahead = 16, 8, 3
num_unknow_features, num_know_features = 4, 3
num_embedding, num_hidden, num_heads = 4, 8, 2

tkat_model = TKAT(
sequence_length=sequence_length,
num_unknow_features=num_unknow_features,
num_know_features=num_know_features,
num_embedding=num_embedding,
num_hidden=num_hidden,
num_heads=num_heads,
n_ahead=n_ahead
)

input_shape = (batch_size, sequence_length + n_ahead, num_unknow_features + num_know_features)
input_data = generate_random_tensor(input_shape)

# Get the embedding layer output
embedding_layer = tkat_model.get_layer('embedding_layer') # Assuming you've named your EmbeddingLayer
embedded_input = embedding_layer(input_data)

# Access the variable selection networks
vsn_past = tkat_model.get_layer('vsn_past_features')
vsn_future = tkat_model.get_layer('vsn_future_features')

# Test VSN outputs
past_features = embedded_input[:, :sequence_length, :, :]
future_features = embedded_input[:, sequence_length:, :, -num_know_features:]

past_output = vsn_past(past_features)
future_output = vsn_future(future_features)

assert past_output.shape == (batch_size, sequence_length, num_hidden)
assert future_output.shape == (batch_size, n_ahead, num_hidden)



def test_tkat_attention():
assert keras.backend.backend() == BACKEND
batch_size, sequence_length, n_ahead = 8, 6, 2
num_unknow_features, num_know_features = 4, 3
num_embedding, num_hidden, num_heads = 4, 8, 2

tkat_model = TKAT(
sequence_length=sequence_length,
num_unknow_features=num_unknow_features,
num_know_features=num_know_features,
num_embedding=num_embedding,
num_hidden=num_hidden,
num_heads=num_heads,
n_ahead=n_ahead
)

input_shape = (batch_size, sequence_length + n_ahead, num_unknow_features + num_know_features)
input_data = generate_random_tensor(input_shape)

# Get the attention layer
attention_layer = next(layer for layer in tkat_model.layers if isinstance(layer, keras.layers.MultiHeadAttention))

# Test attention output
output = tkat_model(input_data)
assert output.shape == (batch_size, n_ahead)

def test_tkat_training():
assert keras.backend.backend() == BACKEND
batch_size, sequence_length, n_ahead = 64, 12, 4
num_unknow_features, num_know_features = 4, 3
num_embedding, num_hidden, num_heads = 8, 16, 4

tkat_model = TKAT(
sequence_length=sequence_length,
num_unknow_features=num_unknow_features,
num_know_features=num_know_features,
num_embedding=num_embedding,
num_hidden=num_hidden,
num_heads=num_heads,
n_ahead=n_ahead
)

input_shape = (batch_size, sequence_length + n_ahead, num_unknow_features + num_know_features)
input_data = generate_random_tensor(input_shape)
target_data = generate_random_tensor((batch_size, n_ahead))

tkat_model.compile(optimizer='adam', loss='mse')
history = tkat_model.fit(input_data, target_data, epochs=2, batch_size=16, verbose=0)

assert len(history.history['loss']) == 2
assert history.history['loss'][1] < history.history['loss'][0]

def test_tkat_prediction():
assert keras.backend.backend() == BACKEND
batch_size, sequence_length, n_ahead = 32, 10, 5
num_unknow_features, num_know_features = 3, 2
num_embedding, num_hidden, num_heads = 8, 16, 4

tkat_model = TKAT(
sequence_length=sequence_length,
num_unknow_features=num_unknow_features,
num_know_features=num_know_features,
num_embedding=num_embedding,
num_hidden=num_hidden,
num_heads=num_heads,
n_ahead=n_ahead
)

input_shape = (batch_size, sequence_length + n_ahead, num_unknow_features + num_know_features)
input_data = generate_random_tensor(input_shape)

predictions = tkat_model.predict(input_data)
assert predictions.shape == (batch_size, n_ahead)

if __name__ == "__main__":
pytest.main([__file__])
Loading
Loading