Skip to content

Commit

Permalink
Fix tsan bug by avoiding string references after done callbacks have …
Browse files Browse the repository at this point in the history
…been invoked.

Tsan bug: Done callback completes the RPC and destroys the request, releasing the underlying string (barrier_id).

Fun fact: We already had a comment in coordination_service.cc that explicitly say to not reference `barrier_id`, but I did it anyway.

Reverts 9136df1

PiperOrigin-RevId: 686605801
  • Loading branch information
Google-ML-Automation committed Oct 16, 2024
1 parent 8f76467 commit d5ff7c1
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 37 deletions.
44 changes: 24 additions & 20 deletions xla/tsl/distributed_runtime/coordination/coordination_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,16 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {
absl::Status error) override;
std::vector<CoordinatedTaskStateInfo> GetTaskState(
const std::vector<CoordinatedTask>& task) override;
absl::Status InsertKeyValue(std::string key, std::string value) override;
absl::Status InsertKeyValue(std::string key, std::string value,
absl::Status InsertKeyValue(std::string_view key,
std::string_view value) override;
absl::Status InsertKeyValue(std::string_view key, std::string_view value,
bool allow_overwrite) override;
void GetKeyValueAsync(std::string key, StatusOrValueCallback done) override;
absl::StatusOr<std::string> TryGetKeyValue(std::string key) override;
std::vector<KeyValueEntry> GetKeyValueDir(std::string directory_key) override;
absl::Status DeleteKeyValue(std::string key) override;
void GetKeyValueAsync(std::string_view key,
StatusOrValueCallback done) override;
absl::StatusOr<std::string> TryGetKeyValue(std::string_view key) override;
std::vector<KeyValueEntry> GetKeyValueDir(
std::string_view directory_key) override;
absl::Status DeleteKeyValue(std::string_view key) override;
void BarrierAsync(std::string barrier_id, absl::Duration timeout,
const CoordinatedTask& task,
const std::vector<CoordinatedTask>& participating_tasks,
Expand Down Expand Up @@ -989,12 +992,12 @@ std::string NormalizeKey(std::string_view orig_key) {
}

absl::Status CoordinationServiceStandaloneImpl::InsertKeyValue(
std::string key, std::string value) {
std::string_view key, std::string_view value) {
return InsertKeyValue(key, value, /*allow_overwrite=*/false);
}

absl::Status CoordinationServiceStandaloneImpl::InsertKeyValue(
std::string key, std::string value, bool allow_overwrite) {
std::string_view key, std::string_view value, bool allow_overwrite) {
VLOG(3) << "InsertKeyValue(): " << key << ": " << value
<< " allow_overwrite: " << allow_overwrite;
const std::string norm_key = NormalizeKey(key);
Expand All @@ -1015,7 +1018,7 @@ absl::Status CoordinationServiceStandaloneImpl::InsertKeyValue(
}

void CoordinationServiceStandaloneImpl::GetKeyValueAsync(
std::string key, StatusOrValueCallback done) {
std::string_view key, StatusOrValueCallback done) {
VLOG(3) << "GetKeyValue(): " << key;
const std::string norm_key = NormalizeKey(key);
absl::MutexLock l(&kv_mu_);
Expand All @@ -1033,7 +1036,7 @@ void CoordinationServiceStandaloneImpl::GetKeyValueAsync(
}

absl::StatusOr<std::string> CoordinationServiceStandaloneImpl::TryGetKeyValue(
std::string key) {
std::string_view key) {
VLOG(3) << "TryGetKeyValue(): " << key;
const std::string norm_key = NormalizeKey(key);
absl::MutexLock l(&kv_mu_);
Expand All @@ -1045,7 +1048,7 @@ absl::StatusOr<std::string> CoordinationServiceStandaloneImpl::TryGetKeyValue(
}

std::vector<KeyValueEntry> CoordinationServiceStandaloneImpl::GetKeyValueDir(
std::string directory_key) {
std::string_view directory_key) {
VLOG(3) << "TryGetKeyValueDir(): " << directory_key;
std::vector<KeyValueEntry> kvs_in_directory;
const std::string norm_key = NormalizeKey(directory_key);
Expand Down Expand Up @@ -1073,7 +1076,7 @@ std::vector<KeyValueEntry> CoordinationServiceStandaloneImpl::GetKeyValueDir(
}

absl::Status CoordinationServiceStandaloneImpl::DeleteKeyValue(
std::string key) {
std::string_view key) {
VLOG(3) << "DeleteKeyValue(): " << key;
const std::string norm_key = NormalizeKey(key);
absl::MutexLock l(&kv_mu_);
Expand Down Expand Up @@ -1283,6 +1286,9 @@ bool CoordinationServiceStandaloneImpl::InitializeBarrier(
}

void CoordinationServiceStandaloneImpl::BarrierAsync(
// Note: `barrier_id` uses a `std::string` instead of `string_view` as the
// RPC may end (i.e. done callback is invoked) before this handler
// completes, which would invalidate the `string_view`.
std::string barrier_id, absl::Duration timeout, const CoordinatedTask& task,
const std::vector<CoordinatedTask>& participating_tasks,
StatusCallback done) {
Expand Down Expand Up @@ -1360,6 +1366,9 @@ void CoordinationServiceStandaloneImpl::BarrierAsync(
}

absl::Status CoordinationServiceStandaloneImpl::CancelBarrier(
// Note: `barrier_id` uses a `std::string` instead of `string_view` as the
// RPC may end (i.e. done callback is invoked) before this handler
// completes, which would invalidate the `string_view`.
std::string barrier_id, const CoordinatedTask& task) {
absl::MutexLock l(&state_mu_);
if (ServiceHasStopped()) {
Expand Down Expand Up @@ -1408,23 +1417,18 @@ void CoordinationServiceStandaloneImpl::PassBarrier(std::string_view barrier_id,
}
barrier->tasks_at_barrier.clear();
ongoing_barriers_.erase(barrier_id);
// Note: barrier_id shouldn't be referenced after this line as its lifetime
// may be tied to one of the callbacks.
// Propagate results to participating tasks.
for (const auto& callback : barrier->done_callbacks) {
callback(result);
}
barrier->done_callbacks.clear();
// Special hook for shutdown barrier to disconnect tasks at the barrier and
// propagate errors to those that have not.
if (barrier_id == shutdown_barrier_id_) {
CompleteShutdownAfterBarrier(result, barrier);
// Exit early if service has stopped due to `CompleteShutdownAfterBarrier()`
// . This prevents any illegal memory access into erased state.
if (ServiceHasStopped()) {
return;
}
// Note: this may stop the service. Be careful about referencing barrier
// state after this point.
}
barrier->done_callbacks.clear();
}

void CoordinationServiceStandaloneImpl::SendErrorPollingResponse(
Expand Down
18 changes: 9 additions & 9 deletions xla/tsl/distributed_runtime/coordination/coordination_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class CoordinationServiceInterface {
std::unique_ptr<CoordinationClientCache> cache)>;

using StatusOrValueCallback =
std::function<void(const absl::StatusOr<std::string>&)>;
std::function<void(const absl::StatusOr<std::string_view>&)>;

virtual ~CoordinationServiceInterface() = default;

Expand Down Expand Up @@ -168,29 +168,31 @@ class CoordinationServiceInterface {
// Insert a configuration key-value in the coordination service.
// For now, a key-value can only be inserted once and cannot be updated.
// The key-values are not persisted and will be lost if the leader fails.
virtual absl::Status InsertKeyValue(std::string key, std::string value) = 0;
virtual absl::Status InsertKeyValue(std::string key, std::string value,
virtual absl::Status InsertKeyValue(std::string_view key,
std::string_view value) = 0;
virtual absl::Status InsertKeyValue(std::string_view key,
std::string_view value,
bool allow_overwrite) = 0;

// Get a configuration key-value from the coordination service. The `done`
// callback is invoked when the key-value becomes available.
virtual void GetKeyValueAsync(std::string key,
virtual void GetKeyValueAsync(std::string_view key,
StatusOrValueCallback done) = 0;

// Get a configuration key-value from the coordination service. If the key
// does not exist, return NotFound error.
virtual absl::StatusOr<std::string> TryGetKeyValue(std::string key) = 0;
virtual absl::StatusOr<std::string> TryGetKeyValue(std::string_view key) = 0;

// Gets all values under a directory (key).
// A value is considered to be in the directory if its key is prefixed with
// the directory. This is not a blocking call. Agent does not need to be
// connected to utilize the distributed key-value store.
virtual std::vector<tensorflow::KeyValueEntry> GetKeyValueDir(
std::string directory_key) = 0;
std::string_view directory_key) = 0;

// Delete configuration key-value. If key is a directory, recursively clean
// up all key-values under the directory.
virtual absl::Status DeleteKeyValue(std::string key) = 0;
virtual absl::Status DeleteKeyValue(std::string_view key) = 0;

// Blocks until all (or a subset of) tasks are at the barrier or the barrier
// fails.
Expand Down Expand Up @@ -223,8 +225,6 @@ class CoordinationServiceInterface {
// list of participating tasks.
// - FailedPrecondition: Agent is in UNINITIALIZED or ERROR state.
virtual void BarrierAsync(
// TODO: b/369222279 - Investigate data race and revert to `string_view`
// for all APIs.
std::string barrier_id, absl::Duration timeout,
const tensorflow::CoordinatedTask& task,
const std::vector<tensorflow::CoordinatedTask>& participating_tasks,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.

#include <cstdint>
#include <iterator>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -214,8 +213,9 @@ void CoordinationServiceRpcHandler::GetKeyValueAsync(
}
response->mutable_kv()->set_key(request->key());
service_->GetKeyValueAsync(
request->key(), [response, done = std::move(done)](
const absl::StatusOr<std::string>& status_or_value) {
request->key(),
[response, done = std::move(done)](
const absl::StatusOr<std::string_view>& status_or_value) {
if (status_or_value.ok()) {
auto value = status_or_value.value();
response->mutable_kv()->set_value(value.data(), value.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,9 +651,9 @@ TEST_F(CoordinateTwoTasksTest, TestSetGetValues) {

// Get simple key
absl::Notification n1;
absl::StatusOr<std::string> ret;
absl::StatusOr<std::string_view> ret;
coord_service_->GetKeyValueAsync(
"key0", [&](const absl::StatusOr<std::string>& status_or_value) {
"key0", [&](const absl::StatusOr<std::string_view>& status_or_value) {
ret = status_or_value;
n1.Notify();
});
Expand All @@ -664,7 +664,7 @@ TEST_F(CoordinateTwoTasksTest, TestSetGetValues) {
absl::Notification n2;
coord_service_->GetKeyValueAsync(
"path//to///key1////",
[&](const absl::StatusOr<std::string>& status_or_value) {
[&](const absl::StatusOr<std::string_view>& status_or_value) {
ret = status_or_value;
n2.Notify();
});
Expand All @@ -676,7 +676,7 @@ TEST_F(CoordinateTwoTasksTest, TestSetGetValues) {
// Get key that is not available
absl::Notification n3;
coord_service_->GetKeyValueAsync(
"key0", [&](const absl::StatusOr<std::string>& status_or_value) {
"key0", [&](const absl::StatusOr<std::string_view>& status_or_value) {
ret = status_or_value;
n3.Notify();
});
Expand All @@ -696,7 +696,7 @@ TEST_F(CoordinateTwoTasksTest, TestSetGetValues) {
// service shutdown. Hence, we use a shared pointer for notification so
// that the it will not be deallocated before the pending callback is
// cleaned up.
[n4](const absl::StatusOr<std::string>& status_or_value) {
[n4](const absl::StatusOr<std::string_view>& status_or_value) {
n4->Notify();
});
EXPECT_FALSE(n4->HasBeenNotified());
Expand Down

0 comments on commit d5ff7c1

Please sign in to comment.