Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TK] Add support for ops required for Flash Attention 2 #385

Merged
merged 2 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,13 @@ def wrapper(f):
### ========================================================================
### Math Operations
### ========================================================================
def handle_exp2(self, op, val):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(val,),
kwargs={},
)

def handle_vector_constant(
self, op, shape: Tuple[int, ...], dtype, value: int | float
Expand All @@ -278,15 +285,82 @@ def handle_vector_constant(
### ========================================================================
### Reduction Operations
### ========================================================================
def handle_vector_max(self, op, vector, axis=None, acc=None):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(vector, axis, acc),
kwargs={},
)

def handle_vector_sum(self, op, vector, axis=None, acc=None):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(vector, axis, acc),
kwargs={},
)

def handle_vector_dot(self, op, lhs, rhs, acc):
def handle_vector_dot(self, op, lhs, rhs, acc=None):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(lhs, rhs, acc),
kwargs={},
)

### ========================================================================
### Shape Manipulation Operations
### ========================================================================
def handle_vector_broadcast(self, op, vector, leading_sizes):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(vector, leading_sizes),
kwargs={},
)

def handle_vector_broadcast_in_dim(self, op, vector, shape, broadcast_dimensions):
# Currently, we do not have a corressponding op in MLIR, so
# we trace this to broadcast + transpose.
# TODO: Add a vector dialect op for this in MLIR.

# Remove broadcast_dimensions from shape.
shape_with_leading = tuple(
dim for i, dim in enumerate(shape) if i not in broadcast_dimensions
)

# Broadcast
broadcasted_vector = self.region_graph.create_proxy(
"call_function",
target=ops.vector_broadcast,
args=(vector, shape_with_leading),
kwargs={},
)

# Get the permutation for the transpose.
permutation = tuple(
i for i in range(len(shape)) if i not in broadcast_dimensions
)
permutation = permutation + tuple(broadcast_dimensions)
print(permutation)

# Transpose
return self.region_graph.create_proxy(
"call_function",
target=ops.vector_transpose,
args=(broadcasted_vector, permutation),
kwargs={},
)

def handle_vector_transpose(self, op, vector, permutation):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(vector, permutation),
kwargs={},
)


###############################################################################
# Launch context
Expand Down
31 changes: 29 additions & 2 deletions python/shark_turbine/kernel/compiler/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Value,
VectorType,
arith_d,
math_d,
builtin_d,
)

Expand Down Expand Up @@ -139,7 +140,7 @@ def binary_arithmetic(

def binary_vector_arithmetic(
self, op: str, lhs: IRProxyValue, rhs: IRProxyValue
) -> Value:
) -> IRProxyValue:
lhs_ir = lhs.ir_value
rhs_ir = rhs.ir_value
lhs_element_type = VectorType(lhs_ir.type).element_type
Expand All @@ -149,10 +150,33 @@ def binary_vector_arithmetic(
handler = getattr(self, attr_name)
except AttributeError:
raise CodegenError(
f"Cannot perform binary arithmetic operation '{op}' between {lhs.type} and {rhs.type} (tried '{attr_name}')"
f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir.type} and {rhs_ir.type} (tried '{attr_name}')"
)
return handler(lhs, rhs)

def unary_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
val_ir_type = val.ir_value.type
attr_name = f"unary_{op}_{val_ir_type}"
try:
handler = getattr(self, attr_name)
except AttributeError:
raise CodegenError(
f"Cannot perform unary arithmetic operation '{op}' on {val_ir_type} (tried '{attr_name}')"
)
return handler(val)

def unary_vector_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
val_ir = val.ir_value
val_element_type = VectorType(val_ir.type).element_type
attr_name = f"unary_{op}_{val_element_type}"
try:
handler = getattr(self, attr_name)
except AttributeError:
raise CodegenError(
f"Cannot perform unary arithmetic operation '{op}' on {val_ir.type} (tried '{attr_name}')"
)
return handler(val)

def promote_index_to_f32(self, value: Value, to_type: IrType) -> Value:
i32_type = IntegerType.get_signless(32)
i32 = arith_d.index_cast(i32_type, value)
Expand Down Expand Up @@ -215,5 +239,8 @@ def binary_truediv_f32_f32(
) -> IRProxyValue:
return IRProxyValue(arith_d.divf(lhs.ir_value, rhs.ir_value))

def unary_exp2_f32(self, val: IRProxyValue) -> IRProxyValue:
return IRProxyValue(math_d.exp2(val.ir_value))


ScalarBuilder = _ScalarBuilder()
Loading
Loading