Skip to content

Commit 19e55ef

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
[XLA] Directly track callers and callees of an HloComputation.
Currently there's no direct way to navigate from an HloComputation to its callers and callees. To determine the callees, one must iterate through all of the instructions in a computation, which can be slow. And there is no way to navigate to the callers of a computation other than iterating over all the computations in a module. This change adds absl::btree_map<> data structures that allow one to navigate from an HloComputation to its callers and callees. For each neighbor, we keep a count of the number of references. PiperOrigin-RevId: 735584503
1 parent 2343740 commit 19e55ef

8 files changed

+247
-17
lines changed

xla/hlo/evaluator/hlo_evaluator.cc

+7-1
Original file line numberDiff line numberDiff line change
@@ -1036,13 +1036,19 @@ absl::StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
10361036

10371037
std::unique_ptr<HloInstruction> cloned_instruction =
10381038
instruction->CloneWithNewOperands(instruction->shape(), operands);
1039-
// TODO(phawkins): it's unfortunate that we need to call set_parent() here.
1039+
// TODO(phawkins): it's unfortunate that we need to call set_parent() here,
1040+
// since it violates the invariant that an instruction has a parent iff it is
1041+
// in a computation.
10401042
// It's probably better to avoid constructing new instructions here in the
10411043
// first place.
10421044
cloned_instruction->set_parent(
10431045
const_cast<HloComputation*>(instruction->parent()));
10441046
auto result = Evaluate(cloned_instruction.get());
10451047

1048+
// Undo the parent change, since it will confuse code that expects the
1049+
// instruction to be in a computation.
1050+
cloned_instruction->set_parent(nullptr);
1051+
10461052
return result;
10471053
}
10481054

xla/hlo/ir/hlo_computation.cc

+96-2
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ limitations under the License.
2929
#include <vector>
3030

3131
#include "absl/algorithm/container.h"
32+
#include "absl/container/btree_map.h"
3233
#include "absl/container/flat_hash_map.h"
3334
#include "absl/container/flat_hash_set.h"
3435
#include "absl/container/inlined_vector.h"
3536
#include "absl/functional/function_ref.h"
37+
#include "absl/log/check.h"
3638
#include "absl/memory/memory.h"
3739
#include "absl/status/status.h"
3840
#include "absl/strings/str_cat.h"
@@ -183,11 +185,38 @@ HloComputation::~HloComputation() {
183185
async_start_->ClearCalledComputations();
184186
}
185187
Cleanup();
188+
ClearCalledComputations();
189+
190+
// We need to make sure there are no dangling references to this computation
191+
// from instructions in other computations.
192+
std::vector<HloComputation*> callers;
193+
for (const auto& [caller, count] : caller_computations_) {
194+
callers.push_back(caller);
195+
}
196+
for (HloComputation* caller : callers) {
197+
for (HloInstruction* inst : caller->instructions()) {
198+
for (int i = 0; i < inst->called_computations().size(); ++i) {
199+
if (inst->called_computations()[i] == this) {
200+
inst->set_called_computation(i, nullptr);
201+
}
202+
}
203+
}
204+
}
205+
CHECK(caller_computations_.empty());
206+
186207
for (const auto& i : instructions_) {
187208
delete i.inst();
188209
}
189210
}
190211

212+
void HloComputation::ClearCalledComputations() {
213+
for (HloInstruction* i : instructions()) {
214+
i->ClearCalledComputations();
215+
}
216+
// Clearing the instructions should have removed all callee computations.
217+
CHECK(callee_computations_.empty());
218+
}
219+
191220
void HloComputation::SetInstruction(HloInstruction* instruction,
192221
InstructionType type) {
193222
static_assert(alignof(HloInstruction) == kInstructionTypeMask + 1,
@@ -241,6 +270,38 @@ HloInstruction* HloComputation::AddInstruction(
241270
return AddInstruction(std::move(instruction));
242271
}
243272

273+
static void IncrementCount(
274+
absl::btree_map<HloComputation*, int, HloComputation::UniqueIdComparator>&
275+
map,
276+
HloComputation* key) {
277+
++map[key];
278+
}
279+
280+
// Returns true if the callee was present and its count was decremented; returns
281+
// false if the callee was not present.
282+
static void DecrementCount(
283+
absl::btree_map<HloComputation*, int, HloComputation::UniqueIdComparator>&
284+
map,
285+
HloComputation* key) {
286+
auto it = map.find(key);
287+
CHECK(it != map.end());
288+
CHECK_GT(it->second, 0);
289+
--it->second;
290+
if (it->second == 0) {
291+
map.erase(it);
292+
}
293+
}
294+
295+
void HloComputation::AddCallee(HloComputation* callee) {
296+
IncrementCount(callee_computations_, callee);
297+
IncrementCount(callee->caller_computations_, this);
298+
}
299+
300+
void HloComputation::RemoveCallee(HloComputation* callee) {
301+
DecrementCount(callee_computations_, callee);
302+
DecrementCount(callee->caller_computations_, this);
303+
}
304+
244305
HloInstruction* HloComputation::AddInstructionInternal(
245306
std::unique_ptr<HloInstruction> instruction) {
246307
if (parent() != nullptr) {
@@ -265,6 +326,7 @@ HloInstruction* HloComputation::AddInstructionInternal(
265326
CHECK(parent() == nullptr || called_computation->parent() == parent())
266327
<< "Called computation " << called_computation->name()
267328
<< " is not in the same module as " << name();
329+
AddCallee(called_computation);
268330
}
269331
return pinst;
270332
}
@@ -521,13 +583,13 @@ absl::Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
521583

522584
HloInstructionInfo* info = &instructions_[instruction->index_in_parent_];
523585
DCHECK_EQ(info->inst(), instruction);
524-
info->inst()->set_parent(nullptr);
525586
to_be_deleted_.push_back(info->inst()); // Takes ownership
526587
to_be_deleted_.back()->DetachFromOperandsAndUsers();
527588
// Clear all operands to avoid Null operands.
528589
to_be_deleted_.back()->RemoveAllOperands();
529590
to_be_deleted_.back()->ClearCalledComputations();
530591
to_be_deleted_.back()->MarkAsDead();
592+
info->inst()->set_parent(nullptr);
531593

532594
// If this instruction is a constant, clear the literal eagerly instead of
533595
// waiting for the instruction to be deleted in Cleanup(). This greatly
@@ -1089,7 +1151,7 @@ HloComputation::CreateFromProto(
10891151

10901152
auto computation = absl::WrapUnique(
10911153
new HloComputation(proto.name(), parameter_count, &instructions, root));
1092-
computation->unique_id_ = proto.id();
1154+
computation->SetUniqueIdHelper(proto.id());
10931155
if (proto.is_fusion_computation()) {
10941156
computation->instruction_and_type_ =
10951157
static_cast<uintptr_t>(InstructionType::kFusion);
@@ -1840,4 +1902,36 @@ bool HloComputation::CanExpandIntoSingleInstruction() const {
18401902
});
18411903
}
18421904

1905+
void HloComputation::ClearUniqueIdInternal() { SetUniqueIdHelper(-1); }
1906+
1907+
void HloComputation::SetUniqueId(int64_t id) {
1908+
CHECK_EQ(unique_id_, -1);
1909+
CHECK_GE(id, 0);
1910+
SetUniqueIdHelper(id);
1911+
}
1912+
1913+
void HloComputation::SetUniqueIdHelper(int64_t id) {
1914+
// The caller/callee computations are ordered by unique ID, so we need to
1915+
// remove and readd them to our neighbor's data structures.
1916+
for (auto& [computation, count] : caller_computations_) {
1917+
auto it = computation->callee_computations_.find(this);
1918+
CHECK(it != computation->callee_computations_.end());
1919+
CHECK_EQ(it->second, count);
1920+
computation->callee_computations_.erase(it);
1921+
}
1922+
for (auto& [computation, count] : callee_computations_) {
1923+
auto it = computation->caller_computations_.find(this);
1924+
CHECK(it != computation->caller_computations_.end());
1925+
CHECK_EQ(it->second, count);
1926+
computation->caller_computations_.erase(it);
1927+
}
1928+
unique_id_ = id;
1929+
for (auto& [computation, count] : caller_computations_) {
1930+
computation->callee_computations_[this] = count;
1931+
}
1932+
for (auto& [computation, count] : callee_computations_) {
1933+
computation->caller_computations_[this] = count;
1934+
}
1935+
}
1936+
18431937
} // namespace xla

xla/hlo/ir/hlo_computation.h

+63-6
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ limitations under the License.
2121
#include <memory>
2222
#include <optional>
2323
#include <string>
24+
#include <tuple>
2425
#include <utility>
2526
#include <vector>
2627

2728
#include "absl/algorithm/container.h"
29+
#include "absl/container/btree_map.h"
2830
#include "absl/container/flat_hash_map.h"
2931
#include "absl/container/flat_hash_set.h"
3032
#include "absl/container/inlined_vector.h"
@@ -917,14 +919,10 @@ class HloComputation {
917919

918920
// Clear the unique ID of the computation so that it can be re-assigned, such
919921
// as for the purpose of compacting the unique IDs.
920-
void ClearUniqueIdInternal() { unique_id_ = -1; }
922+
void ClearUniqueIdInternal();
921923

922924
// The id of this computation should be unique within the module.
923-
void SetUniqueId(int64_t id) {
924-
CHECK_EQ(unique_id_, -1);
925-
CHECK_GE(id, 0);
926-
unique_id_ = id;
927-
}
925+
void SetUniqueId(int64_t id);
928926

929927
// Returns the instruction in this computation that has name `name`. Returns
930928
// null if there is no such computation.
@@ -957,6 +955,34 @@ class HloComputation {
957955
// Returns true iff this computation can be inlined as a single instruction.
958956
bool CanExpandIntoSingleInstruction() const;
959957

958+
// A comparator that orders computations by their unique IDs. This is used
959+
// for determinism.
960+
struct UniqueIdComparator {
961+
bool operator()(const HloComputation* lhs,
962+
const HloComputation* rhs) const {
963+
// We include the computation pointer so that we can disambiguate
964+
// computations that do not belong to any module and therefore have a
965+
// unique ID of -1. This is not deterministic, but we don't need
966+
// determinism for computations not in a module since they are ignored
967+
// by the topological sorting code.
968+
return std::tie(lhs->unique_id_, lhs) < std::tie(rhs->unique_id_, rhs);
969+
}
970+
};
971+
972+
// Count of times this computation calls other computations.
973+
absl::btree_map<HloComputation*, int, UniqueIdComparator>
974+
callee_computations() const {
975+
return callee_computations_;
976+
}
977+
978+
// Count of times this computation is called by other computations.
979+
absl::btree_map<HloComputation*, int, UniqueIdComparator>
980+
caller_computations() const {
981+
return caller_computations_;
982+
}
983+
984+
void ClearCalledComputations();
985+
960986
private:
961987
friend class HloModule;
962988

@@ -1018,6 +1044,18 @@ class HloComputation {
10181044
// set the parent of a computation is to add it to a module.
10191045
void set_parent(HloModule* module) { parent_ = module; }
10201046

1047+
// Helper that updates the unique ID of the computation. This requires
1048+
// updating the callee_computations_ and caller_computations_ sets since they
1049+
// are ordered by unique ID.
1050+
void SetUniqueIdHelper(int64_t id);
1051+
1052+
friend class HloInstruction;
1053+
void AddCallee(HloComputation* callee);
1054+
void RemoveCallee(HloComputation* callee);
1055+
1056+
// Unique ID of this computation.
1057+
// This is set to -1 if the computation is not in a module. Should only be
1058+
// updated by SetUniqueIdHelper().
10211059
int64_t unique_id_;
10221060
HloInstruction* root_instruction_;
10231061

@@ -1056,6 +1094,25 @@ class HloComputation {
10561094

10571095
std::string name_;
10581096

1097+
// Callers and callees of this computation.
1098+
// * These include all computations that have a caller/callee relationship
1099+
// with this computation, even those that may not belong to a module. For
1100+
// example, a computation that has been created and is in the process of
1101+
// being constructed but has not been added to a module yet may appear here.
1102+
// * These are ordered maps, ordered by (unique ID, computation pointer). The
1103+
// unique ID is used to ensure determinism, whereas the computation pointer
1104+
// is used to disambiguate computations that do not belong to any module and
1105+
// therefore have a unique ID of -1. We assume that determinism only matters
1106+
// for computations that belong to a module (i.e, unique_id != -1), since
1107+
// the primary use case for this data structure is to topologically sort
1108+
// computations in a module.
1109+
// * The values of the maps are the number of times the computation is
1110+
// referenced. In a graph sense, this is the number of parallel edges.
1111+
absl::btree_map<HloComputation*, int, UniqueIdComparator>
1112+
callee_computations_;
1113+
absl::btree_map<HloComputation*, int, UniqueIdComparator>
1114+
caller_computations_;
1115+
10591116
HloComputation(const HloComputation&) = delete;
10601117
HloComputation& operator=(const HloComputation&) = delete;
10611118
};

xla/hlo/ir/hlo_instruction.cc

+27-2
Original file line numberDiff line numberDiff line change
@@ -219,16 +219,28 @@ void HloInstruction::AppendComputation(HloComputation* computation) {
219219
// In .cc file since PtrVec<T*>::push_back() wants to check the alignment
220220
// of T and hlo_instruction.h does not include hlo_computation.h.
221221
mutable_rare()->called_computations.push_back(computation);
222+
if (parent()) {
223+
parent()->AddCallee(computation);
224+
}
222225
}
223226

224227
void HloInstruction::set_called_computation(int index,
225228
HloComputation* computation) {
226-
mutable_rare()->called_computations[index] = computation;
227229
// TODO(b/399394039): Consider also enforcing that computation->parent() !=
228230
// nullptr.
229231
CHECK(parent() == nullptr || parent()->parent() == nullptr ||
230-
parent()->parent() == computation->parent())
232+
computation == nullptr || parent()->parent() == computation->parent())
231233
<< ToString();
234+
HloComputation* old_computation = computation;
235+
std::swap(old_computation, mutable_rare()->called_computations[index]);
236+
if (parent()) {
237+
if (old_computation) {
238+
parent()->RemoveCallee(old_computation);
239+
}
240+
if (computation) {
241+
parent()->AddCallee(computation);
242+
}
243+
}
232244
}
233245

234246
void HloInstruction::ReplaceCalledComputations(
@@ -238,6 +250,19 @@ void HloInstruction::ReplaceCalledComputations(
238250
}
239251
}
240252

253+
void HloInstruction::ClearCalledComputations() {
254+
if (has_rare()) {
255+
if (parent()) {
256+
for (HloComputation* computation : called_computations()) {
257+
if (computation) {
258+
parent()->RemoveCallee(computation);
259+
}
260+
}
261+
}
262+
mutable_rare()->called_computations.clear();
263+
}
264+
}
265+
241266
HloInstruction* HloInstruction::AddInstruction(
242267
std::unique_ptr<HloInstruction> derived_instruction) {
243268
HloInstruction* derived =

xla/hlo/ir/hlo_instruction.h

+1-5
Original file line numberDiff line numberDiff line change
@@ -1742,11 +1742,7 @@ class HloInstruction {
17421742
// clearing out the computations, we reflect the fact that all side-effecting
17431743
// properties have been reflected in the caller, and make the call HLO
17441744
// removable.
1745-
virtual void ClearCalledComputations() {
1746-
if (has_rare()) {
1747-
mutable_rare()->called_computations.clear();
1748-
}
1749-
}
1745+
virtual void ClearCalledComputations();
17501746

17511747
// Returns true if this instruction performs an elementwise operation on
17521748
// `operand_idx`-th operand. An instruction is elementwise on an operand iff,

xla/hlo/ir/hlo_module.cc

+8
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,14 @@ HloModule::HloModule(const std::string& name,
9090
metadata_.set_canonical_module_id(unique_id_);
9191
}
9292

93+
HloModule::~HloModule() {
94+
// To avoid dangling references between computations, we first clear all the
95+
// inter-computation references before deleting any of the computations.
96+
for (const auto& computation : computations_) {
97+
computation->ClearCalledComputations();
98+
}
99+
}
100+
93101
absl::Status HloModule::set_schedule(HloSchedule schedule) {
94102
TF_RET_CHECK(schedule.module() == this);
95103
TF_RETURN_IF_ERROR(schedule.Verify());

xla/hlo/ir/hlo_module.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class HloModule {
8888
HloModule(const std::string& name,
8989
std::shared_ptr<const HloModuleConfig> config,
9090
std::unique_ptr<CompilationEnvironments> comp_envs);
91-
virtual ~HloModule() = default;
91+
virtual ~HloModule();
9292

9393
// Adds an entry computation to the module. A module can only have one entry
9494
// computation. Returns a pointer to the newly added computation.

0 commit comments

Comments
 (0)