Predições XGBoost diretamente pelo SQL

O que você vai aprender a fazer ao final da leitura é:

Ou seja, uma vez que você tenha um modelo de XGBoost ajustado dentro do seu R, dá para pedir para o {tidypredict} gerar a query de SQL que calcularia o score. PS: o XGBoost do gif tem 2 árvores.

Motivos para usar

Se existir um servidor com SQL muito poderoso (ou cluster, etc), a predição pode ser feita diretamente dentro do banco de dados, sem necessidade de outras ferramentas. Além disso, colocar em produção estará a um CTRL+C/CTRL+V de distância =P.

Preditor do XGBoost

O valor predito pelo XGBoost é, no fundo, uma soma de CASE WHEN’s (ou IF ELSE’s). O número de árvores que vai definir quantos CASE WHEN’s serão somados.

A cara do preditor de um XGBoost com 2 árvores ficaria assim no SQL:

<SQL> 0.0 + CASE
WHEN ((`dose` < 5.0 OR ((`dose`) IS NULL))) THEN (-1.95000005)
WHEN (`dose` >= 14.0 AND `dose` >= 5.0) THEN (-1.6500001)
WHEN ((`dose` < 10.0 OR ((`dose`) IS NULL)) AND (`dose` < 14.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (1.05000007)
WHEN (`dose` >= 10.0 AND (`dose` < 14.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (1.35000002)
END + CASE
WHEN ((`dose` < 5.0 OR ((`dose`) IS NULL))) THEN (-1.36500013)
WHEN (`dose` >= 14.0 AND `dose` >= 5.0) THEN (-1.15499997)
WHEN ((`dose` < 10.0 OR ((`dose`) IS NULL)) AND (`dose` < 14.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (0.734999955)
WHEN (`dose` >= 10.0 AND (`dose` < 14.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (0.945000052)
END + 0.5

Código minimal

O verdadeiro herói do código abaixo é a linha tidypredict_sql(xgb_fit$fit, con) do pacote {tidypredict} que recebe o objeto com o XGBoost ajustado e retorna a query de SQL com os CASE WHEN’s.

Ajuste do XGBoost com {tidymodels}

library(tidymodels)
library(tidypredict)
library(xgboost)

# dados ----------------------------------------------------
data <- tibble(
  dose = c(2, 8, 12, 18),
  eficacia = c(-6, 4, 5, -5)
)

# especificacao do modelo ---------------------------------
xgb_model <- boost_tree(min_n = 1, 
  trees = 2
) %>%
  set_engine("xgboost", lambda = 0) %>%
  set_mode("regression")

# ajuste do modelo -----------------------------------------
xgb_fit <- fit(xgb_model, eficacia ~ dose, data = data)

# Conexão com o banco SQL ----------------------------------
con <- DBI::dbConnect(RSQLite::SQLite(), "meu_sqlite_db.db")

# Transcrevendo XGBoost para SQL ---------------------------
tidypredict_sql(xgb_fit$fit, con)
## <SQL> 0.0 + CASE
## WHEN ((`dose` < 5.0 OR ((`dose`) IS NULL))) THEN (-1.95000005)
## WHEN (`dose` >= 15.0 AND `dose` >= 5.0) THEN (-1.6500001)
## WHEN ((`dose` < 10.0 OR ((`dose`) IS NULL)) AND (`dose` < 15.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (1.05000007)
## WHEN (`dose` >= 10.0 AND (`dose` < 15.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (1.35000002)
## END + CASE
## WHEN ((`dose` < 5.0 OR ((`dose`) IS NULL))) THEN (-1.36500013)
## WHEN (`dose` >= 15.0 AND `dose` >= 5.0) THEN (-1.15499997)
## WHEN ((`dose` < 10.0 OR ((`dose`) IS NULL)) AND (`dose` < 15.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (0.734999955)
## WHEN (`dose` >= 10.0 AND (`dose` < 15.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (0.945000052)
## END + 0.5

Usando o {dplyr} para rodar código no SQL

# apenas subindo os dados no banco de dados SQL ------
copy_to(con, data, "data", overwrite = TRUE)

# Criando coluna `pred` com as predições em SQL ------
data_sql <- tbl(con, "data") %>%
  mutate(
    pred = !!tidypredict_sql(xgb_fit$fit, con)
  )

# resultado -------------------------------------------
data_sql
## # Source:   lazy query [?? x 3]
## # Database: sqlite 3.35.5 [D:\blog\content\posts\meu_sqlite_db.db]
##    dose eficacia  pred
##   <dbl>    <dbl> <dbl>
## 1     2       -6 -2.82
## 2     8        4  2.29
## 3    12        5  2.80
## 4    18       -5 -2.31

PS: a função tidypredict_sql() devolve uma string e por isso usamos o operador !! para interpretar como código de R literal. Essa parte pode ser confusa porque o dplyr vai pegar o código, traduzir para o literal, mas não vai rodar o código no R, vai traduzir do R para o SQL diretamente (e é por isso que dá certo porque a string não é um código de R válido =P).

SQL por trás dos panos

A tabela data_sql gerada acima pelo {dplyr} é, na verdade, uma query SQL. A função show_query() mostra a tradução resultante:

show_query(data_sql)
## <SQL>
## SELECT `dose`, `eficacia`, 0.0 + CASE
## WHEN ((`dose` < 5.0 OR ((`dose`) IS NULL))) THEN (-1.95000005)
## WHEN (`dose` >= 15.0 AND `dose` >= 5.0) THEN (-1.6500001)
## WHEN ((`dose` < 10.0 OR ((`dose`) IS NULL)) AND (`dose` < 15.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (1.05000007)
## WHEN (`dose` >= 10.0 AND (`dose` < 15.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (1.35000002)
## END + CASE
## WHEN ((`dose` < 5.0 OR ((`dose`) IS NULL))) THEN (-1.36500013)
## WHEN (`dose` >= 15.0 AND `dose` >= 5.0) THEN (-1.15499997)
## WHEN ((`dose` < 10.0 OR ((`dose`) IS NULL)) AND (`dose` < 15.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (0.734999955)
## WHEN (`dose` >= 10.0 AND (`dose` < 15.0 OR ((`dose`) IS NULL)) AND `dose` >= 5.0) THEN (0.945000052)
## END + 0.5 AS `pred`
## FROM `data`

Resumo

  • XGBoost para SQL é simples como uma linha de código: tidypredict_sql(xgb_fit$fit, con)
  • A query SQL pode ser usada dentro do mutate() com o !!
  • Predições com XGBoost não é nenhum bicho de sete cabeças, são singelas somas de CASE WHEN’s.

Apêndice: as duas árvores

library(DiagrammeR)
xgb.plot.tree(model=xgb_fit$fit)
comments powered by Disqus