@@ -29,10 +29,12 @@ limitations under the License.
29
29
#include < vector>
30
30
31
31
#include " absl/algorithm/container.h"
32
+ #include " absl/container/btree_map.h"
32
33
#include " absl/container/flat_hash_map.h"
33
34
#include " absl/container/flat_hash_set.h"
34
35
#include " absl/container/inlined_vector.h"
35
36
#include " absl/functional/function_ref.h"
37
+ #include " absl/log/check.h"
36
38
#include " absl/memory/memory.h"
37
39
#include " absl/status/status.h"
38
40
#include " absl/strings/str_cat.h"
@@ -183,11 +185,38 @@ HloComputation::~HloComputation() {
183
185
async_start_->ClearCalledComputations ();
184
186
}
185
187
Cleanup ();
188
+ ClearCalledComputations ();
189
+
190
+ // We need to make sure there are no dangling references to this computation
191
+ // from instructions in other computations.
192
+ std::vector<HloComputation*> callers;
193
+ for (const auto & [caller, count] : caller_computations_) {
194
+ callers.push_back (caller);
195
+ }
196
+ for (HloComputation* caller : callers) {
197
+ for (HloInstruction* inst : caller->instructions ()) {
198
+ for (int i = 0 ; i < inst->called_computations ().size (); ++i) {
199
+ if (inst->called_computations ()[i] == this ) {
200
+ inst->set_called_computation (i, nullptr );
201
+ }
202
+ }
203
+ }
204
+ }
205
+ CHECK (caller_computations_.empty ());
206
+
186
207
for (const auto & i : instructions_) {
187
208
delete i.inst ();
188
209
}
189
210
}
190
211
212
+ void HloComputation::ClearCalledComputations () {
213
+ for (HloInstruction* i : instructions ()) {
214
+ i->ClearCalledComputations ();
215
+ }
216
+ // Clearing the instructions should have removed all callee computations.
217
+ CHECK (callee_computations_.empty ());
218
+ }
219
+
191
220
void HloComputation::SetInstruction (HloInstruction* instruction,
192
221
InstructionType type) {
193
222
static_assert (alignof (HloInstruction) == kInstructionTypeMask + 1 ,
@@ -241,6 +270,38 @@ HloInstruction* HloComputation::AddInstruction(
241
270
return AddInstruction (std::move (instruction));
242
271
}
243
272
273
+ static void IncrementCount (
274
+ absl::btree_map<HloComputation*, int , HloComputation::UniqueIdComparator>&
275
+ map,
276
+ HloComputation* key) {
277
+ ++map[key];
278
+ }
279
+
280
+ // Returns true if the callee was present and its count was decremented; returns
281
+ // false if the callee was not present.
282
+ static void DecrementCount (
283
+ absl::btree_map<HloComputation*, int , HloComputation::UniqueIdComparator>&
284
+ map,
285
+ HloComputation* key) {
286
+ auto it = map.find (key);
287
+ CHECK (it != map.end ());
288
+ CHECK_GT (it->second , 0 );
289
+ --it->second ;
290
+ if (it->second == 0 ) {
291
+ map.erase (it);
292
+ }
293
+ }
294
+
295
+ void HloComputation::AddCallee (HloComputation* callee) {
296
+ IncrementCount (callee_computations_, callee);
297
+ IncrementCount (callee->caller_computations_ , this );
298
+ }
299
+
300
+ void HloComputation::RemoveCallee (HloComputation* callee) {
301
+ DecrementCount (callee_computations_, callee);
302
+ DecrementCount (callee->caller_computations_ , this );
303
+ }
304
+
244
305
HloInstruction* HloComputation::AddInstructionInternal (
245
306
std::unique_ptr<HloInstruction> instruction) {
246
307
if (parent () != nullptr ) {
@@ -265,6 +326,7 @@ HloInstruction* HloComputation::AddInstructionInternal(
265
326
CHECK (parent () == nullptr || called_computation->parent () == parent ())
266
327
<< " Called computation " << called_computation->name ()
267
328
<< " is not in the same module as " << name ();
329
+ AddCallee (called_computation);
268
330
}
269
331
return pinst;
270
332
}
@@ -521,13 +583,13 @@ absl::Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
521
583
522
584
HloInstructionInfo* info = &instructions_[instruction->index_in_parent_ ];
523
585
DCHECK_EQ (info->inst (), instruction);
524
- info->inst ()->set_parent (nullptr );
525
586
to_be_deleted_.push_back (info->inst ()); // Takes ownership
526
587
to_be_deleted_.back ()->DetachFromOperandsAndUsers ();
527
588
// Clear all operands to avoid Null operands.
528
589
to_be_deleted_.back ()->RemoveAllOperands ();
529
590
to_be_deleted_.back ()->ClearCalledComputations ();
530
591
to_be_deleted_.back ()->MarkAsDead ();
592
+ info->inst ()->set_parent (nullptr );
531
593
532
594
// If this instruction is a constant, clear the literal eagerly instead of
533
595
// waiting for the instruction to be deleted in Cleanup(). This greatly
@@ -1089,7 +1151,7 @@ HloComputation::CreateFromProto(
1089
1151
1090
1152
auto computation = absl::WrapUnique (
1091
1153
new HloComputation (proto.name (), parameter_count, &instructions, root));
1092
- computation->unique_id_ = proto.id ();
1154
+ computation->SetUniqueIdHelper ( proto.id () );
1093
1155
if (proto.is_fusion_computation ()) {
1094
1156
computation->instruction_and_type_ =
1095
1157
static_cast <uintptr_t >(InstructionType::kFusion );
@@ -1840,4 +1902,36 @@ bool HloComputation::CanExpandIntoSingleInstruction() const {
1840
1902
});
1841
1903
}
1842
1904
1905
+ void HloComputation::ClearUniqueIdInternal () { SetUniqueIdHelper (-1 ); }
1906
+
1907
+ void HloComputation::SetUniqueId (int64_t id) {
1908
+ CHECK_EQ (unique_id_, -1 );
1909
+ CHECK_GE (id, 0 );
1910
+ SetUniqueIdHelper (id);
1911
+ }
1912
+
1913
+ void HloComputation::SetUniqueIdHelper (int64_t id) {
1914
+ // The caller/callee computations are ordered by unique ID, so we need to
1915
+ // remove and readd them to our neighbor's data structures.
1916
+ for (auto & [computation, count] : caller_computations_) {
1917
+ auto it = computation->callee_computations_ .find (this );
1918
+ CHECK (it != computation->callee_computations_ .end ());
1919
+ CHECK_EQ (it->second , count);
1920
+ computation->callee_computations_ .erase (it);
1921
+ }
1922
+ for (auto & [computation, count] : callee_computations_) {
1923
+ auto it = computation->caller_computations_ .find (this );
1924
+ CHECK (it != computation->caller_computations_ .end ());
1925
+ CHECK_EQ (it->second , count);
1926
+ computation->caller_computations_ .erase (it);
1927
+ }
1928
+ unique_id_ = id;
1929
+ for (auto & [computation, count] : caller_computations_) {
1930
+ computation->callee_computations_ [this ] = count;
1931
+ }
1932
+ for (auto & [computation, count] : callee_computations_) {
1933
+ computation->caller_computations_ [this ] = count;
1934
+ }
1935
+ }
1936
+
1843
1937
} // namespace xla
0 commit comments