Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/common/linalg_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ void VecScaDiv(Context const* ctx, linalg::VectorView<float> x, double div) {
}

template <auto _tag = detail::SysTag()>
void LogE(Context const* ctx, linalg::VectorView<float> x) {
void LogE(Context const* ctx, linalg::VectorView<float> x, float rt_eps = 0.0f) {
CHECK_EQ(x.Device().ordinal, ctx->Device().ordinal);
TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return log(v); });
TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return log(v + rt_eps); });
}

template <typename T, std::enable_if_t<std::is_floating_point_v<T>>* = nullptr>
Expand Down
7 changes: 3 additions & 4 deletions src/common/stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,14 @@ void Median(Context const* ctx, linalg::Matrix<float> const& t,
}
}

void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out) {
v.SetDevice(ctx->Device());
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::Vector<float>* out) {
out->SetDevice(ctx->Device());
out->Reshape(1);

if (ctx->IsCUDA()) {
cuda_impl::Mean(ctx, v.View(ctx->Device()), out->View(ctx->Device()));
cuda_impl::Mean(ctx, v, out->View(ctx->Device()));
} else {
auto h_v = v.HostView();
auto h_v = v;
float n = v.Size();
MemStackAllocator<float, DefaultMaxThreads()> tloc(ctx->Threads(), 0.0f);
ParallelFor(v.Size(), ctx->Threads(),
Expand Down
2 changes: 1 addition & 1 deletion src/common/stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void Median(Context const* ctx, linalg::Matrix<float> const& t,
/**
* @brief Calculate the mean value of a vector.
*/
void Mean(Context const* ctx, linalg::Vector<float> const& v, linalg::Vector<float>* out);
void Mean(Context const* ctx, linalg::VectorView<float const> v, linalg::Vector<float>* out);

/**
* @brief Calculate the mean value for the first axis.
Expand Down
12 changes: 11 additions & 1 deletion src/objective/multiclass_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "../common/linalg_op.h"
#include "../common/math.h"
#include "../common/optional_weight.h" // for MakeOptionalWeights
#include "../common/stats.h" // for Mean
#include "../common/transform.h"
#include "xgboost/data.h"
#include "xgboost/json.h"
Expand Down Expand Up @@ -197,7 +198,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
*base_score = linalg::Zeros<float>(this->ctx_, n_classes);

std::size_t n = info.labels.Size();

// Calculate probability
auto labels = info.labels.View(ctx_->Device());
auto weights = common::MakeOptionalWeights(this->ctx_->Device(), info.weights_);
auto intercept = base_score->View(ctx_->Device());
Expand All @@ -209,6 +210,15 @@ class SoftmaxMultiClassObj : public ObjFunction {
collective::SafeColl(status);
CHECK_GE(sum_weight, kRtEps);
linalg::VecScaDiv(this->ctx_, intercept, sum_weight);
CHECK_EQ(base_score->Size(), n_classes);

// Transform it back to margin
// ln(v) - E[ln(v)]
linalg::Vector<float> mean;
linalg::LogE(this->ctx_, intercept, kRtEps);
common::Mean(this->ctx_, intercept, &mean);
auto d_mean = mean.View(this->ctx_->Device());
TransformKernel(this->ctx_, intercept, [=] XGBOOST_DEVICE(float v) { return v - d_mean(0); });
}

private:
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/common/test_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ void TestMean(Context const* ctx) {
float mean = nf * (nf - 1) / 2 / n;

linalg::Vector<float> res{{1}, ctx->Device()};
Mean(ctx, data, &res);
Mean(ctx, data.View(ctx->Device()), &res);
auto h_res = res.HostView();
ASSERT_EQ(h_res.Size(), 1);
ASSERT_EQ(mean, h_res(0));
Expand Down
Loading