#include "llvm/ADT/SetVector.h" #include "llvm/ADT/APSInt.h" #include "clang/Driver/Options.h" #include "clang/AST/AST.h" #include "clang/AST/ASTContext.h" #include "clang/AST/ASTConsumer.h" #include "clang/AST/Expr.h" #include "clang/AST/OperationKinds.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/Frontend/ASTConsumers.h" #include "clang/Frontend/FrontendActions.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/Lex/Lexer.h" #include "clang/Rewrite/Core/Rewriter.h" #include "clang/Rewrite/Frontend/FrontendActions.h" #include "clang/Tooling/CommonOptionsParser.h" #include "clang/Tooling/Refactoring.h" #include "clang/Tooling/Tooling.h" #include #include using namespace std; using namespace clang; using namespace clang::driver; using namespace clang::tooling; using namespace llvm; typedef struct { int line_no; } s; // Rewriter rewriter; LangOptions languageOptions; class S2EInstrumentationVisitor : public RecursiveASTVisitor { private: std::set expressions; std::set declarations; FunctionDecl *currentFunctionDecl; ASTContext *astContext; // used for getting additional AST info public: explicit S2EInstrumentationVisitor(CompilerInstance *CI) : astContext(&(CI->getASTContext())) // initialize private members { rewriter.setSourceMgr(astContext->getSourceManager(), astContext->getLangOpts()); } bool VisitDecl(Decl *Declaration) { if(FunctionDecl *funcDecl = dyn_cast(Declaration)) { // Function is defined in external translation unit if (funcDecl->isExternC()) { currentFunctionDecl = funcDecl; } else { currentFunctionDecl = NULL; } expressions.clear(); declarations.clear(); #if 0 llvm::outs() << "function name: " << funcDecl->getNameAsString() << " (return type = " << funcDecl->getResultType().getAsString() << ")\n"; unsigned paramCount = funcDecl->getNumParams(); llvm::outs() << "function param count: " << paramCount << "\n"; for(unsigned i = 0; i < paramCount; ++i) { llvm::outs() << "-param #" << i << "\n"; const ParmVarDecl *currentParam = funcDecl->getParamDecl(i); QualType userType = currentParam->getType(); while(userType->isPointerType()) { llvm::outs() << "\tpointer to" << "\n"; userType = userType->getPointeeType(); } if(userType.isConstQualified()) { llvm::outs() << "\tconst" << "\n"; } if(userType->isReferenceType()) { llvm::outs() << "\treference to" << "\n"; } userType = userType.getNonReferenceType().getUnqualifiedType(); llvm::outs() << "\t(type = " << userType.getAsString() << ", name = " << currentParam->getNameAsString() << ")\n"; } llvm::outs() << "\n"; #endif } if(VarDecl *varDecl = dyn_cast(Declaration)) { if(!dyn_cast(Declaration)) { //llvm::outs() << "variable type: " << varDecl->getType().getAsString() << ", variable name: " << varDecl->getNameAsString(); std::string name = varDecl->getNameAsString(); if(varDecl->hasInit()) { Expr* varInit = varDecl->getInit(); if(varInit->isRValue()) { // Works #if 0 SourceRange varSourceRange = varInit->getSourceRange(); if(!varSourceRange.isValid()) return true; CharSourceRange charSourceRange(varSourceRange, true); StringRef sourceText = Lexer::getSourceText(charSourceRange, astContext->getSourceManager(), astContext->getLangOpts(), 0); //llvm::outs() << ", initialization value: " << sourceText.str(); #endif if (isa(varInit)) { // Works CallExpr *Call = dyn_cast(varInit); Decl *D = Call->getCalleeDecl(); FunctionDecl *FD = Call->getDirectCallee(); std::string fname = FD->getNameInfo().getAsString(); if (FD->isExternC() && (fname == "malloc" || fname == "calloc")) { declarations.insert(varDecl); std::string str = "\ns2e_concretize_fork(" + name + ", " + "sizeof(" + name + "), " + "0" + ");\n"; //llvm::outs() << str; InstrumentStmtAfter(varDecl, str); } } } } } } return true; } // Get assigned variable bool GetAssignedVar(Stmt *s, std::string& name) { BinaryOperator *BinOp = dyn_cast(s); if (BinOp && BinOp->isAssignmentOp()) { Expr *Lhs = BinOp->getLHS(); if (DeclRefExpr *DRE = dyn_cast(Lhs)) { if (VarDecl *VD = dyn_cast(DRE->getDecl())) { name = VD->getQualifiedNameAsString(); return true; } } } return false; } // Override Statements which includes expressions and more bool VisitStmt(Stmt *s) { #if 0 Stmt *TH = If->getThen(); // Add braces if needed to then clause InstrumentStmt(TH); Stmt *EL = If->getElse(); if (EL) { // Add braces if needed to else clause InstrumentStmt(EL); } } else if (isa(s)) { ForStmt *For = cast(s); Stmt *BODY = For->getBody(); //InstrumentStmt(BODY); } #endif return true; // returning false aborts the traversal } virtual bool VisitCallExpr(CallExpr *CallE) { Decl *D = CallE->getCalleeDecl(); FunctionDecl *FD = CallE->getDirectCallee(); std::string fname = FD->getNameInfo().getAsString(); if(fname == "func") { //SourceLocation START = s->getLocStart(); /** Replace function **/ SourceRange range = CallE->getSourceRange(); SourceLocation source = range.getBegin(); rewriter.ReplaceText(source, "s2e"); llvm::outs() << "Begin: " << range.getBegin().printToString(rewriter.getSourceMgr()) << " End: " << range.getEnd().printToString(rewriter.getSourceMgr()) << "\n"; /** Replace function argument **/ //#if 0 //for (CallExpr::const_arg_iterator it = CallE->arg_begin(), ite = CallE->arg_end(); it != ite; ++it) { for (CallExpr::arg_iterator it = CallE->arg_begin(), ite = CallE->arg_end(); it != ite; ++it) { Expr *arg = *it; //SourceLocation source = arg->getExprLoc(); SourceRange r = arg->getSourceRange(); //SourceLocation begin = r.getBegin(); //SourceLocation end = r.getEnd(); SourceLocation begin(arg->getLocStart()), _e(arg->getLocEnd()); SourceLocation end(clang::Lexer::getLocForEndOfToken(_e, 0, rewriter.getSourceMgr(), languageOptions)); llvm::outs() << std::string(rewriter.getSourceMgr().getCharacterData(begin), rewriter.getSourceMgr().getCharacterData(end) - rewriter.getSourceMgr().getCharacterData(begin)) << "\n"; //rewriter.ReplaceText(source, "val"); //llvm::outs() << source.printToString(rewriter.getSourceMgr()) << "\n"; return true; } //#endif } return false; } virtual bool VisitBinaryOperator(BinaryOperator* BinaryOp) { if (BinaryOp->isAssignmentOp() && isa(BinaryOp->getRHS())) { Expr *Lhs = BinaryOp->getLHS(); std::string name; if (DeclRefExpr *DRE = dyn_cast(Lhs)) { if (VarDecl *VD = dyn_cast(DRE->getDecl())) { declarations.insert(VD); name = VD->getQualifiedNameAsString(); } } CallExpr *CallE = cast(BinaryOp->getRHS()); Decl *D = CallE->getCalleeDecl(); FunctionDecl *FD = CallE->getDirectCallee(); std::string fname = FD->getNameInfo().getAsString(); if (fname == "malloc" || fname == "calloc") { //expressions.insert(BinaryOp); //CallE->dumpPretty(*astContext); CallE->dumpColor(); llvm::outs() << "\n"; std::string str = "\ns2e_concretize_fork(" + name + ", " + "sizeof(" + name + "), " + "0" + ");\n"; //llvm::outs() << str; InstrumentStmtAfter(BinaryOp, str); return true; } } return false; } // Returns true if the condition was simple boolean virtual bool VisitBooleanCondition(Expr *Cond) { Expr *Var = Cond; bool negation = false; UnaryOperator *UnaryOp = dyn_cast(Cond); // Handles if (p) or if (!p) cases if (UnaryOp && UnaryOp->getOpcode() == UO_Not) { Var = UnaryOp->getSubExpr(); negation = true; } if (ImplicitCastExpr *Cast = dyn_cast(Var)) { VarDecl *VarD; std::string name; Expr *OriginalCast = Cast->getSubExpr(); if (DeclRefExpr *DRE = dyn_cast(OriginalCast)) { if (VarDecl *VD = dyn_cast(DRE->getDecl())) { VarD = VD; name = VD->getQualifiedNameAsString(); } } llvm::outs() << name << "\n"; } else { return true; } return true; } virtual bool VisitIfStmt(IfStmt* If) { Expr *Cond = If->getCond(); if (VisitBooleanCondition(Cond)) { return true; } #if 0 BinaryOperator *BinaryOp = dyn_cast(Cond); if (!BinaryOp->isEqualityOp() && !BinaryOp->isRelationalOp() && !BinaryOp->isComparisonOp()) { return true; } Expr *Lhs = BinaryOp->getLHS(); std::string name; VarDecl *VarD; // a == 5, p == NULL, s.x == 0 if (ImplicitCastExpr *Cast = dyn_cast(Lhs)) { Expr *OriginalCast = Cast->getSubExpr(); if (DeclRefExpr *DRE = dyn_cast(OriginalCast)) { if (VarDecl *VD = dyn_cast(DRE->getDecl())) { VarD = VD; name = VD->getQualifiedNameAsString(); } } else if (MemberExpr *Member = dyn_cast(OriginalCast)) { // Doesn't work ValueDecl *ValueD = Member->getMemberDecl(); DeclarationNameInfo Name = Member->getMemberNameInfo(); name = Name.getAsString(); //llvm::outs() << "LHS is " << name << "\n"; if (VarDecl *VD = dyn_cast(Member->getMemberDecl())) { VarD = VD; name = VD->getQualifiedNameAsString(); } } } // else if (DeclRefExpr *DRE = dyn_cast(Lhs)) { if (VarDecl *VD = dyn_cast(DRE->getDecl())) { VarD = VD; name = VD->getQualifiedNameAsString(); } } Expr *Rhs = BinaryOp->getRHS(); if (isa(Rhs)) { IntegerLiteral *IntLit = cast(Rhs); llvm::APSInt Result; Rhs->EvaluateAsInt(Result, *astContext); //int64_t result = Result.getExtValue(); llvm::APSInt concreteValue; BinaryOperatorKind opc = BinaryOp->getOpcode(); switch (opc) { case (BO_LE) : case (BO_GE) : case (BO_EQ) : { concreteValue = Result; break; } case (BO_NE) : { concreteValue = Result++; break; } case (BO_LT) : { concreteValue = Result--; break; } case (BO_GT) : { concreteValue = Result++; break; } default : { break; } } llvm::outs() << "Set value " << concreteValue.toString(10) << "\n"; #if 0 if (BinaryOp->isEqualityOp()) { } else if (BinaryOp->isComparisonOp()) { } else if (BinaryOp->isRelationalOp()) { } #endif //llvm::outs() << "Integer Literal: " << result.toString(10) << "\n"; //InstrumentStmtBefore(If); } else if (isa(Rhs)) { CharacterLiteral *CharLit = cast(Rhs); } else if (Rhs->isNullPointerConstant(*astContext, Expr::NPC_ValueDependentIsNotNull)) { // Works //InstrumentStmtBefore(If); } #endif return true; } void InstrumentStmtBefore(Stmt *s, const std::string& str) { if (!isa(s)) { SourceLocation START = s->getLocStart(); rewriter.InsertText(START, str, true, true); } else { SourceLocation START = s->getSourceRange().getBegin(); rewriter.InsertText(START, str, true, true); } } // InstrumentStmt - Add braces to line of code void InstrumentStmtAfter(Stmt *s, const std::string& str) { // Only perform if statement is not compound if (!isa(s)) { #if 0 SourceLocation ST = s->getLocStart(); // Insert opening brace. Note the second true parameter to InsertText() // says to indent. Sadly, it will indent to the line after the if, giving: // if (expr) // { // stmt; // } rewriter.InsertText(ST, "{\n", true, true); // Note Stmt::getLocEnd() returns the source location prior to the // token at the end of the line. For instance, for: // var = 123; // ^---- getLocEnd() points here. #endif SourceLocation END = s->getLocEnd(); // MeasureTokenLength gets us past the last token, and adding 1 gets // us past the ';'. int offset = Lexer::MeasureTokenLength(END, rewriter.getSourceMgr(), rewriter.getLangOpts()) + 1; SourceLocation END1 = END.getLocWithOffset(offset); rewriter.InsertText(END1, str, true, true); } } void InstrumentStmtAfter(Decl *d, const std::string& str) { #if 0 SourceLocation ST = s->getLocStart(); // Insert opening brace. Note the second true parameter to InsertText() // says to indent. Sadly, it will indent to the line after the if, giving: // if (expr) // { // stmt; // } rewriter.InsertText(ST, "{\n", true, true); // Note Stmt::getLocEnd() returns the source location prior to the // token at the end of the line. For instance, for: // var = 123; // ^---- getLocEnd() points here. #endif SourceLocation END = d->getLocEnd(); // MeasureTokenLength gets us past the last token, and adding 1 gets // us past the ';'. int offset = Lexer::MeasureTokenLength(END, rewriter.getSourceMgr(), rewriter.getLangOpts()) + 1; SourceLocation END1 = END.getLocWithOffset(offset); rewriter.InsertText(END1, str, true, true); } #if 0 virtual bool VisitFunctionDecl(FunctionDecl *func) { numFunctions++; string funcName = func->getNameInfo().getName().getAsString(); if (funcName == "do_math") { rewriter.ReplaceText(func->getLocation(), funcName.length(), "add5"); errs() << "** Rewrote function def: " << funcName << "\n"; } return true; } virtual bool VisitStmt(Stmt *st) { if (ReturnStmt *ret = dyn_cast(st)) { rewriter.ReplaceText(ret->getRetValue()->getLocStart(), 6, "val"); errs() << "** Rewrote ReturnStmt\n"; } if (CallExpr *call = dyn_cast(st)) { rewriter.ReplaceText(call->getLocStart(), 7, "add5"); errs() << "** Rewrote function call\n"; } return true; } // Override Binary Operator expressions virtual Expr *VisitBinaryOperator(BinaryOperator *E) { // Determine type of binary operator if (E->isLogicalOp()) { // Insert function call at start of first expression. // Note getLocStart() should work as well as getExprLoc() rewriter.InsertText(E->getLHS()->getExprLoc(), E->getOpcode() == BO_LAnd ? "L_AND(" : "L_OR(", true); // Replace operator ("||" or "&&") with "," rewriter.ReplaceText(E->getOperatorLoc(), E->getOpcodeStr().size(), ","); // Insert closing paren at end of right-hand expression rewriter.InsertTextAfterToken(E->getRHS()->getLocEnd(), ")"); } else // Note isComparisonOp() is like isRelationalOp() but includes == and != if (E->isRelationalOp()) { llvm::errs() << "Relational Op " << E->getOpcodeStr() << "\n"; } else // Handles == and != comparisons if (E->isEqualityOp()) { llvm::errs() << "Equality Op " << E->getOpcodeStr() << "\n"; } return E; } /* virtual bool VisitReturnStmt(ReturnStmt *ret) { rewriter.ReplaceText(ret->getRetValue()->getLocStart(), 6, "val"); errs() << "** Rewrote ReturnStmt\n"; return true; } virtual bool VisitCallExpr(CallExpr *call) { rewriter.ReplaceText(call->getLocStart(), 7, "add5"); errs() << "** Rewrote function call\n"; return true; } */ #endif }; class S2EInstrumentationASTConsumer : public ASTConsumer { private: S2EInstrumentationVisitor *visitor; // doesn't have to be private public: // override the constructor in order to pass CI explicit S2EInstrumentationASTConsumer(CompilerInstance *CI) : visitor(new S2EInstrumentationVisitor(CI)) // initialize the visitor { } #if 0 // override this to call our ExampleVisitor on the entire source file virtual void HandleTranslationUnit(ASTContext &Context) { /* we can use ASTContext to get the TranslationUnitDecl, which is a single Decl that collectively represents the entire source file */ visitor->TraverseDecl(Context.getTranslationUnitDecl()); //visitor->TraverseStmt(); } #endif // override this to call our ExampleVisitor on each top-level Decl virtual bool HandleTopLevelDecl(DeclGroupRef DG) { // a DeclGroupRef may have multiple Decls, so we iterate through each one for (DeclGroupRef::iterator i = DG.begin(), e = DG.end(); i != e; i++) { Decl *D = *i; visitor->TraverseDecl(D); // recursively visit each AST node in Decl "D" //D->dump(); } return true; } }; class S2EInstrumentationFrontendAction : public ASTFrontendAction { public: virtual ASTConsumer *CreateASTConsumer(CompilerInstance &CI, StringRef file) { return new S2EInstrumentationASTConsumer(&CI); // pass CI pointer to ASTConsumer } }; int main(int argc, const char **argv) { // Parse the command-line args passed to your code CommonOptionsParser op(argc, argv); // Create a new Clang Tool instance (a LibTooling environment) ClangTool Tool(op.getCompilations(), op.getSourcePathList()); languageOptions.GNUMode = 1; languageOptions.CXXExceptions = 1; languageOptions.RTTI = 1; languageOptions.Bool = 1; languageOptions.CPlusPlus = 1; // Run the Clang Tool, creating a new FrontendAction (explained below) //int result = Tool.run(newFrontendActionFactory()); int result = Tool.run(newFrontendActionFactory()); //result = Tool.run(newFrontendActionFactory()); // Print out the rewritten source code ("rewriter" is a global var.) rewriter.getEditBuffer(rewriter.getSourceMgr().getMainFileID()).write(errs()); //rewriter.getEditBuffer(rewriter.getSourceMgr().getMainFileID()).write(outs()); return result; }