Skip to content

Commit

Permalink
[Attributes][HLSL] Teach EnumArgument to refer to an external enum (#…
Browse files Browse the repository at this point in the history
…70835)

Rather than write a bunch of logic to shepherd between enums with the
same sets of values, add the ability for EnumArgument to refer to an
external enum in the first place.
  • Loading branch information
bogner authored Nov 1, 2023
1 parent f2c24cc commit 1c6c01f
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 155 deletions.
6 changes: 3 additions & 3 deletions clang-tools-extra/modularize/ModularizeUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ std::error_code ModularizeUtilities::loadModuleMap(
// Walks the modules and collects referenced headers into
// HeaderFileNames.
bool ModularizeUtilities::collectModuleMapHeaders(clang::ModuleMap *ModMap) {
SmallVector<std::pair<StringRef, const Module *>, 0> Vec;
SmallVector<std::pair<StringRef, const clang::Module *>, 0> Vec;
for (auto &M : ModMap->modules())
Vec.emplace_back(M.first(), M.second);
llvm::sort(Vec, llvm::less_first());
Expand All @@ -349,14 +349,14 @@ bool ModularizeUtilities::collectModuleHeaders(const clang::Module &Mod) {
for (auto *Submodule : Mod.submodules())
collectModuleHeaders(*Submodule);

if (std::optional<Module::Header> UmbrellaHeader =
if (std::optional<clang::Module::Header> UmbrellaHeader =
Mod.getUmbrellaHeaderAsWritten()) {
std::string HeaderPath = getCanonicalPath(UmbrellaHeader->Entry.getName());
// Collect umbrella header.
HeaderFileNames.push_back(HeaderPath);

// FUTURE: When needed, umbrella header header collection goes here.
} else if (std::optional<Module::DirectoryName> UmbrellaDir =
} else if (std::optional<clang::Module::DirectoryName> UmbrellaDir =
Mod.getUmbrellaDirAsWritten()) {
// If there normal headers, assume these are umbrellas and skip collection.
if (Mod.Headers->size() == 0) {
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
#include "clang/AST/Type.h"
#include "clang/Basic/AttrKinds.h"
#include "clang/Basic/AttributeCommonInfo.h"
#include "clang/Basic/LangOptions.h"
#include "clang/Basic/LLVM.h"
#include "clang/Basic/LangOptions.h"
#include "clang/Basic/OpenMPKinds.h"
#include "clang/Basic/Sanitizers.h"
#include "clang/Basic/SourceLocation.h"
#include "llvm/Frontend/HLSL/HLSLResource.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/VersionTuple.h"
#include "llvm/Support/raw_ostream.h"
Expand Down
35 changes: 20 additions & 15 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -277,23 +277,28 @@ class DefaultIntArgument<string name, int default> : IntArgument<name, 1> {
int Default = default;
}

// This argument is more complex, it includes the enumerator type name,
// a list of strings to accept, and a list of enumerators to map them to.
// This argument is more complex, it includes the enumerator type
// name, whether the enum type is externally defined, a list of
// strings to accept, and a list of enumerators to map them to.
class EnumArgument<string name, string type, list<string> values,
list<string> enums, bit opt = 0, bit fake = 0>
list<string> enums, bit opt = 0, bit fake = 0,
bit isExternalType = 0>
: Argument<name, opt, fake> {
string Type = type;
list<string> Values = values;
list<string> Enums = enums;
bit IsExternalType = isExternalType;
}

// FIXME: There should be a VariadicArgument type that takes any other type
// of argument and generates the appropriate type.
class VariadicEnumArgument<string name, string type, list<string> values,
list<string> enums> : Argument<name, 1> {
list<string> enums, bit isExternalType = 0>
: Argument<name, 1> {
string Type = type;
list<string> Values = values;
list<string> Enums = enums;
bit IsExternalType = isExternalType;
}

// This handles one spelling of an attribute.
Expand Down Expand Up @@ -4182,26 +4187,26 @@ def HLSLResource : InheritableAttr {
let Spellings = [];
let Subjects = SubjectList<[Struct]>;
let LangOpts = [HLSL];
let Args = [EnumArgument<"ResourceType", "ResourceClass",
let Args = [EnumArgument<"ResourceClass", "llvm::hlsl::ResourceClass",
["SRV", "UAV", "CBuffer", "Sampler"],
["SRV", "UAV", "CBuffer", "Sampler"]
>,
EnumArgument<"ResourceShape", "ResourceKind",
["SRV", "UAV", "CBuffer", "Sampler"],
/*opt=*/0, /*fake=*/0, /*isExternalType=*/1>,
EnumArgument<"ResourceKind", "llvm::hlsl::ResourceKind",
["Texture1D", "Texture2D", "Texture2DMS",
"Texture3D", "TextureCube", "Texture1DArray",
"Texture2DArray", "Texture2DMSArray",
"TextureCubeArray", "TypedBuffer", "RawBuffer",
"StructuredBuffer", "CBufferKind", "SamplerKind",
"TBuffer", "RTAccelerationStructure", "FeedbackTexture2D",
"FeedbackTexture2DArray"],
"StructuredBuffer", "CBuffer", "Sampler",
"TBuffer", "RTAccelerationStructure",
"FeedbackTexture2D", "FeedbackTexture2DArray"],
["Texture1D", "Texture2D", "Texture2DMS",
"Texture3D", "TextureCube", "Texture1DArray",
"Texture2DArray", "Texture2DMSArray",
"TextureCubeArray", "TypedBuffer", "RawBuffer",
"StructuredBuffer", "CBufferKind", "SamplerKind",
"TBuffer", "RTAccelerationStructure", "FeedbackTexture2D",
"FeedbackTexture2DArray"]
>
"StructuredBuffer", "CBuffer", "Sampler",
"TBuffer", "RTAccelerationStructure",
"FeedbackTexture2D", "FeedbackTexture2DArray"],
/*opt=*/0, /*fake=*/0, /*isExternalType=*/1>
];
let Documentation = [InternalOnly];
}
Expand Down
59 changes: 3 additions & 56 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,56 +223,6 @@ void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
ResourceMD->addOperand(Res.getMetadata());
}

static llvm::hlsl::ResourceKind
castResourceShapeToResourceKind(HLSLResourceAttr::ResourceKind RK) {
switch (RK) {
case HLSLResourceAttr::ResourceKind::Texture1D:
return llvm::hlsl::ResourceKind::Texture1D;
case HLSLResourceAttr::ResourceKind::Texture2D:
return llvm::hlsl::ResourceKind::Texture2D;
case HLSLResourceAttr::ResourceKind::Texture2DMS:
return llvm::hlsl::ResourceKind::Texture2DMS;
case HLSLResourceAttr::ResourceKind::Texture3D:
return llvm::hlsl::ResourceKind::Texture3D;
case HLSLResourceAttr::ResourceKind::TextureCube:
return llvm::hlsl::ResourceKind::TextureCube;
case HLSLResourceAttr::ResourceKind::Texture1DArray:
return llvm::hlsl::ResourceKind::Texture1DArray;
case HLSLResourceAttr::ResourceKind::Texture2DArray:
return llvm::hlsl::ResourceKind::Texture2DArray;
case HLSLResourceAttr::ResourceKind::Texture2DMSArray:
return llvm::hlsl::ResourceKind::Texture2DMSArray;
case HLSLResourceAttr::ResourceKind::TextureCubeArray:
return llvm::hlsl::ResourceKind::TextureCubeArray;
case HLSLResourceAttr::ResourceKind::TypedBuffer:
return llvm::hlsl::ResourceKind::TypedBuffer;
case HLSLResourceAttr::ResourceKind::RawBuffer:
return llvm::hlsl::ResourceKind::RawBuffer;
case HLSLResourceAttr::ResourceKind::StructuredBuffer:
return llvm::hlsl::ResourceKind::StructuredBuffer;
case HLSLResourceAttr::ResourceKind::CBufferKind:
return llvm::hlsl::ResourceKind::CBuffer;
case HLSLResourceAttr::ResourceKind::SamplerKind:
return llvm::hlsl::ResourceKind::Sampler;
case HLSLResourceAttr::ResourceKind::TBuffer:
return llvm::hlsl::ResourceKind::TBuffer;
case HLSLResourceAttr::ResourceKind::RTAccelerationStructure:
return llvm::hlsl::ResourceKind::RTAccelerationStructure;
case HLSLResourceAttr::ResourceKind::FeedbackTexture2D:
return llvm::hlsl::ResourceKind::FeedbackTexture2D;
case HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray:
return llvm::hlsl::ResourceKind::FeedbackTexture2DArray;
}
// Make sure to update HLSLResourceAttr::ResourceKind when add new Kind to
// hlsl::ResourceKind. Assume FeedbackTexture2DArray is the last enum for
// HLSLResourceAttr::ResourceKind.
static_assert(
static_cast<uint32_t>(
HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray) ==
(static_cast<uint32_t>(llvm::hlsl::ResourceKind::NumEntries) - 2));
llvm_unreachable("all switch cases should be covered");
}

void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
const Type *Ty = D->getType()->getPointeeOrArrayElementType();
if (!Ty)
Expand All @@ -284,15 +234,12 @@ void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
if (!Attr)
return;

HLSLResourceAttr::ResourceClass RC = Attr->getResourceType();
llvm::hlsl::ResourceKind RK =
castResourceShapeToResourceKind(Attr->getResourceShape());
llvm::hlsl::ResourceClass RC = Attr->getResourceClass();
llvm::hlsl::ResourceKind RK = Attr->getResourceKind();

QualType QT(Ty, 0);
BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
addBufferResourceAnnotation(GV, QT.getAsString(),
static_cast<llvm::hlsl::ResourceClass>(RC), RK,
Binding);
addBufferResourceAnnotation(GV, QT.getAsString(), RC, RK, Binding);
}

CGHLSLRuntime::BufferResBinding::BufferResBinding(
Expand Down
8 changes: 3 additions & 5 deletions clang/lib/Sema/HLSLExternalSemaSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,8 @@ struct BuiltinTypeDeclBuilder {
return addMemberVariable("h", Ty, Access);
}

BuiltinTypeDeclBuilder &
annotateResourceClass(HLSLResourceAttr::ResourceClass RC,
HLSLResourceAttr::ResourceKind RK) {
BuiltinTypeDeclBuilder &annotateResourceClass(ResourceClass RC,
ResourceKind RK) {
if (Record->isCompleteDefinition())
return *this;
Record->addAttr(
Expand Down Expand Up @@ -503,7 +502,6 @@ void HLSLExternalSemaSource::completeBufferType(CXXRecordDecl *Record) {
.addHandleMember()
.addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV)
.addArraySubscriptOperators()
.annotateResourceClass(HLSLResourceAttr::UAV,
HLSLResourceAttr::TypedBuffer)
.annotateResourceClass(ResourceClass::UAV, ResourceKind::TypedBuffer)
.completeDefinition();
}
2 changes: 1 addition & 1 deletion clang/unittests/Sema/SemaNoloadLookupTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class NoloadLookupConsumer : public SemaConsumer {
if (!ID)
return true;

Module *M = ID->getImportedModule();
clang::Module *M = ID->getImportedModule();
assert(M);
if (M->Name != "R")
return true;
Expand Down
Loading

0 comments on commit 1c6c01f

Please sign in to comment.