Skip to content

Commit 9de005d

Browse files
adding a check for large constant
1 parent 96c8197 commit 9de005d

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

tensorflow/compiler/jit/mark_for_compilation_pass.cc

+33
Original file line numberDiff line numberDiff line change
@@ -898,13 +898,46 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
898898
return absl::OkStatus();
899899
}
900900

901+
int64_t GetConstantTensorSize(Node* n) {
902+
if (n->op_def().name() != "Const") return -1;
903+
904+
const TensorProto* proto = nullptr;
905+
Status s = GetNodeAttr(n->def(), "value", &proto);
906+
if (!s.ok()) return -1;
907+
908+
if (!proto->has_tensor_shape()) {
909+
return -1;
910+
}
911+
const auto& tensor_shape_proto = proto->tensor_shape();
912+
if (tensor_shape_proto.unknown_rank()) {
913+
return -1;
914+
}
915+
int64_t num_elements = 1;
916+
for (const auto& dim : tensor_shape_proto.dim()) {
917+
// Note that in some cases, dim.size() can be zero (e.g., empty vector).
918+
num_elements *= dim.size();
919+
}
920+
return num_elements;
921+
}
922+
901923
Status MarkForCompilationPassImpl::DeclusterNodes() {
902924
for (Node* n : compilation_candidates_) {
903925
Cluster* cluster = GetClusterForNode(n);
904926
if (cluster == nullptr) {
905927
continue;
906928
}
907929

930+
// Remove large constants from clustering so they don't get compiled.
931+
// Avoid unnecessary copies of large constants (based on L1 cache).
932+
933+
const int64_t kLargeConstantThreshold = 16384;
934+
if (n->op_def().name() == "Const") {
935+
int64_t tensor_size = GetConstantTensorSize(n);
936+
if (tensor_size > kLargeConstantThreshold) {
937+
declustered_nodes_.insert(n);
938+
}
939+
}
940+
908941
// De-cluster Fill ops that are
909942
// - used at least once outside the cluster, and
910943
// - not used inside the cluster.

0 commit comments

Comments
 (0)