Commit 9897a3a 1 parent 07c00d1 commit 9897a3a Copy full SHA for 9897a3a
File tree 3 files changed +27
-6
lines changed
3 files changed +27
-6
lines changed Original file line number Diff line number Diff line change @@ -2515,12 +2515,18 @@ struct HashOutputIteratorDeref { // this is what you get when you dereference
2515
2515
2516
2516
template <typename T>
2517
2517
struct HashOutputIterator { // outputs just the index of the pair.
2518
- explicit HashOutputIterator (T *t) : t_(t) {}
2519
- __device__ __forceinline__ HashOutputIteratorDeref<T> operator [](
2518
+ explicit __host__ __device__ __forceinline__ HashOutputIterator (T *t)
2519
+ : t_(t) {}
2520
+ __host__ __device__ __forceinline__ HashOutputIteratorDeref<T> operator [](
2520
2521
int32_t idx) const {
2521
2522
return HashOutputIteratorDeref<T>(t_ + idx);
2522
2523
}
2523
- __device__ __forceinline__ HashOutputIterator operator +(size_t offset) {
2524
+ __host__ __device__ __forceinline__ HashOutputIteratorDeref<T> operator *()
2525
+ const {
2526
+ return HashOutputIteratorDeref<T>(t_);
2527
+ }
2528
+ __host__ __device__ __forceinline__ HashOutputIterator
2529
+ operator +(size_t offset) {
2524
2530
return HashOutputIterator{t_ + offset};
2525
2531
}
2526
2532
T *t_;
Original file line number Diff line number Diff line change @@ -578,12 +578,18 @@ struct PairOutputIteratorDeref { // this is what you get when you dereference
578
578
579
579
template <typename T>
580
580
struct PairOutputIterator { // outputs just the index of the pair.
581
- explicit PairOutputIterator (int32_t *i) : i_(i) {}
582
- __device__ __forceinline__ PairOutputIteratorDeref<T> operator [](
581
+ explicit __host__ __device__ __forceinline__ PairOutputIterator (int32_t *i)
582
+ : i_(i) {}
583
+ __host__ __device__ __forceinline__ PairOutputIteratorDeref<T> operator [](
583
584
int32_t idx) const {
584
585
return PairOutputIteratorDeref<T>(i_ + idx);
585
586
}
586
- __device__ __forceinline__ PairOutputIterator operator +(int32_t offset) {
587
+ __host__ __device__ __forceinline__ PairOutputIteratorDeref<T> operator *()
588
+ const {
589
+ return PairOutputIteratorDeref<T>(i_);
590
+ }
591
+ __host__ __device__ __forceinline__ PairOutputIterator
592
+ operator +(int32_t offset) {
587
593
return PairOutputIterator{i_ + offset};
588
594
}
589
595
int32_t *i_;
Original file line number Diff line number Diff line change 30
30
#include " k2/python/csrc/torch.h"
31
31
#include " torch/extension.h"
32
32
33
+ #if K2_TORCH_VERSION_MAJOR > 2 || \
34
+ (K2_TORCH_VERSION_MAJOR == 2 && K2_TORCH_VERSION_MINOR >= 4 )
35
+ // For torch >= 2.4.x
36
+ // do nothing to fix the following error
37
+ // error: class "pybind11::detail::type_caster<c10::ScalarType, void>" has
38
+ // already been defined
39
+ #else
40
+ // For torch < 2.4
33
41
namespace pybind11 {
34
42
namespace detail {
35
43
@@ -71,6 +79,7 @@ struct type_caster<torch::ScalarType> {
71
79
72
80
} // namespace detail
73
81
} // namespace pybind11
82
+ #endif
74
83
75
84
namespace k2 {
76
85
/* Transfer an object to a specific device.
You can’t perform that action at this time.
0 commit comments