Last active
August 9, 2025 03:32
-
-
Save shanebutler/96f0e78a02c84cdcf558 to your computer and use it in GitHub Desktop.
Revisions
-
shanebutler revised this gist
Aug 24, 2015 . 1 changed file with 2 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,5 +1,5 @@ # sql.export.rf(): save a randomForest model as SQL # v0.04 # Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com> # # sql.export.rf is free software: you can redistribute it and/or modify it @@ -41,7 +41,7 @@ sql.export.rf <- function (model, file, input.table="source_table", sink(file, type="output") if (model$type == "classification") { pred.type <- "VARCHAR" } else { pred.type <- "FLOAT" @@ -153,4 +153,3 @@ sql.export.rf <- function (model, file, input.table="source_table", # close the file sink() } -
shanebutler revised this gist
Aug 11, 2014 . 1 changed file with 7 additions and 7 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,7 +1,7 @@ # sql.export.rf(): save a randomForest model as SQL # v0.03 # Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com> # # sql.export.rf is free software: you can redistribute it and/or modify it # under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 2 of the License, or @@ -50,21 +50,21 @@ sql.export.rf <- function (model, file, input.table="source_table", if (variant == "teradata") { cat(paste("CREATE VOLATILE TABLE rf_predictions (\n", "\t",id," INT NOT NULL,\n", "\tpred ",pred.type,"\n", ") ON COMMIT PRESERVE ROWS;\n\n", "CREATE MULTISET VOLATILE TABLE tmp_rf (\n", "\t",id," INT NOT NULL,\n", "\tpred ",pred.type,"\n", ") ON COMMIT PRESERVE ROWS;\n\n",sep="")) } else { cat(paste("CREATE TABLE rf_predictions (\n", "\t",id," INT NOT NULL,\n", "\tpred ",pred.type,"\n", ");\n\n", "DROP TABLE IF EXISTS tmp_rf;\n\n", "CREATE TABLE tmp_rf (\n", "\t",id," INT NOT NULL,\n", "\tpred ",pred.type,"\n", ");\n\n",sep="")) } -
shanebutler revised this gist
Aug 10, 2014 . 1 changed file with 38 additions and 17 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,5 +1,5 @@ # sql.export.rf(): save a randomForest model as SQL # v0.02 # Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com> # # sql.export.rf is free software: you can redistribute it and/or modify it @@ -16,6 +16,12 @@ # along with sql.export.rf. If not, see <http://www.gnu.org/licenses/>. # # ## NOTE: # This code generates SQL scoring code from your randomForest model. # Currently the generated code is not optimal since it makes as many # passes over the input data as there are trees (ie. if there are 500 # trees there will be 500 INSERT... SELECT statements) # ## USAGE: # sql.export.rf(rf1, file="model_output.SQL", input.table="data", id="id") # @@ -24,9 +30,7 @@ # sql.export.rf <- function (model, file, input.table="source_table", id="id", variant="generic") { require (randomForest, quietly=TRUE) @@ -42,16 +46,27 @@ sql.export.rf <- function (model, file, input.table="source_table", } else { pred.type <- "FLOAT" } if (variant == "teradata") { cat(paste("CREATE VOLATILE TABLE rf_predictions (\n", "\t",id," INT NOT NULL,\n", "\tpred ",pred.type," NOT NULL\n", ") ON COMMIT PRESERVE ROWS;\n\n", "CREATE VOLATILE TABLE tmp_rf (\n", "\t",id," INT NOT NULL,\n", "\tpred ",pred.type," NOT NULL\n", ") ON COMMIT PRESERVE ROWS;\n\n",sep="")) } else { cat(paste("CREATE TABLE rf_predictions (\n", "\t",id," INT NOT NULL,\n", "\tpred ",pred.type," NOT NULL\n", ");\n\n", "DROP TABLE IF EXISTS tmp_rf;\n\n", "CREATE TABLE tmp_rf (\n", "\t",id," INT NOT NULL,\n", "\tpred ",pred.type," NOT NULL\n", ");\n\n",sep="")) } for (tree.num in 1:(model$ntree)) { cat(paste("INSERT INTO tmp_rf\nSELECT ",id,",", sep="")) @@ -114,22 +129,28 @@ sql.export.rf <- function (model, file, input.table="source_table", if (model$type == "classification") { # This code is not optimal but many SQL implementations do not support window functions (eg. SQLite) # Had to remove use of WITH because not supported by all SQL variants cat(paste("INSERT INTO rf_predictions\n", "SELECT a.id, a.pred\n", "FROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) a\n", "INNER JOIN (SELECT id, MAX(cnt) as cnt\n", "\t\t\tFROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) cc\n", "\t\t\tGROUP BY id) b\n", "ON a.id = b.id AND a.cnt = b.cnt;\n\n", sep="")) } else { cat(paste("INSERT INTO rf_predictions\n", "SELECT ",id,", AVG(pred)\n", "FROM tmp_rf\n", "GROUP BY ",id,";\n\n", sep="")) } if (variant == "teradata") { cat("DROP TABLE tmp_rf;\n\n") } else { cat("DROP TABLE IF EXISTS tmp_rf;\n\n") } # close the file sink() } -
shanebutler created this gist
Aug 9, 2014 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,135 @@ # sql.export.rf(): save a randomForest model as SQL # v0.01 # Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com> # # sql.export.rf is free software: you can redistribute it and/or modify it # under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 2 of the License, or # (at your option) any later version. # # sql.export.rf is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # General Public License for more details. # # You should have received a copy of the GNU General Public License # along with sql.export.rf. If not, see <http://www.gnu.org/licenses/>. # # ## USAGE: # sql.export.rf(rf1, file="model_output.SQL", input.table="data", id="id") # ## ARGUMENTS: # variant: Optional argument for Teradata variant="teradata" # sql.export.rf <- function (model, file, input.table="source_table", #n.trees=NULL, # need this? id="id", trees.per.query=1, variant="generic") { require (randomForest, quietly=TRUE) if (!("randomForest" %in% class(model))) { stop ("Expected a randomForest object") return } sink(file, type="output") if (model$type == "classification" && is.numeric(t$prediction)==FALSE) { pred.type <- "VARCHAR" } else { pred.type <- "FLOAT" } cat(paste("CREATE TABLE rf_predictions (\n", "\t",id," INT NOT NULL,\n", "\tpred ",pred.type," NOT NULL\n", ");\n\n", "DROP TABLE IF EXISTS tmp_rf;\n\n", "CREATE TABLE tmp_rf (\n", "\t",id," INT NOT NULL,\n", "\tpred ",pred.type," NOT NULL\n", ");\n\n",sep="")) for (tree.num in 1:(model$ntree)) { cat(paste("INSERT INTO tmp_rf\nSELECT ",id,",", sep="")) recurse.rf <- function(model, tree.data, tree.row.num, ind=0) { tree.row <- tree.data[tree.row.num,] indent.str <- paste(rep("\t", ind), collapse="") split.var <- as.character(tree.row[,"split var"]) split.point <- tree.row[,"split point"] if(tree.row[,"status"] != -1) { # splitting node if(is.numeric(unlist(model$forest$xlevels[split.var]))) { cat(paste("\n",indent.str,"CASE WHEN", gsub("[.]","_",split.var), "IS NULL THEN NULL", "\n",indent.str,"WHEN", gsub("[.]","_",split.var), "<=", split.point, "THEN ")) recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1)) cat("\n",indent.str,"ELSE ") recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1)) cat("END ") } else { # categorical # function to convert from binary coding to the category values it represents conv.to.binary <- function (ncat, num.to.convert) { ret <- numeric() if((2^ncat) <= num.to.convert) { return (NULL) } else { for (x in (ncat - 1):0) { if (num.to.convert >= (2^x)) { num.to.convert <- num.to.convert - (2^x) ret <- c(ret, 1) } else { ret <- c(ret, 0) } } return(ret) } } categ.bin <- conv.to.binary(model$forest$ncat[split.var], split.point) categ.flags <- (categ.bin[length(categ.bin):1] == 1) categ.values <- unlist(model$forest$xlevels[split.var]) cat(paste("\n",indent.str,"CASE WHEN ", gsub("[.]","_",split.var), " IN ('", paste(categ.values[categ.flags], sep="", collapse="', '"), #FIXME replace quotes dependant on var type "') THEN ", sep="")) recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1)) cat(paste("\n",indent.str,"WHEN ", gsub("[.]","_",split.var), " IN ('", paste(categ.values[!categ.flags], sep="", collapse="', '"), "') THEN ", sep="")) recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1)) cat(paste("\n", indent.str,"ELSE NULL END ", sep="")) #FIXME: null or a new category } } else { # terminal node if (is.numeric(tree.data$prediction)) { cat(paste(tree.row[,"prediction"], " ", sep="")) } else { cat(paste("'", tree.row[,"prediction"], "' ", sep="")) } } } recurse.rf(model, getTree(model,k=tree.num,labelVar=TRUE), 1) cat(paste("as tree",tree.num,"\nFROM ",input.table,";\n\n", sep="")) } if (model$type == "classification") { # This code is not optimal but many SQL implementations do not support window functions (eg. SQLite) cat(paste("INSERT INTO rf_predictions\n", "SELECT a.id, a.pred\n", "FROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) a\n", "INNER JOIN (SELECT id, MAX(cnt) as cnt\n", "\t\t\tFROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) cc\n", "\t\t\tGROUP BY id) b\n", "ON a.id = b.id AND a.cnt = b.cnt;\n\n", "DROP TABLE IF EXISTS tmp_rf;\n", sep="")) } else { cat(paste("INSERT INTO rf_predictions\n", "SELECT ",id,", AVG(pred)\n", "FROM tmp_rf\n", "GROUP BY ",id,";\n\n", "DROP TABLE IF EXISTS tmp_rf;\n", sep="")) } # close the file sink() }