20 Building a Naive Bayes Classifier

  • Datasets: iris.csv
  • Algorithms:
    • Naive Bayes

https://www.machinelearningplus.com/predictive-modeling/how-naive-bayes-algorithm-works-with-example-and-full-code/

20.1 8. Building a Naive Bayes Classifier in R

Understanding Naive Bayes was the (slightly) tricky part. Implementing it is fairly straightforward.

In R, Naive Bayes classifier is implemented in packages such as e1071, klaR and bnlearn. In Python, it is implemented in scikit-learn.

For sake of demonstration, let’s use the standard iris dataset to predict the Species of flower using 4 different features: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width

# Import Data
training <- read.csv('https://raw.githubusercontent.com/selva86/datasets/master/iris_train.csv')
test <- read.csv('https://raw.githubusercontent.com/selva86/datasets/master/iris_test.csv')

The training data is now contained in training and test data in test dataframe. Lets load the klaR package and build the naive bayes model.

# Using klaR for Naive Bayes
library(klaR)
#> Loading required package: MASS
nb_mod <- NaiveBayes(Species ~ ., data=training)
pred <- predict(nb_mod, test)

Lets see the confusion matrix.

# Confusion Matrix
tab <- table(pred$class, test$Species)
caret::confusionMatrix(tab)  
#> Confusion Matrix and Statistics
#> 
#>             
#>              setosa versicolor virginica
#>   setosa         15          0         0
#>   versicolor      0         11         0
#>   virginica       0          4        15
#> 
#> Overall Statistics
#>                                         
#>                Accuracy : 0.911         
#>                  95% CI : (0.788, 0.975)
#>     No Information Rate : 0.333         
#>     P-Value [Acc > NIR] : 8.47e-16      
#>                                         
#>                   Kappa : 0.867         
#>                                         
#>  Mcnemar's Test P-Value : NA            
#> 
#> Statistics by Class:
#> 
#>                      Class: setosa Class: versicolor Class: virginica
#> Sensitivity                  1.000             0.733            1.000
#> Specificity                  1.000             1.000            0.867
#> Pos Pred Value               1.000             1.000            0.789
#> Neg Pred Value               1.000             0.882            1.000
#> Prevalence                   0.333             0.333            0.333
#> Detection Rate               0.333             0.244            0.333
#> Detection Prevalence         0.333             0.244            0.422
#> Balanced Accuracy            1.000             0.867            0.933
# Plot density of each feature using nb_mod
opar = par(mfrow=c(2, 2), mar=c(4,0,0,0))
plot(nb_mod, main="")  
par(opar)
# Plot the Confusion Matrix
library(ggplot2)
test$pred <- pred$class
ggplot(test, aes(Species, pred, color = Species)) +
  geom_jitter(width = 0.2, height = 0.1, size=2) +
  labs(title="Confusion Matrix", 
       subtitle="Predicted vs. Observed from Iris dataset", 
       y="Predicted", 
       x="Truth",
       caption="machinelearningplus.com")