Skip to content

Commit

Permalink
Add support for s32le, and improve reading of Wav files. (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
devoxin authored May 18, 2024
1 parent 6460c04 commit 701df40
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import java.io.DataInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;

import static com.sedmelluq.discord.lavaplayer.container.MediaContainerDetection.checkNextBytes;

/**
* Loads either WAV header information or a WAV track provider from a stream.
*/
public class WavFileLoader {
static final int[] WAV_RIFF_HEADER = new int[]{0x52, 0x49, 0x46, 0x46, -1, -1, -1, -1, 0x57, 0x41, 0x56, 0x45};
static final int[] WAV_RIFF_HEADER = new int[] { 0x52, 0x49, 0x46, 0x46, -1, -1, -1, -1, 0x57, 0x41, 0x56, 0x45 };
static final byte[] FORMAT_SUBTYPE_PCM = { 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, (byte) 0x80, 0x00, 0x00, (byte) 0xaa, 0x00, 0x38, (byte) 0x9b, 0x71 };

private final SeekableInputStream inputStream;

Expand Down Expand Up @@ -44,10 +46,11 @@ public WavFileInfo parseHeaders() throws IOException {
long chunkSize = Integer.toUnsignedLong(Integer.reverseBytes(dataInput.readInt()));

if ("fmt ".equals(chunkName)) {
readFormatChunk(builder, dataInput);
int bytesRead = readFormatChunk(builder, dataInput);
long chunkBytesRemaining = chunkSize - bytesRead;

if (chunkSize > 16) {
inputStream.skipFully(chunkSize - 16);
if (chunkBytesRemaining > 0) {
inputStream.skipFully(chunkBytesRemaining);
}
} else if ("data".equals(chunkName)) {
builder.sampleAreaSize = chunkSize;
Expand All @@ -65,8 +68,8 @@ private String readChunkName(DataInput dataInput) throws IOException {
return new String(buffer, StandardCharsets.US_ASCII);
}

private void readFormatChunk(InfoBuilder builder, DataInput dataInput) throws IOException {
builder.audioFormat = Short.reverseBytes(dataInput.readShort()) & 0xFFFF;
private int readFormatChunk(InfoBuilder builder, DataInput dataInput) throws IOException {
builder.setAudioFormat(Short.reverseBytes(dataInput.readShort()) & 0xFFFF);
builder.channelCount = Short.reverseBytes(dataInput.readShort()) & 0xFFFF;
builder.sampleRate = Integer.reverseBytes(dataInput.readInt());

Expand All @@ -75,6 +78,16 @@ private void readFormatChunk(InfoBuilder builder, DataInput dataInput) throws IO

builder.blockAlign = Short.reverseBytes(dataInput.readShort()) & 0xFFFF;
builder.bitsPerSample = Short.reverseBytes(dataInput.readShort()) & 0xFFFF;

if (builder.formatType == WaveFormatType.WAVE_FORMAT_EXTENSIBLE) {
dataInput.skipBytes(8);
byte[] subFormat = new byte[16];
dataInput.readFully(subFormat);
builder.subFormat = subFormat;
return 40;
}

return 16;
}

/**
Expand All @@ -90,13 +103,20 @@ public WavTrackProvider loadTrack(AudioProcessingContext context) throws IOExcep

private static class InfoBuilder {
private int audioFormat;
private WaveFormatType formatType;
private byte[] subFormat;
private int channelCount;
private int sampleRate;
private int bitsPerSample;
private int blockAlign;
private long sampleAreaSize;
private long startOffset;

private void setAudioFormat(int audioFormat) {
this.audioFormat = audioFormat;
this.formatType = WaveFormatType.getByCode(audioFormat);
}

private WavFileInfo build() {
validateFormat();
validateAlignment();
Expand All @@ -105,13 +125,15 @@ private WavFileInfo build() {
}

private void validateFormat() {
if (audioFormat != 1) {
throw new IllegalStateException("Invalid audio format " + audioFormat + ", must be 1 (PCM)");
if (formatType == WaveFormatType.WAVE_FORMAT_UNKNOWN) {
throw new IllegalStateException("Invalid audio format " + audioFormat + ", must be 1 (PCM) or 65534 (WAVE_FORMAT_EXTENSIBLE)");
} else if (subFormat != null && !Arrays.equals(subFormat, FORMAT_SUBTYPE_PCM)) {
throw new IllegalStateException("Invalid subformat " + Arrays.toString(subFormat));
} else if (channelCount < 1 || channelCount > 16) {
throw new IllegalStateException("Invalid channel count: " + channelCount);
} else if (sampleRate < 100 || sampleRate > 384000) {
throw new IllegalStateException("Invalid sample rate: " + sampleRate);
} else if (bitsPerSample != 16 && bitsPerSample != 24) {
} else if (bitsPerSample != 16 && bitsPerSample != 24 && bitsPerSample != 32) {
throw new IllegalStateException("Unsupported bits per sample: " + bitsPerSample);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.io.DataInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ShortBuffer;

import static java.nio.ByteOrder.LITTLE_ENDIAN;

Expand All @@ -23,11 +22,12 @@ public class WavTrackProvider {
private final SeekableInputStream inputStream;
private final DataInput dataInput;
private final WavFileInfo info;
private final int bytesPerSample;
private final AudioPipeline downstream;

private final short[] buffer;
private final byte[] rawBuffer;
private final ByteBuffer byteBuffer;
private final ShortBuffer nioBuffer;

/**
* @param context Configuration and output information for processing
Expand All @@ -38,12 +38,12 @@ public WavTrackProvider(AudioProcessingContext context, SeekableInputStream inpu
this.inputStream = inputStream;
this.dataInput = new DataInputStream(inputStream);
this.info = info;
this.bytesPerSample = info.bitsPerSample >> 3;
this.downstream = AudioPipelineFactory.create(context, new PcmFormat(info.channelCount, info.sampleRate));
this.buffer = info.getPadding() > 0 ? new short[info.channelCount * BLOCKS_IN_BUFFER] : null;

this.byteBuffer = ByteBuffer.allocate(info.blockAlign * BLOCKS_IN_BUFFER).order(LITTLE_ENDIAN);
this.rawBuffer = byteBuffer.array();
this.nioBuffer = byteBuffer.asShortBuffer();
}

/**
Expand Down Expand Up @@ -101,10 +101,10 @@ private void processChunkWithPadding(int blockCount) throws IOException, Interru
int indexInBlock = 0;

for (int i = 0; i < sampleCount; i++) {
buffer[i] = nioBuffer.get();
buffer[i] = byteBuffer.getShort();

if (++indexInBlock == info.channelCount) {
nioBuffer.position(nioBuffer.position() + padding);
byteBuffer.position(byteBuffer.position() + padding);
indexInBlock = 0;
}
}
Expand All @@ -115,27 +115,23 @@ private void processChunkWithPadding(int blockCount) throws IOException, Interru
private void processChunk(int blockCount) throws IOException, InterruptedException {
int sampleCount = readChunkToBuffer(blockCount);

if (info.bitsPerSample == 16) {
downstream.process(nioBuffer);
} else if (info.bitsPerSample == 24) {
short[] samples = new short[sampleCount];

if (info.bitsPerSample != 16) {
for (int i = 0; i < sampleCount; i++) {
samples[i] = (short) (byteBuffer.get((i * 3) + 2) << 8 | byteBuffer.get((i * 3) + 1) & 0xFF);
byteBuffer.putShort(i * 2, byteBuffer.getShort((i * bytesPerSample) + bytesPerSample - 2));
}

downstream.process(samples, 0, sampleCount);
byteBuffer.limit(sampleCount * 2);
}

downstream.process(byteBuffer.asShortBuffer());
}

private int readChunkToBuffer(int blockCount) throws IOException {
int bytesPerSample = info.bitsPerSample >> 3;
int bytesToRead = blockCount * info.blockAlign;
dataInput.readFully(rawBuffer, 0, bytesToRead);

byteBuffer.position(0);
nioBuffer.position(0);
nioBuffer.limit(bytesToRead / bytesPerSample);
byteBuffer.limit(bytesToRead);

return bytesToRead / bytesPerSample;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.sedmelluq.discord.lavaplayer.container.wav;

import java.util.Arrays;

public enum WaveFormatType {
// https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/Docs/Pages%20from%20mmreg.h.pdf
WAVE_FORMAT_UNKNOWN(0x0000),
WAVE_FORMAT_PCM(0x0001),
WAVE_FORMAT_EXTENSIBLE(0xFFFE);

final int code;

WaveFormatType(int code) {
this.code = code;
}

static WaveFormatType getByCode(int code) {
return Arrays.stream(values()).filter(type -> type.code == code).findFirst()
.orElse(WAVE_FORMAT_UNKNOWN);
}
}

0 comments on commit 701df40

Please sign in to comment.