Skip to content

Commit ff3c31b

Browse files
committed
Backend/SPIR-V: Fix handling of std430 alignment and size
1 parent 8bb3e0b commit ff3c31b

13 files changed

+178
-61
lines changed

include/NZSL/SpirV/SpirvAstVisitor.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <NZSL/Ast/ExpressionVisitorExcept.hpp>
1313
#include <NZSL/Ast/StatementVisitorExcept.hpp>
1414
#include <NZSL/SpirV/SpirvBlock.hpp>
15+
#include <NZSL/SpirV/SpirvConstantCache.hpp>
1516
#include <functional>
1617
#include <optional>
1718
#include <unordered_map>
@@ -107,6 +108,7 @@ namespace nzsl
107108
{
108109
std::uint32_t pointerId;
109110
std::uint32_t typeId;
111+
SpirvConstantCache::TypePtr type;
110112
};
111113

112114
ShaderStageType stageType;
@@ -130,12 +132,14 @@ namespace nzsl
130132
{
131133
std::uint32_t pointerTypeId;
132134
std::uint32_t typeId;
135+
SpirvConstantCache::TypePtr typePtr;
133136
};
134137

135138
struct Variable
136139
{
137140
std::uint32_t typeId;
138141
std::uint32_t varId;
142+
SpirvConstantCache::TypePtr typePtr;
139143
SourceLocation sourceLocation;
140144
};
141145

@@ -157,7 +161,7 @@ namespace nzsl
157161
void PushResultId(std::uint32_t value);
158162
std::uint32_t PopResultId();
159163

160-
inline void RegisterVariable(std::size_t varIndex, std::uint32_t typeId, std::uint32_t pointerId, SpirvStorageClass storageClass);
164+
inline void RegisterVariable(std::size_t varIndex, SpirvConstantCache::TypePtr typePtr, std::uint32_t typeId, std::uint32_t pointerId, SpirvStorageClass storageClass);
161165

162166
void ResetSourceLocation();
163167

include/NZSL/SpirV/SpirvAstVisitor.inl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ namespace nzsl
1414
{
1515
}
1616

17-
inline void SpirvAstVisitor::RegisterVariable(std::size_t varIndex, std::uint32_t typeId, std::uint32_t pointerId, SpirvStorageClass storageClass)
17+
inline void SpirvAstVisitor::RegisterVariable(std::size_t varIndex, SpirvConstantCache::TypePtr typePtr, std::uint32_t typeId, std::uint32_t pointerId, SpirvStorageClass storageClass)
1818
{
1919
assert(m_variables.find(varIndex) == m_variables.end());
2020
m_variables[varIndex] = SpirvVariable{
2121
pointerId,
2222
typeId,
23+
std::move(typePtr),
2324
storageClass
2425
};
2526
}

include/NZSL/SpirV/SpirvConstantCache.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ namespace nzsl
118118
std::string name;
119119
std::vector<Member> members;
120120
std::vector<SpirvDecoration> decorations;
121+
StructLayout layout;
121122
};
122123

123124
using AnyType = std::variant<Array, Bool, Float, Function, Image, Integer, Matrix, Pointer, SampledImage, Structure, Vector, Void>;
@@ -232,8 +233,11 @@ namespace nzsl
232233
SpirvConstantCache& operator=(const SpirvConstantCache& cache) = delete;
233234
SpirvConstantCache& operator=(SpirvConstantCache&& cache) noexcept = default;
234235

236+
static TypePtr GetIndexedType(const Type& typeHolder, std::int32_t index = -1);
237+
235238
private:
236239
struct DepRegisterer;
240+
struct LayoutVisitor;
237241
struct Eq;
238242
struct Internal;
239243
template<typename T, typename Enable = void> struct TypeBuilder;

include/NZSL/SpirV/SpirvExpressionLoad.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <NZSL/Config.hpp>
1111
#include <NZSL/Ast/ExpressionVisitorExcept.hpp>
12+
#include <NZSL/SpirV/SpirvConstantCache.hpp>
1213
#include <NZSL/SpirV/SpirvData.hpp>
1314
#include <vector>
1415

@@ -48,13 +49,15 @@ namespace nzsl
4849
{
4950
std::vector<std::uint32_t> indicesId;
5051
const Ast::ExpressionType* exprType;
52+
SpirvConstantCache::TypePtr pointedTypePtr;
5153
SpirvStorageClass storage;
5254
std::uint32_t pointerId;
5355
std::uint32_t pointedTypeId;
5456
};
5557

5658
struct Pointer
5759
{
60+
SpirvConstantCache::TypePtr pointedTypePtr;
5861
SpirvStorageClass storage;
5962
std::uint32_t pointerId;
6063
std::uint32_t pointedTypeId;

include/NZSL/SpirV/SpirvExpressionStore.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <NZSL/Config.hpp>
1111
#include <NZSL/Ast/Enums.hpp>
1212
#include <NZSL/Ast/ExpressionVisitorExcept.hpp>
13+
#include <NZSL/SpirV/SpirvConstantCache.hpp>
1314
#include <NZSL/SpirV/SpirvData.hpp>
1415

1516
namespace nzsl
@@ -39,6 +40,7 @@ namespace nzsl
3940
private:
4041
struct Pointer
4142
{
43+
SpirvConstantCache::TypePtr pointedTypePtr;
4244
SpirvStorageClass storage;
4345
std::uint32_t pointerId;
4446
};

include/NZSL/SpirV/SpirvVariable.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99

1010
#include <NZSL/Config.hpp>
1111
#include <NZSL/SpirV/SpirvData.hpp>
12+
#include <NZSL/SpirV/SpirvConstantCache.hpp>
1213

1314
namespace nzsl
1415
{
1516
struct SpirvVariable
1617
{
1718
std::uint32_t pointerId;
1819
std::uint32_t typeId;
20+
SpirvConstantCache::TypePtr typePtr;
1921
SpirvStorageClass storageClass;
2022
};
2123
}

include/NZSL/SpirvWriter.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,25 @@ namespace nzsl
6464

6565
void AppendHeader();
6666

67+
SpirvConstantCache::TypePtr BuildType(const Ast::ExpressionType& type);
6768
SpirvConstantCache::TypePtr BuildFunctionType(const Ast::DeclareFunctionStatement& functionNode);
6869

6970
std::uint32_t GetArrayConstantId(const Ast::ConstantArrayValue& values) const;
7071
std::uint32_t GetSingleConstantId(const Ast::ConstantSingleValue& value) const;
7172
std::uint32_t GetExtendedInstructionSet(const std::string& instructionSetName) const;
7273
const SpirvVariable& GetExtVar(std::size_t varIndex) const;
7374
std::uint32_t GetFunctionTypeId(const Ast::DeclareFunctionStatement& functionNode);
75+
std::uint32_t GetPointerTypeId(const SpirvConstantCache::TypePtr& typePtr, SpirvStorageClass storageClass) const;
7476
std::uint32_t GetPointerTypeId(const Ast::ExpressionType& type, SpirvStorageClass storageClass) const;
7577
std::uint32_t GetSourceFileId(const std::shared_ptr<const std::string>& filepathPtr);
78+
std::uint32_t GetTypeId(const SpirvConstantCache::Type& type) const;
7679
std::uint32_t GetTypeId(const Ast::ExpressionType& type) const;
7780

7881
bool HasDebugInfo(DebugLevel debugInfo) const;
7982

8083
std::uint32_t RegisterArrayConstant(const Ast::ConstantArrayValue& value);
8184
std::uint32_t RegisterFunctionType(const Ast::DeclareFunctionStatement& functionNode);
85+
std::uint32_t RegisterPointerType(const SpirvConstantCache::TypePtr& typePtr, SpirvStorageClass storageClass);
8286
std::uint32_t RegisterPointerType(Ast::ExpressionType type, SpirvStorageClass storageClass);
8387
std::uint32_t RegisterSingleConstant(const Ast::ConstantSingleValue& value);
8488
std::uint32_t RegisterType(Ast::ExpressionType type);

src/NZSL/SpirV/SpirvAstVisitor.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ namespace nzsl
751751
std::uint32_t paramResultId = m_writer.AllocateResultId();
752752
m_instructions.Append(SpirvOp::OpFunctionParameter, m_currentFunc->parameters[i].pointerTypeId, paramResultId);
753753

754-
RegisterVariable(*node.parameters[i].varIndex, m_currentFunc->parameters[i].typeId, paramResultId, SpirvStorageClass::Function);
754+
RegisterVariable(*node.parameters[i].varIndex, m_currentFunc->parameters[i].typePtr, m_currentFunc->parameters[i].typeId, paramResultId, SpirvStorageClass::Function);
755755
}
756756
}
757757

@@ -790,7 +790,7 @@ namespace nzsl
790790
m_currentBlock->Append(SpirvOp::OpCopyMemory, resultId, input.varId);
791791
}
792792

793-
RegisterVariable(*node.parameters.front().varIndex, inputStruct.typeId, paramId, SpirvStorageClass::Function);
793+
RegisterVariable(*node.parameters.front().varIndex, inputStruct.type, inputStruct.typeId, paramId, SpirvStorageClass::Function);
794794
}
795795
}
796796

@@ -818,13 +818,14 @@ namespace nzsl
818818

819819
void SpirvAstVisitor::Visit(Ast::DeclareVariableStatement& node)
820820
{
821-
std::uint32_t typeId = m_writer.GetTypeId(node.varType.GetResultingValue());
821+
SpirvConstantCache::TypePtr typePtr = m_writer.BuildType(node.varType.GetResultingValue());
822+
std::uint32_t typeId = m_writer.GetTypeId(*typePtr);
822823

823824
assert(node.varIndex);
824825
auto varIt = m_currentFunc->varIndexToVarId.find(*node.varIndex);
825826
std::uint32_t varId = m_currentFunc->variables[varIt->second].varId;
826827

827-
RegisterVariable(*node.varIndex, typeId, varId, SpirvStorageClass::Function);
828+
RegisterVariable(*node.varIndex, std::move(typePtr), typeId, varId, SpirvStorageClass::Function);
828829

829830
if (node.initialExpression)
830831
{

src/NZSL/SpirV/SpirvConstantCache.cpp

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// For conditions of distribution and use, see copyright notice in Config.hpp
44

55
#include <NZSL/SpirV/SpirvConstantCache.hpp>
6+
#include <NazaraUtils/Assert.hpp>
67
#include <NZSL/Ast/Nodes.hpp>
78
#include <NZSL/Math/FieldOffsets.hpp>
89
#include <NZSL/SpirvWriter.hpp>
@@ -113,6 +114,9 @@ namespace nzsl
113114
if (lhs.decorations != rhs.decorations)
114115
return false;
115116

117+
if (lhs.layout != rhs.layout)
118+
return false;
119+
116120
if (!Compare(lhs.members, rhs.members))
117121
return false;
118122

@@ -443,7 +447,7 @@ namespace nzsl
443447
StructCallback structCallback;
444448
std::uint32_t& nextResultId;
445449
SpirvWriter& writer;
446-
bool isInBlockStruct = false;
450+
std::optional<StructLayout> currentBlockLayout;
447451
};
448452

449453
SpirvConstantCache::SpirvConstantCache(SpirvWriter& writer, std::uint32_t& resultId) :
@@ -539,7 +543,7 @@ namespace nzsl
539543

540544
FieldOffsets SpirvConstantCache::BuildFieldOffsets(const Structure& structData) const
541545
{
542-
FieldOffsets structOffsets(StructLayout::Std140);
546+
FieldOffsets structOffsets(structData.layout);
543547

544548
for (const Structure::Member& member : structData.members)
545549
{
@@ -667,32 +671,27 @@ namespace nzsl
667671

668672
auto SpirvConstantCache::BuildPointerType(const TypePtr& type, SpirvStorageClass storageClass) const -> TypePtr
669673
{
670-
bool wasInblockStruct = m_internal->isInBlockStruct;
671-
if (storageClass == SpirvStorageClass::Uniform || storageClass == SpirvStorageClass::StorageBuffer)
672-
m_internal->isInBlockStruct = true;
673-
674-
auto typePtr = std::make_shared<Type>(Pointer{
674+
return std::make_shared<Type>(Pointer{
675675
type,
676676
storageClass
677677
});
678-
679-
m_internal->isInBlockStruct = wasInblockStruct;
680-
681-
return typePtr;
682678
}
683679

684680
auto SpirvConstantCache::BuildPointerType(const Ast::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
685681
{
686-
bool wasInblockStruct = m_internal->isInBlockStruct;
682+
std::optional<StructLayout> prevBlockLayout = m_internal->currentBlockLayout;
687683
if (storageClass == SpirvStorageClass::Uniform || storageClass == SpirvStorageClass::StorageBuffer)
688-
m_internal->isInBlockStruct = true;
684+
{
685+
if (!prevBlockLayout)
686+
m_internal->currentBlockLayout = (storageClass == SpirvStorageClass::Uniform) ? StructLayout::Std140 : StructLayout::Std430; // FIXME: When does that happen?
687+
}
689688

690689
auto typePtr = std::make_shared<Type>(Pointer{
691690
BuildType(type),
692691
storageClass
693692
});
694693

695-
m_internal->isInBlockStruct = wasInblockStruct;
694+
m_internal->currentBlockLayout = prevBlockLayout;
696695

697696
return typePtr;
698697
}
@@ -711,9 +710,9 @@ namespace nzsl
711710

712711
// ArrayStride
713712
std::optional<std::uint32_t> arrayStride;
714-
if (m_internal->isInBlockStruct)
713+
if (m_internal->currentBlockLayout)
715714
{
716-
FieldOffsets fieldOffset(StructLayout::Std140);
715+
FieldOffsets fieldOffset(*m_internal->currentBlockLayout);
717716
RegisterArrayField(fieldOffset, builtContainedType->type, 1);
718717

719718
arrayStride = Nz::SafeCast<std::uint32_t>(fieldOffset.GetAlignedSize());
@@ -736,9 +735,9 @@ namespace nzsl
736735

737736
// ArrayStride
738737
std::optional<std::uint32_t> arrayStride;
739-
if (m_internal->isInBlockStruct)
738+
if (m_internal->currentBlockLayout)
740739
{
741-
FieldOffsets fieldOffset(StructLayout::Std140);
740+
FieldOffsets fieldOffset(*m_internal->currentBlockLayout);
742741
RegisterArrayField(fieldOffset, builtContainedType->type, 1);
743742

744743
arrayStride = Nz::SafeCast<std::uint32_t>(fieldOffset.GetAlignedSize());
@@ -761,16 +760,19 @@ namespace nzsl
761760

762761
auto SpirvConstantCache::BuildType(const Ast::ExpressionType& type, SpirvStorageClass storageClass) const -> TypePtr
763762
{
764-
bool wasInblockStruct = m_internal->isInBlockStruct;
763+
std::optional<StructLayout> prevBlockLayout = m_internal->currentBlockLayout;
765764
if (storageClass == SpirvStorageClass::Uniform || storageClass == SpirvStorageClass::StorageBuffer)
766-
m_internal->isInBlockStruct = true;
765+
{
766+
if (!prevBlockLayout)
767+
m_internal->currentBlockLayout = (storageClass == SpirvStorageClass::Uniform) ? StructLayout::Std140 : StructLayout::Std430; // FIXME: When does that happen?
768+
}
767769

768770
auto typePtr = std::visit([&](auto&& arg) -> TypePtr
769771
{
770772
return BuildType(arg);
771773
}, type);
772774

773-
m_internal->isInBlockStruct = wasInblockStruct;
775+
m_internal->currentBlockLayout = prevBlockLayout;
774776

775777
return typePtr;
776778
}
@@ -871,11 +873,24 @@ namespace nzsl
871873
sType.name = structDesc.name;
872874
sType.decorations = std::move(decorations);
873875

874-
bool wasInBlock = m_internal->isInBlockStruct;
875-
if (!wasInBlock)
876+
sType.layout = StructLayout::Std140;
877+
if (structDesc.layout.HasValue())
876878
{
877-
m_internal->isInBlockStruct = std::find(sType.decorations.begin(), sType.decorations.end(), SpirvDecoration::Block) != sType.decorations.end()
878-
|| std::find(sType.decorations.begin(), sType.decorations.end(), SpirvDecoration::BufferBlock) != sType.decorations.end();
879+
switch (structDesc.layout.GetResultingValue())
880+
{
881+
case Ast::MemoryLayout::Std140: sType.layout = StructLayout::Std140; break;
882+
case Ast::MemoryLayout::Std430: sType.layout = StructLayout::Std430; break;
883+
}
884+
}
885+
886+
std::optional<StructLayout> prevBlockLayout = m_internal->currentBlockLayout;
887+
if (!prevBlockLayout)
888+
{
889+
bool isInBlock = std::find(sType.decorations.begin(), sType.decorations.end(), SpirvDecoration::Block) != sType.decorations.end()
890+
|| std::find(sType.decorations.begin(), sType.decorations.end(), SpirvDecoration::BufferBlock) != sType.decorations.end();
891+
892+
if (isInBlock)
893+
m_internal->currentBlockLayout = sType.layout;
879894
}
880895

881896
for (const auto& member : structDesc.members)
@@ -888,7 +903,7 @@ namespace nzsl
888903
sMembers.type = BuildType(member.type.GetResultingValue());
889904
}
890905

891-
m_internal->isInBlockStruct = wasInBlock;
906+
m_internal->currentBlockLayout = prevBlockLayout;
892907

893908
return std::make_shared<Type>(std::move(sType));
894909
}
@@ -1128,7 +1143,7 @@ namespace nzsl
11281143
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& fieldOffsets, const Vector& type, std::size_t arrayLength) const
11291144
{
11301145
assert(type.componentCount > 0 && type.componentCount <= 4);
1131-
return fieldOffsets.AddFieldArray(static_cast<StructFieldType>(Nz::UnderlyingCast(SpirvTypeToStructFieldType(type.componentType->type)) + type.componentCount), arrayLength);
1146+
return fieldOffsets.AddFieldArray(static_cast<StructFieldType>(Nz::UnderlyingCast(SpirvTypeToStructFieldType(type.componentType->type)) + type.componentCount - 1), arrayLength);
11321147
}
11331148

11341149
std::size_t SpirvConstantCache::RegisterArrayField(FieldOffsets& /*fieldOffsets*/, const Void& /*type*/, std::size_t /*arrayLength*/) const
@@ -1209,6 +1224,37 @@ namespace nzsl
12091224
}
12101225
}
12111226

1227+
auto SpirvConstantCache::GetIndexedType(const Type& typeHolder, std::int32_t index) -> TypePtr
1228+
{
1229+
if (std::holds_alternative<SpirvConstantCache::Structure>(typeHolder.type))
1230+
{
1231+
NazaraAssertMsg(index >= 0, "struct access must have a known index");
1232+
1233+
const auto& structData = std::get<SpirvConstantCache::Structure>(typeHolder.type);
1234+
NazaraAssert(std::uint32_t(index) < structData.members.size());
1235+
return structData.members[index].type;
1236+
}
1237+
else if (std::holds_alternative<SpirvConstantCache::Array>(typeHolder.type))
1238+
{
1239+
const auto& arrayData = std::get<SpirvConstantCache::Array>(typeHolder.type);
1240+
return arrayData.elementType;
1241+
}
1242+
else if (std::holds_alternative<SpirvConstantCache::Matrix>(typeHolder.type))
1243+
{
1244+
const auto& matrixData = std::get<SpirvConstantCache::Matrix>(typeHolder.type);
1245+
NazaraAssert(index < 0 || std::uint32_t(index) < matrixData.columnCount);
1246+
return matrixData.columnType;
1247+
}
1248+
else if (std::holds_alternative<SpirvConstantCache::Vector>(typeHolder.type))
1249+
{
1250+
const auto& vectorData = std::get<SpirvConstantCache::Vector>(typeHolder.type);
1251+
NazaraAssert(index < 0 || std::uint32_t(index) < vectorData.componentCount);
1252+
return vectorData.componentType;
1253+
}
1254+
else
1255+
throw std::runtime_error("an internal error occurred");
1256+
}
1257+
12121258
void SpirvConstantCache::Write(const AnyConstant& constant, std::uint32_t resultId, SpirvSection& constants)
12131259
{
12141260
std::visit([&](auto&& arg)

0 commit comments

Comments
 (0)