Skip to content

Commit

Permalink
Compiler: Allow aliases to be conditional
Browse files Browse the repository at this point in the history
  • Loading branch information
SirLynix committed Jul 25, 2024
1 parent 3708a6c commit 95b85a7
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 19 deletions.
2 changes: 1 addition & 1 deletion include/NZSL/Parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace nzsl
Ast::ExpressionPtr BuildUnary(Ast::UnaryType unaryType, Ast::ExpressionPtr expr);

// Statements
Ast::StatementPtr ParseAliasDeclaration();
Ast::StatementPtr ParseAliasDeclaration(std::vector<Attribute> attributes = {});
Ast::StatementPtr ParseBranchStatement();
Ast::StatementPtr ParseBreakStatement();
Ast::StatementPtr ParseConstStatement(std::vector<Attribute> attributes = {});
Expand Down
31 changes: 21 additions & 10 deletions src/NZSL/Ast/SanitizeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3541,8 +3541,14 @@ namespace nzsl::Ast

std::size_t SanitizeVisitor::RegisterAlias(std::string name, std::optional<Identifier> aliasData, std::optional<std::size_t> index, const SourceLocation& sourceLocation)
{
if (!IsIdentifierAvailable(name))
throw CompilerIdentifierAlreadyUsedError{ sourceLocation, name };
bool unresolved = false;
if (const IdentifierData* identifierData = FindIdentifier(name))
{
if (!m_context->inConditionalStatement || !identifierData->isConditional)
throw CompilerIdentifierAlreadyUsedError{ sourceLocation, name };
else
unresolved = true;
}

std::size_t aliasIndex;
if (aliasData)
Expand All @@ -3555,14 +3561,19 @@ namespace nzsl::Ast
else
aliasIndex = m_context->aliases.RegisterNewIndex(true);

m_context->currentEnv->identifiersInScope.push_back({
std::move(name),
{
aliasIndex,
IdentifierCategory::Alias,
m_context->inConditionalStatement
}
});
if (!unresolved)
{
m_context->currentEnv->identifiersInScope.push_back({
std::move(name),
{
aliasIndex,
IdentifierCategory::Alias,
m_context->inConditionalStatement
}
});
}
else
RegisterUnresolved(std::move(name));

return aliasIndex;
}
Expand Down
27 changes: 21 additions & 6 deletions src/NZSL/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,10 +518,25 @@ namespace nzsl
return unaryExpr;
}

Ast::StatementPtr Parser::ParseAliasDeclaration()
Ast::StatementPtr Parser::ParseAliasDeclaration(std::vector<Attribute> attributes)
{
const Token& aliasToken = Expect(Advance(), TokenType::Alias);

Ast::ExpressionValue<bool> condition;

for (auto&& attribute : attributes)
{
switch (attribute.type)
{
case Ast::AttributeType::Cond:
HandleUniqueAttribute(condition, std::move(attribute));
break;

default:
throw ParserUnexpectedAttributeError{ attribute.sourceLocation, attribute.type, "alias declaration" };
}
}

std::string name = ParseIdentifierAsName(nullptr);

Expect(Advance(), TokenType::Assign);
Expand All @@ -533,7 +548,10 @@ namespace nzsl
auto aliasStatement = ShaderBuilder::DeclareAlias(std::move(name), std::move(expr));
aliasStatement->sourceLocation = SourceLocation::BuildFromTo(aliasToken.location, endToken.location);

return aliasStatement;
if (condition.HasValue())
return ShaderBuilder::ConditionalStatement(std::move(condition).GetExpression(), std::move(aliasStatement));
else
return aliasStatement;
}

Ast::StatementPtr Parser::ParseBranchStatement()
Expand Down Expand Up @@ -1069,10 +1087,7 @@ namespace nzsl
switch (nextToken.type)
{
case TokenType::Alias:
if (!attributes.empty())
throw ParserUnexpectedAttributeError{ attributes.front().sourceLocation, attributes.front().type, "alias declaration" };

return ParseAliasDeclaration();
return ParseAliasDeclaration(std::move(attributes));

case TokenType::Const:
return ParseConstStatement(std::move(attributes));
Expand Down
197 changes: 197 additions & 0 deletions tests/src/Tests/AliasTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,201 @@ OpStore
OpReturn
OpFunctionEnd)");
}

SECTION("Conditional aliases")
{
std::string_view nzslSource = R"(
[nzsl_version("1.0")]
module;
struct ForwardOutput
{
[location(0)] color: vec4[f32]
}
struct DeferredOutput
{
[location(0)] color: vec4[f32],
[location(1)] normal: vec3[f32]
}
option ForwardPass: bool;
[cond(ForwardPass)]
alias FragOut = ForwardOutput;
[cond(!ForwardPass)]
alias FragOut = DeferredOutput;
[entry(frag)]
fn main() -> FragOut
{
let output: FragOut;
output.color = vec4[f32](0.0, 0.0, 1.0, 1.0);
const if (!ForwardPass)
output.normal = vec3[f32](0.0, 1.0, 0.0);
return output;
}
)";

nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource);

WHEN("We perform a partial sanitization")
{
nzsl::Ast::SanitizeVisitor::Options options;
options.partialSanitization = true;

shaderModule = SanitizeModule(*shaderModule, options);
}

WHEN("We enable ForwardPass")
{
nzsl::Ast::SanitizeVisitor::Options options;
options.optionValues[nzsl::Ast::HashOption("ForwardPass")] = true;
options.removeOptionDeclaration = true;

shaderModule = SanitizeModule(*shaderModule, options);

ExpectGLSL(*shaderModule, R"(
struct ForwardOutput
{
vec4 color;
};
struct DeferredOutput
{
vec4 color;
vec3 normal;
};
/*************** Outputs ***************/
layout(location = 0) out vec4 _nzslOutcolor;
void main()
{
ForwardOutput output_;
output_.color = vec4(0.0, 0.0, 1.0, 1.0);
_nzslOutcolor = output_.color;
return;
}
)");

ExpectNZSL(*shaderModule, R"(
struct ForwardOutput
{
[location(0)] color: vec4[f32]
}
struct DeferredOutput
{
[location(0)] color: vec4[f32],
[location(1)] normal: vec3[f32]
}
alias FragOut = ForwardOutput;
[entry(frag)]
fn main() -> FragOut
{
let output: FragOut;
output.color = vec4[f32](0.0, 0.0, 1.0, 1.0);
return output;
}
)");

ExpectSPIRV(*shaderModule, R"(
OpFunction
OpLabel
OpVariable
OpCompositeConstruct
OpAccessChain
OpStore
OpLoad
OpCompositeExtract
OpStore
OpReturn
OpFunctionEnd)");
}

WHEN("We disable ForwardPass")
{
nzsl::Ast::SanitizeVisitor::Options options;
options.optionValues[nzsl::Ast::HashOption("ForwardPass")] = false;
options.removeOptionDeclaration = true;

shaderModule = SanitizeModule(*shaderModule, options);

ExpectGLSL(*shaderModule, R"(
struct ForwardOutput
{
vec4 color;
};
struct DeferredOutput
{
vec4 color;
vec3 normal;
};
/*************** Outputs ***************/
layout(location = 0) out vec4 _nzslOutcolor;
layout(location = 1) out vec3 _nzslOutnormal;
void main()
{
DeferredOutput output_;
output_.color = vec4(0.0, 0.0, 1.0, 1.0);
output_.normal = vec3(0.0, 1.0, 0.0);
_nzslOutcolor = output_.color;
_nzslOutnormal = output_.normal;
return;
}
)");

ExpectNZSL(*shaderModule, R"(
struct ForwardOutput
{
[location(0)] color: vec4[f32]
}
struct DeferredOutput
{
[location(0)] color: vec4[f32],
[location(1)] normal: vec3[f32]
}
alias FragOut = DeferredOutput;
[entry(frag)]
fn main() -> FragOut
{
let output: FragOut;
output.color = vec4[f32](0.0, 0.0, 1.0, 1.0);
output.normal = vec3[f32](0.0, 1.0, 0.0);
return output;
}
)");

ExpectSPIRV(*shaderModule, R"(
OpFunction
OpLabel
OpVariable
OpCompositeConstruct
OpAccessChain
OpStore
OpCompositeConstruct
OpAccessChain
OpStore
OpLoad
OpCompositeExtract
OpStore
OpCompositeExtract
OpStore
OpReturn
OpFunctionEnd)");
}
}
}
4 changes: 2 additions & 2 deletions tests/src/Tests/ErrorsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ module;
[nzsl_version("1.0")]
module;
[cond(false)]
[layout(std140)]
alias vec3f32 = vec3[f32];
)"), "(5,2 -> 12): PUnexpectedAttribute error: unexpected attribute cond on alias declaration");
)"), "(5,2 -> 15): PUnexpectedAttribute error: unexpected attribute layout on alias declaration");

// import statements don't support cond attribute
CHECK_THROWS_WITH(nzsl::Parse(R"(
Expand Down

0 comments on commit 95b85a7

Please sign in to comment.