From e2c652d4ee289c055386fef3579ca34376e7e02d Mon Sep 17 00:00:00 2001 From: SirLynix Date: Mon, 30 Dec 2024 17:53:42 +0100 Subject: [PATCH] Add support for whole module import --- include/NZSL/Ast/AstSerializer.hpp | 2 + include/NZSL/Ast/Cloner.hpp | 2 + include/NZSL/Ast/Compare.hpp | 2 + include/NZSL/Ast/Compare.inl | 21 +- include/NZSL/Ast/Enums.hpp | 4 +- include/NZSL/Ast/ExpressionType.hpp | 35 ++- include/NZSL/Ast/ExpressionType.inl | 50 ++++ include/NZSL/Ast/NodeList.hpp | 2 + include/NZSL/Ast/Nodes.hpp | 17 ++ include/NZSL/Ast/RecursiveVisitor.hpp | 2 + include/NZSL/Ast/Utils.hpp | 2 + include/NZSL/GlslWriter.hpp | 2 + include/NZSL/Lang/ErrorList.hpp | 2 + include/NZSL/LangWriter.hpp | 7 +- include/NZSL/Parser.hpp | 2 +- include/NZSL/ShaderBuilder.hpp | 7 + include/NZSL/ShaderBuilder.inl | 18 ++ src/NZSL/Ast/AstSerializer.cpp | 60 ++++- src/NZSL/Ast/Cloner.cpp | 25 +- src/NZSL/Ast/ExpressionType.cpp | 28 +- src/NZSL/Ast/RecursiveVisitor.cpp | 10 + src/NZSL/Ast/SanitizeVisitor.cpp | 368 +++++++++++++++----------- src/NZSL/Ast/Utils.cpp | 10 + src/NZSL/GlslWriter.cpp | 10 + src/NZSL/LangWriter.cpp | 91 ++++++- src/NZSL/Parser.cpp | 75 +++++- tests/src/Tests/ErrorsTests.cpp | 7 + tests/src/Tests/ModuleTests.cpp | 236 ++++++++++++++++- 28 files changed, 886 insertions(+), 211 deletions(-) diff --git a/include/NZSL/Ast/AstSerializer.hpp b/include/NZSL/Ast/AstSerializer.hpp index 76b1e3d..b80ad86 100644 --- a/include/NZSL/Ast/AstSerializer.hpp +++ b/include/NZSL/Ast/AstSerializer.hpp @@ -39,6 +39,8 @@ namespace nzsl::Ast void Serialize(IdentifierExpression& node); void Serialize(IntrinsicExpression& node); void Serialize(IntrinsicFunctionExpression& node); + void Serialize(ModuleExpression& node); + void Serialize(NamedExternalBlockExpression& node); void Serialize(StructTypeExpression& node); void Serialize(SwizzleExpression& node); void Serialize(TypeExpression& node); diff --git a/include/NZSL/Ast/Cloner.hpp b/include/NZSL/Ast/Cloner.hpp index 85b23d5..d875d93 100644 --- a/include/NZSL/Ast/Cloner.hpp +++ b/include/NZSL/Ast/Cloner.hpp @@ -55,6 +55,8 @@ namespace nzsl::Ast virtual ExpressionPtr Clone(IdentifierExpression& node); virtual ExpressionPtr Clone(IntrinsicExpression& node); virtual ExpressionPtr Clone(IntrinsicFunctionExpression& node); + virtual ExpressionPtr Clone(ModuleExpression& node); + virtual ExpressionPtr Clone(NamedExternalBlockExpression& node); virtual ExpressionPtr Clone(StructTypeExpression& node); virtual ExpressionPtr Clone(SwizzleExpression& node); virtual ExpressionPtr Clone(TypeExpression& node); diff --git a/include/NZSL/Ast/Compare.hpp b/include/NZSL/Ast/Compare.hpp index cec83fc..0f66975 100644 --- a/include/NZSL/Ast/Compare.hpp +++ b/include/NZSL/Ast/Compare.hpp @@ -60,6 +60,8 @@ namespace nzsl::Ast inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs, const ComparisonParams& params = {}); inline bool Compare(const IntrinsicExpression& lhs, const IntrinsicExpression& rhs, const ComparisonParams& params = {}); inline bool Compare(const IntrinsicFunctionExpression& lhs, const IntrinsicFunctionExpression& rhs, const ComparisonParams& params = {}); + inline bool Compare(const ModuleExpression& lhs, const ModuleExpression& rhs, const ComparisonParams& params = {}); + inline bool Compare(const NamedExternalBlockExpression& lhs, const NamedExternalBlockExpression& rhs, const ComparisonParams& params = {}); inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs, const ComparisonParams& params = {}); inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs, const ComparisonParams& params = {}); inline bool Compare(const TypeExpression& lhs, const TypeExpression& rhs, const ComparisonParams& params = {}); diff --git a/include/NZSL/Ast/Compare.inl b/include/NZSL/Ast/Compare.inl index 8ff913e..b0e336d 100644 --- a/include/NZSL/Ast/Compare.inl +++ b/include/NZSL/Ast/Compare.inl @@ -263,7 +263,7 @@ namespace nzsl::Ast if (!Compare(lhs.renamedIdentifierLoc, rhs.renamedIdentifierLoc, params)) return false; - + return true; } @@ -503,6 +503,22 @@ namespace nzsl::Ast return true; } + inline bool Compare(const ModuleExpression& lhs, const ModuleExpression& rhs, const ComparisonParams& params) + { + if (!Compare(lhs.moduleId, rhs.moduleId, params)) + return false; + + return true; + } + + inline bool Compare(const NamedExternalBlockExpression& lhs, const NamedExternalBlockExpression& rhs, const ComparisonParams& params) + { + if (!Compare(lhs.externalBlockId, rhs.externalBlockId, params)) + return false; + + return true; + } + inline bool Compare(const StructTypeExpression& lhs, const StructTypeExpression& rhs, const ComparisonParams& params) { if (!Compare(lhs.structTypeId, rhs.structTypeId, params)) @@ -758,6 +774,9 @@ namespace nzsl::Ast if (params.compareModuleName && !Compare(lhs.moduleName, rhs.moduleName, params)) return false; + if (!Compare(lhs.moduleIdentifier, rhs.moduleIdentifier, params)) + return false; + if (!Compare(lhs.identifiers, rhs.identifiers, params)) return false; diff --git a/include/NZSL/Ast/Enums.hpp b/include/NZSL/Ast/Enums.hpp index ceba0d7..734b8ed 100644 --- a/include/NZSL/Ast/Enums.hpp +++ b/include/NZSL/Ast/Enums.hpp @@ -216,6 +216,8 @@ namespace nzsl::Ast IdentifierExpression = 13, IntrinsicExpression = 14, IntrinsicFunctionExpression = 15, + ModuleExpression = 42, + NamedExternalBlockExpression = 43, StructTypeExpression = 16, SwizzleExpression = 17, TypeExpression = 18, @@ -245,7 +247,7 @@ namespace nzsl::Ast ScopedStatement = 38, WhileStatement = 39, - Max = ContinueStatement + Max = NamedExternalBlockExpression }; enum class PrimitiveType diff --git a/include/NZSL/Ast/ExpressionType.hpp b/include/NZSL/Ast/ExpressionType.hpp index 9c5eae4..223e310 100644 --- a/include/NZSL/Ast/ExpressionType.hpp +++ b/include/NZSL/Ast/ExpressionType.hpp @@ -12,15 +12,12 @@ #include #include #include -#include #include #include #include -#ifdef NAZARA_COMPILER_GCC -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" -#endif +NAZARA_WARNING_PUSH() +NAZARA_WARNING_GCC_DISABLE("-Wmaybe-uninitialized") namespace nzsl { @@ -137,6 +134,22 @@ namespace nzsl::Ast inline bool operator!=(const MethodType& rhs) const; }; + struct ModuleType + { + std::size_t moduleIndex; + + inline bool operator==(const ModuleType& rhs) const; + inline bool operator!=(const ModuleType& rhs) const; + }; + + struct NamedExternalBlockType + { + std::size_t namedExternalBlockIndex; + + inline bool operator==(const NamedExternalBlockType& rhs) const; + inline bool operator!=(const NamedExternalBlockType& rhs) const; + }; + struct NoType { inline bool operator==(const NoType& rhs) const; @@ -216,7 +229,7 @@ namespace nzsl::Ast inline bool operator!=(const PushConstantType& rhs) const; }; - using ExpressionType = std::variant; + using ExpressionType = std::variant; struct ContainedType { @@ -252,6 +265,8 @@ namespace nzsl::Ast inline bool IsIntrinsicFunctionType(const ExpressionType& type); inline bool IsMatrixType(const ExpressionType& type); inline bool IsMethodType(const ExpressionType& type); + inline bool IsModuleType(const ExpressionType& type); + inline bool IsNamedExternalBlockType(const ExpressionType& type); inline bool IsNoType(const ExpressionType& type); inline bool IsPrimitiveType(const ExpressionType& type); inline bool IsPushConstantType(const ExpressionType& type); @@ -296,6 +311,8 @@ namespace nzsl::Ast struct Stringifier { std::function aliasStringifier; + std::function moduleStringifier; + std::function namedExternalBlockStringifier; std::function structStringifier; std::function typeStringifier; }; @@ -308,6 +325,8 @@ namespace nzsl::Ast NZSL_API std::string ToString(const IntrinsicFunctionType& type, const Stringifier& stringifier = {}); NZSL_API std::string ToString(const MatrixType& type, const Stringifier& stringifier = {}); NZSL_API std::string ToString(const MethodType& type, const Stringifier& stringifier = {}); + NZSL_API std::string ToString(const ModuleType& type, const Stringifier& stringifier = {}); + NZSL_API std::string ToString(const NamedExternalBlockType& type, const Stringifier& stringifier = {}); NZSL_API std::string ToString(NoType type, const Stringifier& stringifier = {}); NZSL_API std::string ToString(PrimitiveType type, const Stringifier& stringifier = {}); NZSL_API std::string ToString(const PushConstantType& type, const Stringifier& stringifier = {}); @@ -320,9 +339,7 @@ namespace nzsl::Ast NZSL_API std::string ToString(const VectorType& type, const Stringifier& stringifier = {}); } -#ifdef NAZARA_COMPILER_GCC -#pragma GCC diagnostic pop -#endif +NAZARA_WARNING_POP() #include diff --git a/include/NZSL/Ast/ExpressionType.inl b/include/NZSL/Ast/ExpressionType.inl index b924124..52d34ba 100644 --- a/include/NZSL/Ast/ExpressionType.inl +++ b/include/NZSL/Ast/ExpressionType.inl @@ -44,6 +44,17 @@ namespace nzsl::Ast } + inline bool NamedExternalBlockType::operator==(const NamedExternalBlockType& rhs) const + { + return namedExternalBlockIndex == rhs.namedExternalBlockIndex; + } + + inline bool NamedExternalBlockType::operator!=(const NamedExternalBlockType& rhs) const + { + return !operator==(rhs); + } + + inline bool FunctionType::operator==(const FunctionType& rhs) const { return funcIndex == rhs.funcIndex; @@ -83,6 +94,17 @@ namespace nzsl::Ast } + inline bool ModuleType::operator==(const ModuleType& rhs) const + { + return moduleIndex == rhs.moduleIndex; + } + + inline bool ModuleType::operator!=(const ModuleType& rhs) const + { + return !operator==(rhs); + } + + inline bool NoType::operator==(const NoType& /*rhs*/) const { return true; @@ -216,6 +238,16 @@ namespace nzsl::Ast return std::holds_alternative(type); } + inline bool IsModuleType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + + inline bool IsNamedExternalBlockType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + inline bool IsNoType(const ExpressionType& type) { return std::holds_alternative(type); @@ -388,6 +420,24 @@ namespace std } }; + template<> + struct hash + { + std::size_t operator()(const nzsl::Ast::ModuleType& moduleType) const + { + return Nz::HashCombine(moduleType.moduleIndex); + } + }; + + template<> + struct hash + { + std::size_t operator()(const nzsl::Ast::NamedExternalBlockType& namedExternalBlockType) const + { + return Nz::HashCombine(namedExternalBlockType.namedExternalBlockIndex); + } + }; + template<> struct hash { diff --git a/include/NZSL/Ast/NodeList.hpp b/include/NZSL/Ast/NodeList.hpp index 9c8bff0..56f2283 100644 --- a/include/NZSL/Ast/NodeList.hpp +++ b/include/NZSL/Ast/NodeList.hpp @@ -44,6 +44,8 @@ NZSL_SHADERAST_EXPRESSION(Function) NZSL_SHADERAST_EXPRESSION(Identifier) NZSL_SHADERAST_EXPRESSION(Intrinsic) NZSL_SHADERAST_EXPRESSION(IntrinsicFunction) +NZSL_SHADERAST_EXPRESSION(Module) +NZSL_SHADERAST_EXPRESSION(NamedExternalBlock) NZSL_SHADERAST_EXPRESSION(StructType) NZSL_SHADERAST_EXPRESSION(Swizzle) NZSL_SHADERAST_EXPRESSION(Type) diff --git a/include/NZSL/Ast/Nodes.hpp b/include/NZSL/Ast/Nodes.hpp index 82d4425..d806411 100644 --- a/include/NZSL/Ast/Nodes.hpp +++ b/include/NZSL/Ast/Nodes.hpp @@ -223,6 +223,22 @@ namespace nzsl::Ast std::size_t intrinsicId; }; + struct NZSL_API ModuleExpression : Expression + { + NodeType GetType() const override; + void Visit(ExpressionVisitor& visitor) override; + + std::size_t moduleId; + }; + + struct NZSL_API NamedExternalBlockExpression : Expression + { + NodeType GetType() const override; + void Visit(ExpressionVisitor& visitor) override; + + std::size_t externalBlockId; + }; + struct NZSL_API StructTypeExpression : Expression { NodeType GetType() const override; @@ -478,6 +494,7 @@ namespace nzsl::Ast SourceLocation renamedIdentifierLoc; }; + std::string moduleIdentifier; std::string moduleName; std::vector identifiers; }; diff --git a/include/NZSL/Ast/RecursiveVisitor.hpp b/include/NZSL/Ast/RecursiveVisitor.hpp index 4934533..5904229 100644 --- a/include/NZSL/Ast/RecursiveVisitor.hpp +++ b/include/NZSL/Ast/RecursiveVisitor.hpp @@ -35,6 +35,8 @@ namespace nzsl::Ast void Visit(IdentifierExpression& node) override; void Visit(IntrinsicExpression& node) override; void Visit(IntrinsicFunctionExpression& node) override; + void Visit(ModuleExpression& node) override; + void Visit(NamedExternalBlockExpression& node) override; void Visit(StructTypeExpression& node) override; void Visit(SwizzleExpression& node) override; void Visit(TypeExpression& node) override; diff --git a/include/NZSL/Ast/Utils.hpp b/include/NZSL/Ast/Utils.hpp index f0cecc1..ea67973 100644 --- a/include/NZSL/Ast/Utils.hpp +++ b/include/NZSL/Ast/Utils.hpp @@ -46,6 +46,8 @@ namespace nzsl::Ast void Visit(IdentifierExpression& node) override; void Visit(IntrinsicExpression& node) override; void Visit(IntrinsicFunctionExpression& node) override; + void Visit(ModuleExpression& node) override; + void Visit(NamedExternalBlockExpression& node) override; void Visit(StructTypeExpression& node) override; void Visit(SwizzleExpression& node) override; void Visit(TypeExpression& node) override; diff --git a/include/NZSL/GlslWriter.hpp b/include/NZSL/GlslWriter.hpp index d9f5e23..f2dd6b5 100644 --- a/include/NZSL/GlslWriter.hpp +++ b/include/NZSL/GlslWriter.hpp @@ -77,7 +77,9 @@ namespace nzsl void Append(const Ast::IntrinsicFunctionType& intrinsicFunctionType); void Append(const Ast::MatrixType& matrixType); void Append(const Ast::MethodType& methodType); + void Append(const Ast::ModuleType& methodType); void Append(Ast::MemoryLayout layout); + void Append(const Ast::NamedExternalBlockType& namedExternalBlockType); void Append(Ast::NoType); void Append(Ast::PrimitiveType type); void Append(const Ast::PushConstantType& pushConstantType); diff --git a/include/NZSL/Lang/ErrorList.hpp b/include/NZSL/Lang/ErrorList.hpp index af75a2c..a051f9b 100644 --- a/include/NZSL/Lang/ErrorList.hpp +++ b/include/NZSL/Lang/ErrorList.hpp @@ -42,6 +42,8 @@ NZSL_SHADERLANG_PARSER_ERROR(AttributeUnexpectedParameterCount, "attribute {} ex NZSL_SHADERLANG_PARSER_ERROR(ExpectedToken, "expected token {}, got {}", TokenType, TokenType) NZSL_SHADERLANG_PARSER_ERROR(DuplicateIdentifier, "duplicate identifier") NZSL_SHADERLANG_PARSER_ERROR(DuplicateModule, "duplicate module") +NZSL_SHADERLANG_PARSER_ERROR(ModuleImportInvalidIdentifier, "{} is not a valid identifier to import", std::string) +NZSL_SHADERLANG_PARSER_ERROR(ModuleImportMultiple, "a module import can only be a single name") NZSL_SHADERLANG_PARSER_ERROR(InvalidVersion, "\"{}\" is not a valid version", std::string) NZSL_SHADERLANG_PARSER_ERROR(MissingAttribute, "missing attribute {}", Ast::AttributeType) NZSL_SHADERLANG_PARSER_ERROR(ModuleFeatureMultipleUnique, "module feature {} has already been specified", Ast::ModuleFeature) diff --git a/include/NZSL/LangWriter.hpp b/include/NZSL/LangWriter.hpp index 44baf50..4aa4547 100644 --- a/include/NZSL/LangWriter.hpp +++ b/include/NZSL/LangWriter.hpp @@ -12,8 +12,6 @@ #include #include #include -#include -#include #include namespace nzsl @@ -70,6 +68,8 @@ namespace nzsl void Append(const Ast::IntrinsicFunctionType& intrinsicFunctionType); void Append(const Ast::MatrixType& matrixType); void Append(const Ast::MethodType& methodType); + void Append(const Ast::ModuleType& moduleType); + void Append(const Ast::NamedExternalBlockType& namedExternalBlockType); void Append(Ast::NoType); void Append(Ast::PrimitiveType type); void Append(const Ast::PushConstantType& pushConstantType); @@ -120,6 +120,7 @@ namespace nzsl void RegisterAlias(std::size_t aliasIndex, std::string aliasName); void RegisterConstant(std::size_t constantIndex, std::string constantName); void RegisterFunction(std::size_t funcIndex, std::string functionName); + void RegisterModule(std::size_t moduleIndex, std::string moduleName); void RegisterStruct(std::size_t structIndex, std::string structName); void RegisterVariable(std::size_t varIndex, std::string varName); @@ -142,6 +143,8 @@ namespace nzsl void Visit(Ast::FunctionExpression& node) override; void Visit(Ast::IdentifierExpression& node) override; void Visit(Ast::IntrinsicExpression& node) override; + void Visit(Ast::ModuleExpression& node) override; + void Visit(Ast::NamedExternalBlockExpression& node) override; void Visit(Ast::StructTypeExpression& node) override; void Visit(Ast::SwizzleExpression& node) override; void Visit(Ast::VariableValueExpression& node) override; diff --git a/include/NZSL/Parser.hpp b/include/NZSL/Parser.hpp index 80fce58..82dc5d0 100644 --- a/include/NZSL/Parser.hpp +++ b/include/NZSL/Parser.hpp @@ -95,7 +95,7 @@ namespace nzsl Ast::ExpressionPtr ParseStringExpression(); const std::string& ParseIdentifierAsName(SourceLocation* sourceLocation); - std::string ParseModuleName(); + std::string ParseModuleName(SourceLocation* sourceLocation); Ast::ExpressionPtr ParseType(); const std::string& ExtractStringAttribute(Attribute&& attribute); diff --git a/include/NZSL/ShaderBuilder.hpp b/include/NZSL/ShaderBuilder.hpp index 2975e2c..fdbd62b 100644 --- a/include/NZSL/ShaderBuilder.hpp +++ b/include/NZSL/ShaderBuilder.hpp @@ -150,6 +150,7 @@ namespace nzsl::ShaderBuilder struct Import { + inline Ast::ImportStatementPtr operator()(std::string modulePath, std::string moduleIdentifier) const; inline Ast::ImportStatementPtr operator()(std::string modulePath, std::vector identifiers) const; }; @@ -163,6 +164,11 @@ namespace nzsl::ShaderBuilder inline Ast::IntrinsicFunctionExpressionPtr operator()(std::size_t intrinsicFunctionId, Ast::IntrinsicType intrinsicType) const; }; + struct ModuleExpr + { + inline Ast::ModuleExpressionPtr operator()(std::size_t moduleTypeId) const; + }; + struct Multi { inline Ast::MultiStatementPtr operator()(std::vector statements = {}) const; @@ -242,6 +248,7 @@ namespace nzsl::ShaderBuilder constexpr Impl::IntrinsicFunction IntrinsicFunction; constexpr Impl::Import Import; constexpr Impl::Intrinsic Intrinsic; + constexpr Impl::ModuleExpr ModuleExpr; constexpr Impl::Multi MultiStatement; constexpr Impl::NoParam NoOp; constexpr Impl::Return Return; diff --git a/include/NZSL/ShaderBuilder.inl b/include/NZSL/ShaderBuilder.inl index 26d4ded..767f457 100644 --- a/include/NZSL/ShaderBuilder.inl +++ b/include/NZSL/ShaderBuilder.inl @@ -404,6 +404,15 @@ namespace nzsl::ShaderBuilder return identifierNode; } + inline Ast::ImportStatementPtr Impl::Import::operator()(std::string modulePath, std::string moduleIdentifier) const + { + auto importNode = std::make_unique(); + importNode->moduleName = std::move(modulePath); + importNode->moduleIdentifier = std::move(moduleIdentifier); + + return importNode; + } + inline Ast::ImportStatementPtr Impl::Import::operator()(std::string moduleName, std::vector identifiers) const { auto importNode = std::make_unique(); @@ -431,6 +440,15 @@ namespace nzsl::ShaderBuilder return intrinsicTypeExpr; } + inline Ast::ModuleExpressionPtr Impl::ModuleExpr::operator()(std::size_t moduleTypeId) const + { + auto moduleTypeExpr = std::make_unique(); + moduleTypeExpr->cachedExpressionType = Ast::ModuleType{ moduleTypeId }; + moduleTypeExpr->moduleId = moduleTypeId; + + return moduleTypeExpr; + } + inline Ast::MultiStatementPtr Impl::Multi::operator()(std::vector statements) const { auto multiStatement = std::make_unique(); diff --git a/src/NZSL/Ast/AstSerializer.cpp b/src/NZSL/Ast/AstSerializer.cpp index 08eed86..c5aa9f0 100644 --- a/src/NZSL/Ast/AstSerializer.cpp +++ b/src/NZSL/Ast/AstSerializer.cpp @@ -13,7 +13,7 @@ namespace nzsl::Ast namespace { constexpr std::uint32_t s_shaderAstMagicNumber = 0x4E534852; - constexpr std::uint32_t s_shaderAstCurrentVersion = 11; + constexpr std::uint32_t s_shaderAstCurrentVersion = 12; class ShaderSerializerVisitor : public ExpressionVisitor, public StatementVisitor { @@ -254,6 +254,16 @@ namespace nzsl::Ast SizeT(node.intrinsicId); } + void SerializerBase::Serialize(ModuleExpression& node) + { + SizeT(node.moduleId); + } + + void SerializerBase::Serialize(NamedExternalBlockExpression& node) + { + SizeT(node.externalBlockId); + } + void SerializerBase::Serialize(StructTypeExpression& node) { SizeT(node.structTypeId); @@ -470,6 +480,9 @@ namespace nzsl::Ast SourceLoc(identifierEntry.identifierLoc); SourceLoc(identifierEntry.renamedIdentifierLoc); } + + if (IsVersionGreaterOrEqual(12)) + Value(node.moduleIdentifier); } void SerializerBase::Serialize(MultiStatement& node) @@ -697,6 +710,16 @@ namespace nzsl::Ast m_serializer.Serialize(std::uint8_t(17)); SizeT(arg.containedType.structIndex); } + else if constexpr (std::is_same_v) + { + m_serializer.Serialize(std::uint8_t(18)); + SizeT(arg.moduleIndex); + } + else if constexpr (std::is_same_v) + { + m_serializer.Serialize(std::uint8_t(19)); + SizeT(arg.namedExternalBlockIndex); + } else static_assert(Nz::AlwaysFalse(), "non-exhaustive visitor"); }, type); @@ -986,10 +1009,8 @@ namespace nzsl::Ast void ShaderAstDeserializer::Type(ExpressionType& type) { -#ifdef NAZARA_COMPILER_GCC -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" -#endif +NAZARA_WARNING_PUSH() +NAZARA_WARNING_GCC_DISABLE("-Wmaybe-uninitialized") std::uint8_t typeIndex; Value(typeIndex); @@ -1045,8 +1066,8 @@ namespace nzsl::Ast case 5: //< StructType { - std::uint32_t structIndex; - Value(structIndex); + std::size_t structIndex; + SizeT(structIndex); type = StructType{ structIndex @@ -1230,14 +1251,33 @@ namespace nzsl::Ast break; } + case 18: //< ModuleType + { + std::size_t moduleIndex; + SizeT(moduleIndex); + + type = ModuleType{ + moduleIndex + }; + break; + } + + case 19: //< NamedExternalBlockType + { + std::size_t externalBlockIndex; + SizeT(externalBlockIndex); + + type = NamedExternalBlockType{ + externalBlockIndex + }; + break; + } default: throw std::runtime_error("unexpected type index " + std::to_string(typeIndex)); } -#ifdef NAZARA_COMPILER_GCC -#pragma GCC diagnostic pop -#endif +NAZARA_WARNING_POP() } void ShaderAstDeserializer::Value(bool& val) diff --git a/src/NZSL/Ast/Cloner.cpp b/src/NZSL/Ast/Cloner.cpp index 6805b80..763b738 100644 --- a/src/NZSL/Ast/Cloner.cpp +++ b/src/NZSL/Ast/Cloner.cpp @@ -289,8 +289,9 @@ namespace nzsl::Ast StatementPtr Cloner::Clone(ImportStatement& node) { auto clone = std::make_unique(); - clone->moduleName = node.moduleName; clone->identifiers = node.identifiers; + clone->moduleName = node.moduleName; + clone->moduleIdentifier = node.moduleIdentifier; clone->sourceLocation = node.sourceLocation; @@ -559,6 +560,28 @@ namespace nzsl::Ast return clone; } + ExpressionPtr Cloner::Clone(ModuleExpression& node) + { + auto clone = std::make_unique(); + clone->moduleId = node.moduleId; + + clone->cachedExpressionType = node.cachedExpressionType; + clone->sourceLocation = node.sourceLocation; + + return clone; + } + + ExpressionPtr Cloner::Clone(NamedExternalBlockExpression& node) + { + auto clone = std::make_unique(); + clone->externalBlockId = node.externalBlockId; + + clone->cachedExpressionType = node.cachedExpressionType; + clone->sourceLocation = node.sourceLocation; + + return clone; + } + ExpressionPtr Cloner::Clone(StructTypeExpression& node) { auto clone = std::make_unique(); diff --git a/src/NZSL/Ast/ExpressionType.cpp b/src/NZSL/Ast/ExpressionType.cpp index 045a2ce..3e98aa3 100644 --- a/src/NZSL/Ast/ExpressionType.cpp +++ b/src/NZSL/Ast/ExpressionType.cpp @@ -98,14 +98,14 @@ namespace nzsl::Ast return objectType->type == rhs.objectType->type && methodIndex == rhs.methodIndex; } - using ForbiddenStructTypes = Nz::TypeList; + using StructTypes = Nz::TypeList; std::size_t RegisterStructField(FieldOffsets& fieldOffsets, const ExpressionType& type, const StructFinder& structFinder) { return std::visit([&](auto&& arg) -> std::size_t { using T = std::decay_t; - if constexpr (!Nz::TypeListHas) + if constexpr (Nz::TypeListHas) return RegisterStructFieldType(fieldOffsets, arg, structFinder); else throw std::runtime_error("unexpected type (" + ToString(arg) + ") as struct field"); @@ -117,7 +117,7 @@ namespace nzsl::Ast return std::visit([&](auto&& arg) -> std::size_t { using T = std::decay_t; - if constexpr (!Nz::TypeListHas) + if constexpr (Nz::TypeListHas) return RegisterStructFieldType(fieldOffsets, arg, arraySize, structFinder); else throw std::runtime_error("unexpected type (" + ToString(arg) + ") as struct field"); @@ -286,6 +286,8 @@ namespace nzsl::Ast std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || @@ -370,6 +372,22 @@ namespace nzsl::Ast return "type) + " type>"; } + std::string ToString(const ModuleType& type, const Stringifier& stringifier) + { + if (stringifier.moduleStringifier) + return "imported module " + stringifier.moduleStringifier(type.moduleIndex); + else + return fmt::format("imported module #{}", type.moduleIndex); + } + + std::string ToString(const NamedExternalBlockType& type, const Stringifier& stringifier) + { + if (stringifier.namedExternalBlockStringifier) + return "named external block " + stringifier.namedExternalBlockStringifier(type.namedExternalBlockIndex); + else + return fmt::format("named external block #{}", type.namedExternalBlockIndex); + } + std::string ToString(NoType /*type*/, const Stringifier& /*stringifier*/) { return "()"; @@ -421,7 +439,7 @@ namespace nzsl::Ast if (stringifier.structStringifier) return "struct " + stringifier.structStringifier(type.structIndex); else - return "struct #" + std::to_string(type.structIndex); + return fmt::format("struct #{}", type.structIndex); } std::string ToString(const TextureType& type, const Stringifier& /*stringifier*/) @@ -445,7 +463,7 @@ namespace nzsl::Ast if (stringifier.typeStringifier) return "type " + stringifier.typeStringifier(type.typeIndex); else - return "type #" + std::to_string(type.typeIndex); + return fmt::format("type #{}", type.typeIndex); } std::string ToString(const UniformType& type, const Stringifier& stringifier) diff --git a/src/NZSL/Ast/RecursiveVisitor.cpp b/src/NZSL/Ast/RecursiveVisitor.cpp index 4bd2b3e..77384c4 100644 --- a/src/NZSL/Ast/RecursiveVisitor.cpp +++ b/src/NZSL/Ast/RecursiveVisitor.cpp @@ -99,6 +99,16 @@ namespace nzsl::Ast /* Nothing to do */ } + void RecursiveVisitor::Visit(ModuleExpression& /*node*/) + { + /* Nothing to do */ + } + + void RecursiveVisitor::Visit(NamedExternalBlockExpression& /*node*/) + { + /* Nothing to do */ + } + void RecursiveVisitor::Visit(StructTypeExpression& /*node*/) { /* Nothing to do */ diff --git a/src/NZSL/Ast/SanitizeVisitor.cpp b/src/NZSL/Ast/SanitizeVisitor.cpp index d9ac3e3..11fd63b 100644 --- a/src/NZSL/Ast/SanitizeVisitor.cpp +++ b/src/NZSL/Ast/SanitizeVisitor.cpp @@ -179,12 +179,14 @@ namespace nzsl::Ast { std::unordered_map exportedSetByModule; std::shared_ptr environment; + std::string moduleName; std::unique_ptr dependenciesVisitor; }; struct NamedExternalBlockData { std::shared_ptr environment; + std::string name; }; struct UsedExternalData @@ -271,6 +273,7 @@ namespace nzsl::Ast m_context->moduleByName[cloneImportedModule.module->metadata->moduleName] = moduleId; auto& moduleData = m_context->modules.emplace_back(); moduleData.environment = std::move(importedModuleEnv); + moduleData.moduleName = cloneImportedModule.identifier; m_context->currentEnv = m_context->globalEnv; RegisterModule(cloneImportedModule.identifier, moduleId); @@ -318,53 +321,7 @@ namespace nzsl::Ast MandatoryExpr(node.expr, node.sourceLocation); - // Handle module access and named external access (TODO: Add namespace expression?) - if (node.expr->GetType() == NodeType::IdentifierExpression && node.identifiers.size() == 1) - { - auto& identifierExpr = static_cast(*node.expr); - const IdentifierData* identifierData = FindIdentifier(identifierExpr.identifier); - - if (identifierData) - { - switch (identifierData->category) - { - case IdentifierCategory::ExternalBlock: - { - std::size_t namedExternalBlockIndex = m_context->namedExternalBlockIndices.Retrieve(identifierData->index, node.sourceLocation); - - const auto& env = *m_context->namedExternalBlocks[namedExternalBlockIndex].environment; - identifierData = FindIdentifier(env, node.identifiers.front().identifier); - if (identifierData) - { - if (m_context->options.partialSanitization && identifierData->conditionalIndex != m_context->currentConditionalIndex) - return Cloner::Clone(node); - - return HandleIdentifier(identifierData, node.identifiers.front().sourceLocation); - } - break; - } - - case IdentifierCategory::Module: - { - std::size_t moduleIndex = m_context->moduleIndices.Retrieve(identifierData->index, node.sourceLocation); - - const auto& env = *m_context->modules[moduleIndex].environment; - identifierData = FindIdentifier(env, node.identifiers.front().identifier); - if (identifierData) - { - if (m_context->options.partialSanitization && identifierData->conditionalIndex != m_context->currentConditionalIndex) - return Cloner::Clone(node); - - return HandleIdentifier(identifierData, node.identifiers.front().sourceLocation); - } - break; - } - - default: - break; - } - } - } + auto previousEnv = m_context->currentEnv; ExpressionPtr indexedExpr = CloneExpression(node.expr); for (const auto& identifierEntry : node.identifiers) @@ -615,10 +572,80 @@ namespace nzsl::Ast indexedExpr = std::move(swizzle); } } + else if (IsNamedExternalBlockType(resolvedType)) + { + const NamedExternalBlockType& externalBlockType = std::get(resolvedType); + std::size_t namedExternalBlockIndex = externalBlockType.namedExternalBlockIndex; + + const IdentifierData* identifierData = FindIdentifier(*m_context->namedExternalBlocks[namedExternalBlockIndex].environment, identifierEntry.identifier); + if (!identifierData) + { + if (m_context->allowUnknownIdentifiers) + return Cloner::Clone(node); + + throw CompilerUnknownIdentifierError{ node.sourceLocation, identifierEntry.identifier }; + } + + if (identifierData->category == IdentifierCategory::Unresolved) + return Cloner::Clone(node); + + if (m_context->options.partialSanitization && identifierData->conditionalIndex != m_context->currentConditionalIndex) + return Cloner::Clone(node); + + indexedExpr = HandleIdentifier(identifierData, identifierEntry.sourceLocation); + } + else if (IsModuleType(resolvedType)) + { + const ModuleType& moduleType = std::get(resolvedType); + std::size_t moduleId = moduleType.moduleIndex; + + m_context->currentEnv = m_context->modules[moduleId].environment; + + const IdentifierData* identifierData = FindIdentifier(*m_context->currentEnv, identifierEntry.identifier); + if (!identifierData) + { + if (m_context->allowUnknownIdentifiers) + return Cloner::Clone(node); + + throw CompilerUnknownIdentifierError{ node.sourceLocation, identifierEntry.identifier }; + } + + if (identifierData->category == IdentifierCategory::Unresolved) + return Cloner::Clone(node); + + if (m_context->options.partialSanitization && identifierData->conditionalIndex != m_context->currentConditionalIndex) + return Cloner::Clone(node); + + auto& dependencyCheckerPtr = m_context->modules[moduleId].dependenciesVisitor; + if (dependencyCheckerPtr) //< dependency checker can be null when performing partial sanitization + { + switch (identifierData->category) + { + case IdentifierCategory::Constant: + dependencyCheckerPtr->MarkConstantAsUsed(identifierData->index); + break; + + case IdentifierCategory::Function: + dependencyCheckerPtr->MarkFunctionAsUsed(identifierData->index); + break; + + case IdentifierCategory::Struct: + dependencyCheckerPtr->MarkStructAsUsed(identifierData->index); + break; + + default: + break; + } + } + + indexedExpr = HandleIdentifier(identifierData, identifierEntry.sourceLocation); + } else throw CompilerUnexpectedAccessedTypeError{ node.sourceLocation }; } + m_context->currentEnv = std::move(previousEnv); + return indexedExpr; } @@ -1507,6 +1534,7 @@ NAZARA_WARNING_GCC_DISABLE("-Wmaybe-uninitialized") auto& namedExternal = m_context->namedExternalBlocks.emplace_back(); namedExternal.environment = std::make_shared(); namedExternal.environment->parentEnv = m_context->currentEnv; + namedExternal.name = clone->name; RegisterExternalBlock(clone->name, *namedExternalBlockIndex, clone->sourceLocation); @@ -2481,9 +2509,6 @@ NAZARA_WARNING_POP() StatementPtr SanitizeVisitor::Clone(ImportStatement& node) { - if (node.identifiers.empty()) - throw AstEmptyImportError{ node.sourceLocation }; - tsl::ordered_map> importedSymbols; bool importEverythingElse = false; for (const auto& entry : node.identifiers) @@ -2636,141 +2661,146 @@ NAZARA_WARNING_POP() auto& moduleData = m_context->modules[moduleIndex]; - auto& exportedSet = moduleData.exportedSetByModule[m_context->currentEnv->moduleId]; - // Extract exported nodes and their dependencies std::vector aliasStatements; std::vector constStatements; - - auto CheckImport = [&](const std::string& identifier) -> std::pair> + if (!importedSymbols.empty() || importEverythingElse) { - auto it = importedSymbols.find(identifier); - if (it == importedSymbols.end()) - { - if (!importEverythingElse) - return { false, {} }; + // Importing module symbols in global scope + auto& exportedSet = moduleData.exportedSetByModule[m_context->currentEnv->moduleId]; - return { true, { std::string{} } }; - } - else + auto CheckImport = [&](const std::string& identifier) -> std::pair> { - std::vector imports = std::move(it->second); - importedSymbols.erase(it); + auto it = importedSymbols.find(identifier); + if (it == importedSymbols.end()) + { + if (!importEverythingElse) + return { false, {} }; - return { true, std::move(imports) }; - } - }; + return { true, { std::string{} } }; + } + else + { + std::vector imports = std::move(it.value()); + importedSymbols.erase(it); - ExportVisitor::Callbacks callbacks; - callbacks.onExportedConst = [&](DeclareConstStatement& node) - { - assert(node.constIndex); + return { true, std::move(imports) }; + } + }; - auto [imported, aliasesName] = CheckImport(node.name); - if (!imported) - return; + ExportVisitor::Callbacks callbacks; + callbacks.onExportedConst = [&](DeclareConstStatement& node) + { + assert(node.constIndex); - if (moduleData.dependenciesVisitor) - moduleData.dependenciesVisitor->MarkConstantAsUsed(*node.constIndex); + auto [imported, aliasesName] = CheckImport(node.name); + if (!imported) + return; - auto BuildConstant = [&]() -> ExpressionPtr - { - const ConstantValue* value = m_context->constantValues.TryRetrieve(*node.constIndex, node.sourceLocation); - if (!value) - throw AstInvalidConstantIndexError{ node.sourceLocation, *node.constIndex }; + if (moduleData.dependenciesVisitor) + moduleData.dependenciesVisitor->MarkConstantAsUsed(*node.constIndex); - return ShaderBuilder::Constant(*node.constIndex, GetConstantType(*value)); - }; + auto BuildConstant = [&]() -> ExpressionPtr + { + const ConstantValue* value = m_context->constantValues.TryRetrieve(*node.constIndex, node.sourceLocation); + if (!value) + throw AstInvalidConstantIndexError{ node.sourceLocation, *node.constIndex }; - for (const std::string& aliasName : aliasesName) - { - if (aliasName.empty()) + return ShaderBuilder::Constant(*node.constIndex, GetConstantType(*value)); + }; + + for (const std::string& aliasName : aliasesName) { - // symbol not renamed, export it once - if (exportedSet.usedConstants.UnboundedTest(*node.constIndex)) - return; + if (aliasName.empty()) + { + // symbol not renamed, export it once + if (exportedSet.usedConstants.UnboundedTest(*node.constIndex)) + return; - exportedSet.usedConstants.UnboundedSet(*node.constIndex); - constStatements.emplace_back(ShaderBuilder::DeclareConst(node.name, BuildConstant())); + exportedSet.usedConstants.UnboundedSet(*node.constIndex); + constStatements.emplace_back(ShaderBuilder::DeclareConst(node.name, BuildConstant())); + } + else + constStatements.emplace_back(ShaderBuilder::DeclareConst(aliasName, BuildConstant())); } - else - constStatements.emplace_back(ShaderBuilder::DeclareConst(aliasName, BuildConstant())); - } - }; + }; - callbacks.onExportedFunc = [&](DeclareFunctionStatement& node) - { - assert(node.funcIndex); + callbacks.onExportedFunc = [&](DeclareFunctionStatement& node) + { + assert(node.funcIndex); - auto [imported, aliasesName] = CheckImport(node.name); - if (!imported) - return; + auto [imported, aliasesName] = CheckImport(node.name); + if (!imported) + return; - if (moduleData.dependenciesVisitor) - moduleData.dependenciesVisitor->MarkFunctionAsUsed(*node.funcIndex); + if (moduleData.dependenciesVisitor) + moduleData.dependenciesVisitor->MarkFunctionAsUsed(*node.funcIndex); - for (const std::string& aliasName : aliasesName) - { - if (aliasName.empty()) + for (const std::string& aliasName : aliasesName) { - // symbol not renamed, export it once - if (exportedSet.usedFunctions.UnboundedTest(*node.funcIndex)) - return; + if (aliasName.empty()) + { + // symbol not renamed, export it once + if (exportedSet.usedFunctions.UnboundedTest(*node.funcIndex)) + return; - exportedSet.usedFunctions.UnboundedSet(*node.funcIndex); - aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.name, ShaderBuilder::Function(*node.funcIndex))); + exportedSet.usedFunctions.UnboundedSet(*node.funcIndex); + aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.name, ShaderBuilder::Function(*node.funcIndex))); + } + else + aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(aliasName, ShaderBuilder::Function(*node.funcIndex))); } - else - aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(aliasName, ShaderBuilder::Function(*node.funcIndex))); - } - }; + }; - callbacks.onExportedStruct = [&](DeclareStructStatement& node) - { - assert(node.structIndex); + callbacks.onExportedStruct = [&](DeclareStructStatement& node) + { + assert(node.structIndex); - auto [imported, aliasesName] = CheckImport(node.description.name); - if (!imported) - return; + auto [imported, aliasesName] = CheckImport(node.description.name); + if (!imported) + return; - if (moduleData.dependenciesVisitor) - moduleData.dependenciesVisitor->MarkStructAsUsed(*node.structIndex); + if (moduleData.dependenciesVisitor) + moduleData.dependenciesVisitor->MarkStructAsUsed(*node.structIndex); - for (const std::string& aliasName : aliasesName) - { - if (aliasName.empty()) + for (const std::string& aliasName : aliasesName) { - // symbol not renamed, export it once - if (exportedSet.usedStructs.UnboundedTest(*node.structIndex)) - return; + if (aliasName.empty()) + { + // symbol not renamed, export it once + if (exportedSet.usedStructs.UnboundedTest(*node.structIndex)) + return; - exportedSet.usedStructs.UnboundedSet(*node.structIndex); - aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.description.name, ShaderBuilder::StructType(*node.structIndex))); + exportedSet.usedStructs.UnboundedSet(*node.structIndex); + aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.description.name, ShaderBuilder::StructType(*node.structIndex))); + } + else + aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(aliasName, ShaderBuilder::StructType(*node.structIndex))); } - else - aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(aliasName, ShaderBuilder::StructType(*node.structIndex))); - } - }; + }; - ExportVisitor exportVisitor; - exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks); + ExportVisitor exportVisitor; + exportVisitor.Visit(*m_context->currentModule->importedModules[moduleIndex].module->rootNode, callbacks); - if (!importedSymbols.empty()) - { - std::string symbolList; - for (const auto& [identifier, _] : importedSymbols) + if (!importedSymbols.empty()) { - if (!symbolList.empty()) - symbolList += ", "; + std::string symbolList; + for (const auto& [identifier, _] : importedSymbols) + { + if (!symbolList.empty()) + symbolList += ", "; + + symbolList += identifier; + } - symbolList += identifier; + throw CompilerImportIdentifierNotFoundError{ node.sourceLocation, symbolList, node.moduleName }; } - throw CompilerImportIdentifierNotFoundError{ node.sourceLocation, symbolList, node.moduleName }; + if (aliasStatements.empty() && constStatements.empty()) + return ShaderBuilder::NoOp(); } - - if (aliasStatements.empty() && constStatements.empty()) - return ShaderBuilder::NoOp(); + else + aliasStatements.emplace_back(ShaderBuilder::DeclareAlias(node.moduleIdentifier, ShaderBuilder::ModuleExpr(moduleIndex))); // Register aliases for (auto& aliasPtr : aliasStatements) @@ -2949,7 +2979,15 @@ NAZARA_WARNING_POP() } case IdentifierCategory::ExternalBlock: - throw AstUnexpectedIdentifierError{ sourceLocation, "external" }; + { + // Replace IdentifierExpression by NamedExternalBlockExpression + auto moduleExpr = std::make_unique(); + moduleExpr->cachedExpressionType = NamedExternalBlockType{ identifierData->index }; + moduleExpr->sourceLocation = sourceLocation; + moduleExpr->externalBlockId = identifierData->index; + + return moduleExpr; + } case IdentifierCategory::Function: { @@ -2976,7 +3014,15 @@ NAZARA_WARNING_POP() } case IdentifierCategory::Module: - throw AstUnexpectedIdentifierError{ sourceLocation, "module" }; + { + // Replace IdentifierExpression by ModuleExpression + auto moduleExpr = std::make_unique(); + moduleExpr->cachedExpressionType = ModuleType{ identifierData->index }; + moduleExpr->sourceLocation = sourceLocation; + moduleExpr->moduleId = identifierData->index; + + return moduleExpr; + } case IdentifierCategory::Struct: { @@ -4182,6 +4228,22 @@ NAZARA_WARNING_POP() return m_context->aliases.Retrieve(aliasIndex, sourceLocation).name; }; + stringifier.moduleStringifier = [&](std::size_t moduleIndex) + { + const std::string& moduleName = m_context->modules[moduleIndex].moduleName; + return (!moduleName.empty()) ? moduleName : fmt::format("", moduleIndex); + }; + + stringifier.namedExternalBlockStringifier = [&](std::size_t namedExternalBlockIndex) + { + return m_context->namedExternalBlocks[namedExternalBlockIndex].name; + }; + + stringifier.structStringifier = [&](std::size_t structIndex) + { + return m_context->structs.Retrieve(structIndex, sourceLocation)->name; + }; + stringifier.structStringifier = [&](std::size_t structIndex) { return m_context->structs.Retrieve(structIndex, sourceLocation)->name; @@ -4258,10 +4320,14 @@ NAZARA_WARNING_POP() const AliasType& alias = std::get(resolvedType); aliasIdentifier.target = { alias.aliasIndex, IdentifierCategory::Alias }; } + else if (IsModuleType(resolvedType)) + { + const ModuleType& module = std::get(resolvedType); + aliasIdentifier.target = { module.moduleIndex, IdentifierCategory::Module }; + } else throw CompilerAliasUnexpectedTypeError{ node.sourceLocation, ToString(*exprType, node.expression->sourceLocation) }; - node.aliasIndex = RegisterAlias(node.name, std::move(aliasIdentifier), node.aliasIndex, node.sourceLocation); return ValidationResult::Validated; } diff --git a/src/NZSL/Ast/Utils.cpp b/src/NZSL/Ast/Utils.cpp index 681bae7..7f9f844 100644 --- a/src/NZSL/Ast/Utils.cpp +++ b/src/NZSL/Ast/Utils.cpp @@ -102,6 +102,16 @@ namespace nzsl::Ast m_expressionCategory = ExpressionCategory::LValue; } + void ValueCategory::Visit(ModuleExpression& /*node*/) + { + m_expressionCategory = ExpressionCategory::LValue; + } + + void ValueCategory::Visit(NamedExternalBlockExpression& /*node*/) + { + m_expressionCategory = ExpressionCategory::LValue; + } + void ValueCategory::Visit(StructTypeExpression& /*node*/) { m_expressionCategory = ExpressionCategory::LValue; diff --git a/src/NZSL/GlslWriter.cpp b/src/NZSL/GlslWriter.cpp index 35ef9ed..fcbd671 100644 --- a/src/NZSL/GlslWriter.cpp +++ b/src/NZSL/GlslWriter.cpp @@ -682,6 +682,16 @@ namespace nzsl throw std::runtime_error("unexpected method type"); } + void GlslWriter::Append(const Ast::ModuleType& /*moduleType*/) + { + throw std::runtime_error("unexpected module type"); + } + + void GlslWriter::Append(const Ast::NamedExternalBlockType& /*namedExternalBlockType*/) + { + throw std::runtime_error("unexpected named external block type"); + } + void GlslWriter::Append(Ast::PrimitiveType type) { switch (type) diff --git a/src/NZSL/LangWriter.cpp b/src/NZSL/LangWriter.cpp index eadf569..9e6812f 100644 --- a/src/NZSL/LangWriter.cpp +++ b/src/NZSL/LangWriter.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -187,6 +188,7 @@ namespace nzsl std::unordered_map aliases; std::unordered_map constants; std::unordered_map functions; + std::unordered_map modules; std::unordered_map structs; std::unordered_map variables; std::vector externalBlockNames; @@ -223,6 +225,11 @@ namespace nzsl } m_currentState->currentModuleIndex = std::numeric_limits::max(); + + std::size_t moduleIndex = 0; + for (const auto& importedModule : module.importedModules) + RegisterModule(moduleIndex++, importedModule.identifier); + module.rootNode->Visit(previsitor); } @@ -319,6 +326,16 @@ namespace nzsl throw std::runtime_error("unexpected method type"); } + void LangWriter::Append(const Ast::ModuleType& moduleType) + { + AppendIdentifier(m_currentState->modules, moduleType.moduleIndex); + } + + void LangWriter::Append(const Ast::NamedExternalBlockType& namedExternalBlockType) + { + Append(m_currentState->externalBlockNames[namedExternalBlockType.namedExternalBlockIndex]); + } + void LangWriter::Append(Ast::PrimitiveType type) { switch (type) @@ -915,6 +932,16 @@ namespace nzsl m_currentState->functions.emplace(funcIndex, std::move(identifier)); } + void LangWriter::RegisterModule(std::size_t moduleIndex, std::string moduleName) + { + State::Identifier identifier; + identifier.moduleIndex = m_currentState->currentModuleIndex; + identifier.name = std::move(moduleName); + + assert(m_currentState->modules.find(moduleIndex) == m_currentState->modules.end()); + m_currentState->modules.emplace(moduleIndex, std::move(identifier)); + } + void LangWriter::RegisterStruct(std::size_t structIndex, std::string structName) { State::Identifier identifier; @@ -1259,6 +1286,16 @@ namespace nzsl Append(")"); } + void LangWriter::Visit(Ast::ModuleExpression& node) + { + AppendIdentifier(m_currentState->modules, node.moduleId); + } + + void LangWriter::Visit(Ast::NamedExternalBlockExpression& node) + { + Append(m_currentState->externalBlockNames[node.externalBlockId]); + } + void LangWriter::Visit(Ast::StructTypeExpression& node) { AppendIdentifier(m_currentState->structs, node.structTypeId); @@ -1358,6 +1395,14 @@ namespace nzsl Append("alias ", node.name, " = "); assert(node.expression); node.expression->Visit(*this); + + // Special case, if that alias points to a module, use it instead to try to keep source code readable + if (node.expression->GetType() == Ast::NodeType::ModuleExpression) + { + auto& moduleExpr = Nz::SafeCast(*node.expression); + m_currentState->moduleNames[moduleExpr.moduleId] = node.name; + } + AppendLine(";"); } @@ -1580,25 +1625,45 @@ namespace nzsl { Append("import "); - bool first = true; - for (const auto& entry : node.identifiers) + if (node.identifiers.empty()) { - if (!first) - Append(", "); + // Whole module import + Append(node.moduleName); - first = false; + std::string_view defaultIdentifierName; + std::size_t lastSep = node.moduleName.find_last_of('.'); + if (lastSep != std::string::npos) + defaultIdentifierName = std::string_view(node.moduleName).substr(lastSep + 1); + else + defaultIdentifierName = node.moduleName; + + if (node.moduleIdentifier != node.moduleName) + Append(" as ", node.moduleIdentifier); - if (!entry.identifier.empty()) + AppendLine(";"); + } + else + { + // Module identifier import + bool first = true; + for (const auto& entry : node.identifiers) { - Append(entry.identifier); - if (!entry.renamedIdentifier.empty()) - Append(" as ", entry.renamedIdentifier); + if (!first) + Append(", "); + + first = false; + + if (!entry.identifier.empty()) + { + Append(entry.identifier); + if (!entry.renamedIdentifier.empty()) + Append(" as ", entry.renamedIdentifier); + } + else + Append("*"); } - else - Append("*"); + AppendLine(" from ", node.moduleName, ";"); } - - AppendLine(" from ", node.moduleName, ";"); } void LangWriter::Visit(Ast::MultiStatement& node) diff --git a/src/NZSL/Parser.cpp b/src/NZSL/Parser.cpp index 5937046..c92b2a3 100644 --- a/src/NZSL/Parser.cpp +++ b/src/NZSL/Parser.cpp @@ -5,9 +5,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -403,7 +403,7 @@ namespace nzsl if (m_context->module) { - moduleMetadata->moduleName = ParseModuleName(); + moduleMetadata->moduleName = ParseModuleName(nullptr); auto module = std::make_shared(std::move(moduleMetadata)); // Imported module @@ -434,7 +434,7 @@ namespace nzsl { std::string moduleName; if (Peek().type == TokenType::Identifier) - moduleName = ParseModuleName(); + moduleName = ParseModuleName(nullptr); else { moduleName.resize(33); @@ -1051,7 +1051,7 @@ namespace nzsl identifier.identifierLoc = token.location; } else - identifier.identifier = ParseIdentifierAsName(&identifier.identifierLoc); + identifier.identifier = ParseModuleName(&identifier.identifierLoc); // at this point it could be an identifier or a module name (allowing dots), parse module name for now if (Peek().type == TokenType::As) { @@ -1060,18 +1060,60 @@ namespace nzsl identifier.renamedIdentifier = ParseIdentifierAsName(&identifier.renamedIdentifierLoc); } } - while (Peek().type != TokenType::From); + while (Peek().type == TokenType::Comma); + + const Token& token = Peek(); + if (token.type == TokenType::From) + { + // import from ; + Consume(); //< From - Consume(); //< From + for (auto& identifierData : identifiers) + { + if (identifierData.identifier.find('.') != std::string::npos) + throw ParserModuleImportInvalidIdentifierError{ identifierData.identifierLoc, identifierData.identifier }; + } - std::string moduleName = ParseModuleName(); + std::string moduleName = ParseModuleName(nullptr); - const Token& endtoken = Expect(Advance(), TokenType::Semicolon); + const Token& endtoken = Expect(Advance(), TokenType::Semicolon); - auto importStatement = ShaderBuilder::Import(std::move(moduleName), std::move(identifiers)); - importStatement->sourceLocation = SourceLocation::BuildFromTo(importToken.location, endtoken.location); + auto importStatement = ShaderBuilder::Import(std::move(moduleName), std::move(identifiers)); + importStatement->sourceLocation = SourceLocation::BuildFromTo(importToken.location, endtoken.location); - return importStatement; + return importStatement; + } + else + { + // import (as identifier); -- (where modules comes from identifiers) + if (identifiers.size() != 1) + { + const auto& firstIdentifier = identifiers.front(); + const auto& lastIdentifier = identifiers.back(); + SourceLocation importLoc = SourceLocation::BuildFromTo(firstIdentifier.identifierLoc, lastIdentifier.renamedIdentifierLoc.IsValid() ? lastIdentifier.renamedIdentifierLoc : lastIdentifier.identifierLoc); + throw ParserModuleImportMultipleError{ importLoc }; + } + + auto& firstIdentifier = identifiers.front(); + + std::string identifierName = std::move(firstIdentifier.renamedIdentifier); + if (identifierName.empty()) + { + // When importing a module with a dot separator, the default identifier is the last part; + std::size_t lastSep = firstIdentifier.identifier.find_last_of('.'); + if (lastSep != std::string::npos) + identifierName = firstIdentifier.identifier.substr(lastSep + 1); + else + identifierName = firstIdentifier.identifier; + } + + const Token& endtoken = Expect(Advance(), TokenType::Semicolon); + + auto importStatement = ShaderBuilder::Import(std::move(firstIdentifier.identifier), std::move(identifierName)); + importStatement->sourceLocation = SourceLocation::BuildFromTo(importToken.location, endtoken.location); + + return importStatement; + } } Ast::StatementPtr Parser::ParseOptionDeclaration() @@ -1815,14 +1857,19 @@ namespace nzsl return std::get(identifierToken.data); } - std::string Parser::ParseModuleName() + std::string Parser::ParseModuleName(SourceLocation* sourceLocation) { - std::string moduleName = ParseIdentifierAsName(nullptr); + std::string moduleName = ParseIdentifierAsName(sourceLocation); while (Peek().type == TokenType::Dot) { + SourceLocation identifierLocation; + Consume(); moduleName += '.'; - moduleName += ParseIdentifierAsName(nullptr); + moduleName += ParseIdentifierAsName(&identifierLocation); + + if (sourceLocation) + sourceLocation->ExtendToRight(identifierLocation); } return moduleName; diff --git a/tests/src/Tests/ErrorsTests.cpp b/tests/src/Tests/ErrorsTests.cpp index 09fc789..0178849 100644 --- a/tests/src/Tests/ErrorsTests.cpp +++ b/tests/src/Tests/ErrorsTests.cpp @@ -81,6 +81,13 @@ module; import Stuff; )"), "(5,2 -> 11): PUnexpectedAttribute error: unexpected attribute cond on import statement"); + CHECK_THROWS_WITH(nzsl::Parse(R"( +[nzsl_version("1.0")] +module; + +import Foo.Bar from Baz; +)"), "(5,8 -> 14): PModuleImportInvalidIdentifier error: Foo.Bar is not a valid identifier to import"); + // option statements don't support attributes CHECK_THROWS_WITH(nzsl::Parse(R"( [nzsl_version("1.0")] diff --git a/tests/src/Tests/ModuleTests.cpp b/tests/src/Tests/ModuleTests.cpp index a5d4580..af2b92f 100644 --- a/tests/src/Tests/ModuleTests.cpp +++ b/tests/src/Tests/ModuleTests.cpp @@ -28,7 +28,7 @@ void RegisterModule(const std::shared_ptr& modul TEST_CASE("Modules", "[Shader]") { - WHEN("using a simple module") + WHEN("Importing a simple module") { std::string_view importedSource = R"( [nzsl_version("1.0")] @@ -266,7 +266,7 @@ OpReturn OpFunctionEnd)"); } - WHEN("Using nested modules") + WHEN("Importing nested modules") { std::string_view dataModule = R"( [nzsl_version("1.0")] @@ -965,7 +965,7 @@ fn FragMain() -> FragOut WHEN("Testing a more complex hierarchy") { - // Tests a more complex hierarchy where the same module is imported at multiple levels, which caused a bug + // Tests a more complex hierarchy where the same module is imported at multiple levels, which caused a bug at some point std::string_view lightingLightData = R"( [nzsl_version("1.0")] @@ -1105,4 +1105,234 @@ fn FragMain() -> FragOut } )"); } + + WHEN("Importing a simple module by name") + { + std::string_view importedSource = R"( +[nzsl_version("1.0")] +[author("Lynix")] +[desc("Simple \"module\" for testing")] +[license("Public domain")] +module Simple.Module; + +[export] +const Pi = 3.141592; + +[layout(std140)] +struct Data +{ + value: f32 +} + +[export] +[layout(std140)] +struct Block +{ + data: Data +} + +[export] +fn GetDataValue(data: Data) -> f32 +{ + return data.value; +} + +struct Unused {} + +[export] +struct InputData +{ + value: f32 +} + +[export] +struct OutputData +{ + value: f32 +} +)"; + + std::string_view shaderSource = R"( +[nzsl_version("1.0")] +[author("Sir Lynix")] +[desc("Main file")] +[license("MIT")] +module; + +import Simple.Module as SimpleModule; + +external ExtData +{ + [binding(0)] block: uniform[SimpleModule.Block] +} + +[entry(frag)] +fn main(input: SimpleModule.InputData) -> SimpleModule.OutputData +{ + let data = ExtData.block.data; + + let output: SimpleModule.OutputData; + output.value = SimpleModule.GetDataValue(data) * input.value * SimpleModule.Pi; + return output; +} +)"; + + nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(shaderSource); + + auto directoryModuleResolver = std::make_shared(); + RegisterModule(directoryModuleResolver, importedSource); + + nzsl::Ast::SanitizeVisitor::Options sanitizeOpt; + sanitizeOpt.moduleResolver = directoryModuleResolver; + + shaderModule = SanitizeModule(*shaderModule, sanitizeOpt); + + ExpectGLSL(*shaderModule, R"( +// Module Simple.Module +// Author: Lynix +// Description: Simple "module" for testing +// License: Public domain + +struct Data_Simple_Module +{ + float value; +}; + +// struct Block_Simple_Module omitted (used as UBO/SSBO) + +float GetDataValue_Simple_Module(Data_Simple_Module data) +{ + return data.value; +} + +struct InputData_Simple_Module +{ + float value; +}; + +struct OutputData_Simple_Module +{ + float value; +}; + +// Main module +// Author: Sir Lynix +// Description: Main file +// License: MIT + +layout(std140) uniform _nzslBindingExtData_block +{ + Data_Simple_Module data; +} ExtData_block; + +/**************** Inputs ****************/ +in float _nzslInvalue; + +/*************** Outputs ***************/ +out float _nzslOutvalue; + +void main() +{ + InputData_Simple_Module input_; + input_.value = _nzslInvalue; + + Data_Simple_Module data; + data.value = ExtData_block.data.value; + OutputData_Simple_Module output_; + output_.value = ((GetDataValue_Simple_Module(data)) * input_.value) * (3.141592); + + _nzslOutvalue = output_.value; + return; +} +)"); + + ExpectNZSL(*shaderModule, R"( +[nzsl_version("1.0")] +[author("Sir Lynix"), desc("Main file")] +[license("MIT")] +module; + +[nzsl_version("1.0")] +[author("Lynix"), desc("Simple \"module\" for testing")] +[license("Public domain")] +module _Simple_Module +{ + const Pi: f32 = 3.141592; + + [layout(std140)] + struct Data + { + value: f32 + } + + [layout(std140)] + struct Block + { + data: Data + } + + fn GetDataValue(data: Data) -> f32 + { + return data.value; + } + + struct InputData + { + value: f32 + } + + struct OutputData + { + value: f32 + } + +} +alias SimpleModule = _Simple_Module; + +external ExtData +{ + [set(0), binding(0)] block: uniform[SimpleModule.Block] +} + +[entry(frag)] +fn main(input: SimpleModule.InputData) -> SimpleModule.OutputData +{ + let data: SimpleModule.Data = ExtData.block.data; + let output: SimpleModule.OutputData; + output.value = ((SimpleModule.GetDataValue(data)) * input.value) * SimpleModule.Pi; + return output; +} +)"); + + ExpectSPIRV(*shaderModule, R"( +OpFunction +OpFunctionParameter +OpLabel +OpAccessChain +OpLoad +OpReturnValue +OpFunctionEnd +OpFunction +OpLabel +OpVariable +OpVariable +OpVariable +OpVariable +OpAccessChain +OpLoad +OpAccessChain +OpStore +OpLoad +OpStore +OpFunctionCall +OpAccessChain +OpLoad +OpFMul +OpFMul +OpAccessChain +OpStore +OpLoad +OpReturn +OpFunctionEnd)"); + } }