40 Wine classification with neuralnet

Source: https://www.r-bloggers.com/multilabel-classification-with-neuralnet-package/

The neuralnet package is perhaps not the best option in R for using neural networks. If you ask why, for starters it does not recognize the typical formula y~., it does not support factors, it does not provide a lot of models other than a standard MLP, and it has great competitors in the nnet package that seems to be better integrated in R and can be used with the caret package, and in the MXnet package that is a high level deep learning library which provides a wide variety of neural networks.

But still, I think there is some value in the ease of use of the neuralnet package, especially for a beginner, therefore I’ll be using it.

I’m going to be using both the neuralnet and, curiously enough, the nnet package. Let’s load them:

# load libs
require(neuralnet)
#> Loading required package: neuralnet
require(nnet)
#> Loading required package: nnet
require(ggplot2)
#> Loading required package: ggplot2
set.seed(10)

40.1 The dataset

I looked in the UCI Machine Learning Repository1 and found the wine dataset.

This dataset contains the results of a chemical analysis on 3 different kind of wines. The target variable is the label of the wine which is a factor with 3 (unordered) levels. The predictors are all continuous and represent 13 variables obtained as a result of chemical measurements.

# get the data file from the package location
wine_dataset_path <- file.path(data_raw_dir, "wine.data")
wine_dataset_path
#> [1] "../data/wine.data"
wines <- read.csv(wine_dataset_path)
wines
#>     X1 X14.23 X1.71 X2.43 X15.6 X127 X2.8 X3.06 X.28 X2.29 X5.64 X1.04 X3.92
#> 1    1   13.2  1.78  2.14  11.2  100 2.65  2.76 0.26  1.28  4.38 1.050  3.40
#> 2    1   13.2  2.36  2.67  18.6  101 2.80  3.24 0.30  2.81  5.68 1.030  3.17
#> 3    1   14.4  1.95  2.50  16.8  113 3.85  3.49 0.24  2.18  7.80 0.860  3.45
#> 4    1   13.2  2.59  2.87  21.0  118 2.80  2.69 0.39  1.82  4.32 1.040  2.93
#> 5    1   14.2  1.76  2.45  15.2  112 3.27  3.39 0.34  1.97  6.75 1.050  2.85
#> 6    1   14.4  1.87  2.45  14.6   96 2.50  2.52 0.30  1.98  5.25 1.020  3.58
#> 7    1   14.1  2.15  2.61  17.6  121 2.60  2.51 0.31  1.25  5.05 1.060  3.58
#> 8    1   14.8  1.64  2.17  14.0   97 2.80  2.98 0.29  1.98  5.20 1.080  2.85
#> 9    1   13.9  1.35  2.27  16.0   98 2.98  3.15 0.22  1.85  7.22 1.010  3.55
#> 10   1   14.1  2.16  2.30  18.0  105 2.95  3.32 0.22  2.38  5.75 1.250  3.17
#> 11   1   14.1  1.48  2.32  16.8   95 2.20  2.43 0.26  1.57  5.00 1.170  2.82
#> 12   1   13.8  1.73  2.41  16.0   89 2.60  2.76 0.29  1.81  5.60 1.150  2.90
#> 13   1   14.8  1.73  2.39  11.4   91 3.10  3.69 0.43  2.81  5.40 1.250  2.73
#> 14   1   14.4  1.87  2.38  12.0  102 3.30  3.64 0.29  2.96  7.50 1.200  3.00
#> 15   1   13.6  1.81  2.70  17.2  112 2.85  2.91 0.30  1.46  7.30 1.280  2.88
#> 16   1   14.3  1.92  2.72  20.0  120 2.80  3.14 0.33  1.97  6.20 1.070  2.65
#> 17   1   13.8  1.57  2.62  20.0  115 2.95  3.40 0.40  1.72  6.60 1.130  2.57
#> 18   1   14.2  1.59  2.48  16.5  108 3.30  3.93 0.32  1.86  8.70 1.230  2.82
#> 19   1   13.6  3.10  2.56  15.2  116 2.70  3.03 0.17  1.66  5.10 0.960  3.36
#> 20   1   14.1  1.63  2.28  16.0  126 3.00  3.17 0.24  2.10  5.65 1.090  3.71
#> 21   1   12.9  3.80  2.65  18.6  102 2.41  2.41 0.25  1.98  4.50 1.030  3.52
#> 22   1   13.7  1.86  2.36  16.6  101 2.61  2.88 0.27  1.69  3.80 1.110  4.00
#> 23   1   12.8  1.60  2.52  17.8   95 2.48  2.37 0.26  1.46  3.93 1.090  3.63
#> 24   1   13.5  1.81  2.61  20.0   96 2.53  2.61 0.28  1.66  3.52 1.120  3.82
#> 25   1   13.1  2.05  3.22  25.0  124 2.63  2.68 0.47  1.92  3.58 1.130  3.20
#> 26   1   13.4  1.77  2.62  16.1   93 2.85  2.94 0.34  1.45  4.80 0.920  3.22
#> 27   1   13.3  1.72  2.14  17.0   94 2.40  2.19 0.27  1.35  3.95 1.020  2.77
#> 28   1   13.9  1.90  2.80  19.4  107 2.95  2.97 0.37  1.76  4.50 1.250  3.40
#> 29   1   14.0  1.68  2.21  16.0   96 2.65  2.33 0.26  1.98  4.70 1.040  3.59
#> 30   1   13.7  1.50  2.70  22.5  101 3.00  3.25 0.29  2.38  5.70 1.190  2.71
#> 31   1   13.6  1.66  2.36  19.1  106 2.86  3.19 0.22  1.95  6.90 1.090  2.88
#> 32   1   13.7  1.83  2.36  17.2  104 2.42  2.69 0.42  1.97  3.84 1.230  2.87
#> 33   1   13.8  1.53  2.70  19.5  132 2.95  2.74 0.50  1.35  5.40 1.250  3.00
#> 34   1   13.5  1.80  2.65  19.0  110 2.35  2.53 0.29  1.54  4.20 1.100  2.87
#> 35   1   13.5  1.81  2.41  20.5  100 2.70  2.98 0.26  1.86  5.10 1.040  3.47
#> 36   1   13.3  1.64  2.84  15.5  110 2.60  2.68 0.34  1.36  4.60 1.090  2.78
#> 37   1   13.1  1.65  2.55  18.0   98 2.45  2.43 0.29  1.44  4.25 1.120  2.51
#> 38   1   13.1  1.50  2.10  15.5   98 2.40  2.64 0.28  1.37  3.70 1.180  2.69
#> 39   1   14.2  3.99  2.51  13.2  128 3.00  3.04 0.20  2.08  5.10 0.890  3.53
#> 40   1   13.6  1.71  2.31  16.2  117 3.15  3.29 0.34  2.34  6.13 0.950  3.38
#> 41   1   13.4  3.84  2.12  18.8   90 2.45  2.68 0.27  1.48  4.28 0.910  3.00
#> 42   1   13.9  1.89  2.59  15.0  101 3.25  3.56 0.17  1.70  5.43 0.880  3.56
#> 43   1   13.2  3.98  2.29  17.5  103 2.64  2.63 0.32  1.66  4.36 0.820  3.00
#> 44   1   13.1  1.77  2.10  17.0  107 3.00  3.00 0.28  2.03  5.04 0.880  3.35
#> 45   1   14.2  4.04  2.44  18.9  111 2.85  2.65 0.30  1.25  5.24 0.870  3.33
#> 46   1   14.4  3.59  2.28  16.0  102 3.25  3.17 0.27  2.19  4.90 1.040  3.44
#> 47   1   13.9  1.68  2.12  16.0  101 3.10  3.39 0.21  2.14  6.10 0.910  3.33
#> 48   1   14.1  2.02  2.40  18.8  103 2.75  2.92 0.32  2.38  6.20 1.070  2.75
#> 49   1   13.9  1.73  2.27  17.4  108 2.88  3.54 0.32  2.08  8.90 1.120  3.10
#> 50   1   13.1  1.73  2.04  12.4   92 2.72  3.27 0.17  2.91  7.20 1.120  2.91
#> 51   1   13.8  1.65  2.60  17.2   94 2.45  2.99 0.22  2.29  5.60 1.240  3.37
#> 52   1   13.8  1.75  2.42  14.0  111 3.88  3.74 0.32  1.87  7.05 1.010  3.26
#> 53   1   13.8  1.90  2.68  17.1  115 3.00  2.79 0.39  1.68  6.30 1.130  2.93
#> 54   1   13.7  1.67  2.25  16.4  118 2.60  2.90 0.21  1.62  5.85 0.920  3.20
#> 55   1   13.6  1.73  2.46  20.5  116 2.96  2.78 0.20  2.45  6.25 0.980  3.03
#> 56   1   14.2  1.70  2.30  16.3  118 3.20  3.00 0.26  2.03  6.38 0.940  3.31
#> 57   1   13.3  1.97  2.68  16.8  102 3.00  3.23 0.31  1.66  6.00 1.070  2.84
#> 58   1   13.7  1.43  2.50  16.7  108 3.40  3.67 0.19  2.04  6.80 0.890  2.87
#> 59   2   12.4  0.94  1.36  10.6   88 1.98  0.57 0.28  0.42  1.95 1.050  1.82
#> 60   2   12.3  1.10  2.28  16.0  101 2.05  1.09 0.63  0.41  3.27 1.250  1.67
#> 61   2   12.6  1.36  2.02  16.8  100 2.02  1.41 0.53  0.62  5.75 0.980  1.59
#> 62   2   13.7  1.25  1.92  18.0   94 2.10  1.79 0.32  0.73  3.80 1.230  2.46
#> 63   2   12.4  1.13  2.16  19.0   87 3.50  3.10 0.19  1.87  4.45 1.220  2.87
#> 64   2   12.2  1.45  2.53  19.0  104 1.89  1.75 0.45  1.03  2.95 1.450  2.23
#> 65   2   12.4  1.21  2.56  18.1   98 2.42  2.65 0.37  2.08  4.60 1.190  2.30
#> 66   2   13.1  1.01  1.70  15.0   78 2.98  3.18 0.26  2.28  5.30 1.120  3.18
#> 67   2   12.4  1.17  1.92  19.6   78 2.11  2.00 0.27  1.04  4.68 1.120  3.48
#> 68   2   13.3  0.94  2.36  17.0  110 2.53  1.30 0.55  0.42  3.17 1.020  1.93
#> 69   2   12.2  1.19  1.75  16.8  151 1.85  1.28 0.14  2.50  2.85 1.280  3.07
#> 70   2   12.3  1.61  2.21  20.4  103 1.10  1.02 0.37  1.46  3.05 0.906  1.82
#> 71   2   13.9  1.51  2.67  25.0   86 2.95  2.86 0.21  1.87  3.38 1.360  3.16
#> 72   2   13.5  1.66  2.24  24.0   87 1.88  1.84 0.27  1.03  3.74 0.980  2.78
#> 73   2   13.0  1.67  2.60  30.0  139 3.30  2.89 0.21  1.96  3.35 1.310  3.50
#> 74   2   12.0  1.09  2.30  21.0  101 3.38  2.14 0.13  1.65  3.21 0.990  3.13
#> 75   2   11.7  1.88  1.92  16.0   97 1.61  1.57 0.34  1.15  3.80 1.230  2.14
#> 76   2   13.0  0.90  1.71  16.0   86 1.95  2.03 0.24  1.46  4.60 1.190  2.48
#> 77   2   11.8  2.89  2.23  18.0  112 1.72  1.32 0.43  0.95  2.65 0.960  2.52
#> 78   2   12.3  0.99  1.95  14.8  136 1.90  1.85 0.35  2.76  3.40 1.060  2.31
#> 79   2   12.7  3.87  2.40  23.0  101 2.83  2.55 0.43  1.95  2.57 1.190  3.13
#> 80   2   12.0  0.92  2.00  19.0   86 2.42  2.26 0.30  1.43  2.50 1.380  3.12
#> 81   2   12.7  1.81  2.20  18.8   86 2.20  2.53 0.26  1.77  3.90 1.160  3.14
#> 82   2   12.1  1.13  2.51  24.0   78 2.00  1.58 0.40  1.40  2.20 1.310  2.72
#> 83   2   13.1  3.86  2.32  22.5   85 1.65  1.59 0.61  1.62  4.80 0.840  2.01
#> 84   2   11.8  0.89  2.58  18.0   94 2.20  2.21 0.22  2.35  3.05 0.790  3.08
#> 85   2   12.7  0.98  2.24  18.0   99 2.20  1.94 0.30  1.46  2.62 1.230  3.16
#> 86   2   12.2  1.61  2.31  22.8   90 1.78  1.69 0.43  1.56  2.45 1.330  2.26
#> 87   2   11.7  1.67  2.62  26.0   88 1.92  1.61 0.40  1.34  2.60 1.360  3.21
#> 88   2   11.6  2.06  2.46  21.6   84 1.95  1.69 0.48  1.35  2.80 1.000  2.75
#> 89   2   12.1  1.33  2.30  23.6   70 2.20  1.59 0.42  1.38  1.74 1.070  3.21
#> 90   2   12.1  1.83  2.32  18.5   81 1.60  1.50 0.52  1.64  2.40 1.080  2.27
#> 91   2   12.0  1.51  2.42  22.0   86 1.45  1.25 0.50  1.63  3.60 1.050  2.65
#> 92   2   12.7  1.53  2.26  20.7   80 1.38  1.46 0.58  1.62  3.05 0.960  2.06
#> 93   2   12.3  2.83  2.22  18.0   88 2.45  2.25 0.25  1.99  2.15 1.150  3.30
#> 94   2   11.6  1.99  2.28  18.0   98 3.02  2.26 0.17  1.35  3.25 1.160  2.96
#> 95   2   12.5  1.52  2.20  19.0  162 2.50  2.27 0.32  3.28  2.60 1.160  2.63
#> 96   2   11.8  2.12  2.74  21.5  134 1.60  0.99 0.14  1.56  2.50 0.950  2.26
#> 97   2   12.3  1.41  1.98  16.0   85 2.55  2.50 0.29  1.77  2.90 1.230  2.74
#> 98   2   12.4  1.07  2.10  18.5   88 3.52  3.75 0.24  1.95  4.50 1.040  2.77
#> 99   2   12.3  3.17  2.21  18.0   88 2.85  2.99 0.45  2.81  2.30 1.420  2.83
#> 100  2   12.1  2.08  1.70  17.5   97 2.23  2.17 0.26  1.40  3.30 1.270  2.96
#> 101  2   12.6  1.34  1.90  18.5   88 1.45  1.36 0.29  1.35  2.45 1.040  2.77
#> 102  2   12.3  2.45  2.46  21.0   98 2.56  2.11 0.34  1.31  2.80 0.800  3.38
#> 103  2   11.8  1.72  1.88  19.5   86 2.50  1.64 0.37  1.42  2.06 0.940  2.44
#> 104  2   12.5  1.73  1.98  20.5   85 2.20  1.92 0.32  1.48  2.94 1.040  3.57
#> 105  2   12.4  2.55  2.27  22.0   90 1.68  1.84 0.66  1.42  2.70 0.860  3.30
#> 106  2   12.2  1.73  2.12  19.0   80 1.65  2.03 0.37  1.63  3.40 1.000  3.17
#> 107  2   12.7  1.75  2.28  22.5   84 1.38  1.76 0.48  1.63  3.30 0.880  2.42
#> 108  2   12.2  1.29  1.94  19.0   92 2.36  2.04 0.39  2.08  2.70 0.860  3.02
#> 109  2   11.6  1.35  2.70  20.0   94 2.74  2.92 0.29  2.49  2.65 0.960  3.26
#> 110  2   11.5  3.74  1.82  19.5  107 3.18  2.58 0.24  3.58  2.90 0.750  2.81
#> 111  2   12.5  2.43  2.17  21.0   88 2.55  2.27 0.26  1.22  2.00 0.900  2.78
#> 112  2   11.8  2.68  2.92  20.0  103 1.75  2.03 0.60  1.05  3.80 1.230  2.50
#> 113  2   11.4  0.74  2.50  21.0   88 2.48  2.01 0.42  1.44  3.08 1.100  2.31
#> 114  2   12.1  1.39  2.50  22.5   84 2.56  2.29 0.43  1.04  2.90 0.930  3.19
#> 115  2   11.0  1.51  2.20  21.5   85 2.46  2.17 0.52  2.01  1.90 1.710  2.87
#> 116  2   11.8  1.47  1.99  20.8   86 1.98  1.60 0.30  1.53  1.95 0.950  3.33
#> 117  2   12.4  1.61  2.19  22.5  108 2.00  2.09 0.34  1.61  2.06 1.060  2.96
#> 118  2   12.8  3.43  1.98  16.0   80 1.63  1.25 0.43  0.83  3.40 0.700  2.12
#> 119  2   12.0  3.43  2.00  19.0   87 2.00  1.64 0.37  1.87  1.28 0.930  3.05
#> 120  2   11.4  2.40  2.42  20.0   96 2.90  2.79 0.32  1.83  3.25 0.800  3.39
#> 121  2   11.6  2.05  3.23  28.5  119 3.18  5.08 0.47  1.87  6.00 0.930  3.69
#> 122  2   12.4  4.43  2.73  26.5  102 2.20  2.13 0.43  1.71  2.08 0.920  3.12
#> 123  2   13.1  5.80  2.13  21.5   86 2.62  2.65 0.30  2.01  2.60 0.730  3.10
#> 124  2   11.9  4.31  2.39  21.0   82 2.86  3.03 0.21  2.91  2.80 0.750  3.64
#> 125  2   12.1  2.16  2.17  21.0   85 2.60  2.65 0.37  1.35  2.76 0.860  3.28
#> 126  2   12.4  1.53  2.29  21.5   86 2.74  3.15 0.39  1.77  3.94 0.690  2.84
#> 127  2   11.8  2.13  2.78  28.5   92 2.13  2.24 0.58  1.76  3.00 0.970  2.44
#> 128  2   12.4  1.63  2.30  24.5   88 2.22  2.45 0.40  1.90  2.12 0.890  2.78
#> 129  2   12.0  4.30  2.38  22.0   80 2.10  1.75 0.42  1.35  2.60 0.790  2.57
#> 130  3   12.9  1.35  2.32  18.0  122 1.51  1.25 0.21  0.94  4.10 0.760  1.29
#> 131  3   12.9  2.99  2.40  20.0  104 1.30  1.22 0.24  0.83  5.40 0.740  1.42
#> 132  3   12.8  2.31  2.40  24.0   98 1.15  1.09 0.27  0.83  5.70 0.660  1.36
#> 133  3   12.7  3.55  2.36  21.5  106 1.70  1.20 0.17  0.84  5.00 0.780  1.29
#> 134  3   12.5  1.24  2.25  17.5   85 2.00  0.58 0.60  1.25  5.45 0.750  1.51
#> 135  3   12.6  2.46  2.20  18.5   94 1.62  0.66 0.63  0.94  7.10 0.730  1.58
#> 136  3   12.2  4.72  2.54  21.0   89 1.38  0.47 0.53  0.80  3.85 0.750  1.27
#> 137  3   12.5  5.51  2.64  25.0   96 1.79  0.60 0.63  1.10  5.00 0.820  1.69
#> 138  3   13.5  3.59  2.19  19.5   88 1.62  0.48 0.58  0.88  5.70 0.810  1.82
#> 139  3   12.8  2.96  2.61  24.0  101 2.32  0.60 0.53  0.81  4.92 0.890  2.15
#> 140  3   12.9  2.81  2.70  21.0   96 1.54  0.50 0.53  0.75  4.60 0.770  2.31
#> 141  3   13.4  2.56  2.35  20.0   89 1.40  0.50 0.37  0.64  5.60 0.700  2.47
#> 142  3   13.5  3.17  2.72  23.5   97 1.55  0.52 0.50  0.55  4.35 0.890  2.06
#> 143  3   13.6  4.95  2.35  20.0   92 2.00  0.80 0.47  1.02  4.40 0.910  2.05
#> 144  3   12.2  3.88  2.20  18.5  112 1.38  0.78 0.29  1.14  8.21 0.650  2.00
#> 145  3   13.2  3.57  2.15  21.0  102 1.50  0.55 0.43  1.30  4.00 0.600  1.68
#> 146  3   13.9  5.04  2.23  20.0   80 0.98  0.34 0.40  0.68  4.90 0.580  1.33
#> 147  3   12.9  4.61  2.48  21.5   86 1.70  0.65 0.47  0.86  7.65 0.540  1.86
#> 148  3   13.3  3.24  2.38  21.5   92 1.93  0.76 0.45  1.25  8.42 0.550  1.62
#> 149  3   13.1  3.90  2.36  21.5  113 1.41  1.39 0.34  1.14  9.40 0.570  1.33
#> 150  3   13.5  3.12  2.62  24.0  123 1.40  1.57 0.22  1.25  8.60 0.590  1.30
#> 151  3   12.8  2.67  2.48  22.0  112 1.48  1.36 0.24  1.26 10.80 0.480  1.47
#> 152  3   13.1  1.90  2.75  25.5  116 2.20  1.28 0.26  1.56  7.10 0.610  1.33
#> 153  3   13.2  3.30  2.28  18.5   98 1.80  0.83 0.61  1.87 10.52 0.560  1.51
#> 154  3   12.6  1.29  2.10  20.0  103 1.48  0.58 0.53  1.40  7.60 0.580  1.55
#> 155  3   13.2  5.19  2.32  22.0   93 1.74  0.63 0.61  1.55  7.90 0.600  1.48
#> 156  3   13.8  4.12  2.38  19.5   89 1.80  0.83 0.48  1.56  9.01 0.570  1.64
#> 157  3   12.4  3.03  2.64  27.0   97 1.90  0.58 0.63  1.14  7.50 0.670  1.73
#> 158  3   14.3  1.68  2.70  25.0   98 2.80  1.31 0.53  2.70 13.00 0.570  1.96
#> 159  3   13.5  1.67  2.64  22.5   89 2.60  1.10 0.52  2.29 11.75 0.570  1.78
#> 160  3   12.4  3.83  2.38  21.0   88 2.30  0.92 0.50  1.04  7.65 0.560  1.58
#> 161  3   13.7  3.26  2.54  20.0  107 1.83  0.56 0.50  0.80  5.88 0.960  1.82
#> 162  3   12.8  3.27  2.58  22.0  106 1.65  0.60 0.60  0.96  5.58 0.870  2.11
#> 163  3   13.0  3.45  2.35  18.5  106 1.39  0.70 0.40  0.94  5.28 0.680  1.75
#> 164  3   13.8  2.76  2.30  22.0   90 1.35  0.68 0.41  1.03  9.58 0.700  1.68
#> 165  3   13.7  4.36  2.26  22.5   88 1.28  0.47 0.52  1.15  6.62 0.780  1.75
#> 166  3   13.4  3.70  2.60  23.0  111 1.70  0.92 0.43  1.46 10.68 0.850  1.56
#> 167  3   12.8  3.37  2.30  19.5   88 1.48  0.66 0.40  0.97 10.26 0.720  1.75
#> 168  3   13.6  2.58  2.69  24.5  105 1.55  0.84 0.39  1.54  8.66 0.740  1.80
#> 169  3   13.4  4.60  2.86  25.0  112 1.98  0.96 0.27  1.11  8.50 0.670  1.92
#> 170  3   12.2  3.03  2.32  19.0   96 1.25  0.49 0.40  0.73  5.50 0.660  1.83
#> 171  3   12.8  2.39  2.28  19.5   86 1.39  0.51 0.48  0.64  9.90 0.570  1.63
#> 172  3   14.2  2.51  2.48  20.0   91 1.68  0.70 0.44  1.24  9.70 0.620  1.71
#> 173  3   13.7  5.65  2.45  20.5   95 1.68  0.61 0.52  1.06  7.70 0.640  1.74
#> 174  3   13.4  3.91  2.48  23.0  102 1.80  0.75 0.43  1.41  7.30 0.700  1.56
#> 175  3   13.3  4.28  2.26  20.0  120 1.59  0.69 0.43  1.35 10.20 0.590  1.56
#> 176  3   13.2  2.59  2.37  20.0  120 1.65  0.68 0.53  1.46  9.30 0.600  1.62
#> 177  3   14.1  4.10  2.74  24.5   96 2.05  0.76 0.56  1.35  9.20 0.610  1.60
#>     X1065
#> 1    1050
#> 2    1185
#> 3    1480
#> 4     735
#> 5    1450
#> 6    1290
#> 7    1295
#> 8    1045
#> 9    1045
#> 10   1510
#> 11   1280
#> 12   1320
#> 13   1150
#> 14   1547
#> 15   1310
#> 16   1280
#> 17   1130
#> 18   1680
#> 19    845
#> 20    780
#> 21    770
#> 22   1035
#> 23   1015
#> 24    845
#> 25    830
#> 26   1195
#> 27   1285
#> 28    915
#> 29   1035
#> 30   1285
#> 31   1515
#> 32    990
#> 33   1235
#> 34   1095
#> 35    920
#> 36    880
#> 37   1105
#> 38   1020
#> 39    760
#> 40    795
#> 41   1035
#> 42   1095
#> 43    680
#> 44    885
#> 45   1080
#> 46   1065
#> 47    985
#> 48   1060
#> 49   1260
#> 50   1150
#> 51   1265
#> 52   1190
#> 53   1375
#> 54   1060
#> 55   1120
#> 56    970
#> 57   1270
#> 58   1285
#> 59    520
#> 60    680
#> 61    450
#> 62    630
#> 63    420
#> 64    355
#> 65    678
#> 66    502
#> 67    510
#> 68    750
#> 69    718
#> 70    870
#> 71    410
#> 72    472
#> 73    985
#> 74    886
#> 75    428
#> 76    392
#> 77    500
#> 78    750
#> 79    463
#> 80    278
#> 81    714
#> 82    630
#> 83    515
#> 84    520
#> 85    450
#> 86    495
#> 87    562
#> 88    680
#> 89    625
#> 90    480
#> 91    450
#> 92    495
#> 93    290
#> 94    345
#> 95    937
#> 96    625
#> 97    428
#> 98    660
#> 99    406
#> 100   710
#> 101   562
#> 102   438
#> 103   415
#> 104   672
#> 105   315
#> 106   510
#> 107   488
#> 108   312
#> 109   680
#> 110   562
#> 111   325
#> 112   607
#> 113   434
#> 114   385
#> 115   407
#> 116   495
#> 117   345
#> 118   372
#> 119   564
#> 120   625
#> 121   465
#> 122   365
#> 123   380
#> 124   380
#> 125   378
#> 126   352
#> 127   466
#> 128   342
#> 129   580
#> 130   630
#> 131   530
#> 132   560
#> 133   600
#> 134   650
#> 135   695
#> 136   720
#> 137   515
#> 138   580
#> 139   590
#> 140   600
#> 141   780
#> 142   520
#> 143   550
#> 144   855
#> 145   830
#> 146   415
#> 147   625
#> 148   650
#> 149   550
#> 150   500
#> 151   480
#> 152   425
#> 153   675
#> 154   640
#> 155   725
#> 156   480
#> 157   880
#> 158   660
#> 159   620
#> 160   520
#> 161   680
#> 162   570
#> 163   675
#> 164   615
#> 165   520
#> 166   695
#> 167   685
#> 168   750
#> 169   630
#> 170   510
#> 171   470
#> 172   660
#> 173   740
#> 174   750
#> 175   835
#> 176   840
#> 177   560
names(wines) <- c("label",
                  "Alcohol",
                  "Malic_acid",
                  "Ash",
                  "Alcalinity_of_ash",
                  "Magnesium",
                  "Total_phenols",
                  "Flavanoids",
                  "Nonflavanoid_phenols",
                  "Proanthocyanins",
                  "Color_intensity",
                  "Hue",
                  "OD280_OD315_of_diluted_wines",
                  "Proline")
head(wines)
#>   label Alcohol Malic_acid  Ash Alcalinity_of_ash Magnesium Total_phenols
#> 1     1    13.2       1.78 2.14              11.2       100          2.65
#> 2     1    13.2       2.36 2.67              18.6       101          2.80
#> 3     1    14.4       1.95 2.50              16.8       113          3.85
#> 4     1    13.2       2.59 2.87              21.0       118          2.80
#> 5     1    14.2       1.76 2.45              15.2       112          3.27
#> 6     1    14.4       1.87 2.45              14.6        96          2.50
#>   Flavanoids Nonflavanoid_phenols Proanthocyanins Color_intensity  Hue
#> 1       2.76                 0.26            1.28            4.38 1.05
#> 2       3.24                 0.30            2.81            5.68 1.03
#> 3       3.49                 0.24            2.18            7.80 0.86
#> 4       2.69                 0.39            1.82            4.32 1.04
#> 5       3.39                 0.34            1.97            6.75 1.05
#> 6       2.52                 0.30            1.98            5.25 1.02
#>   OD280_OD315_of_diluted_wines Proline
#> 1                         3.40    1050
#> 2                         3.17    1185
#> 3                         3.45    1480
#> 4                         2.93     735
#> 5                         2.85    1450
#> 6                         3.58    1290
plt1 <- ggplot(wines, aes(x = Alcohol, y = Magnesium, colour = as.factor(label))) +
    geom_point(size=3) +
    ggtitle("Wines")

plt1
plt2 <- ggplot(wines, aes(x = Alcohol, y = Proline, colour = as.factor(label))) +
    geom_point(size=3) +
    ggtitle("Wines")
plt2

40.2 Preprocessing

During the preprocessing phase, I have to do at least the following two things:

Encode the categorical variables. Standardize the predictors. First of all, let’s encode our target variable. The encoding of the categorical variables is needed when using neuralnet since it does not like factors at all. It will shout at you if you try to feed in a factor (I am told nnet likes factors though).

In the wine dataset the variable label contains three different labels: 1,2 and 3.

The usual practice, as far as I know, is to encode categorical variables as a “one hot” vector. For instance, if I had three classes, like in this case, I’d need to replace the label variable with three variables like these:

#   l1,l2,l3
#   1,0,0
#   0,0,1
#   ...

In this case the first observation would be labeled as a 1, the second would be labeled as a 2, and so on. Ironically, the nnet package provides a function to perform this encoding in a painless way:

# Encode as a one hot vector multilabel data
train <- cbind(wines[, 2:14], class.ind(as.factor(wines$label)))

# Set labels name
names(train) <- c(names(wines)[2:14],"l1","l2","l3")

By the way, since the predictors are all continuous, you do not need to encode any of them, however, in case you needed to, you could apply the same strategy applied above to all the categorical predictors. Unless of course you’d like to try some other kind of custom encoding.

Now let’s standardize the predictors in the [0−1]">[0−1] interval by leveraging the lapply function:

# Scale data
scl <- function(x) { (x - min(x))/(max(x) - min(x)) }
train[, 1:13] <- data.frame(lapply(train[, 1:13], scl))
head(train)
#>   Alcohol Malic_acid   Ash Alcalinity_of_ash Magnesium Total_phenols Flavanoids
#> 1   0.571      0.206 0.417            0.0309     0.326         0.576      0.511
#> 2   0.561      0.320 0.701            0.4124     0.337         0.628      0.612
#> 3   0.879      0.239 0.610            0.3196     0.467         0.990      0.665
#> 4   0.582      0.366 0.807            0.5361     0.522         0.628      0.496
#> 5   0.834      0.202 0.583            0.2371     0.457         0.790      0.643
#> 6   0.884      0.223 0.583            0.2062     0.283         0.524      0.460
#>   Nonflavanoid_phenols Proanthocyanins Color_intensity   Hue
#> 1                0.245           0.274           0.265 0.463
#> 2                0.321           0.757           0.375 0.447
#> 3                0.208           0.558           0.556 0.309
#> 4                0.491           0.445           0.259 0.455
#> 5                0.396           0.492           0.467 0.463
#> 6                0.321           0.495           0.339 0.439
#>   OD280_OD315_of_diluted_wines Proline l1 l2 l3
#> 1                        0.780   0.551  1  0  0
#> 2                        0.696   0.647  1  0  0
#> 3                        0.799   0.857  1  0  0
#> 4                        0.608   0.326  1  0  0
#> 5                        0.579   0.836  1  0  0
#> 6                        0.846   0.722  1  0  0

40.3 Fitting the model with neuralnet

Now it is finally time to fit the model. As you might remember from the old post I wrote, neuralnet does not like the formula y~.. Fear not, you can build the formula to be used in a simple step:

# Set up formula
n <- names(train)
f <- as.formula(paste("l1 + l2 + l3 ~", paste(n[!n %in% c("l1","l2","l3")], collapse = " + ")))
f
#> l1 + l2 + l3 ~ Alcohol + Malic_acid + Ash + Alcalinity_of_ash + 
#>     Magnesium + Total_phenols + Flavanoids + Nonflavanoid_phenols + 
#>     Proanthocyanins + Color_intensity + Hue + OD280_OD315_of_diluted_wines + 
#>     Proline

Note that the characters in the vector are not pasted to the right of the “~” symbol. Just remember to check that the formula is indeed correct and then you are good to go.

Let’s train the neural network with the full dataset. It should take very little time to converge. If you did not standardize the predictors it could take a lot more though.

nn <- neuralnet(f,
                data = train,
                hidden = c(13, 10, 3),
                act.fct = "logistic",
                linear.output = FALSE,
                lifesign = "minimal")
#> hidden: 13, 10, 3    thresh: 0.01    rep: 1/1    steps:      88  error: 0.03039  time: 0.05 secs

Note that I set the argument linear.output to FALSE in order to tell the model that I want to apply the activation function act.fct and that I am not doing a regression task. Then I set the activation function to logistic (which by the way is the default option) in order to apply the logistic function. The other available option is tanh but the model seems to perform a little worse with it so I opted for the default option. As far as I know these two are the only two available options, there is no relu function available although it seems to be a common activation function in other packages.

As far as the number of hidden neurons, I tried some combination and the one used seems to perform slightly better than the others (around 1% of accuracy difference in cross validation score).

By using the in-built plot method you can get a visual take on what is actually happening inside the model, however the plot is not that helpful I think

plot(nn)

Let’s have a look at the accuracy on the training set:

# Compute predictions
pr.nn <- compute(nn, train[, 1:13])

# Extract results
pr.nn_ <- pr.nn$net.result
head(pr.nn_)
#>       [,1]    [,2]     [,3]
#> [1,] 0.990 0.00317 6.99e-06
#> [2,] 0.991 0.00233 8.69e-06
#> [3,] 0.991 0.00210 8.65e-06
#> [4,] 0.986 0.00442 8.74e-06
#> [5,] 0.992 0.00212 8.32e-06
#> [6,] 0.992 0.00214 8.34e-06
# Accuracy (training set)
original_values <- max.col(train[, 14:16])
pr.nn_2 <- max.col(pr.nn_)
mean(pr.nn_2 == original_values)
#> [1] 1

100% not bad! But wait, this may be because our model over fitted the data, furthermore evaluating accuracy on the training set is kind of cheating since the model already “knows” (or should know) the answers. In order to assess the “true accuracy” of the model you need to perform some kind of cross validation.

40.4 Cross validating the classifier

Let’s cross-validate the model using the evergreen 10 fold cross validation with the following train and test split: 95% of the dataset will be used as training set while the remaining 5% as test set.

Just out of curiosity, I decided to run a LOOCV round too. In case you’d like to run this cross validation technique, just set the proportion variable to 0.995: this will select just one observation for as test set and leave all the other observations as training set. Running LOOCV you should get similar results to the 10 fold cross validation.

# Set seed for reproducibility purposes
set.seed(500)
# 10 fold cross validation
k <- 10
# Results from cv
outs <- NULL
# Train test split proportions
proportion <- 0.95 # Set to 0.995 for LOOCV

# Crossvalidate, go!
for(i in 1:k)
{
    index <- sample(1:nrow(train), round(proportion*nrow(train)))
    train_cv <- train[index, ]
    test_cv <- train[-index, ]
    nn_cv <- neuralnet(f,
                        data = train_cv,
                        hidden = c(13, 10, 3),
                        act.fct = "logistic",
                        linear.output = FALSE)
    
    # Compute predictions
    pr.nn <- compute(nn_cv, test_cv[, 1:13])
    # Extract results
    pr.nn_ <- pr.nn$net.result
    # Accuracy (test set)
    original_values <- max.col(test_cv[, 14:16])
    pr.nn_2 <- max.col(pr.nn_)
    outs[i] <- mean(pr.nn_2 == original_values)
}

mean(outs)
#> [1] 0.978

98.8%, awesome! Next time when you are invited to a relaxing evening that includes a wine tasting competition I think you should definitely bring your laptop as a contestant!

Aside from that poor taste joke, (I made it again!), indeed this dataset is not the most challenging, I think with some more tweaking a better cross validation score could be achieved. Nevertheless I hope you found this tutorial useful. A gist with the entire code for this tutorial can be found here.

Thank you for reading this article, please feel free to leave a comment if you have any questions or suggestions and share the post with others if you find it useful.

Notes: