3
3
// For conditions of distribution and use, see copyright notice in Config.hpp
4
4
5
5
#include < NZSL/SpirV/SpirvConstantCache.hpp>
6
+ #include < NazaraUtils/Assert.hpp>
6
7
#include < NZSL/Ast/Nodes.hpp>
7
8
#include < NZSL/Math/FieldOffsets.hpp>
8
9
#include < NZSL/SpirvWriter.hpp>
@@ -113,6 +114,9 @@ namespace nzsl
113
114
if (lhs.decorations != rhs.decorations )
114
115
return false ;
115
116
117
+ if (lhs.layout != rhs.layout )
118
+ return false ;
119
+
116
120
if (!Compare (lhs.members , rhs.members ))
117
121
return false ;
118
122
@@ -443,7 +447,7 @@ namespace nzsl
443
447
StructCallback structCallback;
444
448
std::uint32_t & nextResultId;
445
449
SpirvWriter& writer;
446
- bool isInBlockStruct = false ;
450
+ std::optional<StructLayout> currentBlockLayout ;
447
451
};
448
452
449
453
SpirvConstantCache::SpirvConstantCache (SpirvWriter& writer, std::uint32_t & resultId) :
@@ -539,7 +543,7 @@ namespace nzsl
539
543
540
544
FieldOffsets SpirvConstantCache::BuildFieldOffsets (const Structure& structData) const
541
545
{
542
- FieldOffsets structOffsets (StructLayout::Std140 );
546
+ FieldOffsets structOffsets (structData. layout );
543
547
544
548
for (const Structure::Member& member : structData.members )
545
549
{
@@ -667,32 +671,27 @@ namespace nzsl
667
671
668
672
auto SpirvConstantCache::BuildPointerType (const TypePtr& type, SpirvStorageClass storageClass) const -> TypePtr
669
673
{
670
- bool wasInblockStruct = m_internal->isInBlockStruct ;
671
- if (storageClass == SpirvStorageClass::Uniform || storageClass == SpirvStorageClass::StorageBuffer)
672
- m_internal->isInBlockStruct = true ;
673
-
674
- auto typePtr = std::make_shared<Type>(Pointer{
674
+ return std::make_shared<Type>(Pointer{
675
675
type,
676
676
storageClass
677
677
});
678
-
679
- m_internal->isInBlockStruct = wasInblockStruct;
680
-
681
- return typePtr;
682
678
}
683
679
684
680
auto SpirvConstantCache::BuildPointerType (const Ast::PrimitiveType& type, SpirvStorageClass storageClass) const -> TypePtr
685
681
{
686
- bool wasInblockStruct = m_internal->isInBlockStruct ;
682
+ std::optional<StructLayout> prevBlockLayout = m_internal->currentBlockLayout ;
687
683
if (storageClass == SpirvStorageClass::Uniform || storageClass == SpirvStorageClass::StorageBuffer)
688
- m_internal->isInBlockStruct = true ;
684
+ {
685
+ if (!prevBlockLayout)
686
+ m_internal->currentBlockLayout = (storageClass == SpirvStorageClass::Uniform) ? StructLayout::Std140 : StructLayout::Std430; // FIXME: When does that happen?
687
+ }
689
688
690
689
auto typePtr = std::make_shared<Type>(Pointer{
691
690
BuildType (type),
692
691
storageClass
693
692
});
694
693
695
- m_internal->isInBlockStruct = wasInblockStruct ;
694
+ m_internal->currentBlockLayout = prevBlockLayout ;
696
695
697
696
return typePtr;
698
697
}
@@ -711,9 +710,9 @@ namespace nzsl
711
710
712
711
// ArrayStride
713
712
std::optional<std::uint32_t > arrayStride;
714
- if (m_internal->isInBlockStruct )
713
+ if (m_internal->currentBlockLayout )
715
714
{
716
- FieldOffsets fieldOffset (StructLayout::Std140 );
715
+ FieldOffsets fieldOffset (*m_internal-> currentBlockLayout );
717
716
RegisterArrayField (fieldOffset, builtContainedType->type , 1 );
718
717
719
718
arrayStride = Nz::SafeCast<std::uint32_t >(fieldOffset.GetAlignedSize ());
@@ -736,9 +735,9 @@ namespace nzsl
736
735
737
736
// ArrayStride
738
737
std::optional<std::uint32_t > arrayStride;
739
- if (m_internal->isInBlockStruct )
738
+ if (m_internal->currentBlockLayout )
740
739
{
741
- FieldOffsets fieldOffset (StructLayout::Std140 );
740
+ FieldOffsets fieldOffset (*m_internal-> currentBlockLayout );
742
741
RegisterArrayField (fieldOffset, builtContainedType->type , 1 );
743
742
744
743
arrayStride = Nz::SafeCast<std::uint32_t >(fieldOffset.GetAlignedSize ());
@@ -761,16 +760,19 @@ namespace nzsl
761
760
762
761
auto SpirvConstantCache::BuildType (const Ast::ExpressionType& type, SpirvStorageClass storageClass) const -> TypePtr
763
762
{
764
- bool wasInblockStruct = m_internal->isInBlockStruct ;
763
+ std::optional<StructLayout> prevBlockLayout = m_internal->currentBlockLayout ;
765
764
if (storageClass == SpirvStorageClass::Uniform || storageClass == SpirvStorageClass::StorageBuffer)
766
- m_internal->isInBlockStruct = true ;
765
+ {
766
+ if (!prevBlockLayout)
767
+ m_internal->currentBlockLayout = (storageClass == SpirvStorageClass::Uniform) ? StructLayout::Std140 : StructLayout::Std430; // FIXME: When does that happen?
768
+ }
767
769
768
770
auto typePtr = std::visit ([&](auto && arg) -> TypePtr
769
771
{
770
772
return BuildType (arg);
771
773
}, type);
772
774
773
- m_internal->isInBlockStruct = wasInblockStruct ;
775
+ m_internal->currentBlockLayout = prevBlockLayout ;
774
776
775
777
return typePtr;
776
778
}
@@ -871,11 +873,24 @@ namespace nzsl
871
873
sType .name = structDesc.name ;
872
874
sType .decorations = std::move (decorations);
873
875
874
- bool wasInBlock = m_internal-> isInBlockStruct ;
875
- if (!wasInBlock )
876
+ sType . layout = StructLayout::Std140 ;
877
+ if (structDesc. layout . HasValue () )
876
878
{
877
- m_internal->isInBlockStruct = std::find (sType .decorations .begin (), sType .decorations .end (), SpirvDecoration::Block) != sType .decorations .end ()
878
- || std::find (sType .decorations .begin (), sType .decorations .end (), SpirvDecoration::BufferBlock) != sType .decorations .end ();
879
+ switch (structDesc.layout .GetResultingValue ())
880
+ {
881
+ case Ast::MemoryLayout::Std140: sType .layout = StructLayout::Std140; break ;
882
+ case Ast::MemoryLayout::Std430: sType .layout = StructLayout::Std430; break ;
883
+ }
884
+ }
885
+
886
+ std::optional<StructLayout> prevBlockLayout = m_internal->currentBlockLayout ;
887
+ if (!prevBlockLayout)
888
+ {
889
+ bool isInBlock = std::find (sType .decorations .begin (), sType .decorations .end (), SpirvDecoration::Block) != sType .decorations .end ()
890
+ || std::find (sType .decorations .begin (), sType .decorations .end (), SpirvDecoration::BufferBlock) != sType .decorations .end ();
891
+
892
+ if (isInBlock)
893
+ m_internal->currentBlockLayout = sType .layout ;
879
894
}
880
895
881
896
for (const auto & member : structDesc.members )
@@ -888,7 +903,7 @@ namespace nzsl
888
903
sMembers .type = BuildType (member.type .GetResultingValue ());
889
904
}
890
905
891
- m_internal->isInBlockStruct = wasInBlock ;
906
+ m_internal->currentBlockLayout = prevBlockLayout ;
892
907
893
908
return std::make_shared<Type>(std::move (sType ));
894
909
}
@@ -1128,7 +1143,7 @@ namespace nzsl
1128
1143
std::size_t SpirvConstantCache::RegisterArrayField (FieldOffsets& fieldOffsets, const Vector& type, std::size_t arrayLength) const
1129
1144
{
1130
1145
assert (type.componentCount > 0 && type.componentCount <= 4 );
1131
- return fieldOffsets.AddFieldArray (static_cast <StructFieldType>(Nz::UnderlyingCast (SpirvTypeToStructFieldType (type.componentType ->type )) + type.componentCount ), arrayLength);
1146
+ return fieldOffsets.AddFieldArray (static_cast <StructFieldType>(Nz::UnderlyingCast (SpirvTypeToStructFieldType (type.componentType ->type )) + type.componentCount - 1 ), arrayLength);
1132
1147
}
1133
1148
1134
1149
std::size_t SpirvConstantCache::RegisterArrayField (FieldOffsets& /* fieldOffsets*/ , const Void& /* type*/ , std::size_t /* arrayLength*/ ) const
@@ -1209,6 +1224,37 @@ namespace nzsl
1209
1224
}
1210
1225
}
1211
1226
1227
+ auto SpirvConstantCache::GetIndexedType (const Type& typeHolder, std::int32_t index) -> TypePtr
1228
+ {
1229
+ if (std::holds_alternative<SpirvConstantCache::Structure>(typeHolder.type ))
1230
+ {
1231
+ NazaraAssertMsg (index >= 0 , " struct access must have a known index" );
1232
+
1233
+ const auto & structData = std::get<SpirvConstantCache::Structure>(typeHolder.type );
1234
+ NazaraAssert (std::uint32_t (index) < structData.members .size ());
1235
+ return structData.members [index].type ;
1236
+ }
1237
+ else if (std::holds_alternative<SpirvConstantCache::Array>(typeHolder.type ))
1238
+ {
1239
+ const auto & arrayData = std::get<SpirvConstantCache::Array>(typeHolder.type );
1240
+ return arrayData.elementType ;
1241
+ }
1242
+ else if (std::holds_alternative<SpirvConstantCache::Matrix>(typeHolder.type ))
1243
+ {
1244
+ const auto & matrixData = std::get<SpirvConstantCache::Matrix>(typeHolder.type );
1245
+ NazaraAssert (index < 0 || std::uint32_t (index) < matrixData.columnCount );
1246
+ return matrixData.columnType ;
1247
+ }
1248
+ else if (std::holds_alternative<SpirvConstantCache::Vector>(typeHolder.type ))
1249
+ {
1250
+ const auto & vectorData = std::get<SpirvConstantCache::Vector>(typeHolder.type );
1251
+ NazaraAssert (index < 0 || std::uint32_t (index) < vectorData.componentCount );
1252
+ return vectorData.componentType ;
1253
+ }
1254
+ else
1255
+ throw std::runtime_error (" an internal error occurred" );
1256
+ }
1257
+
1212
1258
void SpirvConstantCache::Write (const AnyConstant& constant, std::uint32_t resultId, SpirvSection& constants)
1213
1259
{
1214
1260
std::visit ([&](auto && arg)
0 commit comments