-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor/decompose primitives #25
Conversation
While working on the sampling feature for MPS I have realised that I'd prefer not to keep track of the ID of all bond IDs in the network. This is just more data to keep around and it's actually not that useful, since in the end I'm only ever applying I have spent some time today thinking about the implications of this going forward. I think it'd make the implementation of TTN a bit easier and it should work for the belief propagation algorithms as well, so I've decided I'm going for it. Since this PR already had to deal with Tensor's bond IDs and since the new way of using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the simplifications! I wasn't too attentive esp. in MPSxMPO, but it looks good!
|
||
memhandle = (malloc, free, "memory_handler") | ||
cutn.set_device_mem_handler(self.handle, memhandle) | ||
dev = cp.cuda.Device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you not need to pass device_id
anymore to cp.cuda.Device()
?
P.S. I love the simplifications and removal of the boiler plate! Do you think this class is still usable for other purposes such as intrinsic MPI support for the backend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In line 70 cp.cuda.Device(device_id).use()
I've set cupy to use device_id
, so that any call to cp
stuff will use it by default. In particular, cp.cuda.Device()
does as well, so you may think that dev == device_id
everytime. However, the sneaky bit is that device_id
may be None
. In that case, cp.cuda.Device(device_id).use()
is simple saying "use the default device" and dev = cp.cuda.Device()
is returning the id of said default device.
Do you think this class is still usable for other purposes such as intrinsic MPI support for the backend?
Yeah, as usable as it was before. We'd just need to keep multiple instances of CuTensorNetHandle
, each on a different device. The non-trivialy part would still be to refactor the MPS
class itself so that we can keep track of which tensors are on which device, so that we know which of these CuTensorNetHandle
s should be used when applying operations on them. And, of course, what to do with message passing when acting between tensors in different devices (possibly, just sending one of the tensors to the other device, updating and then sending back).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I am thinking about this more and more often - sounds like we really need it. Yes, I think we can use a common stencil-like approach as the first attempt. I will hope to start looking into this soonish.
Refactor of MPS algorithms so that every call to
tensor_svd
andtensor_qr
are replaced with the higher level functiondecompose
from cuTensorNet. This means that the following bits of code are no longer needed (and hence removed):CuTensorNetHandle
.MPSxGate
is greatly improved, since nowcontract
uses the "subscript" notation, which is more human-readable.MPSxMPO
algorithm does keep track of global bond IDs for the MPOs, since I found that's useful here. I don't think the code readability has improved that much here, but it's not worse than before.Tensor
class has been removed since we no longer need to keep track of bond IDs, nor generate tensor descriptors (thanks to the higher-level API from cuTensorNet).Additionally, I have taken the liberty of sneaking in a few related changes:
.use()
method ofcuda.Device
fromMPS
toCuTensorNetHandle
, where it more naturally belongs.Additionally, I have had a look at
contract_decompose
andsplit_gate
from cuTensorNet.split_gate
is never going to be useful for MPS, since the initial QR decompositions that are done would not decrease the rank of the tensors being SVD'd. Similarly for TreeTN, so this will only become useful when working with more general TN states.contract_decompose
is potentially useful since it doesn't just docontract
thendecompose
, but keeps track of the max possible dimension of the shared bond and impose it in the decomposition. However, in the places I apply QR and SVD, I had already taken this into account, so I don't expect any performance improvement. It'd be non-trivial (but not too difficult either) to refactor the code to usecontract_decompose
and, since it's still in theexperimental
module, I'm not using it for now, but I'll keep an eye on it.