Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkchan committed Apr 14, 2024
1 parent 6ce8073 commit f258ee4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 43 deletions.
34 changes: 17 additions & 17 deletions rasterize_gaussians.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx,
torch::Device device = xys.device();

auto t = rasterize_forward_tensor_cpu(imgWidth, imgHeight,
xys.to(torch::kCPU),
conics.to(torch::kCPU),
colors.to(torch::kCPU),
opacity.to(torch::kCPU),
background.to(torch::kCPU),
cov2d.to(torch::kCPU),
camDepths.to(torch::kCPU)
xys,
conics,
colors,
opacity,
background,
cov2d,
camDepths
);
// Final image
torch::Tensor outImg = std::get<0>(t).to(device);
Expand Down Expand Up @@ -200,17 +200,17 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr
torch::Device device = xys.device();

auto t = rasterize_backward_tensor_cpu(imgHeight, imgWidth,
xys.to(torch::kCPU),
conics.to(torch::kCPU),
colors.to(torch::kCPU),
opacity.to(torch::kCPU),
background.to(torch::kCPU),
cov2d.to(torch::kCPU),
camDepths.to(torch::kCPU),
finalTs.to(torch::kCPU),
xys,
conics,
colors,
opacity,
background,
cov2d,
camDepths,
finalTs,
px2gid,
v_outImg.to(torch::kCPU),
v_outAlpha.to(torch::kCPU));
v_outImg,
v_outAlpha);

// delete[] px2gid;

Expand Down
3 changes: 1 addition & 2 deletions vendor/gsplat-metal/gsplat_metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,7 @@ kernel void project_gaussians_forward_kernel(
write_packed_float2(xys, idx, center);
}

// TODO(achan): this is actually the nd_rasterize_forward_kernel
kernel void rasterize_forward_kernel(
kernel void nd_rasterize_forward_kernel(
constant uint3& tile_bounds,
constant uint3& img_size,
constant uint& channels,
Expand Down
28 changes: 4 additions & 24 deletions vendor/gsplat-metal/gsplat_metal.mm
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
dispatch_queue_t d_queue;

id<MTLComputePipelineState> nd_rasterize_backward_kernel_cpso;
id<MTLComputePipelineState> rasterize_forward_kernel_cpso;
id<MTLComputePipelineState> nd_rasterize_forward_kernel_cpso;
id<MTLComputePipelineState> rasterize_backward_kernel_cpso;
id<MTLComputePipelineState> project_gaussians_forward_kernel_cpso;
id<MTLComputePipelineState> project_gaussians_backward_kernel_cpso;
Expand Down Expand Up @@ -108,7 +108,7 @@ @implementation DummyClassForPathHack
}

GSPLAT_METAL_ADD_KERNEL(nd_rasterize_backward_kernel);
GSPLAT_METAL_ADD_KERNEL(rasterize_forward_kernel);
GSPLAT_METAL_ADD_KERNEL(nd_rasterize_forward_kernel);
GSPLAT_METAL_ADD_KERNEL(rasterize_backward_kernel);
GSPLAT_METAL_ADD_KERNEL(project_gaussians_forward_kernel);
GSPLAT_METAL_ADD_KERNEL(project_gaussians_backward_kernel);
Expand All @@ -123,26 +123,6 @@ @implementation DummyClassForPathHack
return ctx;
}

// TODO(achan): Where do I call this?
void free_gsplat_metal_context(MetalContext* ctx) {
[ctx->nd_rasterize_backward_kernel_cpso release];
[ctx->rasterize_forward_kernel_cpso release];
[ctx->rasterize_backward_kernel_cpso release];
[ctx->project_gaussians_forward_kernel_cpso release];
[ctx->project_gaussians_backward_kernel_cpso release];
[ctx->compute_sh_forward_kernel_cpso release];
[ctx->compute_sh_backward_kernel_cpso release];
[ctx->compute_cov2d_bounds_kernel_cpso release];
[ctx->map_gaussian_to_intersects_kernel_cpso release];
[ctx->get_tile_bin_edges_kernel_cpso release];

[ctx->queue release];
[ctx->device release];
// We do not need to release `d_queue` here as that is managed by torch.

free(ctx);
}

MetalContext* get_global_context() {
static MetalContext* ctx = NULL;
if (ctx == NULL) {
Expand Down Expand Up @@ -616,7 +596,7 @@ void dispatchKernel(MetalContext* ctx, id<MTLComputePipelineState> cpso, MTLSize
MetalContext* ctx = get_global_context();
MTLSize grid_size = MTLSizeMake(img_height, img_width, 1);
MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1);
dispatchKernel(ctx, ctx->rasterize_forward_kernel_cpso, grid_size, thread_group_size, {
dispatchKernel(ctx, ctx->nd_rasterize_forward_kernel_cpso, grid_size, thread_group_size, {
EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)),
EncodeArg::array(img_size_dim3, sizeof(img_size_dim3)),
EncodeArg::scalar(channels),
Expand Down Expand Up @@ -687,7 +667,7 @@ void dispatchKernel(MetalContext* ctx, id<MTLComputePipelineState> cpso, MTLSize
MetalContext* ctx = get_global_context();
MTLSize grid_size = MTLSizeMake(img_height, img_width, 1);
MTLSize thread_group_size = MTLSizeMake(block_size_dim2[0], block_size_dim2[1], 1);
dispatchKernel(ctx, ctx->rasterize_forward_kernel_cpso, grid_size, thread_group_size, {
dispatchKernel(ctx, ctx->nd_rasterize_forward_kernel_cpso, grid_size, thread_group_size, {
EncodeArg::array(tile_bounds_arr, sizeof(tile_bounds_arr)),
EncodeArg::array(img_size_dim3, sizeof(img_size_dim3)),
EncodeArg::scalar(channels),
Expand Down

0 comments on commit f258ee4

Please sign in to comment.