@@ -538,3 +538,111 @@ std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(
538
538
539
539
return {ssd_row_addrs, post_bwd_evicted_indices};
540
540
}
541
+
542
+ __global__ __launch_bounds__ (kMaxThreads ) void ssd_update_row_addrs_kernel(
543
+ at::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits>
544
+ ssd_row_addrs_curr,
545
+ const at::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits>
546
+ ssd_curr_next_map,
547
+ const at::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits>
548
+ lxu_cache_locations_curr,
549
+ const at::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits>
550
+ linear_index_inverse_indices_curr,
551
+ const at::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits>
552
+ unique_indices_count_cumsum_curr,
553
+ const at::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits>
554
+ cache_set_inverse_indices_curr,
555
+ const uint64_t lxu_cache_weights_addr,
556
+ const uint64_t inserted_ssd_weights_addr_next,
557
+ const int * N_unique_curr,
558
+ const uint64_t cache_row_bytes // has to be 64 bits to current overflow
559
+ ) {
560
+ const auto n_curr = blockDim .y * blockIdx .x + threadIdx .y ;
561
+ if (n_curr >= *N_unique_curr) {
562
+ return ;
563
+ }
564
+
565
+ // Find mapping between n_curr and n_next
566
+ const auto n_next = ssd_curr_next_map[n_curr];
567
+ // Return if the row is not used in both currious and nextent iterations
568
+ if (n_next < 0 ) {
569
+ return ;
570
+ }
571
+
572
+ // Find out if the row gets moved to the nextent iteration's scratch pad or
573
+ // L1 by checking the lxu_cache_locations_next
574
+ const auto cache_set_id_curr = cache_set_inverse_indices_curr[n_curr];
575
+ const auto segment_start_curr =
576
+ unique_indices_count_cumsum_curr[cache_set_id_curr];
577
+ const auto segment_end_curr =
578
+ unique_indices_count_cumsum_curr[cache_set_id_curr + 1 ];
579
+ const auto cache_loc_curr = lxu_cache_locations_curr
580
+ [linear_index_inverse_indices_curr[segment_start_curr]];
581
+
582
+ const uint64_t ptr_addr = (cache_loc_curr == -1 )
583
+ // The row is moved from the currious iteration's scratch pad to the
584
+ // nextent iteration's scratch pad
585
+ ? (inserted_ssd_weights_addr_next + (n_next * cache_row_bytes))
586
+ // The row is moved from the currious iteration's scratch pad to L1 cache
587
+ : (lxu_cache_weights_addr + (cache_loc_curr * cache_row_bytes));
588
+
589
+ // Set pointer address
590
+ for (auto l = segment_start_curr + threadIdx .x ; l < segment_end_curr;
591
+ l += blockDim .x ) {
592
+ auto dst = linear_index_inverse_indices_curr[l];
593
+ *reinterpret_cast <uint64_t *>(&ssd_row_addrs_curr[dst]) = ptr_addr;
594
+ }
595
+ }
596
+
597
+ void ssd_update_row_addrs_cuda (
598
+ const Tensor& ssd_row_addrs_curr,
599
+ const Tensor& ssd_curr_next_map,
600
+ const Tensor& lxu_cache_locations_curr,
601
+ const Tensor& linear_index_inverse_indices_curr,
602
+ const Tensor& unique_indices_count_cumsum_curr,
603
+ const Tensor& cache_set_inverse_indices_curr,
604
+ const Tensor& lxu_cache_weights,
605
+ const Tensor& inserted_ssd_weights_next,
606
+ const Tensor& unique_indices_length_curr) {
607
+ TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL (
608
+ ssd_row_addrs_curr,
609
+ ssd_curr_next_map,
610
+ lxu_cache_locations_curr,
611
+ linear_index_inverse_indices_curr,
612
+ unique_indices_count_cumsum_curr,
613
+ cache_set_inverse_indices_curr,
614
+ lxu_cache_weights,
615
+ inserted_ssd_weights_next,
616
+ unique_indices_length_curr);
617
+
618
+ CUDA_DEVICE_GUARD (ssd_row_addrs_curr);
619
+
620
+ const auto lxu_cache_weights_addr =
621
+ reinterpret_cast <uint64_t >(lxu_cache_weights.data_ptr ());
622
+ const auto inserted_ssd_weights_addr_next =
623
+ reinterpret_cast <uint64_t >(inserted_ssd_weights_next.data_ptr ());
624
+ const auto cache_row_bytes =
625
+ lxu_cache_weights.size (1 ) * lxu_cache_weights.element_size ();
626
+ constexpr auto kNumWarps = kMaxThreads / kWarpSize ;
627
+
628
+ ssd_update_row_addrs_kernel<<<
629
+ div_round_up (ssd_row_addrs_curr.numel(), kNumWarps),
630
+ dim3(kWarpSize , kNumWarps ),
631
+ 0,
632
+ at::cuda::getCurrentCUDAStream()>>>(
633
+ ssd_row_addrs_curr.packed_accessor32<int64_t , 1 , at::RestrictPtrTraits>(),
634
+ ssd_curr_next_map.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
635
+ lxu_cache_locations_curr
636
+ .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
637
+ linear_index_inverse_indices_curr
638
+ .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
639
+ unique_indices_count_cumsum_curr
640
+ .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
641
+ cache_set_inverse_indices_curr
642
+ .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
643
+ lxu_cache_weights_addr,
644
+ inserted_ssd_weights_addr_next,
645
+ unique_indices_length_curr.data_ptr<int32_t>(),
646
+ cache_row_bytes);
647
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
648
+ }
0 commit comments