Skip to content

Commit

Permalink
remove SVFFunction in AbstractInterpretation.h (#1604)
Browse files Browse the repository at this point in the history
* Create class CallGraph

* fix bug

* samll fix

* remove useless functions

* small chantes

* remove csToIdMap & idToCSMap in CallGraph.h

* small fix

* small fix

* small fix

* adding field CallGraphNode *callGraphNode in SVFFunction class for easing refactoring

(cherry picked from commit cb39cc1)

* remove SVFFunction in AbstractInterpretation.h

(cherry picked from commit 1dd9d41)

* change

---------

Co-authored-by: hwg <geoffreyhe2@gmail.com>
  • Loading branch information
Geoffrey1014 and hwg authored Dec 9, 2024
1 parent 1464994 commit 91b0eeb
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 24 deletions.
2 changes: 1 addition & 1 deletion svf-llvm/include/SVF-LLVM/LLVMModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class LLVMModuleSet
LLVMFunc2SVFFunc[func] = svfFunc;
setValueAttr(func,svfFunc);
}
void addFunctionMap(const Function* func, CallGraphNode* svfFunc);
void addFunctionMap(const Function* func, CallGraphNode* cgNode);

inline void addBasicBlockMap(const BasicBlock* bb, SVFBasicBlock* svfBB)
{
Expand Down
13 changes: 10 additions & 3 deletions svf-llvm/lib/LLVMModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,18 @@ void LLVMModuleSet::build()

createSVFDataStructure();
initSVFFunction();


ICFGBuilder icfgbuilder;
icfg = icfgbuilder.build();

CallGraphBuilder callGraphBuilder;
callgraph = callGraphBuilder.buildSVFIRCallGraph(svfModule);
for (const auto& func : svfModule->getFunctionSet())
{
SVFFunction* svffunc = const_cast<SVFFunction*>(func);
svffunc->setCallGraphNode(callgraph->getCallGraphNode(func));
}

for (const auto& it : *callgraph)
{
Expand Down Expand Up @@ -1217,10 +1224,10 @@ void LLVMModuleSet::dumpModulesToFile(const std::string& suffix)
}
}

void LLVMModuleSet::addFunctionMap(const SVF::Function* func, SVF::CallGraphNode* svfFunc)
void LLVMModuleSet::addFunctionMap(const SVF::Function* func, SVF::CallGraphNode* cgNode)
{
LLVMFunc2CallGraphNode[func] = svfFunc;
setValueAttr(func,svfFunc);
LLVMFunc2CallGraphNode[func] = cgNode;
setValueAttr(func, cgNode);
}

void LLVMModuleSet::setValueAttr(const Value* val, SVFValue* svfvalue)
Expand Down
4 changes: 2 additions & 2 deletions svf/include/AE/Svfexe/AbstractInterpretation.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ class AbstractInterpretation
AEStat* stat;

std::vector<const CallICFGNode*> callSiteStack;
Map<const SVFFunction*, ICFGWTO*> funcToWTO;
Set<const SVFFunction*> recursiveFuns;
Map<const CallGraphNode*, ICFGWTO*> funcToWTO;
Set<const CallGraphNode*> recursiveFuns;


AbstractState& getAbsStateFromTrace(const ICFGNode* node)
Expand Down
2 changes: 2 additions & 0 deletions svf/include/Graphs/CallGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ class CallGraph : public GenericCallGraphTy

void addCallGraphNode(const SVFFunction* fun);

const CallGraphNode* getCallGraphNode(const std::string& name);

/// Destructor
virtual ~CallGraph()
{
Expand Down
13 changes: 12 additions & 1 deletion svf/include/SVFIR/SVFValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace SVF
/// LLVM Aliases and constants
typedef SVF::GraphPrinter GraphPrinter;


class CallGraphNode;
class SVFInstruction;
class SVFBasicBlock;
class SVFArgument;
Expand Down Expand Up @@ -320,8 +320,14 @@ class SVFFunction : public SVFValue
std::vector<const SVFBasicBlock*> allBBs; /// all BasicBlocks of this function
std::vector<const SVFArgument*> allArgs; /// all formal arguments of this function
SVFBasicBlock *exitBlock; /// a 'single' basic block having no successors and containing return instruction in a function
const CallGraphNode *callGraphNode; /// call graph node for this function

protected:
inline void setCallGraphNode(CallGraphNode *cgn)
{
callGraphNode = cgn;
}

///@{ attributes to be set only through Module builders e.g., LLVMModule
inline void addBasicBlock(const SVFBasicBlock* bb)
{
Expand Down Expand Up @@ -354,6 +360,11 @@ class SVFFunction : public SVFValue
SVFFunction(void) = delete;
virtual ~SVFFunction();

inline const CallGraphNode* getCallGraphNode() const
{
return callGraphNode;
}

static inline bool classof(const SVFValue *node)
{
return node->getKind() == SVFFunc;
Expand Down
38 changes: 21 additions & 17 deletions svf/lib/AE/Svfexe/AbstractInterpretation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,19 @@ void AbstractInterpretation::initWTO()
// Detect if the call graph has cycles by finding its strongly connected components (SCC)
Andersen::CallGraphSCC* callGraphScc = ander->getCallGraphSCC();
callGraphScc->find();
auto callGraph = ander->getCallGraph();
CallGraph* svfirCallGraph = PAG::getPAG()->getCallGraph();

// Iterate through the call graph
for (auto it = callGraph->begin(); it != callGraph->end(); it++)
for (auto it = svfirCallGraph->begin(); it != svfirCallGraph->end(); it++)
{
// Check if the current function is part of a cycle
if (callGraphScc->isInCycle(it->second->getId()))
recursiveFuns.insert(it->second->getFunction()); // Mark the function as recursive
}

// Initialize WTO for each function in the module
for (const SVFFunction* fun : svfir->getModule()->getFunctionSet())
{
if(fun->isDeclaration())
recursiveFuns.insert(it->second); // Mark the function as recursive
if (it->second->getFunction()->isDeclaration())
continue;
auto* wto = new ICFGWTO(icfg, icfg->getFunEntryICFGNode(fun));
auto* wto = new ICFGWTO(icfg, icfg->getFunEntryICFGNode(it->second->getFunction()));
wto->init();
funcToWTO[fun] = wto;
funcToWTO[it->second] = wto;
}
}
/// Program entry
Expand All @@ -112,9 +107,9 @@ void AbstractInterpretation::analyse()
handleGlobalNode();
getAbsStateFromTrace(
icfg->getGlobalICFGNode())[PAG::getPAG()->getBlkPtr()] = IntervalValue::top();
if (const SVFFunction* fun = svfir->getModule()->getSVFFunction("main"))
if (const CallGraphNode* cgn = svfir->getCallGraph()->getCallGraphNode("main"))
{
ICFGWTO* wto = funcToWTO[fun];
ICFGWTO* wto = funcToWTO[cgn];
handleWTOComponents(wto->getWTOComponents());
}
}
Expand Down Expand Up @@ -586,7 +581,11 @@ void AbstractInterpretation::extCallPass(const SVF::CallICFGNode *callNode)

bool AbstractInterpretation::isRecursiveCall(const SVF::CallICFGNode *callNode)
{
return recursiveFuns.find(callNode->getCalledFunction()) != recursiveFuns.end();
const SVFFunction *callfun = callNode->getCalledFunction();
if (!callfun)
return false;
else
return recursiveFuns.find(callfun->getCallGraphNode()) != recursiveFuns.end();
}

void AbstractInterpretation::recursiveCallPass(const SVF::CallICFGNode *callNode)
Expand All @@ -610,7 +609,11 @@ void AbstractInterpretation::recursiveCallPass(const SVF::CallICFGNode *callNode

bool AbstractInterpretation::isDirectCall(const SVF::CallICFGNode *callNode)
{
return funcToWTO.find(callNode->getCalledFunction()) != funcToWTO.end();
const SVFFunction *callfun =callNode->getCalledFunction();
if (!callfun)
return false;
else
return funcToWTO.find(callfun->getCallGraphNode()) != funcToWTO.end();
}
void AbstractInterpretation::directCallFunPass(const SVF::CallICFGNode *callNode)
{
Expand All @@ -619,7 +622,8 @@ void AbstractInterpretation::directCallFunPass(const SVF::CallICFGNode *callNode

abstractTrace[callNode] = as;

ICFGWTO* wto = funcToWTO[callNode->getCalledFunction()];
const SVFFunction *callfun =callNode->getCalledFunction();
ICFGWTO* wto = funcToWTO[callfun->getCallGraphNode()];
handleWTOComponents(wto->getWTOComponents());

callSiteStack.pop_back();
Expand Down Expand Up @@ -649,7 +653,7 @@ void AbstractInterpretation::indirectCallFunPass(const SVF::CallICFGNode *callNo
SVFVar *func_var = svfir->getGNode(AbstractState::getInternalID(addr));
if(const FunObjVar*funObjVar = SVFUtil::dyn_cast<FunObjVar>(func_var))
{
const SVFFunction* callfun = funObjVar->getCallGraphNode()->getFunction();
const CallGraphNode* callfun = funObjVar->getCallGraphNode();
callSiteStack.push_back(callNode);
abstractTrace[callNode] = as;

Expand Down
10 changes: 10 additions & 0 deletions svf/lib/Graphs/CallGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ void CallGraph::view()
SVF::ViewGraph(this, "Call Graph");
}

const CallGraphNode* CallGraph::getCallGraphNode(const std::string& name)
{
for (const auto& item : *this)
{
if (item.second->getName() == name)
return item.second;
}
return nullptr;
}

namespace SVF
{

Expand Down

0 comments on commit 91b0eeb

Please sign in to comment.