Skip to content

Commit e431139

Browse files
committed
Additional variables that require initialization in support of ROCm 6.2 enablement.
1 parent 5e5a784 commit e431139

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

fbgemm_gpu/codegen/embedding_common_code_generator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,8 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
10961096
10971097
at::acc_type<cache_t, true> adjusted_multiplier;
10981098
at::acc_type<cache_t, true> exp_reg_correction;
1099+
adjusted_multiplier = 0.0;
1100+
exp_reg_correction = 0.0;
10991101
11001102
if (threadIdx.x == 0) {
11011103
at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square;
@@ -1463,6 +1465,7 @@ def partial_rowwise_lamb() -> Dict[str, Any]:
14631465
warpReduceAllSum<at::acc_type<cache_t, true>, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask) / D;
14641466
14651467
at::acc_type<cache_t, true> m2;
1468+
m2 = 0.0;
14661469
if (threadIdx.x == 0) {
14671470
m2 = beta2 * momentum2[idx] + (1.0 - beta2) * g_avg_square;
14681471
momentum2[idx] = m2;
@@ -1609,6 +1612,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
16091612
warpReduceAllSum<at::acc_type<cache_t, true>, kThreadGroupSize>(g_local_sum_square) / D;
16101613
16111614
at::acc_type<cache_t, true> v_hat_t;
1615+
v_hat_t = 0.0;
16121616
if (threadIdx.x == 0) {
16131617
at::acc_type<cache_t, true> v_t = momentum2[idx] * beta2 + g_avg_square * (1.0 - beta2);
16141618
momentum2[idx] = v_t;

0 commit comments

Comments
 (0)