Skip to content

Commit

Permalink
[aarch64][win] Add support for import call optimization (equivalent t…
Browse files Browse the repository at this point in the history
…o MSVC /d2ImportCallOptimization) (#121516)

This change implements import call optimization for AArch64 Windows
(equivalent to the undocumented MSVC `/d2ImportCallOptimization` flag).

Import call optimization adds additional data to the binary which can be
used by the Windows kernel loader to rewrite indirect calls to imported
functions as direct calls. It uses the same [Dynamic Value Relocation
Table mechanism that was leveraged on x64 to implement
`/d2GuardRetpoline`](https://techcommunity.microsoft.com/blog/windowsosplatform/mitigating-spectre-variant-2-with-retpoline-on-windows/295618).

The change to the obj file is to add a new `.impcall` section with the
following layout:
```cpp
  // Per section that contains calls to imported functions:
  //  uint32_t SectionSize: Size in bytes for information in this section.
  //  uint32_t Section Number
  //  Per call to imported function in section:
  //    uint32_t Kind: the kind of imported function.
  //    uint32_t BranchOffset: the offset of the branch instruction in its
  //                            parent section.
  //    uint32_t TargetSymbolId: the symbol id of the called function.
```

NOTE: If the import call optimization feature is enabled, then the
`.impcall` section must be emitted, even if there are no calls to
imported functions.

The implementation is split across a few parts of LLVM:
* During AArch64 instruction selection, the `GlobalValue` for each call
to a global is recorded into the Extra Information for that node.
* During lowering to machine instructions, the called global value for
each call is noted in its containing `MachineFunction`.
* During AArch64 asm printing, if the import call optimization feature
is enabled:
- A (new) `.impcall` directive is emitted for each call to an imported
function.
- The `.impcall` section is emitted with its magic header (but is not
filled in).
* During COFF object writing, the `.impcall` section is filled in based
on each `.impcall` directive that were encountered.

The `.impcall` section can only be filled in when we are writing the
COFF object as it requires the actual section numbers, which are only
assigned at that point (i.e., they don't exist during asm printing).

I had tried to avoid using the Extra Information during instruction
selection and instead implement this either purely during asm printing
or in a `MachineFunctionPass` (as suggested in [on the
forums](https://discourse.llvm.org/t/design-gathering-locations-of-instructions-to-emit-into-a-section/83729/3))
but this was not possible due to how loading and calling an imported
function works on AArch64. Specifically, they are emitted as `ADRP` +
`LDR` (to load the symbol) then a `BR` (to do the call), so at the point
when we have machine instructions, we would have to work backwards
through the instructions to discover what is being called. An initial
prototype did work by inspecting instructions; however, it didn't
correctly handle the case where the same function was called twice in a
row, which caused LLVM to elide the `ADRP` + `LDR` and reuse the
previously loaded address. Worse than that, sometimes for the
double-call case LLVM decided to spill the loaded address to the stack
and then reload it before making the second call. So, instead of trying
to implement logic to discover where the value in a register came from,
I instead recorded the symbol being called at the last place where it
was easy to do: instruction selection.
  • Loading branch information
dpaoliello authored Jan 12, 2025
1 parent 4f6fabd commit 5ee0a71
Show file tree
Hide file tree
Showing 25 changed files with 673 additions and 38 deletions.
45 changes: 35 additions & 10 deletions llvm/include/llvm/CodeGen/MIRYamlMapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,16 @@ template <> struct ScalarTraits<FrameIndex> {
static QuotingType mustQuote(StringRef S) { return needsQuotes(S); }
};

/// Identifies call instruction location in machine function.
struct MachineInstrLoc {
unsigned BlockNum;
unsigned Offset;

bool operator==(const MachineInstrLoc &Other) const {
return BlockNum == Other.BlockNum && Offset == Other.Offset;
}
};

/// Serializable representation of CallSiteInfo.
struct CallSiteInfo {
// Representation of call argument and register which is used to
Expand All @@ -470,16 +480,6 @@ struct CallSiteInfo {
}
};

/// Identifies call instruction location in machine function.
struct MachineInstrLoc {
unsigned BlockNum;
unsigned Offset;

bool operator==(const MachineInstrLoc &Other) const {
return BlockNum == Other.BlockNum && Offset == Other.Offset;
}
};

MachineInstrLoc CallLocation;
std::vector<ArgRegPair> ArgForwardingRegs;

Expand Down Expand Up @@ -595,6 +595,26 @@ template <> struct MappingTraits<MachineJumpTable::Entry> {
}
};

struct CalledGlobal {
MachineInstrLoc CallSite;
StringValue Callee;
unsigned Flags;

bool operator==(const CalledGlobal &Other) const {
return CallSite == Other.CallSite && Callee == Other.Callee &&
Flags == Other.Flags;
}
};

template <> struct MappingTraits<CalledGlobal> {
static void mapping(IO &YamlIO, CalledGlobal &CG) {
YamlIO.mapRequired("bb", CG.CallSite.BlockNum);
YamlIO.mapRequired("offset", CG.CallSite.Offset);
YamlIO.mapRequired("callee", CG.Callee);
YamlIO.mapRequired("flags", CG.Flags);
}
};

} // end namespace yaml
} // end namespace llvm

Expand All @@ -606,6 +626,7 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::FixedMachineStackObject)
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::CallSiteInfo)
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::MachineConstantPoolValue)
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::MachineJumpTable::Entry)
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::CalledGlobal)

namespace llvm {
namespace yaml {
Expand Down Expand Up @@ -764,6 +785,7 @@ struct MachineFunction {
std::vector<DebugValueSubstitution> DebugValueSubstitutions;
MachineJumpTable JumpTableInfo;
std::vector<StringValue> MachineMetadataNodes;
std::vector<CalledGlobal> CalledGlobals;
BlockStringValue Body;
};

Expand Down Expand Up @@ -822,6 +844,9 @@ template <> struct MappingTraits<MachineFunction> {
if (!YamlIO.outputting() || !MF.MachineMetadataNodes.empty())
YamlIO.mapOptional("machineMetadataNodes", MF.MachineMetadataNodes,
std::vector<StringValue>());
if (!YamlIO.outputting() || !MF.CalledGlobals.empty())
YamlIO.mapOptional("calledGlobals", MF.CalledGlobals,
std::vector<CalledGlobal>());
YamlIO.mapOptional("body", MF.Body, BlockStringValue());
}
};
Expand Down
25 changes: 25 additions & 0 deletions llvm/include/llvm/CodeGen/MachineFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ class LLVM_ABI MachineFunction {
/// a table of valid targets for Windows EHCont Guard.
std::vector<MCSymbol *> CatchretTargets;

/// Mapping of call instruction to the global value and target flags that it
/// calls, if applicable.
DenseMap<const MachineInstr *, std::pair<const GlobalValue *, unsigned>>
CalledGlobalsMap;

/// \name Exception Handling
/// \{

Expand Down Expand Up @@ -1182,6 +1187,26 @@ class LLVM_ABI MachineFunction {
CatchretTargets.push_back(Target);
}

/// Tries to get the global and target flags for a call site, if the
/// instruction is a call to a global.
std::pair<const GlobalValue *, unsigned>
tryGetCalledGlobal(const MachineInstr *MI) const {
return CalledGlobalsMap.lookup(MI);
}

/// Notes the global and target flags for a call site.
void addCalledGlobal(const MachineInstr *MI,
std::pair<const GlobalValue *, unsigned> Details) {
assert(MI && "MI must not be null");
assert(Details.first && "Global must not be null");
CalledGlobalsMap.insert({MI, Details});
}

/// Iterates over the full set of call sites and their associated globals.
auto getCalledGlobals() const {
return llvm::make_range(CalledGlobalsMap.begin(), CalledGlobalsMap.end());
}

/// \name Exception Handling
/// \{

Expand Down
14 changes: 14 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ class SelectionDAG {
MDNode *HeapAllocSite = nullptr;
MDNode *PCSections = nullptr;
MDNode *MMRA = nullptr;
std::pair<const GlobalValue *, unsigned> CalledGlobal{};
bool NoMerge = false;
};
/// Out-of-line extra information for SDNodes.
Expand Down Expand Up @@ -2373,6 +2374,19 @@ class SelectionDAG {
auto It = SDEI.find(Node);
return It != SDEI.end() ? It->second.MMRA : nullptr;
}
/// Set CalledGlobal to be associated with Node.
void addCalledGlobal(const SDNode *Node, const GlobalValue *GV,
unsigned OpFlags) {
SDEI[Node].CalledGlobal = {GV, OpFlags};
}
/// Return CalledGlobal associated with Node, or a nullopt if none exists.
std::optional<std::pair<const GlobalValue *, unsigned>>
getCalledGlobal(const SDNode *Node) {
auto I = SDEI.find(Node);
return I != SDEI.end()
? std::make_optional(std::move(I->second).CalledGlobal)
: std::nullopt;
}
/// Set NoMergeSiteInfo to be associated with Node if NoMerge is true.
void addNoMergeSiteInfo(const SDNode *Node, bool NoMerge) {
if (NoMerge)
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/MC/MCObjectFileInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class MCObjectFileInfo {
/// to emit them into.
MCSection *CompactUnwindSection = nullptr;

/// If import call optimization is supported by the target, this is the
/// section to emit import call data to.
MCSection *ImportCallSection = nullptr;

// Dwarf sections for debug info. If a target supports debug info, these must
// be set.
MCSection *DwarfAbbrevSection = nullptr;
Expand Down Expand Up @@ -269,6 +273,7 @@ class MCObjectFileInfo {
MCSection *getBSSSection() const { return BSSSection; }
MCSection *getReadOnlySection() const { return ReadOnlySection; }
MCSection *getLSDASection() const { return LSDASection; }
MCSection *getImportCallSection() const { return ImportCallSection; }
MCSection *getCompactUnwindSection() const { return CompactUnwindSection; }
MCSection *getDwarfAbbrevSection() const { return DwarfAbbrevSection; }
MCSection *getDwarfInfoSection() const { return DwarfInfoSection; }
Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/MC/MCStreamer.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,14 @@ class MCStreamer {
/// \param Symbol - Symbol the image relative relocation should point to.
virtual void emitCOFFImgRel32(MCSymbol const *Symbol, int64_t Offset);

/// Emits the physical number of the section containing the given symbol as
/// assigned during object writing (i.e., this is not a runtime relocation).
virtual void emitCOFFSecNumber(MCSymbol const *Symbol);

/// Emits the offset of the symbol from the beginning of the section during
/// object writing (i.e., this is not a runtime relocation).
virtual void emitCOFFSecOffset(MCSymbol const *Symbol);

/// Emits an lcomm directive with XCOFF csect information.
///
/// \param LabelSym - Label on the block of storage.
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/MC/MCWinCOFFObjectWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class WinCOFFObjectWriter final : public MCObjectWriter {
const MCFixup &Fixup, MCValue Target,
uint64_t &FixedValue) override;
uint64_t writeObject(MCAssembler &Asm) override;
int getSectionNumber(const MCSection &Section) const;
};

/// Construct a new Win COFF writer instance.
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/MC/MCWinCOFFStreamer.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class MCWinCOFFStreamer : public MCObjectStreamer {
void emitCOFFSectionIndex(MCSymbol const *Symbol) override;
void emitCOFFSecRel32(MCSymbol const *Symbol, uint64_t Offset) override;
void emitCOFFImgRel32(MCSymbol const *Symbol, int64_t Offset) override;
void emitCOFFSecNumber(MCSymbol const *Symbol) override;
void emitCOFFSecOffset(MCSymbol const *Symbol) override;
void emitCommonSymbol(MCSymbol *Symbol, uint64_t Size,
Align ByteAlignment) override;
void emitLocalCommonSymbol(MCSymbol *Symbol, uint64_t Size,
Expand Down
74 changes: 62 additions & 12 deletions llvm/lib/CodeGen/MIRParser/MIRParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ class MIRParserImpl {
MachineFunction &MF,
const yaml::MachineFunction &YMF);

bool parseCalledGlobals(PerFunctionMIParsingState &PFS, MachineFunction &MF,
const yaml::MachineFunction &YMF);

private:
bool parseMDNode(PerFunctionMIParsingState &PFS, MDNode *&Node,
const yaml::StringValue &Source);
Expand All @@ -183,6 +186,9 @@ class MIRParserImpl {

void setupDebugValueTracking(MachineFunction &MF,
PerFunctionMIParsingState &PFS, const yaml::MachineFunction &YamlMF);

bool parseMachineInst(MachineFunction &MF, yaml::MachineInstrLoc MILoc,
MachineInstr const *&MI);
};

} // end namespace llvm
Expand Down Expand Up @@ -457,24 +463,34 @@ bool MIRParserImpl::computeFunctionProperties(
return false;
}

bool MIRParserImpl::parseMachineInst(MachineFunction &MF,
yaml::MachineInstrLoc MILoc,
MachineInstr const *&MI) {
if (MILoc.BlockNum >= MF.size()) {
return error(Twine(MF.getName()) +
Twine(" instruction block out of range.") +
" Unable to reference bb:" + Twine(MILoc.BlockNum));
}
auto BB = std::next(MF.begin(), MILoc.BlockNum);
if (MILoc.Offset >= BB->size())
return error(
Twine(MF.getName()) + Twine(" instruction offset out of range.") +
" Unable to reference instruction at bb: " + Twine(MILoc.BlockNum) +
" at offset:" + Twine(MILoc.Offset));
MI = &*std::next(BB->instr_begin(), MILoc.Offset);
return false;
}

bool MIRParserImpl::initializeCallSiteInfo(
PerFunctionMIParsingState &PFS, const yaml::MachineFunction &YamlMF) {
MachineFunction &MF = PFS.MF;
SMDiagnostic Error;
const TargetMachine &TM = MF.getTarget();
for (auto &YamlCSInfo : YamlMF.CallSitesInfo) {
yaml::CallSiteInfo::MachineInstrLoc MILoc = YamlCSInfo.CallLocation;
if (MILoc.BlockNum >= MF.size())
return error(Twine(MF.getName()) +
Twine(" call instruction block out of range.") +
" Unable to reference bb:" + Twine(MILoc.BlockNum));
auto CallB = std::next(MF.begin(), MILoc.BlockNum);
if (MILoc.Offset >= CallB->size())
return error(Twine(MF.getName()) +
Twine(" call instruction offset out of range.") +
" Unable to reference instruction at bb: " +
Twine(MILoc.BlockNum) + " at offset:" + Twine(MILoc.Offset));
auto CallI = std::next(CallB->instr_begin(), MILoc.Offset);
yaml::MachineInstrLoc MILoc = YamlCSInfo.CallLocation;
const MachineInstr *CallI;
if (parseMachineInst(MF, MILoc, CallI))
return true;
if (!CallI->isCall(MachineInstr::IgnoreBundle))
return error(Twine(MF.getName()) +
Twine(" call site info should reference call "
Expand Down Expand Up @@ -641,6 +657,9 @@ MIRParserImpl::initializeMachineFunction(const yaml::MachineFunction &YamlMF,
if (initializeCallSiteInfo(PFS, YamlMF))
return true;

if (parseCalledGlobals(PFS, MF, YamlMF))
return true;

setupDebugValueTracking(MF, PFS, YamlMF);

MF.getSubtarget().mirFileLoaded(MF);
Expand Down Expand Up @@ -1111,6 +1130,37 @@ bool MIRParserImpl::parseMachineMetadataNodes(
return false;
}

bool MIRParserImpl::parseCalledGlobals(PerFunctionMIParsingState &PFS,
MachineFunction &MF,
const yaml::MachineFunction &YMF) {
Function &F = MF.getFunction();
for (const auto &YamlCG : YMF.CalledGlobals) {
yaml::MachineInstrLoc MILoc = YamlCG.CallSite;
const MachineInstr *CallI;
if (parseMachineInst(MF, MILoc, CallI))
return true;
if (!CallI->isCall(MachineInstr::IgnoreBundle))
return error(Twine(MF.getName()) +
Twine(" called global should reference call "
"instruction. Instruction at bb:") +
Twine(MILoc.BlockNum) + " at offset:" + Twine(MILoc.Offset) +
" is not a call instruction");

auto Callee =
F.getParent()->getValueSymbolTable().lookup(YamlCG.Callee.Value);
if (!Callee)
return error(YamlCG.Callee.SourceRange.Start,
"use of undefined global '" + YamlCG.Callee.Value + "'");
if (!isa<GlobalValue>(Callee))
return error(YamlCG.Callee.SourceRange.Start,
"use of non-global value '" + YamlCG.Callee.Value + "'");

MF.addCalledGlobal(CallI, {cast<GlobalValue>(Callee), YamlCG.Flags});
}

return false;
}

SMDiagnostic MIRParserImpl::diagFromMIStringDiag(const SMDiagnostic &Error,
SMRange SourceRange) {
assert(SourceRange.isValid() && "Invalid source range");
Expand Down
33 changes: 32 additions & 1 deletion llvm/lib/CodeGen/MIRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ class MIRPrinter {
void convertMachineMetadataNodes(yaml::MachineFunction &YMF,
const MachineFunction &MF,
MachineModuleSlotTracker &MST);
void convertCalledGlobals(yaml::MachineFunction &YMF,
const MachineFunction &MF,
MachineModuleSlotTracker &MST);

private:
void initRegisterMaskIds(const MachineFunction &MF);
Expand Down Expand Up @@ -269,6 +272,8 @@ void MIRPrinter::print(const MachineFunction &MF) {
// function.
convertMachineMetadataNodes(YamlMF, MF, MST);

convertCalledGlobals(YamlMF, MF, MST);

yaml::Output Out(OS);
if (!SimplifyMIR)
Out.setWriteDefaultValues(true);
Expand Down Expand Up @@ -555,7 +560,7 @@ void MIRPrinter::convertCallSiteObjects(yaml::MachineFunction &YMF,
const auto *TRI = MF.getSubtarget().getRegisterInfo();
for (auto CSInfo : MF.getCallSitesInfo()) {
yaml::CallSiteInfo YmlCS;
yaml::CallSiteInfo::MachineInstrLoc CallLocation;
yaml::MachineInstrLoc CallLocation;

// Prepare instruction position.
MachineBasicBlock::const_instr_iterator CallI = CSInfo.first->getIterator();
Expand Down Expand Up @@ -596,6 +601,32 @@ void MIRPrinter::convertMachineMetadataNodes(yaml::MachineFunction &YMF,
}
}

void MIRPrinter::convertCalledGlobals(yaml::MachineFunction &YMF,
const MachineFunction &MF,
MachineModuleSlotTracker &MST) {
for (const auto [CallInst, CG] : MF.getCalledGlobals()) {
// If the call instruction was dropped, then we don't need to print it.
auto BB = CallInst->getParent();
if (BB) {
yaml::MachineInstrLoc CallSite;
CallSite.BlockNum = CallInst->getParent()->getNumber();
CallSite.Offset = std::distance(CallInst->getParent()->instr_begin(),
CallInst->getIterator());

yaml::CalledGlobal YamlCG{CallSite, CG.first->getName().str(), CG.second};
YMF.CalledGlobals.push_back(YamlCG);
}
}

// Sort by position of call instructions.
llvm::sort(YMF.CalledGlobals.begin(), YMF.CalledGlobals.end(),
[](yaml::CalledGlobal A, yaml::CalledGlobal B) {
if (A.CallSite.BlockNum == B.CallSite.BlockNum)
return A.CallSite.Offset < B.CallSite.Offset;
return A.CallSite.BlockNum < B.CallSite.BlockNum;
});
}

void MIRPrinter::convert(yaml::MachineFunction &MF,
const MachineConstantPool &ConstantPool) {
unsigned ID = 0;
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/ScheduleDAGSDNodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,10 @@ EmitSchedule(MachineBasicBlock::iterator &InsertPos) {
It->setMMRAMetadata(MF, MMRA);
}

if (auto CalledGlobal = DAG->getCalledGlobal(Node))
if (CalledGlobal->first)
MF.addCalledGlobal(MI, *CalledGlobal);

return MI;
};

Expand Down
Loading

0 comments on commit 5ee0a71

Please sign in to comment.