-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualization.py
72 lines (58 loc) · 2.06 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Callable, Optional
import matplotlib
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
def visualize_hopfield_network(
hopfield_network: Callable[[np.ndarray], np.ndarray],
energy_function: Callable[[np.ndarray], float],
initial_state: np.ndarray,
output_path: Optional[str] = None,
steps: int = 10,
):
state = initial_state
if output_path is not None:
matplotlib.use("Agg")
f, axs = plt.subplots(1, 3, figsize=(18, 6))
axs[0].set_title("Initial State")
axs[1].set_title("Current State")
axs[2].set_title("Energy")
axs[2].set_xlabel("Iterations")
axs[2].set_box_aspect(1)
axs[2].grid()
axs[2].yaxis.set_major_formatter(FormatStrFormatter("%.2E"))
axs[0].imshow(state.reshape(28, 28))
state_img = None
energy_line = None
energy = [energy_function(state)]
frames = []
# Create an animation of the network dynamics
for _ in range(steps):
if state_img is not None:
state_img.remove()
if energy_line is not None:
energy_line.remove()
state_img = axs[1].imshow(state.reshape(28, 28), animated=True)
state = hopfield_network(state)
energy.append(energy_function(state))
(energy_line,) = axs[2].plot(energy, color="blue")
if output_path is not None:
f.canvas.draw_idle()
frames.append(_figure_to_frame(f))
else:
plt.pause(1)
if output_path is not None:
# Convert the list of figures to a GIF
frames[0].save(
output_path,
save_all=True,
append_images=frames[1:],
duration=500, # Duration for each frame
loop=0,
) # Number of times the GIF should loop (0 means infinite)
return f, axs
def _figure_to_frame(fig: plt.Figure) -> Image:
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8")
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return Image.fromarray(image)