Skip to content

Commit

Permalink
[xla:cpu] Add xla_cpu.store operation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686930171
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Oct 17, 2024
1 parent 6683d83 commit 719882f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
21 changes: 18 additions & 3 deletions xla/backends/cpu/codegen/ir/tests/ops.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
// RUN: xla_cpu_opt %s | xla_cpu_opt | FileCheck %s
// RUN: xla_cpu_opt %s --split-input-file | FileCheck %s

func.func @call_frame_arg(%arg0: !xla_cpu.call_frame) -> tensor<32x32xf32> {
func.func @load(%arg0: !xla_cpu.call_frame) -> tensor<32x32xf32> {
%0 = xla_cpu.load %arg0, 0 : tensor<32x32xf32>
return %0 : tensor<32x32xf32>
}

// CHECK-LABEL: @call_frame_arg(
// CHECK-LABEL: @load(
// CHECK: %[[ARG0:.+]]: !xla_cpu.call_frame
// CHECK: ) -> tensor<32x32xf32> {
// CHECK: %[[LOAD:.+]] = xla_cpu.load %[[ARG0]], 0 : tensor<32x32xf32>
// CHECK: return %[[LOAD]] : tensor<32x32xf32>
// CHECK: }

// -----

func.func @store(%arg0: !xla_cpu.call_frame, %arg1: tensor<32x32xf32>) {
xla_cpu.store %arg1 into %arg0, 0 : tensor<32x32xf32>
return
}

// CHECK-LABEL: @store(
// CHECK: %[[ARG0:.+]]: !xla_cpu.call_frame,
// CHECK: %[[ARG1:.+]]: tensor<32x32xf32>
// CHECK: ) {
// CHECK: xla_cpu.store %[[ARG1]] into %[[ARG0]], 0 : tensor<32x32xf32>
// CHECK: return
// CHECK: }
2 changes: 1 addition & 1 deletion xla/backends/cpu/codegen/ir/tests/types.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: xla_cpu_opt %s | xla_cpu_opt | FileCheck %s
// RUN: xla_cpu_opt %s | FileCheck %s

func.func @call_frame_arg(%arg0: !xla_cpu.call_frame) {
return
Expand Down
34 changes: 32 additions & 2 deletions xla/backends/cpu/codegen/ir/xla_cpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,40 @@ def XLACPU_LoadOp : XLACPU_Op<"load"> {
```
}];

let arguments = (ins XLACPU_CallFrame:$call_frame, I32Attr:$index);
let arguments = (ins XLACPU_CallFrame:$call_frame,
I32Attr:$index);

let results = (outs AnyStaticShapeTensor:$result);

let assemblyFormat = "$call_frame `,` $index attr-dict `:` type($result)";
let assemblyFormat = [{
$call_frame `,` $index attr-dict `:` type($result)
}];
}

//===----------------------------------------------------------------------===//
// !xla_cpu.store
//===----------------------------------------------------------------------===//

def XLACPU_StoreOp : XLACPU_Op<"store"> {
let summary = "Stores a tensor into an XLA:CPU call frame";

let description = [{
Stores a tensor into an XLA:CPU call frame at the given index.

```mlir
%0 = ... : tensor<32x32xf32>
xla_cpu.store %0 into %call_frame, 0 : tensor<32x32xf32>
```
}];

let arguments = (ins AnyStaticShapeTensor:$tensor,
XLACPU_CallFrame:$call_frame,
I32Attr:$index);

let assemblyFormat = [{
$tensor `into` $call_frame `,` $index attr-dict `:` type($tensor)
}];

}

#endif // XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_OPS

0 comments on commit 719882f

Please sign in to comment.