Skip to content

Commit

Permalink
Merge pull request #19 from arhik/splatting
Browse files Browse the repository at this point in the history
Splatting Code
  • Loading branch information
arhik authored Jan 4, 2024
2 parents e095e53 + 456ba4d commit e7b0528
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 47 deletions.
3 changes: 2 additions & 1 deletion examples/axisCombo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ run(TracyProfiler_jll.tracy(); wait=false)
sleep(5)

using WGPUgfx
using WGPUgfx: updateViewTransform!
using WGPUCore
using WGPUCanvas
using GLFW
Expand Down Expand Up @@ -79,7 +80,7 @@ mainApp = () -> begin
rot = RotXY(0.01, 0.02)
mat = MMatrix{4, 4, Float32}(I)
mat[1:3, 1:3] = rot
camera1.transform = camera1.transform*mat
updateViewTransform!(camera1, camera1.uniformData.viewMatrix*mat)
theta = time()
quad.uniformData = translate((
1.0*(sin(theta)),
Expand Down
3 changes: 2 additions & 1 deletion examples/axisExample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import Pkg
# ])

using WGPUgfx
using WGPUgfx: updateViewTransform!
using WGPUCore
using GLFW
using GLFW: WindowShouldClose, PollEvents, DestroyWindow
Expand Down Expand Up @@ -74,7 +75,7 @@ main = () -> begin
rot = RotXY(0.01, 0.02)
mat = MMatrix{4, 4, Float32}(I)
mat[1:3, 1:3] = rot
camera1.transform = camera1.transform*mat
updateViewTransform!(camera1, camera1.uniformData.viewMatrix*mat)
runApp(renderer)
PollEvents()
end
Expand Down
8 changes: 6 additions & 2 deletions examples/splat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ WGPUCore.SetLogLevel(WGPUCore.WGPULogLevel_Off)
scene = Scene()
renderer = getRenderer(scene)

pc = defaultGSplat(joinpath(pkgdir(WGPUgfx), "assets", "bonsai", "bonsai_30000.ply"))
# pc = defaultGSplat(joinpath("C:\\", "Users", "arhik", "Downloads", "bonsai_30000.compressed.ply"))
# pc = defaultGSplat(joinpath(pkgdir(WGPUgfx), "assets", "bonsai", "bonsai_30000.ply"))
pc = defaultGSplat(joinpath(ENV["HOME"], "Downloads", "train", "train_30000.ply"))
# pc = defaultGSplat(joinpath(ENV["HOME"], "Downloads", "bonsai", "bonsai_30000.ply"))

axis = defaultAxis()

addObject!(renderer, pc)
addObject!(renderer, axis)

attachEventSystem(renderer)

Expand Down
18 changes: 10 additions & 8 deletions src/events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ function attachScrollCallback(scene, camera::Camera)
WGPUCanvas.setScrollCallback(
scene.canvas,
(_, xoff, yoff) -> begin
# @info "MouseScroll" xoff, yoff
camera.scale = camera.scale .+ yoff.*maximum(mouseState.speed)
camera.eye += (camera.eye - camera.lookat)*yoff.*maximum(mouseState.speed)
camera.lookat += (camera.eye-camera.lookat)*yoff.*maximum(mouseState.speed)
end
)
end
Expand All @@ -69,21 +69,23 @@ function attachCursorPosCallback(scene, camera::Camera)
if all(((x, y) .- scene.canvas.size) .< 0)
if mouseState.leftClick
delta = -1.0.*(mouseState.prevPosition .- (y, x)).*mouseState.speed
# @info delta
rot = RotXY(delta...)
#camera.eye = rot*camera.eye
mat = MMatrix{4, 4, Float32}(I)
mat[1:3, 1:3] = rot
updateViewTransform!(camera, camera.uniformData.viewMatrix*mat)
mouseState.prevPosition = (y, x)
elseif mouseState.rightClick
delta = -1.0.*(mouseState.prevPosition .- (y, x)).*mouseState.speed
mat = MMatrix{4, 4, Float32}(I)
mat[1:3, 3] .= [delta..., 0]
updateViewTransform!(camera, camera.uniformData.viewMatrix*mat)
#camera.lookat += [delta..., 0]
#camera.eye += [delta..., 0]
#mat = MMatrix{4, 4, Float32}(I)
#mat[1:3, 3] .= [delta..., 0]
#updateViewTransform!(camera, camera.uniformData.viewMatrix*mat)
mouseState.prevPosition = (y, x)
elseif mouseState.middleClick
mat = MMatrix{4, 4, Float32}(I)
updateViewTransform!(camera, mat)
#mat = MMatrix{4, 4, Float32}(I)
#updateViewTransform!(camera, mat)
mouseState.prevPosition = (y, x)
else
mouseState.prevPosition = (y, x)
Expand Down
6 changes: 4 additions & 2 deletions src/renderable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,16 @@ function getShaderCode(mesh::Renderable, cameraId::Int; binding=0)
@vertex function vs_main(vertexIn::$vertexInputName)::$vertexOutputName
@var out::$vertexOutputName
out.pos = $(name).transform*vertexIn.pos
out.pos = camera.transform*out.pos
out.pos = camera.viewMatrix*out.pos
out.pos = camera.projMatrix*out.pos
out.vColor = vertexIn.vColor
if $isTexture
out.vTexCoords = vertexIn.vTexCoords
end
if $isLight
out.vNormal = $(name).transform*vertexIn.vNormal
out.vNormal = camera.transform*out.vNormal
out.vNormal = camera.viewMatrix*out.vNormal
out.vNormal = camera.projMatrix*out.vNormal
end
return out
end
Expand Down
24 changes: 13 additions & 11 deletions src/ui/fontsplat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ end


function getSplatData()
point = [1.0, 0.0, 0.0] .|> Float32
scale = [0.86, 0.15, 0.05] .|> Float32
color = [0.0, 0.0, 1.0] .|> Float32
quat = [pi/2, 0, 0, 0] .|> Float32
point = [0.4, 0.0, 0.0] .|> Float32
scale = [0.6, 0.4, 0.4] .|> Float32
color = [1.0, 0.0, 0.0] .|> Float32
quat = [pi/2, 0, 0] .|> Float32
scales = []
points = []
colors = []
Expand All @@ -56,10 +56,11 @@ function getSplatData()
push!(colors, circshift(color, (i-1,)))
push!(points, circshift(point, (i-1,)))
push!(scales, circshift(scale, (i-1,)))
push!(quats, circshift(quat, (i-1,)) |> getQuaternion)
end
colors = cat(colors..., dims=(2,))
scales = cat(scales..., dims=(2,))
quats = repeat(quat, inner=(1,size(colors, 2)))
quats = cat(quats..., dims=(2,))
points = cat(points..., dims=(2,))
splatData = GSplatAxisData(points, scales, colors, quats)
return splatData
Expand Down Expand Up @@ -235,19 +236,20 @@ function getShaderCode(gsplat::GSplatAxis, cameraId::Int; binding=0)
@let sigma = transpose(M)*M
@let pos = Vec4{Float32}(splatIn.pos, 1.0)
out.pos = $(name).transform*pos
out.pos = camera.viewMatrix*out.pos
@let tx = out.pos.x
@let ty = out.pos.y
@let tz = out.pos.z
out.pos = camera.transform*out.pos
# out.pos = out.pos/out.pos.w
@let f::Float32 = 4.0 #1.32 # 40.0*(tan(camera.fov/2.0))
out.pos = camera.projMatrix*out.pos
out.pos = out.pos/out.pos.w
@let f::Float32 = 2.0*(tan(camera.fov/2.0))

@let J = SMatrix{2, 3, Float32, 6}(
f/tz, 0.0, -f*tx/(tz*tz),
0.0, f/tz, -f*ty/(tz*tz),
)

@let Rcam = transToRotMat(camera.transform)
@let Rcam = transToRotMat(camera.viewMatrix)
@let W = transpose(Rcam)*J
@let covinter = transpose(sigma)*W
@let cov4D = transpose(W)*covinter
Expand All @@ -258,7 +260,7 @@ function getShaderCode(gsplat::GSplatAxis, cameraId::Int; binding=0)
)

cov2D[0] = cov2D[0] + 0.3
cov2D[3] = cov2D[0] + 0.3
cov2D[3] = cov2D[3] + 0.3

@let a = cov2D[0]
@let b = cov2D[1]
Expand All @@ -280,7 +282,7 @@ function getShaderCode(gsplat::GSplatAxis, cameraId::Int; binding=0)
#@let result = SH_C0 * splatIn.sh[0] + 0.5;
out.cov2d = cov2D
out.color = Vec4{Float32}(splatIn.color, 1.0)
out.opacity = 0.3
out.opacity = 1.0
return out
end

Expand Down
29 changes: 7 additions & 22 deletions src/ui/gaussiansplat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ function getShaderCode(gsplat::GSplat, cameraId::Int; binding=0)
@let ty = out.pos.y
@let tz = out.pos.z

@let f::Float32 = 2.2*(tan(camera.fov/2.0))
@let f::Float32 = 2.0*(tan(camera.fov/2.0))
#@let tx = pos.x
#@let ty = pos.y
#@let tz = pos.z
Expand Down Expand Up @@ -305,12 +305,10 @@ function getShaderCode(gsplat::GSplat, cameraId::Int; binding=0)

@let intensity::Float32 = 0.5*dot(invCov2d*delta, delta)


@escif if (intensity < 0.0)
@esc discard
end


@let alpha = min(0.99, opacity*exp(-intensity))

@let color::Vec4{Float32} = Vec4{Float32}(
Expand All @@ -332,20 +330,7 @@ function prepareObject(gpuDevice, gsplat::GSplat)
["Uniform", "CopyDst", "CopySrc"]
)

splatData = readPlyFile(gsplat.filepath); # TODO remove

# TODO make mutable structs as default vector operation
# buffer = zeros(UInt8, sizeof(WGSLTypes.GSplatIn)*size(splatData.points, 1))

# storageArray = reinterpret(WGSLTypes.GSplatIn, buffer)

# for (idx, splat) in enumerate(storageArray)
# splat.pos = splatData.points[idx, :]
# #splat.scale = splatData.scale[idx, :]
# #splat.opacity = splatData.opacity[idx, :]
# #splat.quaternions = splatData.quaternion[idx, :]
# #splat.sh = splatData.sphericalHarmonics[idx, :]
# end
splatData = readPlyFile(gsplat.filepath);

points = splatData.points .|> Float32;
scale = splatData.scale .|> Float32;
Expand Down Expand Up @@ -400,12 +385,12 @@ function getBindingLayouts(gsplat::GSplat; binding=0)
:visibility => ["Vertex", "Fragment"],
:type => "Uniform"
],
WGPUCore.WGPUBufferEntry => [ # TODO hardcoded
WGPUCore.WGPUBufferEntry => [
:binding => binding + 1,
:visibility=> ["Vertex", "Fragment"],
:type => "ReadOnlyStorage" # TODO VERTEXWRITABLESTORAGE feature needs to be enabled if its not read-only
],
WGPUCore.WGPUBufferEntry => [ # TODO hardcoded
WGPUCore.WGPUBufferEntry => [
:binding => binding + 2,
:visibility => ["Vertex", "Fragment"],
:type => "ReadOnlyStorage"
Expand Down Expand Up @@ -512,7 +497,7 @@ function getRenderPipelineOptions(renderer, splat::GSplat)
camIdx = scene.cameraId
renderpipelineOptions = [
WGPUCore.GPUVertexState => [
:_module => splat.cshaders[camIdx].internal[], # SET THIS (AUTOMATICALLY)
:_module => splat.cshaders[camIdx].internal[], # SET THIS (AUTOMATICALLY)
:entryPoint => "vs_main", # SET THIS (FIXED FOR NOW)
:buffers => [
getVertexBufferLayout(splat)
Expand All @@ -535,11 +520,11 @@ function getRenderPipelineOptions(renderer, splat::GSplat)
:alphaToCoverageEnabled=>false,
],
WGPUCore.GPUFragmentState => [
:_module => splat.cshaders[camIdx].internal[], # SET THIS
:_module => splat.cshaders[camIdx].internal[], # SET THIS
:entryPoint => "fs_main", # SET THIS (FIXED FOR NOW)
:targets => [
WGPUCore.GPUColorTargetState => [
:format => renderer.renderTextureFormat, # SET THIS
:format => renderer.renderTextureFormat, # SET THIS
:color => [
:srcFactor => "One",
:dstFactor => "OneMinusSrcAlpha",
Expand Down

0 comments on commit e7b0528

Please sign in to comment.