package com.dong.java.test; import java.util.HashMap; import java.util.Map; public class CheckSQL { private static String[] DBNAMES={""}; private static String[] keywords={"alter ","create ", "delete ", "do ", "drop ", "handler ", "insert ", "load data infile ", "rename ", "replace ", "subquery ", "truncate ", "update "}; /** * 检查输入的sql * @param sql * @return */ public boolean checkSql(String sql) { sql=sql.toLowerCase(); if(!sql.startsWith("select")) { // sql不是以select开头 return false; } for(String keyword: keywords) { if(sql.indexOf(keyword) >= 0) { // sql中包含不支持的关键字 return false; } } String[] dbs=DBNAMES; String[] sqlsFrom=sql.split(" from "); String[] sqlsJoin=sql.split(" join "); Map fromMap=dbname(sqlsFrom); Map joinMap=dbname(sqlsJoin); fromMap.putAll(joinMap); for(String s: fromMap.keySet()) { // 配对sql中的数据库名和允许的数据库 boolean flag=false; for(String db: dbs) { if(db.equals(s)) { flag=true; } } if(flag == false) { // 如果有一个不匹配,为false return false; } } return true; } public Map dbname(String[] sql) { Map names=new HashMap(); if(sql.length > 1) { for(int i=0; i < sql.length; i++) { if(i == 0) { continue; } String s=sql[i].split("\\.")[0]; if(s.charAt(0) >= 'a' && s.charAt(0) <= 'z') { names.put(s.trim(), s.trim()); } } } return names; } }