Skip to content

Commit e9ddd15

Browse files
add AD support for SampledSpectrum
1 parent acdb3f8 commit e9ddd15

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/util/spec.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class SampledSpectrum {
5555
private:
5656
Local<float> _samples;
5757

58+
private:
59+
explicit SampledSpectrum(Local<float> &&samples) noexcept
60+
: _samples{std::move(samples)} {}
61+
5862
public:
5963
SampledSpectrum(uint n, Expr<float> value) noexcept : _samples{n} {
6064
compute::outline([&] {
@@ -84,6 +88,12 @@ class SampledSpectrum {
8488
}
8589
[[nodiscard]] Local<float> &values() noexcept { return _samples; }
8690
[[nodiscard]] const Local<float> &values() const noexcept { return _samples; }
91+
92+
void requires_grad() const noexcept { _samples.requires_grad(); }
93+
void backward() const noexcept { _samples.backward(); }
94+
void backward(const SampledSpectrum &grad) const noexcept { _samples.backward(grad._samples); }
95+
[[nodiscard]] auto grad() const noexcept { return SampledSpectrum{_samples.grad()}; }
96+
8797
[[nodiscard]] Float &operator[](Expr<uint> i) noexcept {
8898
return dimension() == 1u ? _samples[0u] : _samples[i];
8999
}
@@ -179,7 +189,7 @@ class SampledSpectrum {
179189
}); \
180190
return *this; \
181191
} \
182-
auto &operator op##=(const SampledSpectrum &rhs) noexcept { \
192+
auto &operator op##=(const SampledSpectrum & rhs) noexcept { \
183193
LUISA_ASSERT(rhs.dimension() == 1u || dimension() == rhs.dimension(), \
184194
"Invalid sampled spectrum dimension for operator" #op "=: {} vs {}.", \
185195
dimension(), rhs.dimension()); \

0 commit comments

Comments
 (0)