Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Transpose primitives #17

Merged
merged 12 commits into from
Jun 11, 2024
Merged

Implement Transpose primitives #17

merged 12 commits into from
Jun 11, 2024

Conversation

ASKabalan
Copy link
Collaborator

First pull request for Tranpose ops

They work for cubes for now, (Pencils and slabs)

@ASKabalan ASKabalan linked an issue May 2, 2024 that may be closed by this pull request
7 tasks
@ASKabalan ASKabalan force-pushed the u/ASKabalan/transpose_ops branch 2 times, most recently from 68b1b73 to 822f657 Compare May 2, 2024 16:18
@@ -9,6 +9,7 @@
import jaxdecomp

# Initialize jax distributed to instruct jax local process which GPU to use
jaxdecomp.init()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed on JZ for now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's fine :-)

print(f"jd_tranposed_zy shape {jd_tranposed_zy.shape}")
print(f"jd_tranposed_yx shape {jd_tranposed_yx.shape}")

if pdims[1] == 1:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure that the grid is transposed and traposed back every 2 steps

case 'x_y' | 'y_z':
transpose_shape = (2, 0, 1)
case 'y_x' | 'z_y':
transpose_shape = (1, 2, 0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the global shape has to follow the col to row tranposition that cudecomp does?
This is fine for cubes, for now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens for non cubes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I solved this in the last commit

@@ -118,129 +119,23 @@ decompGridDescConfig_t getAutotunedGridConfig(decompGridDescConfig_t grid_config
return output_config;
};

/// XLA interface ops
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of three custom callable funtion, I added only one and it uses an enum to know which transposition to do


# Y to X
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Derivatives are actually just residuals because tranposition is linear.
I defined a VJP anyways using the back tranposes as bwd passes for the forrward ones and vice versa

@@ -4,13 +4,15 @@
import jaxdecomp.fft as fft
from jaxdecomp.fft import pfft3d, pifft3d

from ._src import ( # transposeXtoY, transposeYtoX, transposeYtoZ, transposeZtoY
from ._src import ( # transposeXtoY, transposeYtoX, transposeYtoZ, transposeZtoY,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this commented line, isort and yapf are stuck in an endless battle 🥲 during a precommit run

@ASKabalan ASKabalan requested a review from EiffL May 2, 2024 17:02
Copy link

@aboucaud aboucaud left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a nice job !

From my perspective, what is missing from such PR is a bit more documentation.
There are comments here and there but no docstrings or parts that express the reasoning behind what is done.
I also like the fact you tried to add some typing in Python, but it's rather sparse so far, hence I don't really see the point.

To be clear, making this work and behaving as expected is the priority. But then, I think all this very clever piece of code could benefit from statements / explanations on the design and strategy. This could be in the code itself or somewhere next to it. Feel free to disagree.

@ASKabalan
Copy link
Collaborator Author

Thank you @aboucaud .
The tranpose primitive uses a much better structure than the others, I am gradually trying to improve readability.
But it still missing some stuff that I am very close to solving (non-cubes and understand fft forward axis order).

That said, any feedback is much appreciated.
Once this PR is ready I would highly appreciate feed back on code readability and maintability.

@ASKabalan ASKabalan force-pushed the u/ASKabalan/transpose_ops branch 4 times, most recently from 74f1656 to b5fe313 Compare May 4, 2024 08:23
@ASKabalan
Copy link
Collaborator Author

@aboucaud @EiffL .
This is good to go.
Now works for cubes non cubes , slabs XY, slabs YZ and pencils.

@@ -44,16 +44,21 @@ def abstract(x, kind, pdims, global_shape):
assert kind in ['x_y', 'y_z', 'z_y', 'y_x']
match kind:
# From X to Y the axis are rolled by 1 and pdims are swapped wrt to the input pdims
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The process dims are being transposed at every step.
But I need the original pdims for my compiled primitive (the grid descriptor and the tranpose operator are indexed using gdim and pdim ... so even after a tranpose I need them back to find my primitive in m_TransposeDescriptors32 or m_TransposeDescriptors64)

@@ -141,7 +154,10 @@ def infer_sharding_from_operands(kind: str, mesh: Mesh,
arg_infos: Tuple[ShapeDtypeStruct],
result_infos: Tuple[ShapedArray]):
input_sharding = arg_infos[0].sharding
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every step the processor grid is being switched from col major to row major
A test is added for this

transpose_type = _jaxdecomp.TRANSPOSE_ZY
case 'y_x':
transpose_shape = (1, 2, 0)
transpose_type = _jaxdecomp.TRANSPOSE_YX
case _:
raise ValueError("Invalid kind")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get the original shape for my indexed operator (same as above)

@@ -98,6 +98,16 @@ def test_tranpose(pdims):

print(f"JD tranposed yz sharding {jd_tranposed_yz.sharding.spec}")
print(f"JD tranposed xy sharding {jd_tranposed_xy.sharding.spec}")
print(f"JD tranposed zy sharding {jd_tranposed_zy.sharding.spec}")
print(f"JD tranposed yx sharding {jd_tranposed_yx.sharding.spec}")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test that the sharding is being transposed

@ASKabalan ASKabalan requested a review from aboucaud May 4, 2024 08:32
@ASKabalan ASKabalan force-pushed the u/ASKabalan/transpose_ops branch from b5fe313 to d69c2a0 Compare May 4, 2024 08:33
@@ -150,7 +166,14 @@ def partition(kind: str, mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct],
input_sharding = NamedSharding(mesh, P(*arg_infos[0].sharding.spec))
output_sharding = NamedSharding(mesh, P(*result_infos.sharding.spec))
global_shape = arg_infos[0].shape
pdims = (get_axis_size(input_sharding, 0), get_axis_size(input_sharding, 1))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure that pdims provided to the c++ primitive is the original one (same reason as above)
But in the JAX abstract eval we need the reversed pdims because it changes the pencil shape
So c++ pdims are the original and cudecomp knows how to handle them
We mirror cudecomp reversal in our abstract eval



Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cartesian product testing of cubes non cubes, prime factors with pencils and slabs

@ASKabalan ASKabalan force-pushed the u/ASKabalan/transpose_ops branch from d69c2a0 to 48d9b49 Compare May 4, 2024 09:11
This was referenced May 4, 2024
@ASKabalan ASKabalan changed the title U/as kabalan/transpose ops Implement Transpose primitives May 4, 2024
transpose_pdims = pdims[::-1]
case 'y_x':
transpose_shape = (1, 2, 0)
transpose_pdims = pdims

Copy link
Collaborator Author

@ASKabalan ASKabalan May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A forward transpose cycle shifts the axis right for us (2 0 1)
A backward transpose cycle shifts the axis left (1 2 0)

Every transpose the pdims are switched (implicitly by cudecomp)
I need the original pdims for the grid descriptor
So I reverse the pdims every other transpose

@ASKabalan ASKabalan force-pushed the u/ASKabalan/transpose_ops branch from 48d9b49 to 7126c3d Compare May 30, 2024 18:02
@ASKabalan ASKabalan force-pushed the u/ASKabalan/transpose_ops branch from 7126c3d to 1d1a779 Compare May 31, 2024 15:36
Copy link
Member

@EiffL EiffL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@EiffL EiffL merged commit 80172aa into main Jun 11, 2024
1 check passed
@ASKabalan ASKabalan deleted the u/ASKabalan/transpose_ops branch June 19, 2024 09:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Things to update for version 0.0.1
3 participants