Skip to content

Commit

Permalink
Add comments at every instrumentation point.
Browse files Browse the repository at this point in the history
The last thing to implement is to read the config file and use the
strings defined there in place of the comments.
  • Loading branch information
nchaimov committed Dec 12, 2024
1 parent d022c36 commit b9ee595
Showing 1 changed file with 51 additions and 28 deletions.
79 changes: 51 additions & 28 deletions src/salt_instrument_flang_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,34 @@ using namespace Fortran::frontend;
* The main action of the Salt instrumentor.
* Visits each node in the parse tree.
*/
class SaltInstrumentAction : public PluginParseTreeAction {
class SaltInstrumentAction final : public PluginParseTreeAction {
enum class SaltInstrumentationPointType {
PROGRAM_BEGIN, // Declare profiler, initialize TAU, set node, start timer
PROCEDURE_BEGIN, // Declare profiler, start timer
PROCEDURE_END // Stop timer
PROGRAM_BEGIN, // Declare profiler, initialize TAU, set node, start timer
PROCEDURE_BEGIN, // Declare profiler, start timer
PROCEDURE_END // Stop timer
};

struct SaltInstrumentationPoint {
SaltInstrumentationPoint(SaltInstrumentationPointType instrumentation_point_type,
int start_line,
SaltInstrumentationPoint(const SaltInstrumentationPointType instrumentation_point_type,
const int start_line,
const std::optional<std::string> &timer_name = std::nullopt)
: instrumentationPointType(instrumentation_point_type),
startLine(start_line),
timerName(timer_name) {
}

[[nodiscard]] bool instrumentBefore() const {
return instrumentationPointType == SaltInstrumentationPointType::PROGRAM_BEGIN || instrumentationPointType
== SaltInstrumentationPointType::PROCEDURE_BEGIN;
}


SaltInstrumentationPointType instrumentationPointType;
int startLine;
std::optional<std::string> timerName;
};


struct SaltInstrumentParseTreeVisitor {
explicit SaltInstrumentParseTreeVisitor(Fortran::parser::Parsing *parsing)
: parsing(parsing) {
Expand All @@ -83,7 +89,7 @@ class SaltInstrumentAction : public PluginParseTreeAction {
instrumentation_point_type, start_line, timer_name);
}

const auto & getInstrumentationPoints() const {
[[nodiscard]] const auto &getInstrumentationPoints() const {
return instrumentationPoints_;
}

Expand All @@ -95,7 +101,7 @@ class SaltInstrumentAction : public PluginParseTreeAction {
*/
[[nodiscard]] Fortran::parser::SourcePosition locationFromSource(
const Fortran::parser::CharBlock &charBlock, const bool end = false) const {
const auto & sourceRange{parsing->allCooked().GetSourcePositionRange(charBlock)};
const auto &sourceRange{parsing->allCooked().GetSourcePositionRange(charBlock)};
if (end) {
return sourceRange->second;
}
Expand Down Expand Up @@ -257,7 +263,8 @@ class SaltInstrumentAction : public PluginParseTreeAction {
},
[&](const Fortran::common::Indirection<Fortran::parser::CUFKernelDoConstruct> &c) ->
Fortran::parser::SourcePosition {
return locationFromSource(std::get<Fortran::parser::CUFKernelDoConstruct::Directive>(c.value().t).source);
return locationFromSource(
std::get<Fortran::parser::CUFKernelDoConstruct::Directive>(c.value().t).source);
},
[&](const Fortran::common::Indirection<Fortran::parser::OmpEndLoopDirective> &c) ->
Fortran::parser::SourcePosition {
Expand All @@ -275,8 +282,8 @@ class SaltInstrumentAction : public PluginParseTreeAction {
Fortran::parser::SourcePosition {
return getLocation(c.value());
},
[&](const Fortran::common::Indirection<Fortran::parser::CompilerDirective> & c)->
Fortran::parser::SourcePosition {
[&](const Fortran::common::Indirection<Fortran::parser::CompilerDirective> &c)->
Fortran::parser::SourcePosition {
return locationFromSource(c.value().source);
},
[&](const Fortran::common::Indirection<Fortran::parser::ForallConstruct> &c) ->
Expand Down Expand Up @@ -319,8 +326,8 @@ class SaltInstrumentAction : public PluginParseTreeAction {
},
[&](const Fortran::common::Indirection<Fortran::parser::ChangeTeamConstruct> &c) ->
Fortran::parser::SourcePosition {
return locationFromSource(
std::get<Fortran::parser::Statement<Fortran::parser::ChangeTeamStmt> >(c.value().t).source);
return locationFromSource(
std::get<Fortran::parser::Statement<Fortran::parser::ChangeTeamStmt> >(c.value().t).source);
},
[&](const Fortran::common::Indirection<Fortran::parser::CaseConstruct> &c) ->
Fortran::parser::SourcePosition {
Expand All @@ -334,9 +341,9 @@ class SaltInstrumentAction : public PluginParseTreeAction {
},
[&](const Fortran::common::Indirection<Fortran::parser::AssociateConstruct> &c) ->
Fortran::parser::SourcePosition {
return locationFromSource(
std::get<Fortran::parser::Statement<Fortran::parser::AssociateStmt> >(c.value().t).
source);
return locationFromSource(
std::get<Fortran::parser::Statement<Fortran::parser::AssociateStmt> >(c.value().t).
source);
}
}, construct.u);
}
Expand All @@ -358,20 +365,14 @@ class SaltInstrumentAction : public PluginParseTreeAction {
[&](const auto &c) -> Fortran::parser::SourcePosition {
return locationFromSource(c.source);
},
[&](const Fortran::parser::ErrorRecovery &c) -> Fortran::parser::SourcePosition {
[&](const Fortran::parser::ErrorRecovery &) -> Fortran::parser::SourcePosition {
DIE("Should not encounter ErrorRecovery in parse tree");
}
}, construct.u);
}

bool Pre(const Fortran::parser::ExecutionPart &executionPart) {
// TODO Need to get the FIRST and the LAST components
// Insert timer start before first component
// Use main program insert if in main program, else subprogram insert
// Insert timer end after last component

const Fortran::parser::Block &block = executionPart.v;
if (block.empty()) {
if (const Fortran::parser::Block &block = executionPart.v; block.empty()) {
llvm::outs() << "WARNING: Execution part empty.\n";
} else {
const Fortran::parser::SourcePosition startLoc{getLocation(block.front())};
Expand All @@ -380,7 +381,7 @@ class SaltInstrumentAction : public PluginParseTreeAction {
llvm::outs() << "Program begin \"" << mainProgramName_ << "\" at " << startLoc.line << "\n";
addInstrumentationPoint(SaltInstrumentationPointType::PROGRAM_BEGIN, startLoc.line,
mainProgramName_);
} else{
} else {
llvm::outs() << "Subprogram begin \"" << subprogramName_ << "\" at " << startLoc.line << "\n";
addInstrumentationPoint(SaltInstrumentationPointType::PROCEDURE_BEGIN, startLoc.line,
subprogramName_);
Expand Down Expand Up @@ -422,6 +423,19 @@ class SaltInstrumentAction : public PluginParseTreeAction {
return std::nullopt;
}

[[nodiscard]] static std::string getInstrumentationPointString(SaltInstrumentationPointType type) {
switch (type) {
case SaltInstrumentationPointType::PROCEDURE_BEGIN:
return "! PROCEDURE BEGIN";
case SaltInstrumentationPointType::PROGRAM_BEGIN:
return "! PROGRAM BEGIN";
case SaltInstrumentationPointType::PROCEDURE_END:
return "! PROCEDURE END";
default:
CRASH_NO_CASE;
}
}

static void instrumentFile(const std::string &inputFilePath, llvm::raw_pwrite_stream &outputStream,
const SaltInstrumentParseTreeVisitor &visitor) {
std::ifstream inputStream{inputFilePath};
Expand All @@ -431,11 +445,20 @@ class SaltInstrumentAction : public PluginParseTreeAction {
}
std::string line;
int lineNum{0};
const auto &instPts{visitor.getInstrumentationPoints()};
auto instIter{instPts.cbegin()};
while (std::getline(inputStream, line)) {
++lineNum;
if (instIter != instPts.cend() && instIter->startLine == lineNum && instIter->instrumentBefore()) {
outputStream << getInstrumentationPointString(instIter->instrumentationPointType) << "\n";
++instIter;
}
outputStream << line << "\n";
if (instIter != instPts.cend() && instIter->startLine == lineNum && !instIter->instrumentBefore()) {
outputStream << getInstrumentationPointString(instIter->instrumentationPointType) << "\n";
++instIter;
}
}
(void)lineNum;
}

/**
Expand Down Expand Up @@ -469,11 +492,11 @@ class SaltInstrumentAction : public PluginParseTreeAction {
const std::string outputFileExtension = "inst."s + inputFileExtension;
const auto outputFileStream = createOutputFile(outputFileExtension);

// Walk the parse tree
// Walk the parse tree -- marks nodes for instrumentation
SaltInstrumentParseTreeVisitor visitor{&parsing};
Walk(parsing.parseTree(), visitor);

// TODO write the instrumented code
// Use the instrumentation points stored in the Visitor to write the instrumented file.
instrumentFile(*inputFilePath, *outputFileStream, visitor);

outputFileStream->flush();
Expand Down

0 comments on commit b9ee595

Please sign in to comment.