-
Notifications
You must be signed in to change notification settings - Fork 0
/
integrated_model.py
52 lines (42 loc) · 1.78 KB
/
integrated_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
# Import custom modules
from kolmogorov_arnold_network import kolmogorov_arnold_network
from wae_mnist import WAE, build_encoder, build_decoder
def load_preprocess_data():
# Load MNIST data
(x_train, _), (x_test, _) = mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype('float32') / 255.0
x_test = np.expand_dims(x_test, -1).astype('float32') / 255.0
return x_train, x_test
class IntegratedModel:
def __init__(self, encoder, decoder, V, W):
self.encoder = encoder
self.decoder = decoder
self.V = V
self.W = W
def preprocess_with_ka_network(self, x):
# Flatten the images for processing with K-A network
flat_x = np.reshape(x, (x.shape[0], -1))
processed_x = np.array([kolmogorov_arnold_network(xi, self.V, self.W) for xi in flat_x])
return processed_x.reshape(x.shape)
def predict(self, x):
preprocessed_x = self.preprocess_with_ka_network(x)
z_mean, _ = self.encoder.predict(preprocessed_x)
reconstructed = self.decoder.predict(z_mean)
return reconstructed
# Setup
encoder = build_encoder()
decoder = build_decoder()
V = np.random.rand(784, 784) # Assuming flattening the 28x28 images
W = np.random.rand(784)
# Load and preprocess data
x_train, x_test = load_preprocess_data()
# Create integrated model instance
integrated_model = IntegratedModel(encoder, decoder, V, W)
# Use the integrated model to process and reconstruct images
reconstructed_images = integrated_model.predict(x_test[:10])
# Simple visualization of results (considering implementing a visualization function)
print("Reconstructed shapes:", reconstructed_images.shape)