From 51a7d51e3123e5b19ff5e072e9d6c635b8f671c9 Mon Sep 17 00:00:00 2001 From: kchern Date: Thu, 13 Nov 2025 18:42:23 +0000 Subject: [PATCH 1/2] Add minimal example of MMD-AE --- examples/mmd_ae.py | 144 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 examples/mmd_ae.py diff --git a/examples/mmd_ae.py b/examples/mmd_ae.py new file mode 100644 index 0000000..87beeee --- /dev/null +++ b/examples/mmd_ae.py @@ -0,0 +1,144 @@ +from itertools import cycle +from math import prod +from os import makedirs + +import torch +from torch import nn +from torch.optim import SGD, AdamW +from torch.utils.data import DataLoader +from torchvision.datasets import MNIST +from torchvision.transforms.v2 import Compose, ToDtype, ToImage +from torchvision.utils import make_grid, save_image + +from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM +from dwave.plugins.torch.nn import (ConvolutionNetwork, FullyConnectedNetwork, LinearBlock, + MaximumMeanDiscrepancy, RadialBasis, StraightThroughTanh, + rands_like, zephyr_subgraph) +from dwave.system import DWaveSampler + + +@torch.compile +class Autoencoder(nn.Module): + def __init__(self, shape, n_bits): + super().__init__() + dim = prod(shape) + c, h, w = shape + chidden = 1 + depth_fcnn = 3 + depth_cnn = 3 + dropout = 0.0 + self.encoder = nn.Sequential( + ConvolutionNetwork([chidden]*depth_cnn, shape), + nn.Flatten(), + FullyConnectedNetwork(chidden*h*w, n_bits, depth_fcnn, False, dropout), + ) + self.mixer = LinearBlock(n_bits, n_bits, False, dropout) + self.binarizer = StraightThroughTanh() + self.decoder = nn.Sequential( + FullyConnectedNetwork(n_bits, chidden*h*w, depth_fcnn, False, dropout), + nn.Unflatten(1, (chidden, h, w)), + ConvolutionNetwork([chidden]*(depth_cnn-1) + [1], (chidden, h, w)) + ) + + def decode(self, q): + z = self.mixer(q) + xhat = self.decoder(z) + return z, xhat + + def forward(self, x): + spins = self.binarizer(self.encoder(x)) + z, xhat = self.decode(spins) + return spins, z, xhat + + +def collect_stats(model, grbm, x, q, compute_mmd): + s, z, xhat = model(x) + zgen, xgen = model.decode(q) + stats = { + "quasi": grbm.quasi_objective(s.detach(), q), + "bce": nn.functional.binary_cross_entropy_with_logits(xhat, x), + "mmd": compute_mmd(s, q), + "mmd2": compute_mmd(z, zgen), + } + return stats + + +def get_dataset(bs, data_dir="/tmp/"): + transforms = Compose([ToImage(), ToDtype(torch.float32, scale=True)]) + train_kwargs = dict(root=data_dir, download=True) + transforms = Compose([transforms, lambda x: 1 - x]) + data_train = MNIST(transform=transforms, **train_kwargs) + train_loader = DataLoader(data_train, bs, True) + return train_loader + + +def round_graph_down(graph, group_size): + n_in = graph.number_of_nodes() + no = group_size*(n_in//group_size) + return graph.subgraph(list(graph.nodes)[:no]) + + +def run(*, num_steps): + sampler = DWaveSampler(solver="Advantage2_system1.7") + sample_params = dict(num_reads=500, annealing_time=0.5, answer_mode="raw", auto_scale=False) + h_range, j_range = sampler.properties["h_range"], sampler.properties["j_range"] + outdir = "output/mmd_ae/" + makedirs(outdir, exist_ok=True) + + device = "cuda" + + # Setup data + train_loader = get_dataset(500) + + # Instantiate model + G = zephyr_subgraph(sampler.to_networkx_graph(), 4) + nodes = list(G.nodes) + edges = list(G.edges) + grbm = GRBM(nodes, edges).to(device) + model = Autoencoder((1, 28, 28), grbm.n_nodes).to(device) + model.train() + grbm.train() + + compute_mmd = MaximumMeanDiscrepancy(RadialBasis()).to(device) + + opt_grbm = SGD(grbm.parameters(), lr=1e-3) + opt_ae = AdamW(model.parameters(), lr=1e-3) + + for step, (x, y) in enumerate(cycle(train_loader)): + torch.cuda.empty_cache() + if step > num_steps: + break + # Send data to device + x = x.to(device).float() + + q = grbm.sample(sampler, prefactor=1, linear_range=h_range, quadratic_range=j_range, + device=device, sample_params=sample_params) + + # Train autoencoder + stats = collect_stats(model, grbm, x, q, compute_mmd) + opt_ae.zero_grad() + (stats["bce"] + stats["mmd"] + stats["mmd2"]).backward() + opt_ae.step() + + # Train GRBM + if step < 1000: + # NOTE: collecting stats because the autoencoder has been updated. + stats = collect_stats(model, grbm, x, q, compute_mmd) + opt_grbm.zero_grad() + stats['quasi'].backward() + opt_grbm.step() + print(step, {k: v.item() for k, v in stats.items()}) + if step % 10 == 0: + with torch.no_grad(): + grbm.eval() + xgen = model.decode(q[:100])[-1] + xuni = model.decode(rands_like(q[:100]))[-1] + xhat = model(x[:100])[-1] + save_image(make_grid(xgen.sigmoid(), 10, pad_value=1), outdir + "xgen.png") + save_image(make_grid(xhat.sigmoid(), 10, pad_value=1), outdir + "xhat.png") + save_image(make_grid(xuni.sigmoid(), 10, pad_value=1), outdir + "xuni.png") + grbm.train() + + +if __name__ == "__main__": + run(num_steps=10_000) From 2d28639dd23bc00faf92cded5f896bc75875b08e Mon Sep 17 00:00:00 2001 From: kchern Date: Thu, 20 Nov 2025 23:26:12 +0000 Subject: [PATCH 2/2] Simplify model and use public solver --- examples/mmd_ae.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/examples/mmd_ae.py b/examples/mmd_ae.py index 87beeee..f7bebf3 100644 --- a/examples/mmd_ae.py +++ b/examples/mmd_ae.py @@ -11,9 +11,9 @@ from torchvision.utils import make_grid, save_image from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM -from dwave.plugins.torch.nn import (ConvolutionNetwork, FullyConnectedNetwork, LinearBlock, - MaximumMeanDiscrepancy, RadialBasis, StraightThroughTanh, - rands_like, zephyr_subgraph) +from dwave.plugins.torch.nn.modules import (ConvolutionNetwork, FullyConnectedNetwork, + MaximumMeanDiscrepancy, RadialBasis, + StraightThroughTanh, rands_like, zephyr_subgraph) from dwave.system import DWaveSampler @@ -32,7 +32,6 @@ def __init__(self, shape, n_bits): nn.Flatten(), FullyConnectedNetwork(chidden*h*w, n_bits, depth_fcnn, False, dropout), ) - self.mixer = LinearBlock(n_bits, n_bits, False, dropout) self.binarizer = StraightThroughTanh() self.decoder = nn.Sequential( FullyConnectedNetwork(n_bits, chidden*h*w, depth_fcnn, False, dropout), @@ -41,24 +40,21 @@ def __init__(self, shape, n_bits): ) def decode(self, q): - z = self.mixer(q) - xhat = self.decoder(z) - return z, xhat + xhat = self.decoder(q) + return xhat def forward(self, x): spins = self.binarizer(self.encoder(x)) - z, xhat = self.decode(spins) - return spins, z, xhat + xhat = self.decode(spins) + return spins, xhat def collect_stats(model, grbm, x, q, compute_mmd): - s, z, xhat = model(x) - zgen, xgen = model.decode(q) + s, xhat = model(x) stats = { "quasi": grbm.quasi_objective(s.detach(), q), "bce": nn.functional.binary_cross_entropy_with_logits(xhat, x), "mmd": compute_mmd(s, q), - "mmd2": compute_mmd(z, zgen), } return stats @@ -79,10 +75,10 @@ def round_graph_down(graph, group_size): def run(*, num_steps): - sampler = DWaveSampler(solver="Advantage2_system1.7") + sampler = DWaveSampler(solver="Advantage2_system1.8") sample_params = dict(num_reads=500, annealing_time=0.5, answer_mode="raw", auto_scale=False) h_range, j_range = sampler.properties["h_range"], sampler.properties["j_range"] - outdir = "output/mmd_ae/" + outdir = "output/example_mmd_ae/" makedirs(outdir, exist_ok=True) device = "cuda" @@ -117,7 +113,7 @@ def run(*, num_steps): # Train autoencoder stats = collect_stats(model, grbm, x, q, compute_mmd) opt_ae.zero_grad() - (stats["bce"] + stats["mmd"] + stats["mmd2"]).backward() + (stats["bce"] + stats["mmd"]).backward() opt_ae.step() # Train GRBM @@ -131,9 +127,10 @@ def run(*, num_steps): if step % 10 == 0: with torch.no_grad(): grbm.eval() - xgen = model.decode(q[:100])[-1] - xuni = model.decode(rands_like(q[:100]))[-1] + xgen = model.decode(q[:100]) + xuni = model.decode(rands_like(q[:100])) xhat = model(x[:100])[-1] + save_image(make_grid(x[:100], 10, pad_value=1), outdir + "x.png") save_image(make_grid(xgen.sigmoid(), 10, pad_value=1), outdir + "xgen.png") save_image(make_grid(xhat.sigmoid(), 10, pad_value=1), outdir + "xhat.png") save_image(make_grid(xuni.sigmoid(), 10, pad_value=1), outdir + "xuni.png")