Skip to content

Instantly share code, notes, and snippets.

@shanebutler
Last active August 9, 2025 03:32
Show Gist options
  • Select an option

  • Save shanebutler/96f0e78a02c84cdcf558 to your computer and use it in GitHub Desktop.

Select an option

Save shanebutler/96f0e78a02c84cdcf558 to your computer and use it in GitHub Desktop.

Revisions

  1. shanebutler revised this gist Aug 24, 2015. 1 changed file with 2 additions and 3 deletions.
    5 changes: 2 additions & 3 deletions sql.export.randomForest.R
    Original file line number Diff line number Diff line change
    @@ -1,5 +1,5 @@
    # sql.export.rf(): save a randomForest model as SQL
    # v0.03
    # v0.04
    # Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com>
    #
    # sql.export.rf is free software: you can redistribute it and/or modify it
    @@ -41,7 +41,7 @@ sql.export.rf <- function (model, file, input.table="source_table",

    sink(file, type="output")

    if (model$type == "classification" && is.numeric(t$prediction)==FALSE) {
    if (model$type == "classification") {
    pred.type <- "VARCHAR"
    } else {
    pred.type <- "FLOAT"
    @@ -153,4 +153,3 @@ sql.export.rf <- function (model, file, input.table="source_table",
    # close the file
    sink()
    }

  2. shanebutler revised this gist Aug 11, 2014. 1 changed file with 7 additions and 7 deletions.
    14 changes: 7 additions & 7 deletions sql.export.randomForest.R
    Original file line number Diff line number Diff line change
    @@ -1,7 +1,7 @@
    # sql.export.rf(): save a randomForest model as SQL
    # v0.02
    # v0.03
    # Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com>
    #
    #
    # sql.export.rf is free software: you can redistribute it and/or modify it
    # under the terms of the GNU General Public License as published by
    # the Free Software Foundation, either version 2 of the License, or
    @@ -50,21 +50,21 @@ sql.export.rf <- function (model, file, input.table="source_table",
    if (variant == "teradata") {
    cat(paste("CREATE VOLATILE TABLE rf_predictions (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    "\tpred ",pred.type,"\n",
    ") ON COMMIT PRESERVE ROWS;\n\n",
    "CREATE VOLATILE TABLE tmp_rf (\n",
    "CREATE MULTISET VOLATILE TABLE tmp_rf (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    "\tpred ",pred.type,"\n",
    ") ON COMMIT PRESERVE ROWS;\n\n",sep=""))
    } else {
    cat(paste("CREATE TABLE rf_predictions (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    "\tpred ",pred.type,"\n",
    ");\n\n",
    "DROP TABLE IF EXISTS tmp_rf;\n\n",
    "CREATE TABLE tmp_rf (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    "\tpred ",pred.type,"\n",
    ");\n\n",sep=""))
    }

  3. shanebutler revised this gist Aug 10, 2014. 1 changed file with 38 additions and 17 deletions.
    55 changes: 38 additions & 17 deletions sql.export.randomForest.R
    Original file line number Diff line number Diff line change
    @@ -1,5 +1,5 @@
    # sql.export.rf(): save a randomForest model as SQL
    # v0.01
    # v0.02
    # Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com>
    #
    # sql.export.rf is free software: you can redistribute it and/or modify it
    @@ -16,6 +16,12 @@
    # along with sql.export.rf. If not, see <http://www.gnu.org/licenses/>.
    #
    #
    ## NOTE:
    # This code generates SQL scoring code from your randomForest model.
    # Currently the generated code is not optimal since it makes as many
    # passes over the input data as there are trees (ie. if there are 500
    # trees there will be 500 INSERT... SELECT statements)
    #
    ## USAGE:
    # sql.export.rf(rf1, file="model_output.SQL", input.table="data", id="id")
    #
    @@ -24,9 +30,7 @@
    #

    sql.export.rf <- function (model, file, input.table="source_table",
    #n.trees=NULL, # need this?
    id="id",
    trees.per.query=1,
    variant="generic") {
    require (randomForest, quietly=TRUE)

    @@ -42,16 +46,27 @@ sql.export.rf <- function (model, file, input.table="source_table",
    } else {
    pred.type <- "FLOAT"
    }

    cat(paste("CREATE TABLE rf_predictions (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    ");\n\n",
    "DROP TABLE IF EXISTS tmp_rf;\n\n",
    "CREATE TABLE tmp_rf (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    ");\n\n",sep=""))

    if (variant == "teradata") {
    cat(paste("CREATE VOLATILE TABLE rf_predictions (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    ") ON COMMIT PRESERVE ROWS;\n\n",
    "CREATE VOLATILE TABLE tmp_rf (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    ") ON COMMIT PRESERVE ROWS;\n\n",sep=""))
    } else {
    cat(paste("CREATE TABLE rf_predictions (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    ");\n\n",
    "DROP TABLE IF EXISTS tmp_rf;\n\n",
    "CREATE TABLE tmp_rf (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    ");\n\n",sep=""))
    }

    for (tree.num in 1:(model$ntree)) {
    cat(paste("INSERT INTO tmp_rf\nSELECT ",id,",", sep=""))
    @@ -114,22 +129,28 @@ sql.export.rf <- function (model, file, input.table="source_table",

    if (model$type == "classification") {
    # This code is not optimal but many SQL implementations do not support window functions (eg. SQLite)
    # Had to remove use of WITH because not supported by all SQL variants
    cat(paste("INSERT INTO rf_predictions\n",
    "SELECT a.id, a.pred\n",
    "FROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) a\n",
    "INNER JOIN (SELECT id, MAX(cnt) as cnt\n",
    "\t\t\tFROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) cc\n",
    "\t\t\tGROUP BY id) b\n",
    "ON a.id = b.id AND a.cnt = b.cnt;\n\n",
    "DROP TABLE IF EXISTS tmp_rf;\n", sep=""))
    "ON a.id = b.id AND a.cnt = b.cnt;\n\n", sep=""))
    } else {
    cat(paste("INSERT INTO rf_predictions\n",
    "SELECT ",id,", AVG(pred)\n",
    "FROM tmp_rf\n",
    "GROUP BY ",id,";\n\n",
    "DROP TABLE IF EXISTS tmp_rf;\n", sep=""))
    "GROUP BY ",id,";\n\n", sep=""))
    }

    if (variant == "teradata") {
    cat("DROP TABLE tmp_rf;\n\n")
    } else {
    cat("DROP TABLE IF EXISTS tmp_rf;\n\n")
    }

    # close the file
    sink()
    }

  4. shanebutler created this gist Aug 9, 2014.
    135 changes: 135 additions & 0 deletions sql.export.randomForest.R
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,135 @@
    # sql.export.rf(): save a randomForest model as SQL
    # v0.01
    # Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com>
    #
    # sql.export.rf is free software: you can redistribute it and/or modify it
    # under the terms of the GNU General Public License as published by
    # the Free Software Foundation, either version 2 of the License, or
    # (at your option) any later version.
    #
    # sql.export.rf is distributed in the hope that it will be useful, but
    # WITHOUT ANY WARRANTY; without even the implied warranty of
    # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
    # General Public License for more details.
    #
    # You should have received a copy of the GNU General Public License
    # along with sql.export.rf. If not, see <http://www.gnu.org/licenses/>.
    #
    #
    ## USAGE:
    # sql.export.rf(rf1, file="model_output.SQL", input.table="data", id="id")
    #
    ## ARGUMENTS:
    # variant: Optional argument for Teradata variant="teradata"
    #

    sql.export.rf <- function (model, file, input.table="source_table",
    #n.trees=NULL, # need this?
    id="id",
    trees.per.query=1,
    variant="generic") {
    require (randomForest, quietly=TRUE)

    if (!("randomForest" %in% class(model))) {
    stop ("Expected a randomForest object")
    return
    }

    sink(file, type="output")

    if (model$type == "classification" && is.numeric(t$prediction)==FALSE) {
    pred.type <- "VARCHAR"
    } else {
    pred.type <- "FLOAT"
    }

    cat(paste("CREATE TABLE rf_predictions (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    ");\n\n",
    "DROP TABLE IF EXISTS tmp_rf;\n\n",
    "CREATE TABLE tmp_rf (\n",
    "\t",id," INT NOT NULL,\n",
    "\tpred ",pred.type," NOT NULL\n",
    ");\n\n",sep=""))

    for (tree.num in 1:(model$ntree)) {
    cat(paste("INSERT INTO tmp_rf\nSELECT ",id,",", sep=""))
    recurse.rf <- function(model, tree.data, tree.row.num, ind=0) {
    tree.row <- tree.data[tree.row.num,]
    indent.str <- paste(rep("\t", ind), collapse="")
    split.var <- as.character(tree.row[,"split var"])
    split.point <- tree.row[,"split point"]
    if(tree.row[,"status"] != -1) { # splitting node
    if(is.numeric(unlist(model$forest$xlevels[split.var]))) {
    cat(paste("\n",indent.str,"CASE WHEN", gsub("[.]","_",split.var), "IS NULL THEN NULL",
    "\n",indent.str,"WHEN", gsub("[.]","_",split.var), "<=", split.point, "THEN "))
    recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1))
    cat("\n",indent.str,"ELSE ")
    recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1))
    cat("END ")
    } else { # categorical
    # function to convert from binary coding to the category values it represents
    conv.to.binary <- function (ncat, num.to.convert) {
    ret <- numeric()
    if((2^ncat) <= num.to.convert) {
    return (NULL)
    } else {
    for (x in (ncat - 1):0) {
    if (num.to.convert >= (2^x)) {
    num.to.convert <- num.to.convert - (2^x)
    ret <- c(ret, 1)
    } else {
    ret <- c(ret, 0)
    }
    }
    return(ret)
    }
    }
    categ.bin <- conv.to.binary(model$forest$ncat[split.var], split.point)
    categ.flags <- (categ.bin[length(categ.bin):1] == 1)

    categ.values <- unlist(model$forest$xlevels[split.var])
    cat(paste("\n",indent.str,"CASE WHEN ", gsub("[.]","_",split.var), " IN ('",
    paste(categ.values[categ.flags], sep="", collapse="', '"), #FIXME replace quotes dependant on var type
    "') THEN ", sep=""))
    recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1))
    cat(paste("\n",indent.str,"WHEN ", gsub("[.]","_",split.var), " IN ('",
    paste(categ.values[!categ.flags], sep="", collapse="', '"),
    "') THEN ", sep=""))
    recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1))
    cat(paste("\n", indent.str,"ELSE NULL END ", sep="")) #FIXME: null or a new category
    }
    } else { # terminal node
    if (is.numeric(tree.data$prediction)) {
    cat(paste(tree.row[,"prediction"], " ", sep=""))
    } else {
    cat(paste("'", tree.row[,"prediction"], "' ", sep=""))
    }
    }
    }
    recurse.rf(model, getTree(model,k=tree.num,labelVar=TRUE), 1)
    cat(paste("as tree",tree.num,"\nFROM ",input.table,";\n\n", sep=""))
    }

    if (model$type == "classification") {
    # This code is not optimal but many SQL implementations do not support window functions (eg. SQLite)
    cat(paste("INSERT INTO rf_predictions\n",
    "SELECT a.id, a.pred\n",
    "FROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) a\n",
    "INNER JOIN (SELECT id, MAX(cnt) as cnt\n",
    "\t\t\tFROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) cc\n",
    "\t\t\tGROUP BY id) b\n",
    "ON a.id = b.id AND a.cnt = b.cnt;\n\n",
    "DROP TABLE IF EXISTS tmp_rf;\n", sep=""))
    } else {
    cat(paste("INSERT INTO rf_predictions\n",
    "SELECT ",id,", AVG(pred)\n",
    "FROM tmp_rf\n",
    "GROUP BY ",id,";\n\n",
    "DROP TABLE IF EXISTS tmp_rf;\n", sep=""))
    }

    # close the file
    sink()
    }