@@ -898,13 +898,46 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() {
898
898
return absl::OkStatus ();
899
899
}
900
900
901
+ int64_t GetConstantTensorSize (Node* n) {
902
+ if (n->op_def ().name () != " Const" ) return -1 ;
903
+
904
+ const TensorProto* proto = nullptr ;
905
+ Status s = GetNodeAttr (n->def (), " value" , &proto);
906
+ if (!s.ok ()) return -1 ;
907
+
908
+ if (!proto->has_tensor_shape ()) {
909
+ return -1 ;
910
+ }
911
+ const auto & tensor_shape_proto = proto->tensor_shape ();
912
+ if (tensor_shape_proto.unknown_rank ()) {
913
+ return -1 ;
914
+ }
915
+ int64_t num_elements = 1 ;
916
+ for (const auto & dim : tensor_shape_proto.dim ()) {
917
+ // Note that in some cases, dim.size() can be zero (e.g., empty vector).
918
+ num_elements *= dim.size ();
919
+ }
920
+ return num_elements;
921
+ }
922
+
901
923
Status MarkForCompilationPassImpl::DeclusterNodes () {
902
924
for (Node* n : compilation_candidates_) {
903
925
Cluster* cluster = GetClusterForNode (n);
904
926
if (cluster == nullptr ) {
905
927
continue ;
906
928
}
907
929
930
+ // Remove large constants from clustering so they don't get compiled.
931
+ // Avoid unnecessary copies of large constants (based on L1 cache).
932
+
933
+ const int64_t kLargeConstantThreshold = 16384 ;
934
+ if (n->op_def ().name () == " Const" ) {
935
+ int64_t tensor_size = GetConstantTensorSize (n);
936
+ if (tensor_size > kLargeConstantThreshold ) {
937
+ declustered_nodes_.insert (n);
938
+ }
939
+ }
940
+
908
941
// De-cluster Fill ops that are
909
942
// - used at least once outside the cluster, and
910
943
// - not used inside the cluster.
0 commit comments