-
Notifications
You must be signed in to change notification settings - Fork 11
Add block-spin update sampler #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
jackraymond
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job, I'll come back to this when you add tests. Per your summary. Thanks for getting this draft up quickly.
Would be nice to have a tidy public GPU code to use and cite for speedup relative to our CPU implementation in higher throughput applications as soon as possible.
|
@jackraymond I updated the PR:
|
VolodyaCO
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments about form.
Overall looks fantastic. Thank you.
|
Can the user compile a method of an existing |
|
Addressed @thisac and @VolodyaCO 's PRs. Briefly,
|
jackraymond
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good tests ran for me.
I'd recommend you allow setting of x as an initial condition. Check it doesn't move when num_sweeps is none, and check it oscillates under BlockMetropolis when temperature is infinity as tests perhaps.
In the latest commit, I made the following changes:
|
VolodyaCO
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All comments have been addressed. Thank you @kevinchern
thisac
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @kevinchern! 🎄
dwave/plugins/torch/tensor.py
Outdated
| if "low" in kwargs: | ||
| raise ValueError("Invalid keyword argument `low`.") | ||
| if "high" in kwargs: | ||
| raise ValueError("Invalid keyword argument `high`.") | ||
| return 2*torch.randint(0, 2, size, | ||
| generator=generator, | ||
| dtype=dtype, | ||
| device=device, | ||
| requires_grad=requires_grad, | ||
| **kwargs)-1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some cleanup needed. Spaces around - and *, and perhaps separating out the torch.randint() call since the final 2* and -1 get kind-of lost in there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also not convinced that this i needed. Couldn't you just use the regular torch.randint() and pass the results through bit2spin_soft()? Other than that this only sets low/high which isn't too cumbersome to do manually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1get lost
Changed to b = randint ... and return 2 * b - 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also not convinced that this i needed. Couldn't you just use the regular torch.randint() and pass the results through bit2spin_soft()? Other than that this only sets low/high which isn't too cumbersome to do manually.
My motivation for adding it is similar to adding a bit2spin function. It's very simple to write repeatedly, but it helps with readability to have a dedicated function. What are some reasons to avoid adding this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, either be clear about this being a special case of torch.randint(), and simplify it a bit to make it clear that it should work with all kwargs supported in randint.
def randspin(size: _size, **kwargs) -> torch.Tensor:
"""Wrapper for ``torch.randint()`` restricted to spin outputs."""
if kwargs.get("high", None) or kwargs.get("low", None):
raise ValueError(...)
b = torch.randint(0, 2, size, **kwargs)
return 2 * b - 1or remove the kwargs from the signature and add all the relevant arguments, which circumvents the need to check for "high" and "low".
| ZEPHYR = dnx.zephyr_graph(1, coordinates=True) | ||
| GRBM_ZEPHYR = GRBM(ZEPHYR.nodes, ZEPHYR.edges) | ||
| CRAYON_ZEPHYR = dnx.zephyr_four_color | ||
|
|
||
| BIPARTITE = nx.complete_bipartite_graph(5, 3) | ||
| GRBM_BIPARTITE = GRBM(BIPARTITE.nodes, BIPARTITE.edges) | ||
| def CRAYON_BIPARTITE(b): return b < 5 | ||
|
|
||
| GRBM_SINGLE = GRBM([0], []) | ||
| def CRAYON_SINGLE(s): 0 | ||
|
|
||
| GRBM_CRAYON_TEST_CASES = [(GRBM_ZEPHYR, CRAYON_ZEPHYR), | ||
| (GRBM_BIPARTITE, CRAYON_BIPARTITE), | ||
| (GRBM_SINGLE, CRAYON_SINGLE)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better put in a setUpClass() method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understood correctly, that would be incompatible with @parameterized.expand
| bss._x.data[:] = 1 | ||
| zero = torch.tensor(0.0) | ||
| ones = torch.ones((sample_size, 1)) | ||
| bss._gibbs_update(0.0, bss._partition[0], ones*zero) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, you should try to avoid calling hidden methods/attributes even in tests (only what is meant to be public should need to be tested), but need to think about this whether it makes sense here.
Overall, not sure if this test could be streamlined a bit.
This PR introduces a block-spin update sampler.
It includes duplicate code from #40