-
Notifications
You must be signed in to change notification settings - Fork 22
/
Example.cpp
127 lines (104 loc) · 4.15 KB
/
Example.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include "clang/Driver/Options.h"
#include "clang/AST/AST.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Frontend/ASTConsumers.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Tooling.h"
#include "clang/Rewrite/Core/Rewriter.h"
using namespace std;
using namespace clang;
using namespace clang::driver;
using namespace clang::tooling;
using namespace llvm;
Rewriter rewriter;
int numFunctions = 0;
class ExampleVisitor : public RecursiveASTVisitor<ExampleVisitor> {
private:
ASTContext *astContext; // used for getting additional AST info
public:
explicit ExampleVisitor(CompilerInstance *CI)
: astContext(&(CI->getASTContext())) // initialize private members
{
rewriter.setSourceMgr(astContext->getSourceManager(), astContext->getLangOpts());
}
virtual bool VisitFunctionDecl(FunctionDecl *func) {
numFunctions++;
string funcName = func->getNameInfo().getName().getAsString();
if (funcName == "do_math") {
rewriter.ReplaceText(func->getLocation(), funcName.length(), "add5");
errs() << "** Rewrote function def: " << funcName << "\n";
}
return true;
}
virtual bool VisitStmt(Stmt *st) {
if (ReturnStmt *ret = dyn_cast<ReturnStmt>(st)) {
rewriter.ReplaceText(ret->getRetValue()->getLocStart(), 6, "val");
errs() << "** Rewrote ReturnStmt\n";
}
if (CallExpr *call = dyn_cast<CallExpr>(st)) {
rewriter.ReplaceText(call->getLocStart(), 7, "add5");
errs() << "** Rewrote function call\n";
}
return true;
}
/*
virtual bool VisitReturnStmt(ReturnStmt *ret) {
rewriter.ReplaceText(ret->getRetValue()->getLocStart(), 6, "val");
errs() << "** Rewrote ReturnStmt\n";
return true;
}
virtual bool VisitCallExpr(CallExpr *call) {
rewriter.ReplaceText(call->getLocStart(), 7, "add5");
errs() << "** Rewrote function call\n";
return true;
}
*/
};
class ExampleASTConsumer : public ASTConsumer {
private:
ExampleVisitor *visitor; // doesn't have to be private
public:
// override the constructor in order to pass CI
explicit ExampleASTConsumer(CompilerInstance *CI)
: visitor(new ExampleVisitor(CI)) // initialize the visitor
{ }
// override this to call our ExampleVisitor on the entire source file
virtual void HandleTranslationUnit(ASTContext &Context) {
/* we can use ASTContext to get the TranslationUnitDecl, which is
a single Decl that collectively represents the entire source file */
visitor->TraverseDecl(Context.getTranslationUnitDecl());
}
/*
// override this to call our ExampleVisitor on each top-level Decl
virtual bool HandleTopLevelDecl(DeclGroupRef DG) {
// a DeclGroupRef may have multiple Decls, so we iterate through each one
for (DeclGroupRef::iterator i = DG.begin(), e = DG.end(); i != e; i++) {
Decl *D = *i;
visitor->TraverseDecl(D); // recursively visit each AST node in Decl "D"
}
return true;
}
*/
};
class ExampleFrontendAction : public ASTFrontendAction {
public:
virtual ASTConsumer *CreateASTConsumer(CompilerInstance &CI, StringRef file) {
return new ExampleASTConsumer(&CI); // pass CI pointer to ASTConsumer
}
};
int main(int argc, const char **argv) {
// parse the command-line args passed to your code
CommonOptionsParser op(argc, argv);
// create a new Clang Tool instance (a LibTooling environment)
ClangTool Tool(op.getCompilations(), op.getSourcePathList());
// run the Clang Tool, creating a new FrontendAction (explained below)
int result = Tool.run(newFrontendActionFactory<ExampleFrontendAction>());
errs() << "\nFound " << numFunctions << " functions.\n\n";
// print out the rewritten source code ("rewriter" is a global var.)
rewriter.getEditBuffer(rewriter.getSourceMgr().getMainFileID()).write(errs());
return result;
}