Skip to content

Commit

Permalink
[Mosaic TPU][Python] Check validity of VectorLayout on init
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661226283
  • Loading branch information
tlongeri authored and jax authors committed Aug 9, 2024
1 parent e57a7e3 commit 77afe25
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
11 changes: 9 additions & 2 deletions jaxlib/mlir/_mlir_libs/tpu_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,14 +374,21 @@ PYBIND11_MODULE(_tpu_ext, m) {
.def(py::init([](int bitwidth, py::tuple offsets, py::tuple tiling,
MlirTpuImplicitDim implicit_dim) {
if (offsets.size() != 2) {
throw py::value_error("offsets should be of length 2");
throw py::value_error("Offsets should be of length 2");
}
return mlirTpuVectorLayoutCreate(
if (tiling.size() != 2) {
throw py::value_error("Tiling should be of length 2");
}
MlirTpuVectorLayout layout = mlirTpuVectorLayoutCreate(
bitwidth,
{offsetFromPyOffset(offsets[0]),
offsetFromPyOffset(offsets[1])},
{tiling[0].cast<int64_t>(), tiling[1].cast<int64_t>()},
implicit_dim);
if (!mlirTpuVectorLayoutIsValid(layout, TARGET_SHAPE)) {
throw py::value_error("Layout not valid for target shape");
}
return layout;
}),
py::arg("bitwidth"), py::arg("offsets"), py::arg("tiling"),
py::arg("implicit_dim"))
Expand Down
5 changes: 5 additions & 0 deletions jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ void mlirTpuVectorLayoutPrint(
unwrap(layout)->print<llvm::raw_ostream>(stream);
}

bool mlirTpuVectorLayoutIsValid(MlirTpuVectorLayout layout,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->isValid(unwrap(target_shape));
}

void mlirTpuVregDataBoundsDestroy(MlirTpuVregDataBounds data_bounds) {
delete unwrap(data_bounds);
}
Expand Down
3 changes: 3 additions & 0 deletions jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutEquivalentTo(
MLIR_CAPI_EXPORTED void mlirTpuVectorLayoutPrint(
MlirTpuVectorLayout layout, MlirStringCallback callback, void* user_data);

MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutIsValid(
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape);

MLIR_CAPI_EXPORTED void mlirTpuVregDataBoundsDestroy(
MlirTpuVregDataBounds data_bounds);

Expand Down

0 comments on commit 77afe25

Please sign in to comment.