Skip to content

Commit

Permalink
Decode sodium vertices directly into FP32 in compute shaders right be…
Browse files Browse the repository at this point in the history
…fore BLAS build, no more CPU decoding time / uploading & uses less memory
  • Loading branch information
bobcao3 committed Sep 27, 2023
1 parent b4f4b9d commit 70758f5
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@
import me.cortex.vulkanite.lib.base.VContext;
import me.cortex.vulkanite.lib.cmd.VCmdBuff;
import me.cortex.vulkanite.lib.cmd.VCommandPool;
import me.cortex.vulkanite.lib.descriptors.DescriptorSetLayoutBuilder;
import me.cortex.vulkanite.lib.descriptors.VDescriptorSetLayout;
import me.cortex.vulkanite.lib.memory.VAccelerationStructure;
import me.cortex.vulkanite.lib.memory.VBuffer;
import me.cortex.vulkanite.lib.other.VQueryPool;
import me.cortex.vulkanite.lib.other.sync.VFence;
import me.cortex.vulkanite.lib.other.sync.VSemaphore;
import me.cortex.vulkanite.lib.pipeline.ComputePipelineBuilder;
import me.cortex.vulkanite.lib.pipeline.VComputePipeline;
import me.cortex.vulkanite.lib.shader.ShaderCompiler;
import me.cortex.vulkanite.lib.shader.ShaderModule;
import me.cortex.vulkanite.lib.shader.VShader;
import me.jellysquid.mods.sodium.client.render.chunk.compile.ChunkBuildOutput;
import me.jellysquid.mods.sodium.client.render.chunk.data.BuiltSectionMeshParts;
import me.jellysquid.mods.sodium.client.render.chunk.terrain.DefaultTerrainRenderPasses;
Expand All @@ -35,7 +42,7 @@

public class AccelerationBlasBuilder {
private final VContext context;
private record BLASTriangleData(int quadCount, NativeBuffer geometry, int geometryFlags) {}
private record BLASTriangleData(int quadCount, int geometryFlags) {}
private record BLASBuildJob(List<BLASTriangleData> geometries, JobPassThroughData data) {}
public record BLASBuildResult(VAccelerationStructure structure, JobPassThroughData data) {}
public record BLASBatchResult(List<BLASBuildResult> results, VSemaphore semaphore) { }
Expand All @@ -50,12 +57,74 @@ public record BLASBatchResult(List<BLASBuildResult> results, VSemaphore semaphor
private final Semaphore awaitingJobBatchess = new Semaphore(0);//Note: this is done to avoid spin locking on the job consumer
private final ConcurrentLinkedDeque<List<BLASBuildJob>> batchedJobs = new ConcurrentLinkedDeque<>();

private final VComputePipeline gpuVertexDecodePipeline;

public AccelerationBlasBuilder(VContext context, int asyncQueue, Consumer<BLASBatchResult> resultConsumer) {
this.sinlgeUsePool = context.cmd.createSingleUsePool();
this.queryPool = new VQueryPool(context.device, 10000, VK_QUERY_TYPE_ACCELERATION_STRUCTURE_COMPACTED_SIZE_KHR);
this.context = context;
this.asyncQueue = asyncQueue;
this.resultConsumer = resultConsumer;

var decodeShader = VShader.compileLoad(context, """
#version 460
#extension GL_EXT_buffer_reference : require
#extension GL_EXT_shader_8bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_16bit_storage : require
layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
struct InputVertex {
u16vec4 position;
u8vec4 color;
u16vec2 blockTexture;
u16vec2 lightTexture;
u16vec2 midTexCoord;
i8vec4 tangent;
i8vec3 normal;
int8_t padA__;
i16vec2 blockId;
i8vec3 midBlock;
int8_t padB__;
};
layout(buffer_reference, std430) buffer InputVertices {
InputVertex vertices[];
};
layout(buffer_reference, std430) buffer OutputVertices {
float vertices[];
};
layout(push_constant) uniform PushConstants {
uint32_t nVertices;
uint64_t inAddr;
uint64_t outAddr;
};
void main() {
uint32_t idx = gl_GlobalInvocationID.x;
uint32_t gridSize = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
InputVertices inputs = InputVertices(inAddr);
OutputVertices outputs = OutputVertices(outAddr);
for (idx; idx < nVertices; idx += gridSize) {
vec3 position = vec3(inputs.vertices[idx].position.xyz) * (32.0 / 65536.0) - 8.0;
outputs.vertices[idx * 3 + 0] = position.x;
outputs.vertices[idx * 3 + 1] = position.y;
outputs.vertices[idx * 3 + 2] = position.z;
}
}
""",
VK_SHADER_STAGE_COMPUTE_BIT);

var decodePipeBuilder = new ComputePipelineBuilder();
decodePipeBuilder.addPushConstantRange(8 * 3, 0);
decodePipeBuilder.set(decodeShader.named());
gpuVertexDecodePipeline = decodePipeBuilder.build(context);

decodeShader.free();

worker = new Thread(this::run);
worker.setName("Acceleration blas worker");
worker.start();
Expand Down Expand Up @@ -121,13 +190,9 @@ private void run() {

long buildBufferSize = 0;

int vertexStride = 4 * 3;
for (var geometry : job.geometries) {
if (geometry.geometry.getLength() <= 0) {
throw new IllegalStateException("Geometry size <= 0");
}

buildBufferSize += geometry.geometry.getLength();
//After we copy it over, we can free the native buffer
buildBufferSize += geometry.quadCount * 4 * vertexStride;
}

if (buildBufferSize <= 0) {
Expand All @@ -150,9 +215,17 @@ private void run() {
var geometryInfo = geometryInfos.get().sType$Default();
var br = brs.get();

uploadBuildCmd.encodeDataUpload(context.memory,
MemoryUtil.memAddress(geometry.geometry.getDirectBuffer()), buildBuffer,
buildBufferOffset, geometry.geometry.getLength());
// We know the geometry data has been uploaded
// 0: n_vertices
// 1: inAddr
// 2: outAddr
var pushConstant = new long[3];
pushConstant[0] = geometry.quadCount * 4;
pushConstant[1] = job.data.geometryBuffers().get(geoIdx).deviceAddress();
pushConstant[2] = buildBufferAddr + buildBufferOffset;
vkCmdBindPipeline(uploadBuildCmd.buffer, VK_PIPELINE_BIND_POINT_COMPUTE, gpuVertexDecodePipeline.pipeline());
vkCmdPushConstants(uploadBuildCmd.buffer, gpuVertexDecodePipeline.layout(), VK_SHADER_STAGE_ALL, 0, pushConstant);
vkCmdDispatch(uploadBuildCmd.buffer, Math.min((geometry.quadCount * 4 + 255) / 256, 128), 1, 1);

VkDeviceOrHostAddressConstKHR indexData = SharedQuadVkIndexBuffer.getIndexBuffer(context,
uploadBuildCmd,
Expand All @@ -161,8 +234,7 @@ private void run() {

VkDeviceOrHostAddressConstKHR vertexData = VkDeviceOrHostAddressConstKHR.calloc(stack)
.deviceAddress(buildBufferAddr + buildBufferOffset);
int vertexFormat = VK_FORMAT_R16G16B16_SFLOAT;//VK_FORMAT_R32G32B32_SFLOAT;
int vertexStride = 2*3;
int vertexFormat = VK_FORMAT_R32G32B32_SFLOAT;

geometryInfo.geometry(VkAccelerationStructureGeometryDataKHR.calloc(stack)
.triangles(VkAccelerationStructureGeometryTrianglesDataKHR.calloc(stack)
Expand All @@ -176,15 +248,14 @@ private void run() {
.indexData(indexData)
.indexType(indexType)))
.geometryType(VK_GEOMETRY_TYPE_TRIANGLES_KHR)
.flags(geometry.geometryFlags);//TODO: ADD VkGeometryFlagsKHR VK_GEOMETRY_OPAQUE_BIT_KHR
.flags(geometry.geometryFlags);

maxPrims.put(geometry.quadCount * 2);
br.primitiveCount(geometry.quadCount * 2);
//maxPrims.put(2);
//br.primitiveCount(2);

buildBufferOffset += geometry.geometry.getLength();
geometry.geometry.free();
buildBufferOffset += geometry.quadCount * 4 * vertexStride;
}

uploadBuildCmd.encodeBufferBarrier(buildBuffer, 0, VK_WHOLE_SIZE, VK_PIPELINE_STAGE_TRANSFER_BIT,
Expand Down Expand Up @@ -370,9 +441,9 @@ public void enqueue(List<ChunkBuildOutput> batch) {
List<BLASTriangleData> buildData = new ArrayList<>();
List<VBuffer> geometryBuffers = new ArrayList<>();
for (var entry : acbr.entrySet()) {
// TODO: dont hardcode the stride size
int flag = entry.getKey() == DefaultTerrainRenderPasses.SOLID ? VK_GEOMETRY_OPAQUE_BIT_KHR : 0;
buildData.add(new BLASTriangleData(entry.getValue().quadCount(), entry.getValue().data(), flag));
buildData.add(new BLASTriangleData(entry.getValue().quadCount(), flag));
// TODO: Just don't create this data in the first place

var geometry = cbr.getMesh(entry.getKey());
if (geometry.getVertexData().getLength() == 0) {
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/me/cortex/vulkanite/compat/GeometryData.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

import me.jellysquid.mods.sodium.client.util.NativeBuffer;

public record GeometryData(int quadCount, NativeBuffer data) {
public record GeometryData(int quadCount) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,7 @@ public static void compute(ChunkBuildOutput buildResult) {
if (vertices % 4 != 0)
throw new IllegalStateException("Non multiple 4 vertex count");

NativeBuffer geometryBuffer = new NativeBuffer(vertices*(2*3));
long addr = MemoryUtil.memAddress(geometryBuffer.getDirectBuffer());
long srcVert = MemoryUtil.memAddress(vertData.getDirectBuffer());

for (var faceData : pass.getValue().getVertexRanges()) {
if (faceData == null) continue;
for (int i = 0; i < faceData.vertexCount(); i++) {
long base = srcVert + (long) stride * (i + faceData.vertexStart());
float x = decodePosition(MemoryUtil.memGetShort(base));
float y = decodePosition(MemoryUtil.memGetShort(base + 2));
float z = decodePosition(MemoryUtil.memGetShort(base + 4));

MemoryUtil.memPutShort(addr, (short) fromFloat(x));
MemoryUtil.memPutShort(addr + 2, (short) fromFloat(y));
MemoryUtil.memPutShort(addr + 4, (short) fromFloat(z));
addr += 6;
}
}
map.put(pass.getKey(), new GeometryData(vertices>>2, geometryBuffer));
map.put(pass.getKey(), new GeometryData(vertices>>2));
}

if (!map.isEmpty()) {
Expand All @@ -53,36 +35,4 @@ public static void compute(ChunkBuildOutput buildResult) {
ebr.setAccelerationGeometryData(null);
}
}


private static float decodePosition(short v) {
return Short.toUnsignedInt(v)*(1f/2048.0f)-8.0f;
}

public static int fromFloat( float fval )
{
int fbits = Float.floatToIntBits( fval );
int sign = fbits >>> 16 & 0x8000; // sign only
int val = ( fbits & 0x7fffffff ) + 0x1000; // rounded value

if( val >= 0x47800000 ) // might be or become NaN/Inf
{ // avoid Inf due to rounding
if( ( fbits & 0x7fffffff ) >= 0x47800000 )
{ // is or must become NaN/Inf
if( val < 0x7f800000 ) // was value but too large
return sign | 0x7c00; // make it +/-Inf
return sign | 0x7c00 | // remains +/-Inf or NaN
( fbits & 0x007fffff ) >>> 13; // keep NaN (and Inf) bits
}
return sign | 0x7bff; // unrounded not quite Inf
}
if( val >= 0x38800000 ) // remains normalized value
return sign | val - 0x38000000 >>> 13; // exp - 127 + 15
if( val < 0x33000000 ) // too small for subnormal
return sign; // becomes +/-0
val = ( fbits & 0x7fffffff ) >>> 23; // tmp exp for subnormal calc
return sign | ( ( fbits & 0x7fffff | 0x800000 ) // add subnormal bit
+ ( 0x800000 >>> val - 102 ) // round depending on cut off
>>> 126 - val ); // div by 2^(1-(exp-127+15)) and >> 13 | exp=0
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ public ComputePipelineBuilder set(ShaderModule shader) {
return this;
}

private record PushConstant(int size, int offset) {}
private List<PushConstant> pushConstants = new ArrayList<>();

public void addPushConstantRange(int size, int offset) {
pushConstants.add(new PushConstant(size, offset));
}

public VComputePipeline build(VContext context) {
try (var stack = stackPush()) {

Expand All @@ -41,6 +48,18 @@ public VComputePipeline build(VContext context) {
layoutCreateInfo.pSetLayouts(stack.longs(layouts.stream().mapToLong(a->a.layout).toArray()));
}

if (pushConstants.size() > 0) {
var pushConstantRanges = VkPushConstantRange.calloc(pushConstants.size(), stack);
for (int i = 0; i < pushConstants.size(); i++) {
var pushConstant = pushConstants.get(i);
pushConstantRanges.get(i)
.stageFlags(VK_SHADER_STAGE_ALL)
.offset(pushConstant.offset)
.size(pushConstant.size);
}
layoutCreateInfo.pPushConstantRanges(pushConstantRanges);
}

LongBuffer pLayout = stack.mallocLong(1);
_CHECK_(vkCreatePipelineLayout(context.device, layoutCreateInfo, null, pLayout));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ public VComputePipeline(VContext context, long layout, long pipeline) {
this.pipeline = pipeline;
}

public long layout() {
return layout;
}

public long pipeline() {
return pipeline;
}

@Override
public void free() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ private void onDestroy(CallbackInfo ci) {
@Redirect(method = "destroy", at = @At(value = "INVOKE", target = "Lme/jellysquid/mods/sodium/client/render/chunk/compile/ChunkBuildOutput;delete()V"))
private void destroyAccelerationData(ChunkBuildOutput instance) {
var data = ((IAccelerationBuildResult)instance).getAccelerationGeometryData();
if (data != null) {
data.values().forEach(entry->entry.data().free());
}
instance.delete();
//TODO: need to ingest and cleanup all the blas builds and tlas updates
}
Expand All @@ -51,19 +48,7 @@ private void processResults(ArrayList<ChunkBuildOutput> results, CallbackInfo ci
ChunkBuildOutput previous = map.get(render);
if (previous == null || previous.buildTime < output.buildTime) {
var prev = map.put(render, output);
if (prev != null) {
var data = ((IAccelerationBuildResult)output).getAccelerationGeometryData();
data.values().forEach(a->a.data().free());
}
} else {
//Else need to free the injected result
var data = ((IAccelerationBuildResult)output).getAccelerationGeometryData();
data.values().forEach(a->a.data().free());
}
} else {
//Else need to free the injected result
var data = ((IAccelerationBuildResult)output).getAccelerationGeometryData();
data.values().forEach(a->a.data().free());
}
}
if (!map.values().isEmpty()) {
Expand Down

0 comments on commit 70758f5

Please sign in to comment.