Skip to content

Commit

Permalink
Less crashy in the BLAS builder
Browse files Browse the repository at this point in the history
  • Loading branch information
bobcao3 committed Dec 10, 2023
1 parent d791dc7 commit ef5a5a1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public AccelerationBlasBuilder(VContext context, int asyncQueue, Consumer<BLASBa
};
layout(push_constant) uniform PushConstants {
uint32_t nVertices;
uint64_t nVertices;
uint64_t inAddr;
uint64_t outAddr;
};
Expand All @@ -107,7 +107,7 @@ void main() {
uint32_t gridSize = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
InputVertices inputs = InputVertices(inAddr);
OutputVertices outputs = OutputVertices(outAddr);
for (idx; idx < nVertices; idx += gridSize) {
for (idx; idx < uint32_t(nVertices); idx += gridSize) {
vec3 position = vec3(inputs.vertices[idx].position.xyz) * (32.0 / 65536.0) - 8.0;
outputs.vertices[idx * 3 + 0] = position.x;
outputs.vertices[idx * 3 + 1] = position.y;
Expand Down Expand Up @@ -144,7 +144,7 @@ private void run() {
}
int i = -1;
//Collect the job batch
while (!this.batchedJobs.isEmpty()) {
while (!this.batchedJobs.isEmpty() && jobs.size() < 32) {
i++;
jobs.addAll(this.batchedJobs.poll());
}
Expand All @@ -159,10 +159,7 @@ private void run() {
}
var sinlgeUsePoolWorker = context.cmd.getSingleUsePool();
sinlgeUsePoolWorker.doReleases();
if (jobs.size() > 100) {
System.err.println("EXCESSIVE JOBS FOR SOME REASON AAAAAAAAAA");
//while (true);
}

//Jobs are batched and built on the async vulkan queue then block synchronized with fence
// which then results in compaction and dispatch to consumer

Expand All @@ -181,6 +178,7 @@ private void run() {

//Fill in the buildInfo and buildRanges
int i = -1;
vkCmdBindPipeline(uploadBuildCmd.buffer, VK_PIPELINE_BIND_POINT_COMPUTE, gpuVertexDecodePipeline.pipeline());
for (var job : jobs) {
i++;
var brs = VkAccelerationStructureBuildRangeInfoKHR.calloc(job.geometries.size(), stack);
Expand All @@ -200,7 +198,7 @@ private void run() {
}

var buildBuffer = context.memory.createBuffer(buildBufferSize,
VK_BUFFER_USAGE_TRANSFER_DST_BIT
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT
| VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR
| VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT,
VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT,
Expand All @@ -210,6 +208,9 @@ private void run() {
var buildBufferAddr = buildBuffer.deviceAddress();
long buildBufferOffset = 0;

uploadBuildCmd.encodeBufferBarrier(buildBuffer, 0, VK_WHOLE_SIZE, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);

for (int geoIdx = 0; geoIdx < job.geometries.size(); geoIdx++) {
var geometry = job.geometries.get(geoIdx);
//TODO: Fill in geometryInfo, maxPrims and buildRangeInfo
Expand All @@ -221,12 +222,20 @@ private void run() {
// 1: inAddr
// 2: outAddr
var pushConstant = new long[3];
var geometryInputBuffer = job.data.geometryBuffers().get(geoIdx);
pushConstant[0] = geometry.quadCount * 4;
pushConstant[1] = job.data.geometryBuffers().get(geoIdx).deviceAddress();
pushConstant[1] = geometryInputBuffer.deviceAddress();
pushConstant[2] = buildBufferAddr + buildBufferOffset;
vkCmdBindPipeline(uploadBuildCmd.buffer, VK_PIPELINE_BIND_POINT_COMPUTE, gpuVertexDecodePipeline.pipeline());
if (pushConstant[1] == 0) {
throw new IllegalStateException("Geometry input buffer address is 0");
}
if (pushConstant[2] == 0) {
throw new IllegalStateException("Build buffer address is 0");
}
vkCmdPushConstants(uploadBuildCmd.buffer, gpuVertexDecodePipeline.layout(), VK_SHADER_STAGE_ALL, 0, pushConstant);
vkCmdDispatch(uploadBuildCmd.buffer, Math.min((geometry.quadCount * 4 + 255) / 256, 128), 1, 1);
uploadBuildCmd.encodeBufferBarrier(geometryInputBuffer, 0, VK_WHOLE_SIZE, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
vkCmdDispatch(uploadBuildCmd.buffer, Math.min((geometry.quadCount * 4 + 255) / 256, 256), 1, 1);

VkDeviceOrHostAddressConstKHR indexData = SharedQuadVkIndexBuffer.getIndexBuffer(context,
uploadBuildCmd,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
import static org.lwjgl.vulkan.KHRAccelerationStructure.VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR;
import static org.lwjgl.vulkan.VK10.*;

import java.util.ArrayList;
import java.util.List;

public class DescriptorUpdateBuilder {
private final VContext ctx;
private final MemoryStack stack;
private final VkWriteDescriptorSet.Buffer updates;
private final VImageView placeholderImageView;
private ArrayList<VkDescriptorBufferInfo.Buffer> bulkBufferInfos = new ArrayList<>();
private ArrayList<VkDescriptorImageInfo.Buffer> bulkImageInfos = new ArrayList<>();
private ShaderReflection.Set refSet = null;

public DescriptorUpdateBuilder(VContext ctx, int maxUpdates) {
Expand All @@ -30,7 +33,15 @@ public DescriptorUpdateBuilder(VContext ctx, int maxUpdates) {

public DescriptorUpdateBuilder(VContext ctx, int maxUpdates, VImageView placeholderImageView) {
this.ctx = ctx;
this.stack = MemoryStack.stackPush();
// this.stack = MemoryStack.stackPush();
int objSize = Integer.max(
Integer.max(
VkDescriptorBufferInfo.SIZEOF,
VkDescriptorImageInfo.SIZEOF),
VkWriteDescriptorSetAccelerationStructureKHR.SIZEOF);
objSize = ((objSize + 15) / 16) * 16;
this.stack = MemoryStack.create(1024 + maxUpdates * VkWriteDescriptorSet.SIZEOF + maxUpdates * objSize);
this.stack.push();
this.updates = VkWriteDescriptorSet.calloc(maxUpdates, stack);
this.placeholderImageView = placeholderImageView;
}
Expand Down Expand Up @@ -81,7 +92,7 @@ public DescriptorUpdateBuilder buffer(int binding, int dstArrayElement, List<VBu
if (refSet != null && refSet.getBindingAt(binding) == null) {
return this;
}
var bufInfo = VkDescriptorBufferInfo.calloc(buffers.size(), stack);
var bufInfo = VkDescriptorBufferInfo.calloc(buffers.size());
for (int i = 0; i < buffers.size(); i++) {
bufInfo.get(i)
.buffer(buffers.get(i).buffer())
Expand All @@ -96,7 +107,7 @@ public DescriptorUpdateBuilder buffer(int binding, int dstArrayElement, List<VBu
.descriptorType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER)
.descriptorCount(buffers.size())
.pBufferInfo(bufInfo);

bulkBufferInfos.add(bufInfo);
return this;
}

Expand Down Expand Up @@ -147,7 +158,7 @@ public DescriptorUpdateBuilder imageStore(int binding, int dstArrayElement, List
if (refSet != null && refSet.getBindingAt(binding) == null) {
return this;
}
var imgInfo = VkDescriptorImageInfo.calloc(views.size(), stack);
var imgInfo = VkDescriptorImageInfo.calloc(views.size());
for (int i = 0; i < views.size(); i++) {
imgInfo.get(i)
.imageLayout(VK_IMAGE_LAYOUT_GENERAL)
Expand All @@ -160,6 +171,7 @@ public DescriptorUpdateBuilder imageStore(int binding, int dstArrayElement, List
.descriptorType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
.descriptorCount(views.size())
.pImageInfo(imgInfo);
bulkImageInfos.add(imgInfo);
return this;
}
public DescriptorUpdateBuilder imageStore(int binding, VImageView view) {
Expand Down Expand Up @@ -209,5 +221,13 @@ public void apply() {
updates.rewind();
vkUpdateDescriptorSets(ctx.device, updates, null);
stack.pop();
for (var bufInfo : bulkBufferInfos) {
bufInfo.free();
}
bulkBufferInfos.clear();
for (var imgInfo : bulkImageInfos) {
imgInfo.free();
}
bulkImageInfos.clear();
}
}

0 comments on commit ef5a5a1

Please sign in to comment.