Skip to content

Conversation

@kevinchern
Copy link
Collaborator

@kevinchern kevinchern commented Nov 14, 2025

This PR introduces a block-spin update sampler.
It includes duplicate code from #40

@kevinchern kevinchern marked this pull request as draft November 17, 2025 21:24
@kevinchern kevinchern assigned VolodyaCO and mhramani and unassigned VolodyaCO and mhramani Nov 17, 2025
Copy link

@jackraymond jackraymond left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job, I'll come back to this when you add tests. Per your summary. Thanks for getting this draft up quickly.

Would be nice to have a tidy public GPU code to use and cite for speedup relative to our CPU implementation in higher throughput applications as soon as possible.

@kevinchern
Copy link
Collaborator Author

Samplers should have a consistent interface. Related: #52, #47

@kevinchern
Copy link
Collaborator Author

kevinchern commented Dec 11, 2025

@jackraymond I updated the PR:

  1. added tests (some notable mentions below)
    • test gibbs update by checking magnetization is close to expected
    • test metropolis update by checking magnetization is close to expected
    • test metropolis update by checking oscillatory behavior for beta = 0 case
  2. added documentation (first pass)
  3. capitalized Metropolis/Gibbs and also make it case-insensitive
  4. The sampler has a simplified signature. It no longer tries to do the colouring itself. Instead, it requests for a colouring function. This generalizes it beyond CPZ topologies. One could, for example, default to a colouring heuristic when no colouring is given, but I've opted away from that for modularity.
    5. Now a metadata dictionary is returned along with the sample.

@kevinchern kevinchern requested a review from thisac December 11, 2025 16:18
@kevinchern kevinchern marked this pull request as ready for review December 11, 2025 16:19
Copy link
Collaborator

@VolodyaCO VolodyaCO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments about form.

Overall looks fantastic. Thank you.

@VolodyaCO
Copy link
Collaborator

Can the user compile a method of an existing nn.Module subclassed instance? Ideally, we would want the user to decide whether they want to compile the _step method. It seems as though torch.compile(module) compiles the __call__ method, so maybe we can wrap _step in it and time the compiled and non-compiled versions.

@kevinchern
Copy link
Collaborator Author

kevinchern commented Dec 15, 2025

Addressed @thisac and @VolodyaCO 's PRs. Briefly,

  1. Remove inheritance of nn.Module
  2. Add functional module docstring
  3. Add explicit named args for, and renamed to randspin
  4. Improved documentation and type-hinting

Copy link

@jackraymond jackraymond left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good tests ran for me.

I'd recommend you allow setting of x as an initial condition. Check it doesn't move when num_sweeps is none, and check it oscillates under BlockMetropolis when temperature is infinity as tests perhaps.

@kevinchern
Copy link
Collaborator Author

kevinchern commented Dec 17, 2025

Looks good tests ran for me.

I'd recommend you allow setting of x as an initial condition. Check it doesn't move when num_sweeps is none, and check it oscillates under BlockMetropolis when temperature is infinity as tests perhaps.

In the latest commit, I made the following changes:

  • Fixed documentation RE nonzero & finite temperature regimes.
  • Allow setting of initial states.
    • Test by checking sampler.x.tolist() is equal to input initial states list
  • Added custom errors to distinguish from ValueError

@jackraymond @thisac

Copy link
Collaborator

@VolodyaCO VolodyaCO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All comments have been addressed. Thank you @kevinchern

Copy link
Contributor

@thisac thisac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kevinchern! 🎄

Comment on lines 55 to 64
if "low" in kwargs:
raise ValueError("Invalid keyword argument `low`.")
if "high" in kwargs:
raise ValueError("Invalid keyword argument `high`.")
return 2*torch.randint(0, 2, size,
generator=generator,
dtype=dtype,
device=device,
requires_grad=requires_grad,
**kwargs)-1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some cleanup needed. Spaces around - and *, and perhaps separating out the torch.randint() call since the final 2* and -1 get kind-of lost in there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also not convinced that this i needed. Couldn't you just use the regular torch.randint() and pass the results through bit2spin_soft()? Other than that this only sets low/high which isn't too cumbersome to do manually.

Copy link
Collaborator Author

@kevinchern kevinchern Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 get lost

Changed to b = randint ... and return 2 * b - 1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also not convinced that this i needed. Couldn't you just use the regular torch.randint() and pass the results through bit2spin_soft()? Other than that this only sets low/high which isn't too cumbersome to do manually.

My motivation for adding it is similar to adding a bit2spin function. It's very simple to write repeatedly, but it helps with readability to have a dedicated function. What are some reasons to avoid adding this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, either be clear about this being a special case of torch.randint(), and simplify it a bit to make it clear that it should work with all kwargs supported in randint.

def randspin(size: _size, **kwargs) -> torch.Tensor:
    """Wrapper for ``torch.randint()`` restricted to spin outputs."""
    if kwargs.get("high", None) or kwargs.get("low", None):
        raise ValueError(...)

    b = torch.randint(0, 2, size, **kwargs)
    return 2 * b - 1

or remove the kwargs from the signature and add all the relevant arguments, which circumvents the need to check for "high" and "low".

Comment on lines +31 to +44
ZEPHYR = dnx.zephyr_graph(1, coordinates=True)
GRBM_ZEPHYR = GRBM(ZEPHYR.nodes, ZEPHYR.edges)
CRAYON_ZEPHYR = dnx.zephyr_four_color

BIPARTITE = nx.complete_bipartite_graph(5, 3)
GRBM_BIPARTITE = GRBM(BIPARTITE.nodes, BIPARTITE.edges)
def CRAYON_BIPARTITE(b): return b < 5

GRBM_SINGLE = GRBM([0], [])
def CRAYON_SINGLE(s): 0

GRBM_CRAYON_TEST_CASES = [(GRBM_ZEPHYR, CRAYON_ZEPHYR),
(GRBM_BIPARTITE, CRAYON_BIPARTITE),
(GRBM_SINGLE, CRAYON_SINGLE)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better put in a setUpClass() method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understood correctly, that would be incompatible with @parameterized.expand

bss._x.data[:] = 1
zero = torch.tensor(0.0)
ones = torch.ones((sample_size, 1))
bss._gibbs_update(0.0, bss._partition[0], ones*zero)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, you should try to avoid calling hidden methods/attributes even in tests (only what is meant to be public should need to be tested), but need to think about this whether it makes sense here.

Overall, not sure if this test could be streamlined a bit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants