Skip to content

Commit

Permalink
Migrate TypeAsmInterface to OpAsmTypeInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Feb 20, 2025
1 parent 27df7a0 commit 6d893c1
Show file tree
Hide file tree
Showing 28 changed files with 84 additions and 250 deletions.
1 change: 0 additions & 1 deletion lib/Analysis/SelectVariableNames/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ cc_library(
srcs = ["SelectVariableNames.cpp"],
hdrs = ["SelectVariableNames.h"],
deps = [
"@heir//lib/Utils/Tablegen:AsmInterfaces",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
10 changes: 6 additions & 4 deletions lib/Analysis/SelectVariableNames/SelectVariableNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
#include <map>
#include <string>

#include "lib/Utils/Tablegen/AsmInterfaces.h"
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
Expand All @@ -15,9 +15,11 @@ namespace mlir {
namespace heir {

std::string SelectVariableNames::suggestNameForValue(Value value) {
if (auto typeAsmInterface =
mlir::dyn_cast<TypeAsmInterface>(value.getType())) {
return typeAsmInterface.suggestedName();
if (auto opAsmTypeInterface =
mlir::dyn_cast<OpAsmTypeInterface>(value.getType())) {
std::string asmName;
opAsmTypeInterface.getAsmName([&](StringRef name) { asmName = name; });
return asmName;
}
return defaultPrefix;
}
Expand Down
13 changes: 1 addition & 12 deletions lib/Dialect/BGV/IR/BGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,13 @@ include "lib/Dialect/LWE/IR/LWETraits.td"
include "lib/Dialect/Polynomial/IR/PolynomialAttributes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/OpAsmInterface.td"

class BGV_Op<string mnemonic, list<Trait> traits = []> :
Op<BGV_Dialect, mnemonic, traits # [OpAsmOpInterface]> {
Op<BGV_Dialect, mnemonic, traits> {
let cppNamespace = "::mlir::heir::bgv";
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
// OpAsmOpInterface Methods
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) {
for (auto result : getOperation()->getResults()) {
if (auto ty = dyn_cast<TypeAsmInterface>(result.getType()))
setNameFn(result, ty.suggestedName());
}
}
}];
}

class BGV_CiphertextPlaintextOp<string mnemonic, list<Trait> traits =
Expand Down
12 changes: 1 addition & 11 deletions lib/Dialect/CKKS/IR/CKKSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,11 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"

class CKKS_Op<string mnemonic, list<Trait> traits = []> :
Op<CKKS_Dialect, mnemonic, traits # [OpAsmOpInterface]> {
Op<CKKS_Dialect, mnemonic, traits> {
let cppNamespace = "::mlir::heir::ckks";
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
// OpAsmOpInterface Methods
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) {
for (auto result : getOperation()->getResults()) {
if (auto ty = dyn_cast<TypeAsmInterface>(result.getType()))
setNameFn(result, ty.suggestedName());
}
}
}];
}

class CKKS_CiphertextPlaintextOp<string mnemonic, list<Trait> traits =
Expand Down
2 changes: 0 additions & 2 deletions lib/Dialect/LWE/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ cc_library(
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Dialect/Polynomial/IR:Dialect",
"@heir//lib/Dialect/RNS/IR:Dialect",
"@heir//lib/Utils/Tablegen:AsmInterfaces",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
Expand Down Expand Up @@ -63,7 +62,6 @@ td_library(
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@heir//lib/Utils/Tablegen:td_files",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
Expand Down
13 changes: 1 addition & 12 deletions lib/Dialect/LWE/IR/LWEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/OpAsmInterface.td"

class HasEncoding<
string encodingHolder,
Expand Down Expand Up @@ -86,21 +85,11 @@ class KeyAndCiphertextMatch<

// LWE Operations are always Pure by design
class LWE_Op<string mnemonic, list<Trait> traits = []> :
Op<LWE_Dialect, mnemonic, traits # [Pure, OpAsmOpInterface]> {
Op<LWE_Dialect, mnemonic, traits # [Pure]> {
let cppNamespace = "::mlir::heir::lwe";
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
// OpAsmOpInterface Methods
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) {
for (auto result : getOperation()->getResults()) {
if (auto ty = dyn_cast<TypeAsmInterface>(result.getType()))
setNameFn(result, ty.suggestedName());
}
}
}];
}

class LWE_BinOp<string mnemonic, list<Trait> traits = []> :
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/LWE/IR/LWETypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "lib/Dialect/LWE/IR/LWEAttributes.h"
#include "lib/Dialect/LWE/IR/LWEDialect.h"
#include "lib/Utils/Tablegen/AsmInterfaces.h"
#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project

#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/LWE/IR/LWETypes.h.inc"
Expand Down
12 changes: 6 additions & 6 deletions lib/Dialect/LWE/IR/LWETypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def LWECiphertext : LWE_Type<"LWECiphertext", "lwe_ciphertext", [MemRefElementTy
"::mlir::Attribute":$encoding,
OptionalParameter<"LWEParamsAttr">:$lwe_params
);
let nameSuggestion = "ct";
let asmName = "ct";
}

def LWECiphertextLike : TypeOrValueSemanticsContainer<LWECiphertext, "ciphertext-like">;
Expand All @@ -39,21 +39,21 @@ def RLWECiphertext : LWE_Type<"RLWECiphertext", "rlwe_ciphertext"> {
"RLWEParamsAttr":$rlwe_params,
"Type":$underlying_type
);
let nameSuggestion = "ct";
let asmName = "ct";
}

def RLWECiphertextLike : TypeOrValueSemanticsContainer<RLWECiphertext, "ciphertext-like">;

def RLWESecretKey : LWE_Type<"RLWESecretKey", "rlwe_secret_key"> {
let summary = "A secret key for RLWE";
let parameters = (ins "RLWEParamsAttr":$rlwe_params);
let nameSuggestion = "sk";
let asmName = "sk";
}

def RLWEPublicKey : LWE_Type<"RLWEPublicKey", "rlwe_public_key"> {
let summary = "A public key for RLWE";
let parameters = (ins "RLWEParamsAttr":$rlwe_params);
let nameSuggestion = "pk";
let asmName = "pk";
}

def RLWESecretOrPublicKey : AnyTypeOf<[RLWESecretKey, RLWEPublicKey]>;
Expand All @@ -71,7 +71,7 @@ def LWEPlaintext : LWE_Type<"LWEPlaintext", "lwe_plaintext"> {
let parameters = (ins
"::mlir::Attribute":$encoding
);
let nameSuggestion = "pt";
let asmName = "pt";
}

def LWEPlaintextLike : TypeOrValueSemanticsContainer<LWEPlaintext, "lwe-plaintext-like">;
Expand All @@ -84,7 +84,7 @@ def RLWEPlaintext : LWE_Type<"RLWEPlaintext", "rlwe_plaintext"> {
"::mlir::heir::polynomial::RingAttr":$ring,
"Type":$underlying_type
);
let nameSuggestion = "pt";
let asmName = "pt";
}

def RLWEPlaintextLike : TypeOrValueSemanticsContainer<RLWEPlaintext, "rlwe-plaintext-like">;
Expand Down
23 changes: 13 additions & 10 deletions lib/Dialect/LWE/IR/NewLWETypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@

include "lib/Dialect/LWE/IR/LWEDialect.td"
include "lib/Dialect/LWE/IR/NewLWEAttributes.td"
include "lib/Utils/Tablegen/AsmInterfaces.td"

include "mlir/IR/DialectBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpAsmInterface.td"

// A base class for all types in this dialect
class LWE_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<LWE_Dialect, name, traits # [
DeclareTypeInterfaceMethods<TypeAsmInterface, ["suggestedName"]>
]> {
: TypeDef<LWE_Dialect, name, traits # [OpAsmTypeInterface]> {
let mnemonic = typeMnemonic;
let assemblyFormat = "`<` struct(params) `>`";

string nameSuggestion = ?;
let extraClassDeclaration = "std::string suggestedName() {return \"" # nameSuggestion # "\"; }";
string asmName = ?;
let extraClassDeclaration = [{
// OpAsmTypeInterface method
void getAsmName(::mlir::OpAsmSetNameFn setNameFn) const {
setNameFn("}] # asmName # [{");
}
}];
}

// This file defines new LWE types following
Expand All @@ -30,7 +33,7 @@ def NewLWESecretKey : LWE_Type<"NewLWESecretKey", "new_lwe_secret_key"> {
"KeyAttr":$key,
"::mlir::heir::polynomial::RingAttr":$ring
);
let nameSuggestion = "sk";
let asmName = "sk";
}

def NewLWEPublicKey : LWE_Type<"NewLWEPublicKey", "new_lwe_public_key"> {
Expand All @@ -39,7 +42,7 @@ def NewLWEPublicKey : LWE_Type<"NewLWEPublicKey", "new_lwe_public_key"> {
"KeyAttr":$key,
"::mlir::heir::polynomial::RingAttr":$ring
);
let nameSuggestion = "pk";
let asmName = "pk";
}

def NewLWESecretOrPublicKey : AnyTypeOf<[NewLWESecretKey, NewLWEPublicKey]>;
Expand All @@ -50,7 +53,7 @@ def NewLWEPlaintext : LWE_Type<"NewLWEPlaintext", "new_lwe_plaintext"> {
"ApplicationDataAttr":$application_data,
"PlaintextSpaceAttr":$plaintext_space
);
let nameSuggestion = "pt";
let asmName = "pt";
}

def NewLWEPlaintextLike : TypeOrValueSemanticsContainer<NewLWEPlaintext, "new-lwe-plaintext-like">;
Expand All @@ -77,7 +80,7 @@ def NewLWECiphertext : LWE_Type<"NewLWECiphertext", "new_lwe_ciphertext"> {
);

let genVerifyDecl = 1;
let nameSuggestion = "ct";
let asmName = "ct";
}

def NewLWECiphertextLike : TypeOrValueSemanticsContainer<NewLWECiphertext, "new-lwe-ciphertext-like">;
Expand Down
4 changes: 0 additions & 4 deletions lib/Dialect/Lattigo/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ cc_library(
":LattigoAttributes",
":LattigoOps",
":LattigoTypes",
"@heir//lib/Utils/Tablegen:AsmInterfaces",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
Expand Down Expand Up @@ -64,7 +63,6 @@ cc_library(
":attributes_inc_gen",
":dialect_inc_gen",
":types_inc_gen",
"@heir//lib/Utils/Tablegen:AsmInterfaces",
"@llvm-project//mlir:IR",
],
)
Expand All @@ -85,7 +83,6 @@ cc_library(
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@heir//lib/Utils/Tablegen:AsmInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
Expand All @@ -110,7 +107,6 @@ td_library(
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@heir//lib/Utils/Tablegen:td_files",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
],
Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/Lattigo/IR/LattigoBGVTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ def Lattigo_BGVParameter : Lattigo_BGVType<"Parameter", "parameter"> {
let description = [{
This type represents the parameters for the BGV encryption scheme.
}];
let nameSuggestion = "param";
let asmName = "param";
}

// BGVEvaluator type definition
def Lattigo_BGVEvaluator : Lattigo_BGVType<"Evaluator", "evaluator"> {
let description = [{
This type represents the evaluator for the BGV encryption scheme.
}];
let nameSuggestion = "evaluator";
let asmName = "evaluator";
}

// BGVEncoder type definition
def Lattigo_BGVEncoder : Lattigo_BGVType<"Encoder", "encoder"> {
let description = [{
This type represents the encoder for the BGV encryption scheme.
}];
let nameSuggestion = "encoder";
let asmName = "encoder";
}


Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/Lattigo/IR/LattigoCKKSTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ def Lattigo_CKKSParameter : Lattigo_CKKSType<"Parameter", "parameter"> {
let description = [{
This type represents the parameters for the CKKS encryption scheme.
}];
let nameSuggestion = "param";
let asmName = "param";
}

// CKKSEvaluator type definition
def Lattigo_CKKSEvaluator : Lattigo_CKKSType<"Evaluator", "evaluator"> {
let description = [{
This type represents the evaluator for the CKKS encryption scheme.
}];
let nameSuggestion = "evaluator";
let asmName = "evaluator";
}

// CKKSEncoder type definition
def Lattigo_CKKSEncoder : Lattigo_CKKSType<"Encoder", "encoder"> {
let description = [{
This type represents the encoder for the CKKS encryption scheme.
}];
let nameSuggestion = "encoder";
let asmName = "encoder";
}


Expand Down
13 changes: 1 addition & 12 deletions lib/Dialect/Lattigo/IR/LattigoOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,12 @@
include "LattigoDialect.td"
include "LattigoTypes.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"

class Lattigo_Op<string mnemonic, list<Trait> traits = []> :
Op<Lattigo_Dialect, mnemonic, traits # [OpAsmOpInterface]> {
Op<Lattigo_Dialect, mnemonic, traits> {
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
// OpAsmOpInterface Methods
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) {
for (auto result : getOperation()->getResults()) {
if (auto ty = dyn_cast<TypeAsmInterface>(result.getType()))
setNameFn(result, ty.suggestedName());
}
}
}];
}

include "LattigoBGVOps.td"
Expand Down
Loading

0 comments on commit 6d893c1

Please sign in to comment.