Skip to content

Commit

Permalink
Change the flags for gemm and how sdpa upcast should work.
Browse files Browse the repository at this point in the history
Fix issue where cmul doesn't have a MPS implementation.
  • Loading branch information
liuliu committed Aug 12, 2024
1 parent 00be1b4 commit 6c30517
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 11 deletions.
10 changes: 5 additions & 5 deletions lib/nnc/ccv_cnnp_model_addons.c
Original file line number Diff line number Diff line change
Expand Up @@ -3861,7 +3861,7 @@ typedef struct {
float scale;
int is_causal;
int has_attn_mask;
int upcast;
int flags;
int fused_unify_head_weights;
int no_bias;
} ccv_cnnp_model_scaled_dot_product_attention_t;
Expand Down Expand Up @@ -3889,7 +3889,7 @@ static void _ccv_cnnp_scaled_dot_product_attention_build(ccv_cnnp_model_t* const
cmd.cmd = CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD;
cmd.info.scaled_dot_product_attention.scale = self->scale;
cmd.info.scaled_dot_product_attention.is_causal = self->is_causal;
cmd.info.scaled_dot_product_attention.upcast = self->upcast;
cmd.info.scaled_dot_product_attention.flags = self->flags;
ccv_nnc_tensor_param_t output_params[3];
ccv_nnc_tensor_symbol_t output;
ccv_nnc_tensor_symbol_t saved_softmax_lse;
Expand Down Expand Up @@ -3976,7 +3976,7 @@ static const ccv_cnnp_model_vtab_t ccv_cnnp_scaled_dot_product_attention_fused_i
.copy = _ccv_cnnp_scaled_dot_product_attention_copy,
};

ccv_cnnp_model_t* ccv_cnnp_scaled_dot_product_attention(const float scale, const int is_causal, const int has_attn_mask, const int upcast, const int fused_unify_head_weights, const int no_bias, const int is_trainable, const char* const name)
ccv_cnnp_model_t* ccv_cnnp_scaled_dot_product_attention(const float scale, const int is_causal, const int has_attn_mask, const int flags, const int fused_unify_head_weights, const int no_bias, const int is_trainable, const char* const name)
{
ccv_cnnp_model_scaled_dot_product_attention_t* const model_scaled_dot_product_attention = (ccv_cnnp_model_scaled_dot_product_attention_t*)cccalloc(1, sizeof(ccv_cnnp_model_scaled_dot_product_attention_t));
model_scaled_dot_product_attention->super.isa = fused_unify_head_weights ? &ccv_cnnp_scaled_dot_product_attention_fused_isa : &ccv_cnnp_scaled_dot_product_attention_isa;
Expand All @@ -3992,7 +3992,7 @@ ccv_cnnp_model_t* ccv_cnnp_scaled_dot_product_attention(const float scale, const
model_scaled_dot_product_attention->scale = scale;
model_scaled_dot_product_attention->is_causal = is_causal;
model_scaled_dot_product_attention->has_attn_mask = has_attn_mask;
model_scaled_dot_product_attention->upcast = upcast;
model_scaled_dot_product_attention->flags = flags;
model_scaled_dot_product_attention->fused_unify_head_weights = fused_unify_head_weights;
model_scaled_dot_product_attention->no_bias = no_bias;
return (ccv_cnnp_model_t*)model_scaled_dot_product_attention;
Expand All @@ -4001,5 +4001,5 @@ ccv_cnnp_model_t* ccv_cnnp_scaled_dot_product_attention(const float scale, const
static ccv_cnnp_model_t* _ccv_cnnp_scaled_dot_product_attention_copy(const ccv_cnnp_model_t* const super, void* const context)
{
const ccv_cnnp_model_scaled_dot_product_attention_t* const self = (const ccv_cnnp_model_scaled_dot_product_attention_t*)super;
return ccv_cnnp_scaled_dot_product_attention(self->scale, self->is_causal, self->has_attn_mask, self->upcast, self->fused_unify_head_weights, self->no_bias, self->super.is_trainable, self->super.name);
return ccv_cnnp_scaled_dot_product_attention(self->scale, self->is_causal, self->has_attn_mask, self->flags, self->fused_unify_head_weights, self->no_bias, self->super.is_trainable, self->super.name);
}
12 changes: 9 additions & 3 deletions lib/nnc/ccv_nnc.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ enum {
CCV_NNC_PAD_REPLICATE = 1, /**< Pad by replicating the edge. */
};

enum {
CCV_NNC_GEMM_32F = 0x1, /**< For GEMM (or similar op), whether prefer to use FP32 for accumulator. */
CCV_NNC_GEMM_32TF = 0x2, /**< For GEMM (or similar op), whether prefer to use TF32 for accumulator. */
CCV_NNC_GEMM_16F = 0x4, /**< For GEMM (or similar op), whether prefer to use FP16 for accumulator. */
};

/**
* Parameters for command.
*/
Expand Down Expand Up @@ -256,7 +262,7 @@ typedef struct {
struct {
float scale; /**< [scaled_dot_product_attention.scale] The scale we multiple to the dot product of Q & K */
int is_causal; /**< [scaled_dot_product_attention.is_causal] Whether we have causal matrix associated with the attention. The attention mask will be cut to triangular if provided. */
int upcast; /**< [scaled_dot_product_attention.upcast] Whether we want to run the attention computation at higher precision (from FP16 to FP32). */
int flags; /**< [scaled_dot_product_attention.flags] Which precision is preferred for accumulator, FP16 or FP32. */
int deterministic; /**< [scaled_dot_product_attention.deterministic] Whether we want the attention computation to be deterministic (CUDA only). */
} scaled_dot_product_attention;
struct {
Expand Down Expand Up @@ -4691,14 +4697,14 @@ CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_contiguous(const char* const name);
* @param scale The scale to be applied to the qk dot product.
* @param is_causal Whether to apply is_causal mask to it. If both attn_mask and is_causal supplied, we will cut attn_mask to upper right triangle.
* @param has_attn_mask Whether the input would accept a 4th parameter the attention mask.
* @param upcast Whether the attention computation will be run at higher precision (from FP16 to FP32).
* @param flags Which precision is preferred for the attention computation be run at (FP16 or FP32).
* @param fused_unify_head_weights Whether we also have unifying head weight fused into it. The output would be in shape of (N, S, H * Ev).
* @param no_bias Whether we have bias or not for the unifying head output.
* @param is_trainable Whether or not it is trainable (if weight / bias provided).
* @param name The unique name of the model.
* @return A model that can apply scaled dot product attention compute.
*/
CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_scaled_dot_product_attention(const float scale, const int is_causal, const int has_attn_mask, const int upcast, const int fused_unify_head_weights, const int no_bias, const int is_trainable, const char* const name);
CCV_WARN_UNUSED(ccv_cnnp_model_t*) ccv_cnnp_scaled_dot_product_attention(const float scale, const int is_causal, const int has_attn_mask, const int flags, const int fused_unify_head_weights, const int no_bias, const int is_trainable, const char* const name);

/** @} */

Expand Down
53 changes: 52 additions & 1 deletion lib/nnc/cmd/blas/mps/ccv_nnc_cmul_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,58 @@ static int _ccv_nnc_cmul_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint
ccv_nnc_mfa_encode_cmul(context, params, command_batch, tensors, tensor_offsets);
ccv_nnc_stream_context_finish_command_batch(stream_context, command_batch);
} else {
assert(0);
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
const ccv_nnc_tensor_view_t* const a = (const ccv_nnc_tensor_view_t*)inputs[0];
const ccv_nnc_tensor_view_t* const b = (const ccv_nnc_tensor_view_t*)inputs[1];
ccv_nnc_tensor_view_t* const c = (ccv_nnc_tensor_view_t*)outputs[0];
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size);
int indices[2];
int nd = ccv_nnc_tensor_nd(a->info.dim);
assert(nd = ccv_nnc_tensor_nd(b->info.dim));
assert(nd = ccv_nnc_tensor_nd(c->info.dim));
MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, a->info.dim, a->stride, &mps_input_a);
[inputTensors addObject:mps_input_a];
MPSGraphShapedType* mps_a_shape = ccv_nnc_mps_graph_tensor_input_shape(a, a->info.dim, a->stride);
[inputShapedTypes addObject:mps_a_shape];
MPSGraphTensor* mps_input_b;
MPSGraphTensor* mps_b = ccv_nnc_mps_graph_tensor_input(graph, b, b->info.dim, b->stride, &mps_input_b);
[inputTensors addObject:mps_input_b];
MPSGraphShapedType* mps_b_shape = ccv_nnc_mps_graph_tensor_input_shape(b, b->info.dim, b->stride);
[inputShapedTypes addObject:mps_b_shape];
int i;
// Reshape to [..., n / 2, 2]
NSMutableArray<NSNumber*>* a_shape = [NSMutableArray new];
for (i = 0; i < nd - 1; i++)
[a_shape addObject:@(a->info.dim[i])];
[a_shape addObject: @(a->info.dim[nd - 1] / 2)];
[a_shape addObject: @2];
mps_a = [graph reshapeTensor:mps_a withShape:a_shape name:nil];
[a_shape release];
NSArray<MPSGraphTensor*>* mps_a_splits = [graph splitTensor:mps_a numSplits:2 axis:nd name:nil];
NSMutableArray<NSNumber*>* b_shape = [NSMutableArray new];
for (i = 0; i < nd - 1; i++)
[b_shape addObject:@(b->info.dim[i])];
[b_shape addObject: @(b->info.dim[nd - 1] / 2)];
[b_shape addObject: @2];
mps_b = [graph reshapeTensor:mps_b withShape:b_shape name:nil];
[b_shape release];
NSArray<MPSGraphTensor*>* mps_b_splits = [graph splitTensor:mps_b numSplits:2 axis:nd name:nil];
MPSGraphTensor* mps_c_0 = [graph subtractionWithPrimaryTensor:[graph multiplicationWithPrimaryTensor:mps_a_splits[0] secondaryTensor:mps_b_splits[0] name:nil] secondaryTensor:[graph multiplicationWithPrimaryTensor:mps_a_splits[1] secondaryTensor:mps_b_splits[1] name:nil] name:nil];
MPSGraphTensor* mps_c_1 = [graph additionWithPrimaryTensor:[graph multiplicationWithPrimaryTensor:mps_a_splits[0] secondaryTensor:mps_b_splits[1] name:nil] secondaryTensor:[graph multiplicationWithPrimaryTensor:mps_a_splits[1] secondaryTensor:mps_b_splits[0] name:nil] name:nil];
NSMutableArray<NSNumber*>* c_shape = [NSMutableArray new];
for (i = 0; i < nd; i++)
[c_shape addObject:@(c->info.dim[i])];
MPSGraphTensor* mps_c = [graph reshapeTensor:[graph concatTensor:mps_c_0 withTensor:mps_c_1 dimension:nd name:nil] withShape:c_shape name:nil];
[resultTensors addObject:mps_c];
[c_shape release];
});
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, a->info.dim, a->stride);
MPSGraphTensorData* data_b = ccv_nnc_mps_graph_tensor_data(b, b->info.dim, b->stride);
MPSGraphTensorData* data[] = {data_a, data_b};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]]], &c, (int*[]){ c->info.dim }, (int*[]){ c->stride }, 1, 0);
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
}
return CCV_NNC_EXEC_SUCCESS;
Expand Down
4 changes: 3 additions & 1 deletion lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,10 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint

ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();
const int is_mfa_gemv = !is_batched && ((a_rows == 1 && is_transpose_w && (w_rows % 4) == 0) || (!is_transpose_a && w_cols == 1 && (a_cols % 4) == 0));
// v1 only supports the same precision of accumulator as the tensor.
int is_different_accumulator_precision = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && a_datatype == CCV_16F) || ((cmd.info.blas.flags & CCV_NNC_GEMM_16F) && a_datatype == CCV_32F);
const int is_mfa_supported =
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || (!(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM) && !(cmd.info.blas.flags & CCV_NNC_DISABLE_MFA_GEMM)));
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && (!is_batched || is_mfa_compatible_batch) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (is_mfa_gemv || (!(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM) && !is_different_accumulator_precision));

size_t a_data_size = 0;
if (CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
.alpha = cmd.info.scaled_dot_product_attention.scale,
.batched = (attention_is_batched ? 1 : 0),
.masked = (attn_mask != NULL ? 1 : 0),
.upcast = cmd.info.scaled_dot_product_attention.upcast,
.upcast = (cmd.info.scaled_dot_product_attention.flags & CCV_NNC_GEMM_32F),

.batch_dims_q = { 0 },
.batch_dims_mask = { 0 },
Expand Down

0 comments on commit 6c30517

Please sign in to comment.