diff --git a/src/salt_instrument_flang_plugin.cpp b/src/salt_instrument_flang_plugin.cpp index cb51111..24088e9 100644 --- a/src/salt_instrument_flang_plugin.cpp +++ b/src/salt_instrument_flang_plugin.cpp @@ -43,28 +43,34 @@ using namespace Fortran::frontend; * The main action of the Salt instrumentor. * Visits each node in the parse tree. */ -class SaltInstrumentAction : public PluginParseTreeAction { +class SaltInstrumentAction final : public PluginParseTreeAction { enum class SaltInstrumentationPointType { - PROGRAM_BEGIN, // Declare profiler, initialize TAU, set node, start timer - PROCEDURE_BEGIN, // Declare profiler, start timer - PROCEDURE_END // Stop timer + PROGRAM_BEGIN, // Declare profiler, initialize TAU, set node, start timer + PROCEDURE_BEGIN, // Declare profiler, start timer + PROCEDURE_END // Stop timer }; struct SaltInstrumentationPoint { - SaltInstrumentationPoint(SaltInstrumentationPointType instrumentation_point_type, - int start_line, + SaltInstrumentationPoint(const SaltInstrumentationPointType instrumentation_point_type, + const int start_line, const std::optional &timer_name = std::nullopt) : instrumentationPointType(instrumentation_point_type), startLine(start_line), timerName(timer_name) { } + [[nodiscard]] bool instrumentBefore() const { + return instrumentationPointType == SaltInstrumentationPointType::PROGRAM_BEGIN || instrumentationPointType + == SaltInstrumentationPointType::PROCEDURE_BEGIN; + } + SaltInstrumentationPointType instrumentationPointType; int startLine; std::optional timerName; }; + struct SaltInstrumentParseTreeVisitor { explicit SaltInstrumentParseTreeVisitor(Fortran::parser::Parsing *parsing) : parsing(parsing) { @@ -83,7 +89,7 @@ class SaltInstrumentAction : public PluginParseTreeAction { instrumentation_point_type, start_line, timer_name); } - const auto & getInstrumentationPoints() const { + [[nodiscard]] const auto &getInstrumentationPoints() const { return instrumentationPoints_; } @@ -95,7 +101,7 @@ class SaltInstrumentAction : public PluginParseTreeAction { */ [[nodiscard]] Fortran::parser::SourcePosition locationFromSource( const Fortran::parser::CharBlock &charBlock, const bool end = false) const { - const auto & sourceRange{parsing->allCooked().GetSourcePositionRange(charBlock)}; + const auto &sourceRange{parsing->allCooked().GetSourcePositionRange(charBlock)}; if (end) { return sourceRange->second; } @@ -257,7 +263,8 @@ class SaltInstrumentAction : public PluginParseTreeAction { }, [&](const Fortran::common::Indirection &c) -> Fortran::parser::SourcePosition { - return locationFromSource(std::get(c.value().t).source); + return locationFromSource( + std::get(c.value().t).source); }, [&](const Fortran::common::Indirection &c) -> Fortran::parser::SourcePosition { @@ -275,8 +282,8 @@ class SaltInstrumentAction : public PluginParseTreeAction { Fortran::parser::SourcePosition { return getLocation(c.value()); }, - [&](const Fortran::common::Indirection & c)-> - Fortran::parser::SourcePosition { + [&](const Fortran::common::Indirection &c)-> + Fortran::parser::SourcePosition { return locationFromSource(c.value().source); }, [&](const Fortran::common::Indirection &c) -> @@ -319,8 +326,8 @@ class SaltInstrumentAction : public PluginParseTreeAction { }, [&](const Fortran::common::Indirection &c) -> Fortran::parser::SourcePosition { - return locationFromSource( - std::get >(c.value().t).source); + return locationFromSource( + std::get >(c.value().t).source); }, [&](const Fortran::common::Indirection &c) -> Fortran::parser::SourcePosition { @@ -334,9 +341,9 @@ class SaltInstrumentAction : public PluginParseTreeAction { }, [&](const Fortran::common::Indirection &c) -> Fortran::parser::SourcePosition { - return locationFromSource( - std::get >(c.value().t). - source); + return locationFromSource( + std::get >(c.value().t). + source); } }, construct.u); } @@ -358,20 +365,14 @@ class SaltInstrumentAction : public PluginParseTreeAction { [&](const auto &c) -> Fortran::parser::SourcePosition { return locationFromSource(c.source); }, - [&](const Fortran::parser::ErrorRecovery &c) -> Fortran::parser::SourcePosition { + [&](const Fortran::parser::ErrorRecovery &) -> Fortran::parser::SourcePosition { DIE("Should not encounter ErrorRecovery in parse tree"); } }, construct.u); } bool Pre(const Fortran::parser::ExecutionPart &executionPart) { - // TODO Need to get the FIRST and the LAST components - // Insert timer start before first component - // Use main program insert if in main program, else subprogram insert - // Insert timer end after last component - - const Fortran::parser::Block &block = executionPart.v; - if (block.empty()) { + if (const Fortran::parser::Block &block = executionPart.v; block.empty()) { llvm::outs() << "WARNING: Execution part empty.\n"; } else { const Fortran::parser::SourcePosition startLoc{getLocation(block.front())}; @@ -380,7 +381,7 @@ class SaltInstrumentAction : public PluginParseTreeAction { llvm::outs() << "Program begin \"" << mainProgramName_ << "\" at " << startLoc.line << "\n"; addInstrumentationPoint(SaltInstrumentationPointType::PROGRAM_BEGIN, startLoc.line, mainProgramName_); - } else{ + } else { llvm::outs() << "Subprogram begin \"" << subprogramName_ << "\" at " << startLoc.line << "\n"; addInstrumentationPoint(SaltInstrumentationPointType::PROCEDURE_BEGIN, startLoc.line, subprogramName_); @@ -422,6 +423,19 @@ class SaltInstrumentAction : public PluginParseTreeAction { return std::nullopt; } + [[nodiscard]] static std::string getInstrumentationPointString(SaltInstrumentationPointType type) { + switch (type) { + case SaltInstrumentationPointType::PROCEDURE_BEGIN: + return "! PROCEDURE BEGIN"; + case SaltInstrumentationPointType::PROGRAM_BEGIN: + return "! PROGRAM BEGIN"; + case SaltInstrumentationPointType::PROCEDURE_END: + return "! PROCEDURE END"; + default: + CRASH_NO_CASE; + } + } + static void instrumentFile(const std::string &inputFilePath, llvm::raw_pwrite_stream &outputStream, const SaltInstrumentParseTreeVisitor &visitor) { std::ifstream inputStream{inputFilePath}; @@ -431,11 +445,20 @@ class SaltInstrumentAction : public PluginParseTreeAction { } std::string line; int lineNum{0}; + const auto &instPts{visitor.getInstrumentationPoints()}; + auto instIter{instPts.cbegin()}; while (std::getline(inputStream, line)) { ++lineNum; + if (instIter != instPts.cend() && instIter->startLine == lineNum && instIter->instrumentBefore()) { + outputStream << getInstrumentationPointString(instIter->instrumentationPointType) << "\n"; + ++instIter; + } outputStream << line << "\n"; + if (instIter != instPts.cend() && instIter->startLine == lineNum && !instIter->instrumentBefore()) { + outputStream << getInstrumentationPointString(instIter->instrumentationPointType) << "\n"; + ++instIter; + } } - (void)lineNum; } /** @@ -469,11 +492,11 @@ class SaltInstrumentAction : public PluginParseTreeAction { const std::string outputFileExtension = "inst."s + inputFileExtension; const auto outputFileStream = createOutputFile(outputFileExtension); - // Walk the parse tree + // Walk the parse tree -- marks nodes for instrumentation SaltInstrumentParseTreeVisitor visitor{&parsing}; Walk(parsing.parseTree(), visitor); - // TODO write the instrumented code + // Use the instrumentation points stored in the Visitor to write the instrumented file. instrumentFile(*inputFilePath, *outputFileStream, visitor); outputFileStream->flush();