Skip to content

Commit

Permalink
[CELEBORN-1374] Refactor SortBuffer and PartitionSortedBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
SteNicholas committed Apr 7, 2024
1 parent 186899f commit 31e1087
Show file tree
Hide file tree
Showing 17 changed files with 374 additions and 373 deletions.
4 changes: 2 additions & 2 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,6 @@ Remote Shuffle Service for Flink
./client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferRecycler.java
./client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
./client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/CreditListener.java
./client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/PartitionSortedBuffer.java
./client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/SortBuffer.java
./client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/DataBuffer.java
./client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/SortBasedDataBuffer.java
./client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/TransferBufferPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.celeborn.plugin.flink.buffer.PartitionSortedBuffer;
import org.apache.celeborn.plugin.flink.buffer.SortBuffer;
import org.apache.celeborn.plugin.flink.buffer.BufferWithSubpartition;
import org.apache.celeborn.plugin.flink.buffer.DataBuffer;
import org.apache.celeborn.plugin.flink.buffer.SortBasedDataBuffer;
import org.apache.celeborn.plugin.flink.utils.BufferUtils;
import org.apache.celeborn.plugin.flink.utils.Utils;

Expand All @@ -46,29 +47,29 @@ public class RemoteShuffleResultPartitionDelegation {
/** Size of network buffer and write buffer. */
public int networkBufferSize;

/** {@link SortBuffer} for records sent by broadcastRecord. */
public SortBuffer broadcastSortBuffer;
/** {@link DataBuffer} for records sent by broadcastRecord. */
public DataBuffer broadcastDataBuffer;

/** {@link SortBuffer} for records sent by emitRecord. */
public SortBuffer unicastSortBuffer;
/** {@link DataBuffer} for records sent by emitRecord. */
public DataBuffer unicastDataBuffer;

/** Utility to spill data to shuffle workers. */
public RemoteShuffleOutputGate outputGate;

/** Whether notifyEndOfData has been called or not. */
private boolean endOfDataNotified;

private int numSubpartitions;
private final int numSubpartitions;
private BufferPool bufferPool;
private BufferCompressor bufferCompressor;
private Function<Buffer, Boolean> canBeCompressed;
private Runnable checkProducerState;
private BiConsumer<SortBuffer.BufferWithChannel, Boolean> statisticsConsumer;
private final BiConsumer<BufferWithSubpartition, Boolean> statisticsConsumer;

public RemoteShuffleResultPartitionDelegation(
int networkBufferSize,
RemoteShuffleOutputGate outputGate,
BiConsumer<SortBuffer.BufferWithChannel, Boolean> statisticsConsumer,
BiConsumer<BufferWithSubpartition, Boolean> statisticsConsumer,
int numSubpartitions) {
this.networkBufferSize = networkBufferSize;
this.outputGate = outputGate;
Expand Down Expand Up @@ -105,20 +106,20 @@ public void emit(
targetSubpartition == 0, "Target subpartition index can only be 0 when broadcast.");
}

SortBuffer sortBuffer = isBroadcast ? getBroadcastSortBuffer() : getUnicastSortBuffer();
if (sortBuffer.append(record, targetSubpartition, dataType)) {
DataBuffer dataBuffer = isBroadcast ? getBroadcastDataBuffer() : getUnicastDataBuffer();
if (dataBuffer.append(record, targetSubpartition, dataType)) {
return;
}

try {
if (!sortBuffer.hasRemaining()) {
if (!dataBuffer.hasRemaining()) {
// the record can not be appended to the free sort buffer because it is too large
sortBuffer.finish();
sortBuffer.release();
dataBuffer.finish();
dataBuffer.release();
writeLargeRecord(record, targetSubpartition, dataType, isBroadcast);
return;
}
flushSortBuffer(sortBuffer, isBroadcast);
flushDataBuffer(dataBuffer, isBroadcast);
} catch (InterruptedException e) {
LOG.error("Failed to flush the sort buffer.", e);
Utils.rethrowAsRuntimeException(e);
Expand All @@ -127,70 +128,70 @@ public void emit(
}

@VisibleForTesting
public SortBuffer getUnicastSortBuffer() throws IOException {
flushBroadcastSortBuffer();
public DataBuffer getUnicastDataBuffer() throws IOException {
flushBroadcastDataBuffer();

if (unicastSortBuffer != null && !unicastSortBuffer.isFinished()) {
return unicastSortBuffer;
if (unicastDataBuffer != null && !unicastDataBuffer.isFinished()) {
return unicastDataBuffer;
}

unicastSortBuffer =
new PartitionSortedBuffer(bufferPool, numSubpartitions, networkBufferSize, null);
return unicastSortBuffer;
unicastDataBuffer =
new SortBasedDataBuffer(bufferPool, numSubpartitions, networkBufferSize, null);
return unicastDataBuffer;
}

public SortBuffer getBroadcastSortBuffer() throws IOException {
flushUnicastSortBuffer();
public DataBuffer getBroadcastDataBuffer() throws IOException {
flushUnicastDataBuffer();

if (broadcastSortBuffer != null && !broadcastSortBuffer.isFinished()) {
return broadcastSortBuffer;
if (broadcastDataBuffer != null && !broadcastDataBuffer.isFinished()) {
return broadcastDataBuffer;
}

broadcastSortBuffer =
new PartitionSortedBuffer(bufferPool, numSubpartitions, networkBufferSize, null);
return broadcastSortBuffer;
broadcastDataBuffer =
new SortBasedDataBuffer(bufferPool, numSubpartitions, networkBufferSize, null);
return broadcastDataBuffer;
}

public void flushBroadcastSortBuffer() throws IOException {
flushSortBuffer(broadcastSortBuffer, true);
public void flushBroadcastDataBuffer() throws IOException {
flushDataBuffer(broadcastDataBuffer, true);
}

public void flushUnicastSortBuffer() throws IOException {
flushSortBuffer(unicastSortBuffer, false);
public void flushUnicastDataBuffer() throws IOException {
flushDataBuffer(unicastDataBuffer, false);
}

@VisibleForTesting
void flushSortBuffer(SortBuffer sortBuffer, boolean isBroadcast) throws IOException {
if (sortBuffer == null || sortBuffer.isReleased()) {
void flushDataBuffer(DataBuffer dataBuffer, boolean isBroadcast) throws IOException {
if (dataBuffer == null || dataBuffer.isReleased()) {
return;
}
sortBuffer.finish();
if (sortBuffer.hasRemaining()) {
dataBuffer.finish();
if (dataBuffer.hasRemaining()) {
try {
outputGate.regionStart(isBroadcast);
while (sortBuffer.hasRemaining()) {
while (dataBuffer.hasRemaining()) {
MemorySegment segment = outputGate.getBufferPool().requestMemorySegmentBlocking();
SortBuffer.BufferWithChannel bufferWithChannel;
BufferWithSubpartition bufferWithSubpartition;
try {
bufferWithChannel =
sortBuffer.copyIntoSegment(
bufferWithSubpartition =
dataBuffer.getNextBuffer(
segment, outputGate.getBufferPool(), BufferUtils.HEADER_LENGTH);
} catch (Throwable t) {
outputGate.getBufferPool().recycle(segment);
throw new FlinkRuntimeException("Shuffle write failure.", t);
}

Buffer buffer = bufferWithChannel.getBuffer();
int subpartitionIndex = bufferWithChannel.getChannelIndex();
statisticsConsumer.accept(bufferWithChannel, isBroadcast);
Buffer buffer = bufferWithSubpartition.getBuffer();
int subpartitionIndex = bufferWithSubpartition.getSubpartitionIndex();
statisticsConsumer.accept(bufferWithSubpartition, isBroadcast);
writeCompressedBufferIfPossible(buffer, subpartitionIndex);
}
outputGate.regionFinish();
} catch (InterruptedException e) {
throw new IOException("Failed to flush the sort buffer, broadcast=" + isBroadcast, e);
}
}
releaseSortBuffer(sortBuffer);
releaseDataBuffer(dataBuffer);
}

public void writeCompressedBufferIfPossible(Buffer buffer, int targetSubpartition)
Expand Down Expand Up @@ -234,9 +235,9 @@ public void writeLargeRecord(
dataType,
toCopy + BufferUtils.HEADER_LENGTH);

SortBuffer.BufferWithChannel bufferWithChannel =
new SortBuffer.BufferWithChannel(buffer, targetSubpartition);
statisticsConsumer.accept(bufferWithChannel, isBroadcast);
BufferWithSubpartition bufferWithSubpartition =
new BufferWithSubpartition(buffer, targetSubpartition);
statisticsConsumer.accept(bufferWithSubpartition, isBroadcast);
writeCompressedBufferIfPossible(buffer, targetSubpartition);
}
outputGate.regionFinish();
Expand All @@ -246,17 +247,17 @@ public void broadcast(ByteBuffer record, Buffer.DataType dataType) throws IOExce
emit(record, 0, dataType, true);
}

public void releaseSortBuffer(SortBuffer sortBuffer) {
if (sortBuffer != null) {
sortBuffer.release();
public void releaseDataBuffer(DataBuffer dataBuffer) {
if (dataBuffer != null) {
dataBuffer.release();
}
}

public void finish() throws IOException {
Utils.checkState(
unicastSortBuffer == null || unicastSortBuffer.isReleased(),
unicastDataBuffer == null || unicastDataBuffer.isReleased(),
"The unicast sort buffer should be either null or released.");
flushBroadcastSortBuffer();
flushBroadcastDataBuffer();
try {
outputGate.finish();
} catch (InterruptedException e) {
Expand All @@ -265,22 +266,21 @@ public void finish() throws IOException {
}

public synchronized void close(Runnable closeHandler) {
Throwable closeException = null;
Throwable closeException;
closeException =
checkException(
() -> releaseSortBuffer(unicastSortBuffer),
closeException,
() -> releaseDataBuffer(unicastDataBuffer),
null,
"Failed to release unicast sort buffer.");

closeException =
checkException(
() -> releaseSortBuffer(broadcastSortBuffer),
() -> releaseDataBuffer(broadcastDataBuffer),
closeException,
"Failed to release broadcast sort buffer.");

closeException =
checkException(
() -> closeHandler.run(), closeException, "Failed to call super#close() method.");
checkException(closeHandler, closeException, "Failed to call super#close() method.");

try {
outputGate.close();
Expand All @@ -307,8 +307,8 @@ public Throwable checkException(Runnable runnable, Throwable exception, String e

public void flushAll() {
try {
flushUnicastSortBuffer();
flushBroadcastSortBuffer();
flushUnicastDataBuffer();
flushBroadcastDataBuffer();
} catch (Throwable t) {
LOG.error("Failed to flush the current sort buffer.", t);
Utils.rethrowAsRuntimeException(t);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.plugin.flink.buffer;

import static org.apache.flink.util.Preconditions.checkNotNull;

import org.apache.flink.runtime.io.network.buffer.Buffer;

/** Buffer and the corresponding subpartition index. */
public class BufferWithSubpartition {

private final Buffer buffer;

private final int subpartitionIndex;

public BufferWithSubpartition(Buffer buffer, int subpartitionIndex) {
this.buffer = checkNotNull(buffer);
this.subpartitionIndex = subpartitionIndex;
}

public Buffer getBuffer() {
return buffer;
}

public int getSubpartitionIndex() {
return subpartitionIndex;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.plugin.flink.buffer;

import java.io.IOException;
import java.nio.ByteBuffer;

import javax.annotation.Nullable;

import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferRecycler;

/**
* Data of different subpartitions can be appended to a {@link DataBuffer} and after the {@link
* DataBuffer} is full or finished, the appended data can be copied from it in subpartition index
* order.
*
* <p>The lifecycle of a {@link DataBuffer} can be: new, write, [read, reset, write], finish, read,
* release. There can be multiple [read, reset, write] operations before finish.
*/
public interface DataBuffer {

/**
* Appends data of the specified subpartition to this {@link DataBuffer} and returns true if this
* {@link DataBuffer} is full.
*/
boolean append(ByteBuffer source, int targetSubpartition, Buffer.DataType dataType)
throws IOException;

/**
* Copies data in this {@link DataBuffer} to the target {@link MemorySegment} in subpartition
* index order and returns {@link BufferWithSubpartition} which contains the copied data and the
* corresponding subpartition index.
*/
BufferWithSubpartition getNextBuffer(
@Nullable MemorySegment transitBuffer, BufferRecycler recycler, int offset);

/** Returns the total number of records written to this {@link DataBuffer}. */
long numTotalRecords();

/** Returns the total number of bytes written to this {@link DataBuffer}. */
long numTotalBytes();

/** Returns true if not all data appended to this {@link DataBuffer} is consumed. */
boolean hasRemaining();

/** Finishes this {@link DataBuffer} which means no record can be appended anymore. */
void finish();

/** Whether this {@link DataBuffer} is finished or not. */
boolean isFinished();

/** Releases this {@link DataBuffer} which releases all resources. */
void release();

/** Whether this {@link DataBuffer} is released or not. */
boolean isReleased();
}
Loading

0 comments on commit 31e1087

Please sign in to comment.