O que é regressão logística?
A regressão logística é usada para prever uma classe, ou seja, uma probabilidade. A regressão logística pode prever um resultado binário com precisão.
Imagine que você deseja prever se um empréstimo será negado / aceito com base em muitos atributos. A regressão logística é da forma 0/1. y = 0 se um empréstimo for rejeitado, y = 1 se aceito.
Um modelo de regressão logística difere do modelo de regressão linear de duas maneiras.
- Em primeiro lugar, a regressão logística aceita apenas entrada dicotômica (binária) como variável dependente (ou seja, um vetor de 0 e 1).
- Em segundo lugar, o resultado é medido pela seguinte função de ligação probabilística chamada sigmóide devido à sua forma de S:
A saída da função está sempre entre 0 e 1. Verifique a imagem abaixo
A função sigmóide retorna valores de 0 a 1. Para a tarefa de classificação, precisamos de uma saída discreta de 0 ou 1.
Para converter um fluxo contínuo em valor discreto, podemos definir um limite de decisão em 0,5. Todos os valores acima deste limite são classificados como 1
Neste tutorial, você aprenderá
- O que é regressão logística?
- Como criar um modelo de revestimento generalizado (GLM)
- Etapa 1) Verifique as variáveis contínuas
- Etapa 2) Verifique as variáveis do fator
- Etapa 3) Engenharia de recursos
- Etapa 4) Estatística resumida
- Etapa 5) Conjunto de treinamento / teste
- Etapa 6) Construir o modelo
- Etapa 7) Avalie o desempenho do modelo
Como criar um modelo de revestimento generalizado (GLM)
Vamos usar o conjunto de dados de adultos para ilustrar a regressão logística. O "adulto" é um ótimo conjunto de dados para a tarefa de classificação. O objetivo é prever se a renda anual em dólares de um indivíduo será superior a 50.000. O conjunto de dados contém 46.033 observações e dez recursos:
- idade: idade do indivíduo. Numérico
- educação: Nível educacional do indivíduo. Fator.
- estado marital: estado civil do indivíduo. Fator ou seja, nunca se casou, casou-civil-cônjuge, ...
- gênero: gênero do indivíduo. Fator, ou seja, masculino ou feminino
- renda: variável alvo. Renda acima ou abaixo de 50K. Fator ou seja,> 50K, <= 50K
entre outros
library(dplyr)data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")glimpse(data_adult)
Resultado:
Observations: 48,842Variables: 10$ x1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,… $ age 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26… $ workclass Private, Private, Local-gov, Private, ?, Private,… $ education 11th, HS-grad, Assoc-acdm, Some-college, Some-col… $ educational.num 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,… $ marital.status Never-married, Married-civ-spouse, Married-civ-sp… $ race Black, White, White, Black, White, White, Black,… $ gender Male, Male, Male, Male, Female, Male, Male, Male,… $ hours.per.week 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39… $ income <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5…
Vamos proceder da seguinte forma:
- Etapa 1: verificar as variáveis contínuas
- Etapa 2: verificar as variáveis do fator
- Etapa 3: Engenharia de recursos
- Etapa 4: estatística de resumo
- Etapa 5: conjunto de treinamento / teste
- Etapa 6: construir o modelo
- Etapa 7: avalie o desempenho do modelo
- etapa 8: melhorar o modelo
Sua tarefa é prever qual indivíduo terá uma receita superior a 50 mil.
Neste tutorial, cada etapa será detalhada para realizar uma análise em um conjunto de dados real.
Etapa 1) Verifique as variáveis contínuas
Na primeira etapa, você pode ver a distribuição das variáveis contínuas.
continuous <-select_if(data_adult, is.numeric)summary(continuous)
Explicação do código
- contínuo <- select_if (data_adult, is.numeric): Use a função select_if () da biblioteca dplyr para selecionar apenas as colunas numéricas
- resumo (contínuo): Imprime a estatística de resumo
Resultado:
## X age educational.num hours.per.week## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00## Median :23017 Median :37.00 Median :10.00 Median :40.00## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Na tabela acima, você pode ver que os dados têm escalas totalmente diferentes e hours.per.weeks tem grandes outliers (por exemplo, observe o último quartil e o valor máximo).
Você pode lidar com isso seguindo duas etapas:
- 1: Trace a distribuição de horas.por.semana
- 2: Padronizar as variáveis contínuas
- Trace a distribuição
Vamos examinar mais de perto a distribuição de hours.per.week
# Histogram with kernel density curvelibrary(ggplot2)ggplot(continuous, aes(x = hours.per.week)) +geom_density(alpha = .2, fill = "#FF6666")
Resultado:
A variável tem muitos outliers e uma distribuição não bem definida. Você pode resolver parcialmente esse problema excluindo os primeiros 0,01% das horas semanais.
Sintaxe básica do quantil:
quantile(variable, percentile)arguments:-variable: Select the variable in the data frame to compute the percentile-percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C,… )- `A`,`B`,`C` and `… ` are all integer from 0 to 1.
Calculamos o primeiro percentil 2 por cento
top_one_percent <- quantile(data_adult$hours.per.week, .99)top_one_percent
Explicação do código
- quantil (data_adult $ hours.per.week, .99): Calcule o valor de 99 por cento do tempo de trabalho
Resultado:
## 99%## 80
98 por cento da população trabalha menos de 80 horas por semana.
Você pode descartar as observações acima desse limite. Você usa o filtro da biblioteca dplyr.
data_adult_drop <-data_adult %>%filter(hours.per.weekResultado:
## [1] 45537 10
- Padronizar as variáveis contínuas
Você pode padronizar cada coluna para melhorar o desempenho porque seus dados não têm a mesma escala. Você pode usar a função mutate_if da biblioteca dplyr. A sintaxe básica é:
mutate_if(df, condition, funs(function))arguments:-`df`: Data frame used to compute the function- `condition`: Statement used. Do not use parenthesis- funs(function): Return the function to apply. Do not use parenthesis for the functionVocê pode padronizar as colunas numéricas da seguinte forma:
data_adult_rescale <- data_adult_drop % > %mutate_if(is.numeric, funs(as.numeric(scale(.))))head(data_adult_rescale)Explicação do código
- mutate_if (is.numeric, funs (scale)): A condição é apenas coluna numérica e a função é escala
Resultado:
## X age workclass education educational.num## 1 -1.732680 -1.02325949 Private 11th -1.22106443## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494## 4 -1.732455 0.41426100 Private Some-college -0.04945081## 5 -1.732379 -0.34232873 Private 10th -1.61160231## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857## marital.status race gender hours.per.week income## 1 Never-married Black Male -0.03995944 <=50K## 2 Married-civ-spouse White Male 0.86863037 <=50K## 3 Married-civ-spouse White Male -0.03995944 >50K## 4 Married-civ-spouse Black Male -0.03995944 >50K## 5 Never-married White Male -0.94854924 <=50K## 6 Married-civ-spouse White Male -0.76683128 >50KEtapa 2) Verifique as variáveis do fator
Esta etapa tem dois objetivos:
- Verifique o nível em cada coluna categórica
- Defina novos níveis
Vamos dividir esta etapa em três partes:
- Selecione as colunas categóricas
- Armazene o gráfico de barras de cada coluna em uma lista
- Imprima os gráficos
Podemos selecionar as colunas do fator com o código abaixo:
# Select categorical columnfactor <- data.frame(select_if(data_adult_rescale, is.factor))ncol(factor)Explicação do código
- data.frame (select_if (data_adult, is.factor)): Armazenamos as colunas de fator em fator em um tipo de frame de dados. A biblioteca ggplot2 requer um objeto de quadro de dados.
Resultado:
## [1] 6O conjunto de dados contém 6 variáveis categóricas
A segunda etapa é mais habilidosa. Você deseja traçar um gráfico de barras para cada coluna no fator de quadro de dados. É mais conveniente automatizar o processo, especialmente quando há muitas colunas.
library(ggplot2)# Create graph for each columngraph <- lapply(names(factor),function(x)ggplot(factor, aes(get(x))) +geom_bar() +theme(axis.text.x = element_text(angle = 90)))Explicação do código
- lapply (): Use a função lapply () para passar uma função em todas as colunas do conjunto de dados. Você armazena a saída em uma lista
- função (x): A função será processada para cada x. Aqui x são as colunas
- ggplot (factor, aes (get (x))) + geom_bar () + theme (axis.text.x = element_text (angle = 90)): Crie um gráfico de barras para cada elemento x. Observe, para retornar x como uma coluna, você precisa incluí-lo dentro de get ()
A última etapa é relativamente fácil. Você deseja imprimir os 6 gráficos.
# Print the graphgraphResultado:
## [[1]]## ## [[2]]## ## [[3]]## ## [[4]]## ## [[5]]## ## [[6]]Nota: Use o próximo botão para navegar para o próximo gráfico
Etapa 3) Engenharia de recursos
Reforma da educação
No gráfico acima, você pode ver que a variável educação possui 16 níveis. Isso é substancial e alguns níveis têm um número relativamente baixo de observações. Se você quiser melhorar a quantidade de informações que pode obter dessa variável, pode reformulá-la para um nível superior. Ou seja, você cria grupos maiores com nível de educação semelhante. Por exemplo, baixo nível de educação será convertido em evasão. Os níveis mais elevados de educação serão alterados para mestre.
Aqui está o detalhe:
Nível antigo
Novo nível
Pré escola
cair fora
10º
Cair fora
11º
Cair fora
12º
Cair fora
1o ao 4o
Cair fora
5º a 6º
Cair fora
7 a 8
Cair fora
9º
Cair fora
HS-Grad
HighGrad
Alguma faculdade
Comunidade
Assoc-acdm
Comunidade
Assoc-voc
Comunidade
Solteiros
Solteiros
Mestres
Mestres
Prof-escola
Mestres
Doutorado
PhD
recast_data <- data_adult_rescale % > %select(-X) % > %mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",ifelse(education == "Bachelors", "Bachelors",ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))Explicação do código
- Usamos o verbo mutate da biblioteca dplyr. Mudamos os valores da educação com a declaração ifelse
Na tabela abaixo, você cria uma estatística resumida para ver, em média, quantos anos de escolaridade (valor z) leva para se atingir o Bacharelado, Mestrado ou Doutorado.
recast_data % > %group_by(education) % > %summarize(average_educ_year = mean(educational.num),count = n()) % > %arrange(average_educ_year)Resultado:
## # A tibble: 6 x 3## education average_educ_year count#### 1 dropout -1.76147258 5712## 2 HighGrad -0.43998868 14803## 3 Community 0.09561361 13407## 4 Bachelors 1.12216282 7720## 5 Master 1.60337381 3338## 6 PhD 2.29377644 557 Reformulação do estado civil
Também é possível criar níveis mais baixos para o estado civil. No código a seguir, você altera o nível da seguinte maneira:
Nível antigo
Novo nível
Nunca casado
Solteiro
Casado-cônjuge-ausente
Solteiro
Casado-AF-cônjuge
Casado
Cônjuge casada
Separados
Separados
Divorciado
Viúvas
Viúva
# Change level marryrecast_data <- recast_data % > %mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))Você pode verificar o número de indivíduos em cada grupo.table(recast_data$marital.status)Resultado:
## ## Married Not_married Separated Widow## 21165 15359 7727 1286Etapa 4) Estatística resumida
É hora de verificar algumas estatísticas sobre nossas variáveis de destino. No gráfico abaixo, você conta a porcentagem de indivíduos que ganham mais de 50 mil de acordo com seu gênero.
# Plot gender incomeggplot(recast_data, aes(x = gender, fill = income)) +geom_bar(position = "fill") +theme_classic()Resultado:
A seguir, verifique se a origem do indivíduo afeta seus ganhos.
# Plot origin incomeggplot(recast_data, aes(x = race, fill = income)) +geom_bar(position = "fill") +theme_classic() +theme(axis.text.x = element_text(angle = 90))Resultado:
O número de horas de trabalho por gênero.
# box plot gender working timeggplot(recast_data, aes(x = gender, y = hours.per.week)) +geom_boxplot() +stat_summary(fun.y = mean,geom = "point",size = 3,color = "steelblue") +theme_classic()Resultado:
O gráfico de caixa confirma que a distribuição do tempo de trabalho se ajusta a grupos diferentes. No box plot, ambos os gêneros não apresentam observações homogêneas.
Você pode verificar a densidade do tempo de trabalho semanal por tipo de ensino. As distribuições têm muitas opções distintas. Provavelmente, isso pode ser explicado pelo tipo de contrato nos EUA.
# Plot distribution working time by educationggplot(recast_data, aes(x = hours.per.week)) +geom_density(aes(color = education), alpha = 0.5) +theme_classic()Explicação do código
- ggplot (recast_data, aes (x = hours.per.week)): um gráfico de densidade requer apenas uma variável
- geom_density (aes (color = education), alpha = 0.5): O objeto geométrico para controlar a densidade
Resultado:
Para confirmar suas idéias, você pode realizar um teste ANOVA unilateral:
anova <- aov(hours.per.week~education, recast_data)summary(anova)Resultado:
## Df Sum Sq Mean Sq F value Pr(>F)## education 5 1552 310.31 321.2 <2e-16 ***## Residuals 45531 43984 0.97## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1O teste ANOVA confirma a diferença de média entre os grupos.
Não-linearidade
Antes de executar o modelo, você pode ver se o número de horas trabalhadas está relacionado à idade.
library(ggplot2)ggplot(recast_data, aes(x = age, y = hours.per.week)) +geom_point(aes(color = income),size = 0.5) +stat_smooth(method = 'lm',formula = y~poly(x, 2),se = TRUE,aes(color = income)) +theme_classic()Explicação do código
- ggplot (recast_data, aes (x = age, y = hours.per.week)): define a estética do gráfico
- geom_point (aes (cor = renda), tamanho = 0,5): Construir o gráfico de pontos
- stat_smooth (): Adicione a linha de tendência com os seguintes argumentos:
- method = 'lm': Plote o valor ajustado se a regressão linear
- formula = y ~ poly (x, 2): Ajustar uma regressão polinomial
- se = TRUE: Adicione o erro padrão
- aes (cor = renda): Divida o modelo por renda
Resultado:
Em suma, você pode testar os termos de interação no modelo para captar o efeito da não linearidade entre o tempo de trabalho semanal e outros recursos. É importante detectar em que condições o tempo de trabalho difere.
Correlação
A próxima verificação é visualizar a correlação entre as variáveis. Você converte o tipo de nível de fator em numérico para que possa plotar um mapa de calor contendo o coeficiente de correlação calculado com o método de Spearman.
library(GGally)# Convert data to numericcorr <- data.frame(lapply(recast_data, as.integer))# Plot the graphggcorr(corr,method = c("pairwise", "spearman"),nbreaks = 6,hjust = 0.8,label = TRUE,label_size = 3,color = "grey50")Explicação do código
- data.frame (lapply (recast_data, as.integer)): converter dados em numéricos
- ggcorr () plota o mapa de calor com os seguintes argumentos:
- método: Método para calcular a correlação
- nbreaks = 6: Número de quebra
- hjust = 0,8: posição de controle do nome da variável no gráfico
- rótulo = TRUE: Adicionar rótulos no centro das janelas
- label_size = 3: rótulos de tamanho
- color = "grey50"): Cor do rótulo
Resultado:
Etapa 5) Conjunto de treinamento / teste
Qualquer tarefa de aprendizado de máquina supervisionada exige a divisão dos dados entre um conjunto de trens e um conjunto de teste. Você pode usar a "função" criada nos outros tutoriais de aprendizado supervisionado para criar um conjunto de treinamento / teste.
set.seed(1234)create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample <- 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}data_train <- create_train_test(recast_data, 0.8, train = TRUE)data_test <- create_train_test(recast_data, 0.8, train = FALSE)dim(data_train)Resultado:
## [1] 36429 9dim(data_test)Resultado:
## [1] 9108 9Etapa 6) Construir o modelo
Para ver o desempenho do algoritmo, use o pacote glm (). O Modelo Linear Generalizado é uma coleção de modelos. A sintaxe básica é:
glm(formula, data=data, family=linkfunction()Argument:- formula: Equation used to fit the model- data: dataset used- Family: - binomial: (link = "logit")- gaussian: (link = "identity")- Gamma: (link = "inverse")- inverse.gaussian: (link = "1/mu^2")- poisson: (link = "log")- quasi: (link = "identity", variance = "constant")- quasibinomial: (link = "logit")- quasipoisson: (link = "log")Você está pronto para estimar o modelo logístico para dividir o nível de renda entre um conjunto de recursos.
formula <- income~.logit <- glm(formula, data = data_train, family = 'binomial')summary(logit)Explicação do código
- fórmula <- renda ~.: Crie o modelo para se ajustar
- logit <- glm (formula, data = data_train, family = 'binomial'): Ajuste um modelo logístico (family = 'binomial') com os dados data_train.
- resumo (logit): Imprima o resumo do modelo
Resultado:
#### Call:## glm(formula = formula, family = "binomial", data = data_train)## ## Deviance Residuals:## Min 1Q Median 3Q Max## -2.6456 -0.5858 -0.2609 -0.0651 3.1982#### Coefficients:## Estimate Std. Error z value Pr(>|z|)## (Intercept) 0.07882 0.21726 0.363 0.71675## age 0.41119 0.01857 22.146 < 2e-16 ***## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 ***## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 ***## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 ***## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 ***## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 ***## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 ***## educationMaster 0.35651 0.06780 5.258 1.46e-07 ***## educationPhD 0.46995 0.15772 2.980 0.00289 **## educationdropout -1.04974 0.21280 -4.933 8.10e-07 ***## educational.num 0.56908 0.07063 8.057 7.84e-16 ***## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 ***## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 ***## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 ***## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117## raceBlack 0.07188 0.19330 0.372 0.71001## raceOther 0.01370 0.27695 0.049 0.96054## raceWhite 0.34830 0.18441 1.889 0.05894 .## genderMale 0.08596 0.04289 2.004 0.04506 *## hours.per.week 0.41942 0.01748 23.998 < 2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## (Dispersion parameter for binomial family taken to be 1)## ## Null deviance: 40601 on 36428 degrees of freedom## Residual deviance: 27041 on 36406 degrees of freedom## AIC: 27087#### Number of Fisher Scoring iterations: 6O resumo do nosso modelo revela informações interessantes. O desempenho de uma regressão logística é avaliado com métricas chave específicas.
- AIC (Akaike Information Criteria): É o equivalente a R2 na regressão logística. Ele mede o ajuste quando uma penalidade é aplicada ao número de parâmetros. Valores menores de AIC indicam que o modelo está mais próximo da verdade.
- Desvio nulo: ajusta o modelo apenas com a interceptação. O grau de liberdade é n-1. Podemos interpretá-lo como um valor Qui-quadrado (valor ajustado diferente do teste de hipótese do valor real).
- Desvio residual: Modelo com todas as variáveis. Também é interpretado como um teste de hipótese do qui-quadrado.
- Número de iterações do Fisher Scoring: Número de iterações antes da convergência.
A saída da função glm () é armazenada em uma lista. O código a seguir mostra todos os itens disponíveis na variável logit que construímos para avaliar a regressão logística.
# A lista é muito longa, imprima apenas os três primeiros elementos
lapply(logit, class)[1:3]Resultado:
## $coefficients## [1] "numeric"#### $residuals## [1] "numeric"#### $fitted.values## [1] "numeric"Cada valor pode ser extraído com o sinal $ seguido do nome das métricas. Por exemplo, você armazenou o modelo como logit. Para extrair os critérios AIC, você usa:
logit$aicResultado:
## [1] 27086.65Etapa 7) Avalie o desempenho do modelo
Matriz de confusão
A matriz de confusão é a melhor escolha para avaliar o desempenho da classificação em comparação com as diferentes métricas que você viu antes. A ideia geral é contar o número de vezes que as instâncias True são classificadas como Falsas.
Para calcular a matriz de confusão, primeiro você precisa ter um conjunto de previsões para que possam ser comparadas aos alvos reais.
predict <- predict(logit, data_test, type = 'response')# confusion matrixtable_mat <- table(data_test$income, predict > 0.5)table_matExplicação do código
- predizer (logit, data_test, type = 'resposta'): Calcula a predição no conjunto de teste. Defina type = 'response' para calcular a probabilidade de resposta.
- tabela (data_test $ renda, previsão> 0,5): Calcula a matriz de confusão. predizer> 0,5 significa que retorna 1 se as probabilidades previstas estiverem acima de 0,5, caso contrário, 0.
Resultado:
#### FALSE TRUE## <=50K 6310 495## >50K 1074 1229Cada linha em uma matriz de confusão representa um alvo real, enquanto cada coluna representa um alvo previsto. A primeira linha desta matriz considera a renda inferior a 50k (classe Falsa): 6241 foram corretamente classificados como indivíduos com renda inferior a 50k ( Verdadeiro negativo ), enquanto o restante foi classificado erroneamente como acima de 50k ( Falso positivo ). A segunda linha considera a receita acima de 50k, a classe positiva foi 1229 ( Verdadeiro positivo ), enquanto a Verdadeira negativa foi 1074.
Você pode calcular a precisão do modelo somando o verdadeiro positivo + verdadeiro negativo sobre a observação total
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_TestExplicação do código
- soma (diag (table_mat)): Soma da diagonal
- sum (table_mat): Soma da matriz.
Resultado:
## [1] 0.8277339O modelo parece ter um problema: ele superestima o número de falsos negativos. Isso é chamado de paradoxo do teste de precisão . Afirmamos que a precisão é a razão entre as previsões corretas e o número total de casos. Podemos ter uma precisão relativamente alta, mas um modelo inútil. Acontece quando existe uma classe dominante. Se você olhar novamente para a matriz de confusão, verá que a maioria dos casos são classificados como negativos verdadeiros. Imagine agora, o modelo classificou todas as classes como negativas (ou seja, abaixo de 50k). Você teria uma precisão de 75 por cento (6718/6718 + 2257). Seu modelo tem um desempenho melhor, mas se esforça para distinguir o verdadeiro positivo do verdadeiro negativo.
Nessa situação, é preferível ter uma métrica mais concisa. Podemos olhar para:
- Precisão = TP / (TP + FP)
- Rechamada = TP / (TP + FN)
Precisão vs recall
A precisão analisa a precisão da previsão positiva. Recall é a proporção de instâncias positivas que são detectadas corretamente pelo classificador;
Você pode construir duas funções para calcular essas duas métricas
- Precisão de construção
precision <- function(matrix) {# True positivetp <- matrix[2, 2]# false positivefp <- matrix[1, 2]return (tp / (tp + fp))}Explicação do código
- mat [1,1]: Retorna a primeira célula da primeira coluna do quadro de dados, ou seja, o verdadeiro positivo
- tapete [1,2]; Retorna a primeira célula da segunda coluna do quadro de dados, ou seja, o falso positivo
recall <- function(matrix) {# true positivetp <- matrix[2, 2]# false positivefn <- matrix[2, 1]return (tp / (tp + fn))}Explicação do código
- mat [1,1]: Retorna a primeira célula da primeira coluna do quadro de dados, ou seja, o verdadeiro positivo
- tapete [2,1]; Retorna a segunda célula da primeira coluna do quadro de dados, ou seja, o falso negativo
Você pode testar suas funções
prec <- precision(table_mat)precrec <- recall(table_mat)recResultado:
## [1] 0.712877## [2] 0.5336518Quando o modelo diz que é um indivíduo acima de 50k, está correto em apenas 54% dos casos e pode reivindicar indivíduos acima de 50k em 72% dos casos.
Você pode criar a é uma média harmônica dessas duas métricas, o que significa que dá mais peso aos valores mais baixos.
f1 <- 2 * ((prec * rec) / (prec + rec))f1Resultado:
## [1] 0.6103799Troca de precisão x recall
É impossível ter alta precisão e alto recall.
Se aumentarmos a precisão, o indivíduo correto será melhor previsto, mas perderíamos muitos deles (menor recall). Em algumas situações, preferimos maior precisão do que recall. Existe uma relação côncava entre precisão e recall.
- Imagine, você precisa prever se um paciente tem uma doença. Você quer ser o mais preciso possível.
- Se você precisar detectar pessoas potencialmente fraudulentas na rua por meio do reconhecimento facial, seria melhor detectar muitas pessoas rotuladas como fraudulentas, mesmo que a precisão seja baixa. A polícia poderá libertar o indivíduo não fraudulento.
A curva ROC
A curva Receiver Operating Characteristic é outra ferramenta comum usada com classificação binária. É muito semelhante à curva de precisão / rechamada, mas em vez de representar graficamente a precisão versus rechamada, a curva ROC mostra a taxa de verdadeiro positivo (ou seja, rechamada) em relação à taxa de falso positivo. A taxa de falsos positivos é a proporção de instâncias negativas classificadas incorretamente como positivas. É igual a um menos a taxa negativa verdadeira. A verdadeira taxa negativa também é chamada de especificidade . Portanto, a curva ROC plota a sensibilidade (recall) versus 1-especificidade
Para plotar a curva ROC, precisamos instalar uma biblioteca chamada RORC. Podemos encontrar na biblioteca do conda. Você pode digitar o código:
conda install -cr r-rocr - sim
Podemos plotar o ROC com as funções prediction () e performance ().
library(ROCR)ROCRpred <- prediction(predict, data_test$income)ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))Explicação do código
- previsão (prever, data_test $ renda): a biblioteca ROCR precisa criar um objeto de previsão para transformar os dados de entrada
- performance (ROCRpred, 'tpr', 'fpr'): Retorne as duas combinações para produzir no gráfico. Aqui, tpr e fpr são construídos. Para plotar a precisão e chamar juntos, use "prec", "rec".
Resultado:
Etapa 8) Melhorar o modelo
Você pode tentar adicionar não linearidade ao modelo com a interação entre
- idade e horas.por.semana
- gênero e horas.por.semana.
Você precisa usar o teste de pontuação para comparar os dois modelos
formula_2 <- income~age: hours.per.week + gender: hours.per.week + .logit_2 <- glm(formula_2, data = data_train, family = 'binomial')predict_2 <- predict(logit_2, data_test, type = 'response')table_mat_2 <- table(data_test$income, predict_2 > 0.5)precision_2 <- precision(table_mat_2)recall_2 <- recall(table_mat_2)f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))f1_2Resultado:
## [1] 0.6109181A pontuação é ligeiramente superior à anterior. Você pode continuar trabalhando nos dados e tentar bater a pontuação.
Resumo
Podemos resumir a função para treinar uma regressão logística na tabela abaixo:
Pacote
Objetivo
função
argumento
-
Criar conjunto de dados de treinamento / teste
create_train_set ()
dados, tamanho, trem
glm
Treine um modelo linear generalizado
glm ()
fórmula, dados, família *
glm
Resuma o modelo
resumo()
modelo ajustado
base
Fazer previsão
prever()
modelo ajustado, conjunto de dados, tipo = 'resposta'
base
Crie uma matriz de confusão
tabela()
y, predizer ()
base
Criar pontuação de precisão
soma (diag (tabela ()) / soma (tabela ()
ROCR
Criar ROC: Etapa 1 Criar previsão
predição()
predizer (), y
ROCR
Criar ROC: Etapa 2 Criar desempenho
atuação()
predição (), 'tpr', 'fpr'
ROCR
Criar ROC: Etapa 3 Gráfico de plotagem
trama()
atuação()
Os outros tipos de modelos GLM são:
- binomial: (link = "logit")
- gaussian: (link = "identidade")
- Gama: (link = "inverso")
- inverse.gaussian: (link = "1 / mu 2")
- poisson: (link = "log")
- quase: (link = "identidade", variância = "constante")
- quasibinomial: (link = "logit")
- quasipoisson: (link = "log")