-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Transpose primitives #17
Conversation
68b1b73
to
822f657
Compare
@@ -9,6 +9,7 @@ | |||
import jaxdecomp | |||
|
|||
# Initialize jax distributed to instruct jax local process which GPU to use | |||
jaxdecomp.init() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is needed on JZ for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that's fine :-)
print(f"jd_tranposed_zy shape {jd_tranposed_zy.shape}") | ||
print(f"jd_tranposed_yx shape {jd_tranposed_yx.shape}") | ||
|
||
if pdims[1] == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure that the grid is transposed and traposed back every 2 steps
case 'x_y' | 'y_z': | ||
transpose_shape = (2, 0, 1) | ||
case 'y_x' | 'z_y': | ||
transpose_shape = (1, 2, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps the global shape has to follow the col to row tranposition that cudecomp does?
This is fine for cubes, for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens for non cubes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I solved this in the last commit
@@ -118,129 +119,23 @@ decompGridDescConfig_t getAutotunedGridConfig(decompGridDescConfig_t grid_config | |||
return output_config; | |||
}; | |||
|
|||
/// XLA interface ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of three custom callable funtion, I added only one and it uses an enum to know which transposition to do
|
||
# Y to X |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Derivatives are actually just residuals because tranposition is linear.
I defined a VJP anyways using the back tranposes as bwd passes for the forrward ones and vice versa
@@ -4,13 +4,15 @@ | |||
import jaxdecomp.fft as fft | |||
from jaxdecomp.fft import pfft3d, pifft3d | |||
|
|||
from ._src import ( # transposeXtoY, transposeYtoX, transposeYtoZ, transposeZtoY | |||
from ._src import ( # transposeXtoY, transposeYtoX, transposeYtoZ, transposeZtoY, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without this commented line, isort and yapf are stuck in an endless battle 🥲 during a precommit run
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a nice job !
From my perspective, what is missing from such PR is a bit more documentation.
There are comments here and there but no docstrings or parts that express the reasoning behind what is done.
I also like the fact you tried to add some typing in Python, but it's rather sparse so far, hence I don't really see the point.
To be clear, making this work and behaving as expected is the priority. But then, I think all this very clever piece of code could benefit from statements / explanations on the design and strategy. This could be in the code itself or somewhere next to it. Feel free to disagree.
Thank you @aboucaud . That said, any feedback is much appreciated. |
74f1656
to
b5fe313
Compare
@@ -44,16 +44,21 @@ def abstract(x, kind, pdims, global_shape): | |||
assert kind in ['x_y', 'y_z', 'z_y', 'y_x'] | |||
match kind: | |||
# From X to Y the axis are rolled by 1 and pdims are swapped wrt to the input pdims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The process dims are being transposed at every step.
But I need the original pdims for my compiled primitive (the grid descriptor and the tranpose operator are indexed using gdim
and pdim
... so even after a tranpose I need them back to find my primitive in m_TransposeDescriptors32
or m_TransposeDescriptors64
)
@@ -141,7 +154,10 @@ def infer_sharding_from_operands(kind: str, mesh: Mesh, | |||
arg_infos: Tuple[ShapeDtypeStruct], | |||
result_infos: Tuple[ShapedArray]): | |||
input_sharding = arg_infos[0].sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every step the processor grid is being switched from col major to row major
A test is added for this
transpose_type = _jaxdecomp.TRANSPOSE_ZY | ||
case 'y_x': | ||
transpose_shape = (1, 2, 0) | ||
transpose_type = _jaxdecomp.TRANSPOSE_YX | ||
case _: | ||
raise ValueError("Invalid kind") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get the original shape for my indexed operator (same as above)
@@ -98,6 +98,16 @@ def test_tranpose(pdims): | |||
|
|||
print(f"JD tranposed yz sharding {jd_tranposed_yz.sharding.spec}") | |||
print(f"JD tranposed xy sharding {jd_tranposed_xy.sharding.spec}") | |||
print(f"JD tranposed zy sharding {jd_tranposed_zy.sharding.spec}") | |||
print(f"JD tranposed yx sharding {jd_tranposed_yx.sharding.spec}") | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test that the sharding is being transposed
b5fe313
to
d69c2a0
Compare
jaxdecomp/_src/transpose.py
Outdated
@@ -150,7 +166,14 @@ def partition(kind: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct], | |||
input_sharding = NamedSharding(mesh, P(*arg_infos[0].sharding.spec)) | |||
output_sharding = NamedSharding(mesh, P(*result_infos.sharding.spec)) | |||
global_shape = arg_infos[0].shape | |||
pdims = (get_axis_size(input_sharding, 0), get_axis_size(input_sharding, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure that pdims
provided to the c++ primitive is the original one (same reason as above)
But in the JAX abstract eval we need the reversed pdims because it changes the pencil shape
So c++ pdims
are the original and cudecomp knows how to handle them
We mirror cudecomp reversal in our abstract eval
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cartesian product testing of cubes non cubes, prime factors with pencils and slabs
d69c2a0
to
48d9b49
Compare
transpose_pdims = pdims[::-1] | ||
case 'y_x': | ||
transpose_shape = (1, 2, 0) | ||
transpose_pdims = pdims | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A forward transpose cycle shifts the axis right for us (2 0 1)
A backward transpose cycle shifts the axis left (1 2 0)
Every transpose the pdims are switched (implicitly by cudecomp)
I need the original pdims for the grid descriptor
So I reverse the pdims every other transpose
48d9b49
to
7126c3d
Compare
7126c3d
to
1d1a779
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
First pull request for Tranpose ops
They work for cubes for now, (Pencils and slabs)