Skip to content

Commit e429a3b

Browse files
swolchokpytorchmergebot
authored andcommitted
Move complex<Half> from Half.h to complex.h (pytorch#140565)
Executing on old TODO on the way to sharing Half.h with ExecuTorch. Differential Revision: [D65888037](https://our.internmc.facebook.com/intern/diff/D65888037/) Pull Request resolved: pytorch#140565 Approved by: https://github.com/ezyang, https://github.com/malfet ghstack dependencies: pytorch#140564
1 parent f630799 commit e429a3b

File tree

3 files changed

+51
-51
lines changed

3 files changed

+51
-51
lines changed

c10/util/Half.h

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <c10/macros/Export.h>
1313
#include <c10/macros/Macros.h>
1414
#include <c10/util/bit_cast.h>
15-
#include <c10/util/complex.h>
1615
#include <c10/util/floating_point_utils.h>
1716
#include <type_traits>
1817

@@ -385,56 +384,6 @@ struct alignas(2) Half {
385384
#endif
386385
};
387386

388-
// TODO : move to complex.h
389-
template <>
390-
struct alignas(4) complex<Half> {
391-
Half real_;
392-
Half imag_;
393-
394-
// Constructors
395-
complex() = default;
396-
// Half constructor is not constexpr so the following constructor can't
397-
// be constexpr
398-
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
399-
: real_(real), imag_(imag) {}
400-
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
401-
: real_(value.real()), imag_(value.imag()) {}
402-
403-
// Conversion operator
404-
inline C10_HOST_DEVICE operator c10::complex<float>() const {
405-
return {real_, imag_};
406-
}
407-
408-
constexpr C10_HOST_DEVICE Half real() const {
409-
return real_;
410-
}
411-
constexpr C10_HOST_DEVICE Half imag() const {
412-
return imag_;
413-
}
414-
415-
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
416-
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
417-
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
418-
return *this;
419-
}
420-
421-
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
422-
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
423-
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
424-
return *this;
425-
}
426-
427-
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
428-
auto a = static_cast<float>(real_);
429-
auto b = static_cast<float>(imag_);
430-
auto c = static_cast<float>(other.real());
431-
auto d = static_cast<float>(other.imag());
432-
real_ = a * c - b * d;
433-
imag_ = a * d + b * c;
434-
return *this;
435-
}
436-
};
437-
438387
C10_API inline std::ostream& operator<<(std::ostream& out, const Half& value) {
439388
out << (float)value;
440389
return out;

c10/util/complex.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <complex>
44

55
#include <c10/macros/Macros.h>
6+
#include <c10/util/Half.h>
67

78
#if defined(__CUDACC__) || defined(__HIPCC__)
89
#include <thrust/complex.h>
@@ -606,6 +607,55 @@ C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
606607
#endif
607608
}
608609

610+
template <>
611+
struct alignas(4) complex<Half> {
612+
Half real_;
613+
Half imag_;
614+
615+
// Constructors
616+
complex() = default;
617+
// Half constructor is not constexpr so the following constructor can't
618+
// be constexpr
619+
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
620+
: real_(real), imag_(imag) {}
621+
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
622+
: real_(value.real()), imag_(value.imag()) {}
623+
624+
// Conversion operator
625+
inline C10_HOST_DEVICE operator c10::complex<float>() const {
626+
return {real_, imag_};
627+
}
628+
629+
constexpr C10_HOST_DEVICE Half real() const {
630+
return real_;
631+
}
632+
constexpr C10_HOST_DEVICE Half imag() const {
633+
return imag_;
634+
}
635+
636+
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
637+
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
638+
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
639+
return *this;
640+
}
641+
642+
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
643+
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
644+
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
645+
return *this;
646+
}
647+
648+
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
649+
auto a = static_cast<float>(real_);
650+
auto b = static_cast<float>(imag_);
651+
auto c = static_cast<float>(other.real());
652+
auto d = static_cast<float>(other.imag());
653+
real_ = a * c - b * d;
654+
imag_ = a * d + b * c;
655+
return *this;
656+
}
657+
};
658+
609659
} // namespace c10
610660

611661
C10_CLANG_DIAGNOSTIC_POP()

torch/csrc/utils/byte_order.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <c10/util/BFloat16.h>
2+
#include <c10/util/complex.h>
23
#include <c10/util/irange.h>
34
#include <torch/csrc/utils/byte_order.h>
45

0 commit comments

Comments
 (0)