Chapter 12 Neural Networks using NumPy, r-base, rTorch and PyTorch

We will compare three neural networks:

  • a neural network written in numpy

  • a neural network written in r-base

  • a neural network written in PyTorch

  • a neural network written in rTorch

12.1 A neural network with numpy

We start the neural network by simply using numpy:

library(rTorch)
# A simple neural network using NumPy
# Code in file tensor/two_layer_net_numpy.py
import time
import numpy as np

tic = time.process_time()

np.random.seed(123)   # set a seed for reproducibility
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)
# print(x.shape)
# print(y.shape)

w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)
# print(w1.shape)
# print(w2.shape)

learning_rate = 1e-6
for t in range(500):
  # Forward pass: compute predicted y
  h = x.dot(w1)
  # print(t, h.max())
  h_relu = np.maximum(h, 0)
  y_pred = h_relu.dot(w2)
  
  # Compute and print loss
  sq = np.square(y_pred - y)
  loss = sq.sum()
  print(t, loss)
  
  # Backprop to compute gradients of w1 and w2 with respect to loss
  grad_y_pred = 2.0 * (y_pred - y)
  grad_w2 = h_relu.T.dot(grad_y_pred)
  grad_h_relu = grad_y_pred.dot(w2.T)
  grad_h = grad_h_relu.copy()
  grad_h[h < 0] = 0
  grad_w1 = x.T.dot(grad_h)
 
  # Update weights
  w1 -= learning_rate * grad_w1
  w2 -= learning_rate * grad_w2
# processing time  
#> 0 28624200.800938517
#> 1 24402861.381040636
#> 2 23157437.29147552
#> 3 21617191.63397175
#> 4 18598190.361558598
#> 5 14198211.419692844
#> 6 9786244.45261814
#> 7 6233451.217340663
#> 8 3862647.267829599
#> 9 2412366.632764836
#> 10 1569915.4392193707
#> 11 1078501.3381487518
#> 12 785163.9233288621
#> 13 601495.2825043725
#> 14 479906.0403613456
#> 15 394555.19331746205
#> 16 331438.6987273826
#> 17 282679.6687873873
#> 18 243807.84432087594
#> 19 211970.18110708205
#> 20 185451.6861514274
#> 21 163078.20881862927
#> 22 144011.80160918707
#> 23 127662.96132466741
#> 24 113546.29175681781
#> 25 101291.55288493488
#> 26 90623.20833654879
#> 27 81307.32590692889
#> 28 73135.24710426925
#> 29 65937.50294095621
#> 30 59570.26425368039
#> 31 53923.82804264227
#> 32 48909.69273028215
#> 33 44438.89933807681
#> 34 40445.34031569733
#> 35 36873.30041989413
#> 36 33664.990437423825
#> 37 30781.198962949587
#> 38 28184.24227268406
#> 39 25843.99793108194
#> 40 23727.282448406426
#> 41 21810.062067327668
#> 42 20071.326437572196
#> 43 18492.63752543329
#> 44 17056.72779714255
#> 45 15749.299484025236
#> 46 14557.324481207237
#> 47 13468.469764338035
#> 48 12473.575866914027
#> 49 11562.485809665774
#> 50 10727.865926563407
#> 51 9962.411372816146
#> 52 9259.619803682268
#> 53 8613.269071227103
#> 54 8018.523834750763
#> 55 7471.080819104451
#> 56 6966.00651845651
#> 57 6499.96685422581
#> 58 6069.576425345411
#> 59 5671.2821228408475
#> 60 5302.644980086279
#> 61 4961.339043761728
#> 62 4645.02541423451
#> 63 4351.473575805103
#> 64 4079.2165446062972
#> 65 3826.1480820887655
#> 66 3590.887308956795
#> 67 3372.0103280622666
#> 68 3168.173408650748
#> 69 2978.362100081684
#> 70 2801.302649097963
#> 71 2636.037950790892
#> 72 2481.7354010452655
#> 73 2337.6093944873246
#> 74 2202.8250425683987
#> 75 2076.8872560589616
#> 76 1958.9976460120263
#> 77 1848.5060338548483
#> 78 1744.9993380824799
#> 79 1647.9807349258715
#> 80 1556.9947585282196
#> 81 1471.7081797400347
#> 82 1391.6136870762566
#> 83 1316.3329239757227
#> 84 1245.5902641069824
#> 85 1179.0691783286234
#> 86 1116.5095209528572
#> 87 1057.6662051951396
#> 88 1002.2519686823666
#> 89 950.0167505993219
#> 90 900.7916929993518
#> 91 854.3816389576979
#> 92 810.6277767708903
#> 93 769.3592041348505
#> 94 730.3836012940042
#> 95 693.5644048073411
#> 96 658.7807027999521
#> 97 625.9238747325827
#> 98 594.8758111695068
#> 99 565.4973547949257
#> 100 537.7012178149556
#> 101 511.3901106843991
#> 102 486.4837276215478
#> 103 462.90746955458474
#> 104 440.5787622887435
#> 105 419.4121231392399
#> 106 399.34612374957226
#> 107 380.3221777272873
#> 108 362.2821345456067
#> 109 345.18049757120184
#> 110 328.94028615976936
#> 111 313.5191206271147
#> 112 298.8754770672758
#> 113 284.96926791620496
#> 114 271.7642984526849
#> 115 259.2246266311472
#> 116 247.30122156531897
#> 117 235.96203976771662
#> 118 225.17874184522793
#> 119 214.9253969806085
#> 120 205.16916168826197
#> 121 195.88920014324063
#> 122 187.0522150132689
#> 123 178.6428873875804
#> 124 170.63479897325027
#> 125 163.00806018890546
#> 126 155.7440191346056
#> 127 148.83352898111042
#> 128 142.2496666996878
#> 129 135.97509122834504
#> 130 129.98982612428355
#> 131 124.28418865778005
#> 132 118.84482149781273
#> 133 113.65645952102406
#> 134 108.7054397008061
#> 135 103.98144604072209
#> 136 99.47512083365962
#> 137 95.17318303450762
#> 138 91.06775169947714
#> 139 87.14952592945869
#> 140 83.4075554849774
#> 141 79.8333553283839
#> 142 76.41993249926654
#> 143 73.159531678603
#> 144 70.04535899921396
#> 145 67.0700037713867
#> 146 64.22536514818646
#> 147 61.50715956099643
#> 148 58.90970110703718
#> 149 56.42818157298958
#> 150 54.053456343974474
#> 151 51.78409899250521
#> 152 49.613042222061935
#> 153 47.537088681832714
#> 154 45.55073951374691
#> 155 43.651385230775375
#> 156 41.8333828820336
#> 157 40.0944925576898
#> 158 38.4304655768987
#> 159 36.83773398481151
#> 160 35.313368600585044
#> 161 33.85436928433868
#> 162 32.457997092726586
#> 163 31.120973836567913
#> 164 29.841057186484246
#> 165 28.61536631365921
#> 166 27.441646501921213
#> 167 26.31767712811449
#> 168 25.241065734351473
#> 169 24.210568668753154
#> 170 23.223366825888164
#> 171 22.27691447596546
#> 172 21.370561777029383
#> 173 20.502013041055037
#> 174 19.669605151002397
#> 175 18.872156637147214
#> 176 18.107932697664136
#> 177 17.375347093063624
#> 178 16.67329705241241
#> 179 16.000313127916616
#> 180 15.355056259809643
#> 181 14.736642044314163
#> 182 14.143657665391123
#> 183 13.575482981169435
#> 184 13.03055792072713
#> 185 12.507813624903267
#> 186 12.00650847964371
#> 187 11.525873890625666
#> 188 11.064924569594556
#> 189 10.622845128602144
#> 190 10.199224278747348
#> 191 9.79248532294249
#> 192 9.40221537769526
#> 193 9.027996925837858
#> 194 8.668895520243254
#> 195 8.324385761675554
#> 196 7.99390867066041
#> 197 7.676665609325665
#> 198 7.3722991001285685
#> 199 7.080233920966563
#> 200 6.7999405980009
#> 201 6.530984430178585
#> 202 6.2728878687947365
#> 203 6.025197539285438
#> 204 5.787473375780924
#> 205 5.559253501791474
#> 206 5.340172472449113
#> 207 5.129896948041436
#> 208 4.928007606815918
#> 209 4.734225282679221
#> 210 4.548186858907342
#> 211 4.369651328446663
#> 212 4.198236457646962
#> 213 4.033565011138579
#> 214 3.8754625080281464
#> 215 3.7236914115521316
#> 216 3.5779627242857224
#> 217 3.4379821914239286
#> 218 3.303565587540205
#> 219 3.174454405800678
#> 220 3.0504743070396323
#> 221 2.931383709316906
#> 222 2.8170418304762785
#> 223 2.7072412196038553
#> 224 2.6017277000868093
#> 225 2.50040409121904
#> 226 2.403078781570677
#> 227 2.309594481835507
#> 228 2.219794799730801
#> 229 2.133526678637347
#> 230 2.0506760423604566
#> 231 1.9710453639295484
#> 232 1.894559024310974
#> 233 1.8211210547720629
#> 234 1.7505340383436803
#> 235 1.6826932948721067
#> 236 1.6175070289508109
#> 237 1.5549072300348752
#> 238 1.4947316986695944
#> 239 1.436912502600996
#> 240 1.381372987946563
#> 241 1.3279854205041584
#> 242 1.2766884038688984
#> 243 1.2273848146334094
#> 244 1.1800217450316255
#> 245 1.1344919105891025
#> 246 1.0907369940975837
#> 247 1.0486826235693274
#> 248 1.0082656206399931
#> 249 0.9694282665755529
#> 250 0.9320976601575675
#> 251 0.8962339607475229
#> 252 0.8617533865905884
#> 253 0.8286151485833971
#> 254 0.7967578289852474
#> 255 0.7661404678425654
#> 256 0.7367202044072118
#> 257 0.708422713667491
#> 258 0.6812311487720265
#> 259 0.6550822696783506
#> 260 0.6299469090210432
#> 261 0.605786995355434
#> 262 0.5825650778276774
#> 263 0.5602382140936045
#> 264 0.5387735503110371
#> 265 0.5181403816556053
#> 266 0.49830590931295304
#> 267 0.47922937308117297
#> 268 0.46088901492620127
#> 269 0.44325464817119054
#> 270 0.42630408406116316
#> 271 0.41000543380657917
#> 272 0.39433673295843236
#> 273 0.37927114581493265
#> 274 0.36478176529460243
#> 275 0.35085044445134994
#> 276 0.3374578361158044
#> 277 0.32457682402453136
#> 278 0.31219123729919207
#> 279 0.300296586147234
#> 280 0.28884848624094894
#> 281 0.27783526470539743
#> 282 0.26724487697010957
#> 283 0.2570618106928273
#> 284 0.2472693951468085
#> 285 0.23785306876436113
#> 286 0.22879648231270536
#> 287 0.22008909643106767
#> 288 0.21171318526106842
#> 289 0.2036578219834066
#> 290 0.19591133993811427
#> 291 0.18846041746510728
#> 292 0.18129477007162065
#> 293 0.174405315161736
#> 294 0.16777998120837712
#> 295 0.16140610523836268
#> 296 0.1552756501716649
#> 297 0.14937904644542377
#> 298 0.14370793039467633
#> 299 0.13825290527822973
#> 300 0.13300640130439656
#> 301 0.12796012311324031
#> 302 0.12310750541656884
#> 303 0.11844182274749851
#> 304 0.11395158652041627
#> 305 0.10963187686672912
#> 306 0.10547640155933785
#> 307 0.10148022089409026
#> 308 0.0976363799328684
#> 309 0.09393976586801374
#> 310 0.09038186218007657
#> 311 0.08696004033318867
#> 312 0.08366808215670352
#> 313 0.08050159133387036
#> 314 0.0774556507265311
#> 315 0.07452541616811464
#> 316 0.07170677388789805
#> 317 0.06899492388917926
#> 318 0.06638632065320674
#> 319 0.06387707772657374
#> 320 0.06146291085125196
#> 321 0.0591402294396231
#> 322 0.05690662209831464
#> 323 0.05475707395743591
#> 324 0.05268944906989688
#> 325 0.05069984545069233
#> 326 0.048785688597973095
#> 327 0.046944795197577285
#> 328 0.045173966618895535
#> 329 0.043469382749897256
#> 330 0.04182932192085659
#> 331 0.04025154186795582
#> 332 0.038733588417595735
#> 333 0.03727299017402862
#> 334 0.03586799441058297
#> 335 0.03451589218265247
#> 336 0.03321501089199479
#> 337 0.03196371785309425
#> 338 0.030759357425241718
#> 339 0.029600888472444742
#> 340 0.028485919148238392
#> 341 0.02741317225069457
#> 342 0.026380963792005673
#> 343 0.025387828276963217
#> 344 0.02443225636975702
#> 345 0.02351279471955997
#> 346 0.02262815392798661
#> 347 0.02177684408442846
#> 348 0.02095765200803268
#> 349 0.02016947466161515
#> 350 0.019410962895712616
#> 351 0.018681045066734122
#> 352 0.017978879513468316
#> 353 0.017303468563130222
#> 354 0.016653437842251186
#> 355 0.01602766278432409
#> 356 0.015425464893044428
#> 357 0.01484594678906112
#> 358 0.014288249850265784
#> 359 0.01375163575426638
#> 360 0.01323528665049373
#> 361 0.012738339025978556
#> 362 0.012260186918304262
#> 363 0.011799970856220952
#> 364 0.011357085981162363
#> 365 0.010930950268775873
#> 366 0.010520842685022909
#> 367 0.010126145830079638
#> 368 0.009746393154855839
#> 369 0.009380889339520658
#> 370 0.009029161386689313
#> 371 0.00869059833698051
#> 372 0.00836477207696539
#> 373 0.008051209390678065
#> 374 0.0077494325069793705
#> 375 0.007459023266150334
#> 376 0.007179590434333104
#> 377 0.006910623445853765
#> 378 0.006651749941578513
#> 379 0.006402648026678379
#> 380 0.006162978285307884
#> 381 0.005932194796367616
#> 382 0.005710085052295781
#> 383 0.005496310244895275
#> 384 0.0052906289241425215
#> 385 0.0050926241688279104
#> 386 0.004902076613033862
#> 387 0.004718638851167859
#> 388 0.004542078962047164
#> 389 0.004372164586665975
#> 390 0.004208618626839021
#> 391 0.004051226677923414
#> 392 0.0038997374494828298
#> 393 0.003753918301513866
#> 394 0.003613561837935153
#> 395 0.0034784786917529164
#> 396 0.003348462575629662
#> 397 0.003223327362263324
#> 398 0.0031028635490837437
#> 399 0.002986912218213565
#> 400 0.002875348146367024
#> 401 0.0027679524720207994
#> 402 0.0026645903412969877
#> 403 0.00256506728009952
#> 404 0.0024692701898842025
#> 405 0.0023770671718814063
#> 406 0.0022883091777422303
#> 407 0.0022029269889801703
#> 408 0.0021207379368966914
#> 409 0.0020415781423120893
#> 410 0.001965380838191689
#> 411 0.0018920388674650765
#> 412 0.0018214489876606395
#> 413 0.0017534990549357195
#> 414 0.0016880979054376358
#> 415 0.0016251364192863505
#> 416 0.0015645343026947606
#> 417 0.0015062064772070694
#> 418 0.0014500530088225327
#> 419 0.0013959868097274688
#> 420 0.001343946421404061
#> 421 0.0012938496041169677
#> 422 0.001245622397754905
#> 423 0.0011992050880615885
#> 424 0.0011545283489900085
#> 425 0.0011115075856686302
#> 426 0.001070100670544413
#> 427 0.0010302364937566674
#> 428 0.0009918591300819473
#> 429 0.000954924393232083
#> 430 0.0009193639132775486
#> 431 0.0008851308467932729
#> 432 0.0008521777959560448
#> 433 0.0008204570911784497
#> 434 0.0007899223397731109
#> 435 0.0007605278374214596
#> 436 0.0007322343466954752
#> 437 0.0007049830914115257
#> 438 0.0006787512341473519
#> 439 0.00065350212037464
#> 440 0.0006291921955255096
#> 441 0.0006057856348208776
#> 442 0.0005832525024800561
#> 443 0.0005615598539424442
#> 444 0.0005406761235200468
#> 445 0.0005205750249286578
#> 446 0.0005012184845940066
#> 447 0.0004825848028301716
#> 448 0.0004646447575300741
#> 449 0.0004473739461918762
#> 450 0.0004307513759213604
#> 451 0.00041474810355609723
#> 452 0.00039933580480713945
#> 453 0.0003844970781264902
#> 454 0.0003702109250696993
#> 455 0.00035645948619340297
#> 456 0.0003432213223641764
#> 457 0.0003304723731848576
#> 458 0.00031819830164465815
#> 459 0.00030638121798918724
#> 460 0.0002950045353519474
#> 461 0.0002840533130499193
#> 462 0.00027350873727298176
#> 463 0.00026335657398426546
#> 464 0.000253581258369829
#> 465 0.00024416913722126747
#> 466 0.0002351142689424904
#> 467 0.0002263919313737711
#> 468 0.00021799257674327073
#> 469 0.00020990427540056088
#> 470 0.0002021174506938248
#> 471 0.00019462054044199915
#> 472 0.00018740325426984858
#> 473 0.00018045252249983815
#> 474 0.000173759960543912
#> 475 0.00016731630060690805
#> 476 0.0001611122710715995
#> 477 0.00015513993832625702
#> 478 0.00014938925941558148
#> 479 0.00014385207870578823
#> 480 0.00013852014130375656
#> 481 0.00013338601187671428
#> 482 0.000128442793294424
#> 483 0.0001236841045646944
#> 484 0.00011910150087090696
#> 485 0.00011468967274610794
#> 486 0.00011044058002490428
#> 487 0.00010634983745106246
#> 488 0.00010241132940006558
#> 489 9.861901302344988e-05
#> 490 9.496682985475842e-05
#> 491 9.144989845880715e-05
#> 492 8.806354488018214e-05
#> 493 8.480312707749194e-05
#> 494 8.166404591653792e-05
#> 495 7.864135637113095e-05
#> 496 7.573027443124469e-05
#> 497 7.292787602990206e-05
#> 498 7.023030228370285e-05
#> 499 6.763183953445079e-05
toc = time.process_time()
print(toc - tic, "seconds")
#> 7.387020459 seconds

12.2 A neural network with r-base

It is the same algorithm above in numpy but written in R base.

library(tictoc)

tic()
set.seed(123)
N <- 64; D_in <- 1000; H <- 100; D_out <- 10;
# Create random input and output data
x <- array(rnorm(N * D_in),  dim = c(N, D_in))
y <- array(rnorm(N * D_out), dim = c(N, D_out))
# Randomly initialize weights
w1 <- array(rnorm(D_in * H),  dim = c(D_in, H))
w2 <- array(rnorm(H * D_out),  dim = c(H, D_out))
learning_rate <-  1e-6

for (t in seq(1, 500)) {
  # Forward pass: compute predicted y
  h = x %*% w1
  h_relu = pmax(h, 0)
  y_pred = h_relu %*% w2

  # Compute and print loss
  sq <- (y_pred - y)^2
  loss = sum(sq)
  cat(t, loss, "\n")
  
  # Backprop to compute gradients of w1 and w2 with respect to loss
  grad_y_pred = 2.0 * (y_pred - y)
  grad_w2 = t(h_relu) %*% grad_y_pred
  grad_h_relu = grad_y_pred %*% t(w2)
  # grad_h <- sapply(grad_h_relu, function(i) i, simplify = FALSE )   # grad_h = grad_h_relu.copy()
  grad_h <- rlang::duplicate(grad_h_relu)
  grad_h[h < 0] <-  0
  grad_w1 = t(x) %*% grad_h
  
  # Update weights
  w1 = w1 - learning_rate * grad_w1
  w2 = w2 - learning_rate * grad_w2
}
toc()
#> 1 2.8e+07 
#> 2 25505803 
#> 3 29441299 
#> 4 35797650 
#> 5 39517126 
#> 6 34884942 
#> 7 23333535 
#> 8 11927525 
#> 9 5352787 
#> 10 2496984 
#> 11 1379780 
#> 12 918213 
#> 13 695760 
#> 14 564974 
#> 15 474479 
#> 16 405370 
#> 17 349747 
#> 18 303724 
#> 19 265075 
#> 20 232325 
#> 21 204394 
#> 22 180414 
#> 23 159752 
#> 24 141895 
#> 25 126374 
#> 26 112820 
#> 27 100959 
#> 28 90536 
#> 29 81352 
#> 30 73244 
#> 31 66058 
#> 32 59675 
#> 33 53993 
#> 34 48921 
#> 35 44388 
#> 36 40328 
#> 37 36687 
#> 38 33414 
#> 39 30469 
#> 40 27816 
#> 41 25419 
#> 42 23251 
#> 43 21288 
#> 44 19508 
#> 45 17893 
#> 46 16426 
#> 47 15092 
#> 48 13877 
#> 49 12769 
#> 50 11758 
#> 51 10835 
#> 52 9991 
#> 53 9218 
#> 54 8510 
#> 55 7862 
#> 56 7267 
#> 57 6719 
#> 58 6217 
#> 59 5754 
#> 60 5329 
#> 61 4938 
#> 62 4577 
#> 63 4245 
#> 64 3938 
#> 65 3655 
#> 66 3394 
#> 67 3153 
#> 68 2930 
#> 69 2724 
#> 70 2533 
#> 71 2357 
#> 72 2193 
#> 73 2042 
#> 74 1902 
#> 75 1772 
#> 76 1651 
#> 77 1539 
#> 78 1435 
#> 79 1338 
#> 80 1249 
#> 81 1165 
#> 82 1088 
#> 83 1016 
#> 84 949 
#> 85 886 
#> 86 828 
#> 87 774 
#> 88 724 
#> 89 677 
#> 90 633 
#> 91 592 
#> 92 554 
#> 93 519 
#> 94 486 
#> 95 455 
#> 96 426 
#> 97 399 
#> 98 374 
#> 99 350 
#> 100 328 
#> 101 308 
#> 102 289 
#> 103 271 
#> 104 254 
#> 105 238 
#> 106 224 
#> 107 210 
#> 108 197 
#> 109 185 
#> 110 174 
#> 111 163 
#> 112 153 
#> 113 144 
#> 114 135 
#> 115 127 
#> 116 119 
#> 117 112 
#> 118 106 
#> 119 99.2 
#> 120 93.3 
#> 121 87.8 
#> 122 82.6 
#> 123 77.7 
#> 124 73.1 
#> 125 68.8 
#> 126 64.7 
#> 127 60.9 
#> 128 57.4 
#> 129 54 
#> 130 50.9 
#> 131 47.9 
#> 132 45.1 
#> 133 42.5 
#> 134 40.1 
#> 135 37.8 
#> 136 35.6 
#> 137 33.5 
#> 138 31.6 
#> 139 29.8 
#> 140 28.1 
#> 141 26.5 
#> 142 25 
#> 143 23.6 
#> 144 22.2 
#> 145 21 
#> 146 19.8 
#> 147 18.7 
#> 148 17.6 
#> 149 16.6 
#> 150 15.7 
#> 151 14.8 
#> 152 14 
#> 153 13.2 
#> 154 12.5 
#> 155 11.8 
#> 156 11.1 
#> 157 10.5 
#> 158 9.94 
#> 159 9.39 
#> 160 8.87 
#> 161 8.38 
#> 162 7.92 
#> 163 7.49 
#> 164 7.08 
#> 165 6.69 
#> 166 6.32 
#> 167 5.98 
#> 168 5.65 
#> 169 5.35 
#> 170 5.06 
#> 171 4.78 
#> 172 4.52 
#> 173 4.28 
#> 174 4.05 
#> 175 3.83 
#> 176 3.62 
#> 177 3.43 
#> 178 3.25 
#> 179 3.07 
#> 180 2.91 
#> 181 2.75 
#> 182 2.6 
#> 183 2.47 
#> 184 2.33 
#> 185 2.21 
#> 186 2.09 
#> 187 1.98 
#> 188 1.88 
#> 189 1.78 
#> 190 1.68 
#> 191 1.6 
#> 192 1.51 
#> 193 1.43 
#> 194 1.36 
#> 195 1.29 
#> 196 1.22 
#> 197 1.15 
#> 198 1.09 
#> 199 1.04 
#> 200 0.983 
#> 201 0.932 
#> 202 0.883 
#> 203 0.837 
#> 204 0.794 
#> 205 0.753 
#> 206 0.714 
#> 207 0.677 
#> 208 0.642 
#> 209 0.609 
#> 210 0.577 
#> 211 0.548 
#> 212 0.519 
#> 213 0.493 
#> 214 0.467 
#> 215 0.443 
#> 216 0.421 
#> 217 0.399 
#> 218 0.379 
#> 219 0.359 
#> 220 0.341 
#> 221 0.324 
#> 222 0.307 
#> 223 0.292 
#> 224 0.277 
#> 225 0.263 
#> 226 0.249 
#> 227 0.237 
#> 228 0.225 
#> 229 0.213 
#> 230 0.203 
#> 231 0.192 
#> 232 0.183 
#> 233 0.173 
#> 234 0.165 
#> 235 0.156 
#> 236 0.149 
#> 237 0.141 
#> 238 0.134 
#> 239 0.127 
#> 240 0.121 
#> 241 0.115 
#> 242 0.109 
#> 243 0.104 
#> 244 0.0985 
#> 245 0.0936 
#> 246 0.0889 
#> 247 0.0845 
#> 248 0.0803 
#> 249 0.0763 
#> 250 0.0725 
#> 251 0.0689 
#> 252 0.0655 
#> 253 0.0623 
#> 254 0.0592 
#> 255 0.0563 
#> 256 0.0535 
#> 257 0.0508 
#> 258 0.0483 
#> 259 0.0459 
#> 260 0.0437 
#> 261 0.0415 
#> 262 0.0395 
#> 263 0.0375 
#> 264 0.0357 
#> 265 0.0339 
#> 266 0.0323 
#> 267 0.0307 
#> 268 0.0292 
#> 269 0.0278 
#> 270 0.0264 
#> 271 0.0251 
#> 272 0.0239 
#> 273 0.0227 
#> 274 0.0216 
#> 275 0.0206 
#> 276 0.0196 
#> 277 0.0186 
#> 278 0.0177 
#> 279 0.0168 
#> 280 0.016 
#> 281 0.0152 
#> 282 0.0145 
#> 283 0.0138 
#> 284 0.0131 
#> 285 0.0125 
#> 286 0.0119 
#> 287 0.0113 
#> 288 0.0108 
#> 289 0.0102 
#> 290 0.00975 
#> 291 0.00927 
#> 292 0.00883 
#> 293 0.0084 
#> 294 0.008 
#> 295 0.00761 
#> 296 0.00724 
#> 297 0.0069 
#> 298 0.00656 
#> 299 0.00625 
#> 300 0.00595 
#> 301 0.00566 
#> 302 0.00539 
#> 303 0.00513 
#> 304 0.00489 
#> 305 0.00465 
#> 306 0.00443 
#> 307 0.00422 
#> 308 0.00401 
#> 309 0.00382 
#> 310 0.00364 
#> 311 0.00347 
#> 312 0.0033 
#> 313 0.00314 
#> 314 0.00299 
#> 315 0.00285 
#> 316 0.00271 
#> 317 0.00259 
#> 318 0.00246 
#> 319 0.00234 
#> 320 0.00223 
#> 321 0.00213 
#> 322 0.00203 
#> 323 0.00193 
#> 324 0.00184 
#> 325 0.00175 
#> 326 0.00167 
#> 327 0.00159 
#> 328 0.00151 
#> 329 0.00144 
#> 330 0.00137 
#> 331 0.00131 
#> 332 0.00125 
#> 333 0.00119 
#> 334 0.00113 
#> 335 0.00108 
#> 336 0.00103 
#> 337 0.000979 
#> 338 0.000932 
#> 339 0.000888 
#> 340 0.000846 
#> 341 0.000807 
#> 342 0.000768 
#> 343 0.000732 
#> 344 0.000698 
#> 345 0.000665 
#> 346 0.000634 
#> 347 0.000604 
#> 348 0.000575 
#> 349 0.000548 
#> 350 0.000523 
#> 351 0.000498 
#> 352 0.000475 
#> 353 0.000452 
#> 354 0.000431 
#> 355 0.000411 
#> 356 0.000392 
#> 357 0.000373 
#> 358 0.000356 
#> 359 0.000339 
#> 360 0.000323 
#> 361 0.000308 
#> 362 0.000294 
#> 363 0.00028 
#> 364 0.000267 
#> 365 0.000254 
#> 366 0.000243 
#> 367 0.000231 
#> 368 0.00022 
#> 369 0.00021 
#> 370 2e-04 
#> 371 0.000191 
#> 372 0.000182 
#> 373 0.000174 
#> 374 0.000165 
#> 375 0.000158 
#> 376 0.00015 
#> 377 0.000143 
#> 378 0.000137 
#> 379 0.00013 
#> 380 0.000124 
#> 381 0.000119 
#> 382 0.000113 
#> 383 0.000108 
#> 384 0.000103 
#> 385 9.8e-05 
#> 386 9.34e-05 
#> 387 8.91e-05 
#> 388 8.49e-05 
#> 389 8.1e-05 
#> 390 7.72e-05 
#> 391 7.37e-05 
#> 392 7.02e-05 
#> 393 6.7e-05 
#> 394 6.39e-05 
#> 395 6.09e-05 
#> 396 5.81e-05 
#> 397 5.54e-05 
#> 398 5.28e-05 
#> 399 5.04e-05 
#> 400 4.81e-05 
#> 401 4.58e-05 
#> 402 4.37e-05 
#> 403 4.17e-05 
#> 404 3.98e-05 
#> 405 3.79e-05 
#> 406 3.62e-05 
#> 407 3.45e-05 
#> 408 3.29e-05 
#> 409 3.14e-05 
#> 410 2.99e-05 
#> 411 2.86e-05 
#> 412 2.72e-05 
#> 413 2.6e-05 
#> 414 2.48e-05 
#> 415 2.36e-05 
#> 416 2.25e-05 
#> 417 2.15e-05 
#> 418 2.05e-05 
#> 419 1.96e-05 
#> 420 1.87e-05 
#> 421 1.78e-05 
#> 422 1.7e-05 
#> 423 1.62e-05 
#> 424 1.55e-05 
#> 425 1.48e-05 
#> 426 1.41e-05 
#> 427 1.34e-05 
#> 428 1.28e-05 
#> 429 1.22e-05 
#> 430 1.17e-05 
#> 431 1.11e-05 
#> 432 1.06e-05 
#> 433 1.01e-05 
#> 434 9.66e-06 
#> 435 9.22e-06 
#> 436 8.79e-06 
#> 437 8.39e-06 
#> 438 8e-06 
#> 439 7.64e-06 
#> 440 7.29e-06 
#> 441 6.95e-06 
#> 442 6.63e-06 
#> 443 6.33e-06 
#> 444 6.04e-06 
#> 445 5.76e-06 
#> 446 5.5e-06 
#> 447 5.25e-06 
#> 448 5.01e-06 
#> 449 4.78e-06 
#> 450 4.56e-06 
#> 451 4.35e-06 
#> 452 4.15e-06 
#> 453 3.96e-06 
#> 454 3.78e-06 
#> 455 3.61e-06 
#> 456 3.44e-06 
#> 457 3.28e-06 
#> 458 3.13e-06 
#> 459 2.99e-06 
#> 460 2.85e-06 
#> 461 2.72e-06 
#> 462 2.6e-06 
#> 463 2.48e-06 
#> 464 2.37e-06 
#> 465 2.26e-06 
#> 466 2.15e-06 
#> 467 2.06e-06 
#> 468 1.96e-06 
#> 469 1.87e-06 
#> 470 1.79e-06 
#> 471 1.71e-06 
#> 472 1.63e-06 
#> 473 1.55e-06 
#> 474 1.48e-06 
#> 475 1.42e-06 
#> 476 1.35e-06 
#> 477 1.29e-06 
#> 478 1.23e-06 
#> 479 1.17e-06 
#> 480 1.12e-06 
#> 481 1.07e-06 
#> 482 1.02e-06 
#> 483 9.74e-07 
#> 484 9.3e-07 
#> 485 8.88e-07 
#> 486 8.47e-07 
#> 487 8.09e-07 
#> 488 7.72e-07 
#> 489 7.37e-07 
#> 490 7.03e-07 
#> 491 6.71e-07 
#> 492 6.41e-07 
#> 493 6.12e-07 
#> 494 5.84e-07 
#> 495 5.57e-07 
#> 496 5.32e-07 
#> 497 5.08e-07 
#> 498 4.85e-07 
#> 499 4.63e-07 
#> 500 4.42e-07 
#> 2.034 sec elapsed

12.3 The neural network written in PyTorch

Here is the same example we have used above but written in PyTorch. Notice the following differences with the numpy code:

  • we select the computation device which could be cpu or gpu

  • when building or creating the tensors, we specify which device we want to use

  • the tensors have torch methods and properties. Example: mm(), clamp(), sum(), clone(), and t(),

  • also notice the use some torch functions: device(), randn()

reticulate::use_condaenv("r-torch")
# Code in file tensor/two_layer_net_tensor.py
import torch
import time

ms = torch.manual_seed(0)
tic = time.process_time()
device = torch.device('cpu')
# device = torch.device('cuda')  # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device)
w2 = torch.randn(H, D_out, device=device)

learning_rate = 1e-6
for t in range(500):
  # Forward pass: compute predicted y
  h = x.mm(w1)
  h_relu = h.clamp(min=0)
  y_pred = h_relu.mm(w2)

  # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor
  # of shape (); we can get its value as a Python number with loss.item().
  loss = (y_pred - y).pow(2).sum()
  print(t, loss.item())

  # Backprop to compute gradients of w1 and w2 with respect to loss
  grad_y_pred = 2.0 * (y_pred - y)
  grad_w2 = h_relu.t().mm(grad_y_pred)
  grad_h_relu = grad_y_pred.mm(w2.t())
  grad_h = grad_h_relu.clone()
  grad_h[h < 0] = 0
  grad_w1 = x.t().mm(grad_h)

  # Update weights using gradient descent
  w1 -= learning_rate * grad_w1
  w2 -= learning_rate * grad_w2
#> 0 29428664.0
#> 1 22739448.0
#> 2 20605260.0
#> 3 19520372.0
#> 4 17810224.0
#> 5 14999204.0
#> 6 11483334.0
#> 7 8096649.0
#> 8 5398717.5
#> 9 3521559.75
#> 10 2315861.5
#> 11 1570273.5
#> 12 1111700.375
#> 13 825062.8125
#> 14 639684.4375
#> 15 514220.625
#> 16 425155.3125
#> 17 358904.5625
#> 18 307636.71875
#> 19 266625.90625
#> 20 232998.625
#> 21 204887.296875
#> 22 181051.0625
#> 23 160643.0
#> 24 143036.09375
#> 25 127729.578125
#> 26 114360.25
#> 27 102621.0234375
#> 28 92276.9375
#> 29 83144.0859375
#> 30 75053.3984375
#> 31 67870.3984375
#> 32 61485.79296875
#> 33 55786.6328125
#> 34 50690.8515625
#> 35 46128.6328125
#> 36 42029.546875
#> 37 38341.875
#> 38 35017.33203125
#> 39 32016.68359375
#> 40 29303.43359375
#> 41 26847.1484375
#> 42 24620.376953125
#> 43 22599.46875
#> 44 20762.5625
#> 45 19090.986328125
#> 46 17568.359375
#> 47 16180.1083984375
#> 48 14911.99609375
#> 49 13753.8525390625
#> 50 12694.0205078125
#> 51 11723.640625
#> 52 10834.490234375
#> 53 10019.25390625
#> 54 9270.923828125
#> 55 8583.36328125
#> 56 7950.5625
#> 57 7368.46875
#> 58 6832.73779296875
#> 59 6339.20703125
#> 60 5884.1484375
#> 61 5464.44384765625
#> 62 5077.45849609375
#> 63 4719.9833984375
#> 64 4389.5400390625
#> 65 4084.009765625
#> 66 3801.313232421875
#> 67 3539.627197265625
#> 68 3297.266845703125
#> 69 3072.8017578125
#> 70 2864.869140625
#> 71 2672.025390625
#> 72 2493.096435546875
#> 73 2326.89697265625
#> 74 2172.523193359375
#> 75 2029.1279296875
#> 76 1895.768310546875
#> 77 1771.71435546875
#> 78 1656.3409423828125
#> 79 1548.9505615234375
#> 80 1448.9840087890625
#> 81 1355.846923828125
#> 82 1269.0556640625
#> 83 1188.1507568359375
#> 84 1112.7042236328125
#> 85 1042.3167724609375
#> 86 976.61328125
#> 87 915.2999267578125
#> 88 858.0404052734375
#> 89 804.5496826171875
#> 90 754.5780029296875
#> 91 707.8599243164062
#> 92 664.1988525390625
#> 93 623.3640747070312
#> 94 585.147216796875
#> 95 549.3995971679688
#> 96 515.9583740234375
#> 97 484.6272277832031
#> 98 455.28955078125
#> 99 427.81829833984375
#> 100 402.0847473144531
#> 101 377.9535827636719
#> 102 355.3477783203125
#> 103 334.1396179199219
#> 104 314.2633361816406
#> 105 295.61749267578125
#> 106 278.1217346191406
#> 107 261.7001953125
#> 108 246.2969512939453
#> 109 231.8272247314453
#> 110 218.24240112304688
#> 111 205.48812866210938
#> 112 193.5052490234375
#> 113 182.24417114257812
#> 114 171.66690063476562
#> 115 161.72601318359375
#> 116 152.3784942626953
#> 117 143.59078979492188
#> 118 135.32354736328125
#> 119 127.55582427978516
#> 120 120.24463653564453
#> 121 113.36481475830078
#> 122 106.89350128173828
#> 123 100.80726623535156
#> 124 95.07266998291016
#> 125 89.6752700805664
#> 126 84.59477233886719
#> 127 79.80913543701172
#> 128 75.30223083496094
#> 129 71.0572509765625
#> 130 67.05980682373047
#> 131 63.292694091796875
#> 132 59.7408447265625
#> 133 56.394203186035156
#> 134 53.243412017822266
#> 135 50.2683219909668
#> 136 47.46772003173828
#> 137 44.82497787475586
#> 138 42.33271408081055
#> 139 39.983646392822266
#> 140 37.76749801635742
#> 141 35.67666244506836
#> 142 33.70509338378906
#> 143 31.84467124938965
#> 144 30.089385986328125
#> 145 28.432872772216797
#> 146 26.869369506835938
#> 147 25.39266586303711
#> 148 23.999008178710938
#> 149 22.684724807739258
#> 150 21.4434757232666
#> 151 20.270301818847656
#> 152 19.164194107055664
#> 153 18.11824607849121
#> 154 17.131380081176758
#> 155 16.199291229248047
#> 156 15.318136215209961
#> 157 14.486746788024902
#> 158 13.700006484985352
#> 159 12.957758903503418
#> 160 12.256866455078125
#> 161 11.593376159667969
#> 162 10.96681022644043
#> 163 10.374650955200195
#> 164 9.815613746643066
#> 165 9.286172866821289
#> 166 8.78611946105957
#> 167 8.313515663146973
#> 168 7.866476058959961
#> 169 7.443814754486084
#> 170 7.044161319732666
#> 171 6.666952133178711
#> 172 6.309534072875977
#> 173 5.9717559814453125
#> 174 5.652008056640625
#> 175 5.3500075340271
#> 176 5.06421422958374
#> 177 4.793882846832275
#> 178 4.538228511810303
#> 179 4.296501159667969
#> 180 4.067446708679199
#> 181 3.8510499000549316
#> 182 3.6461739540100098
#> 183 3.4524216651916504
#> 184 3.2690694332122803
#> 185 3.0956828594207764
#> 186 2.9311866760253906
#> 187 2.7758116722106934
#> 188 2.628840684890747
#> 189 2.4897918701171875
#> 190 2.357895851135254
#> 191 2.2333240509033203
#> 192 2.1151578426361084
#> 193 2.003354072570801
#> 194 1.897698998451233
#> 195 1.7976123094558716
#> 196 1.7029246091842651
#> 197 1.6131364107131958
#> 198 1.5283033847808838
#> 199 1.4478871822357178
#> 200 1.371699333190918
#> 201 1.2994897365570068
#> 202 1.231500267982483
#> 203 1.1667163372039795
#> 204 1.1054186820983887
#> 205 1.0472912788391113
#> 206 0.9924129247665405
#> 207 0.9405249953269958
#> 208 0.8911417722702026
#> 209 0.8445178866386414
#> 210 0.8003085851669312
#> 211 0.758423388004303
#> 212 0.7187696099281311
#> 213 0.6812056303024292
#> 214 0.6455042362213135
#> 215 0.6117878556251526
#> 216 0.5798596739768982
#> 217 0.5495442152023315
#> 218 0.5209972858428955
#> 219 0.4938827455043793
#> 220 0.46809014678001404
#> 221 0.4436979293823242
#> 222 0.42065465450286865
#> 223 0.3987467288970947
#> 224 0.3779408633708954
#> 225 0.35838788747787476
#> 226 0.3397265076637268
#> 227 0.3221140503883362
#> 228 0.30536866188049316
#> 229 0.2895379662513733
#> 230 0.27451151609420776
#> 231 0.2602919638156891
#> 232 0.24681799113750458
#> 233 0.23405984044075012
#> 234 0.22187164425849915
#> 235 0.2103630006313324
#> 236 0.19945508241653442
#> 237 0.18917179107666016
#> 238 0.1794165074825287
#> 239 0.1700771450996399
#> 240 0.1613144725561142
#> 241 0.152926966547966
#> 242 0.14506009221076965
#> 243 0.1375567466020584
#> 244 0.13043273985385895
#> 245 0.12370903044939041
#> 246 0.11734490096569061
#> 247 0.11129261553287506
#> 248 0.10555146634578705
#> 249 0.10010744631290436
#> 250 0.09495128691196442
#> 251 0.09006303548812866
#> 252 0.08542166650295258
#> 253 0.08105342835187912
#> 254 0.07687549293041229
#> 255 0.07293462008237839
#> 256 0.06918356567621231
#> 257 0.06564081460237503
#> 258 0.062239713966846466
#> 259 0.059055205434560776
#> 260 0.05602336302399635
#> 261 0.05314234644174576
#> 262 0.05042209476232529
#> 263 0.04785769432783127
#> 264 0.045423999428749084
#> 265 0.04309770092368126
#> 266 0.04090772941708565
#> 267 0.03880797326564789
#> 268 0.03683297708630562
#> 269 0.03495331108570099
#> 270 0.03315659612417221
#> 271 0.031475357711315155
#> 272 0.029864072799682617
#> 273 0.028345633298158646
#> 274 0.026901375502347946
#> 275 0.025526201352477074
#> 276 0.024225471541285515
#> 277 0.023021651431918144
#> 278 0.021845556795597076
#> 279 0.020738258957862854
#> 280 0.01967737451195717
#> 281 0.01868186891078949
#> 282 0.017737826332449913
#> 283 0.016843702644109726
#> 284 0.015994098037481308
#> 285 0.015187159180641174
#> 286 0.014432456344366074
#> 287 0.013691866770386696
#> 288 0.013026118278503418
#> 289 0.012365361675620079
#> 290 0.011741021648049355
#> 291 0.011153185740113258
#> 292 0.010602883994579315
#> 293 0.010070282965898514
#> 294 0.009570850059390068
#> 295 0.009099053218960762
#> 296 0.008648849092423916
#> 297 0.008217266760766506
#> 298 0.007814647629857063
#> 299 0.007436459884047508
#> 300 0.007072300184518099
#> 301 0.006720009259879589
#> 302 0.006387100555002689
#> 303 0.00608158390969038
#> 304 0.00578821636736393
#> 305 0.005504274740815163
#> 306 0.005235536955296993
#> 307 0.004986326675862074
#> 308 0.004750200547277927
#> 309 0.004520890768617392
#> 310 0.004305804148316383
#> 311 0.004104197025299072
#> 312 0.003908107057213783
#> 313 0.0037259890232235193
#> 314 0.0035482768435031176
#> 315 0.0033842488192021847
#> 316 0.0032260832376778126
#> 317 0.0030806262511759996
#> 318 0.002938204212114215
#> 319 0.002802144968882203
#> 320 0.002674166578799486
#> 321 0.0025522327050566673
#> 322 0.0024338625371456146
#> 323 0.002325983252376318
#> 324 0.0022217126097530127
#> 325 0.002122103003785014
#> 326 0.0020273567643016577
#> 327 0.0019368595676496625
#> 328 0.0018519405275583267
#> 329 0.0017723542405292392
#> 330 0.0016958083724603057
#> 331 0.00162519421428442
#> 332 0.001555908122099936
#> 333 0.0014901482500135899
#> 334 0.0014247691724449396
#> 335 0.0013653874630108476
#> 336 0.001307258615270257
#> 337 0.0012546550715342164
#> 338 0.0012025412870571017
#> 339 0.0011545777088031173
#> 340 0.001107968739233911
#> 341 0.0010642317356541753
#> 342 0.0010200864635407925
#> 343 0.0009793058270588517
#> 344 0.0009410151396878064
#> 345 0.0009048299980349839
#> 346 0.0008693647105246782
#> 347 0.000835308397654444
#> 348 0.0008031500619836152
#> 349 0.0007735351100564003
#> 350 0.000744393328204751
#> 351 0.00071698147803545
#> 352 0.00069050322053954
#> 353 0.0006645384710282087
#> 354 0.0006397517863661051
#> 355 0.0006177832838147879
#> 356 0.0005949471960775554
#> 357 0.0005744362715631723
#> 358 0.0005537742399610579
#> 359 0.0005348395789042115
#> 360 0.0005162699380889535
#> 361 0.000499469693750143
#> 362 0.00048172459355555475
#> 363 0.0004661969724111259
#> 364 0.0004515194450505078
#> 365 0.0004358708392828703
#> 366 0.0004218583053443581
#> 367 0.00040883725159801543
#> 368 0.0003956131695304066
#> 369 0.0003827497421298176
#> 370 0.000370656605809927
#> 371 0.00036004791036248207
#> 372 0.0003480703162495047
#> 373 0.0003388348559383303
#> 374 0.000327684567309916
#> 375 0.0003175089950673282
#> 376 0.0003082627372350544
#> 377 0.0002986858307849616
#> 378 0.00028960598865523934
#> 379 0.0002815576735883951
#> 380 0.0002736181777436286
#> 381 0.0002657140721566975
#> 382 0.00025785667821764946
#> 383 0.0002509196347091347
#> 384 0.00024437913089059293
#> 385 0.00023740741016808897
#> 386 0.0002299495681654662
#> 387 0.0002234804560430348
#> 388 0.0002169939107261598
#> 389 0.00021134663256816566
#> 390 0.0002056143421214074
#> 391 0.00020046206191182137
#> 392 0.00019536828040145338
#> 393 0.00019056514429394156
#> 394 0.00018598540918901563
#> 395 0.00018159380124416202
#> 396 0.00017640764417592436
#> 397 0.00017208821373060346
#> 398 0.000168110869708471
#> 399 0.00016350964142475277
#> 400 0.00015964081103447825
#> 401 0.00015596051525790244
#> 402 0.00015269994037225842
#> 403 0.00014866374840494245
#> 404 0.00014477886725217104
#> 405 0.00014148686022963375
#> 406 0.00013842849875800312
#> 407 0.00013507613039109856
#> 408 0.0001322997995885089
#> 409 0.00012896949192509055
#> 410 0.00012618394976016134
#> 411 0.00012356613297015429
#> 412 0.00012060831068083644
#> 413 0.00011798611376434565
#> 414 0.0001152795521193184
#> 415 0.00011272911069681868
#> 416 0.00011033188638975844
#> 417 0.00010773474059533328
#> 418 0.00010584026313154027
#> 419 0.00010329326323699206
#> 420 0.00010140397353097796
#> 421 9.970468090614304e-05
#> 422 9.72362540778704e-05
#> 423 9.54945498961024e-05
#> 424 9.346337174065411e-05
#> 425 9.128850797424093e-05
#> 426 8.97917925613001e-05
#> 427 8.779048221185803e-05
#> 428 8.59305146150291e-05
#> 429 8.416303899139166e-05
#> 430 8.247063669841737e-05
#> 431 8.109148620860651e-05
#> 432 7.982019451446831e-05
#> 433 7.818565791239962e-05
#> 434 7.673520303796977e-05
#> 435 7.54009815864265e-05
#> 436 7.374506094492972e-05
#> 437 7.267539331223816e-05
#> 438 7.122510578483343e-05
#> 439 6.98604853823781e-05
#> 440 6.852982915006578e-05
#> 441 6.75098126521334e-05
#> 442 6.636354373767972e-05
#> 443 6.522039620904252e-05
#> 444 6.410140485968441e-05
#> 445 6.307245348580182e-05
#> 446 6.221079092938453e-05
#> 447 6.089429371058941e-05
#> 448 5.975936437607743e-05
#> 449 5.893126945011318e-05
#> 450 5.780566425528377e-05
#> 451 5.694766514352523e-05
#> 452 5.5986300139920786e-05
#> 453 5.502309068106115e-05
#> 454 5.420695379143581e-05
#> 455 5.31858422618825e-05
#> 456 5.239694655756466e-05
#> 457 5.1775907195406035e-05
#> 458 5.109262929181568e-05
#> 459 5.0413200369803235e-05
#> 460 4.956878183293156e-05
#> 461 4.8856254579732195e-05
#> 462 4.8221645556623116e-05
#> 463 4.7429402911802754e-05
#> 464 4.700458885054104e-05
#> 465 4.615000216290355e-05
#> 466 4.5314704038901255e-05
#> 467 4.466490645427257e-05
#> 468 4.406480729812756e-05
#> 469 4.344138142187148e-05
#> 470 4.302451270632446e-05
#> 471 4.255307430867106e-05
#> 472 4.1863419028231874e-05
#> 473 4.148659354541451e-05
#> 474 4.099802754353732e-05
#> 475 4.034798257634975e-05
#> 476 3.994005237473175e-05
#> 477 3.94669477827847e-05
#> 478 3.9117549022194e-05
#> 479 3.8569156458834186e-05
#> 480 3.8105612475192174e-05
#> 481 3.753463170141913e-05
#> 482 3.679965084302239e-05
#> 483 3.646357436082326e-05
#> 484 3.597680915845558e-05
#> 485 3.555299190338701e-05
#> 486 3.504360938677564e-05
#> 487 3.449235737207346e-05
#> 488 3.391931386431679e-05
#> 489 3.374389780219644e-05
#> 490 3.328040838823654e-05
#> 491 3.31329574692063e-05
#> 492 3.259751247242093e-05
#> 493 3.2441555958939716e-05
#> 494 3.1837684218771756e-05
#> 495 3.1491359550273046e-05
#> 496 3.120429755654186e-05
#> 497 3.089967503910884e-05
#> 498 3.059657319681719e-05
#> 499 3.0050463465158828e-05
toc = time.process_time()
print(toc - tic, "seconds")
#> 32.989317715 seconds

12.4 A neural network written in rTorch

The example shows the long and manual way of calculating the forward and backward passes but using rTorch. The objective is getting familiarized with the rTorch tensor operations.

The following example was converted from PyTorch to rTorch to show differences and similarities of both approaches. The original source can be found here: Source.

12.4.1 Load the libraries

library(rTorch)
library(ggplot2)

device = torch$device('cpu')
# device = torch.device('cuda')  # Uncomment this to run on GPU
invisible(torch$manual_seed(0))
  • N is batch size;
  • D_in is input dimension;
  • H is hidden dimension;
  • D_out is output dimension.

12.4.2 Dataset

We will create a random dataset for a two layer neural network.

N <- 64L; D_in <- 1000L; H <- 100L; D_out <- 10L

# Create random Tensors to hold inputs and outputs
x <- torch$randn(N, D_in, device=device)
y <- torch$randn(N, D_out, device=device)
# dimensions of both tensors
dim(x)
dim(y)
#> [1]   64 1000
#> [1] 64 10

12.4.3 Initialize the weights

# Randomly initialize weights
w1 <- torch$randn(D_in, H, device=device)   # layer 1
w2 <- torch$randn(H, D_out, device=device)  # layer 2
dim(w1)
dim(w2)
#> [1] 1000  100
#> [1] 100  10

12.4.4 Iterate through the dataset

Now, we are going to train our neural network on the training dataset. The equestion is: “how many times do we have to expose the training data to the algorithm?”. By looking at the graph of the loss we may get an idea when we should stop.

12.4.4.1 Iterate 50 times

Let’s say that for the sake of time we select to run only 50 iterations of the loop doing the training.

learning_rate = 1e-6

# loop
for (t in 1:50) {
  # Forward pass: compute predicted y, y_pred
  h <- x$mm(w1)              # matrix multiplication, x*w1
  h_relu <- h$clamp(min=0)   # make elements greater than zero
  y_pred <- h_relu$mm(w2)    # matrix multiplication, h_relu*w2

  # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor
  # of shape (); we can get its value as a Python number with loss.item().
  loss <- (torch$sub(y_pred, y))$pow(2)$sum()   # sum((y_pred-y)^2)
  # cat(t, "\t")
  # cat(loss$item(), "\n")

  # Backprop to compute gradients of w1 and w2 with respect to loss
  grad_y_pred <- torch$mul(torch$scalar_tensor(2.0), torch$sub(y_pred, y))
  grad_w2 <- h_relu$t()$mm(grad_y_pred)        # compute gradient of w2
  grad_h_relu <- grad_y_pred$mm(w2$t())
  grad_h <- grad_h_relu$clone()
  mask <- grad_h$lt(0)                         # filter values lower than zero 
  torch$masked_select(grad_h, mask)$fill_(0.0) # make them equal to zero
  grad_w1 <- x$t()$mm(grad_h)                  # compute gradient of w1
   
  # Update weights using gradient descent
  w1 <- torch$sub(w1, torch$mul(learning_rate, grad_w1))
  w2 <- torch$sub(w2, torch$mul(learning_rate, grad_w2))
}
# y vs predicted y
df_50 <- data.frame(y = y$flatten()$numpy(), 
                    y_pred = y_pred$flatten()$numpy(), iter = 50)

ggplot(df_50, aes(x = y, y = y_pred)) +
    geom_point()

We see a lot of dispersion between the predicted values, \(y_{pred}\) and the real values, \(y\). We are far from our goal.

Let’s take a look at the dataframe:

library('DT')
datatable(df_50, options = list(pageLength = 10))

12.4.4.2 A function to train the neural network

Now, we convert the script above to a function, so we could reuse it several times. We want to study the effect of the iteration on the performance of the algorithm.

This time we create a function train to input the number of iterations that we want to run:

train <- function(iterations) {
    # Randomly initialize weights
    w1 <- torch$randn(D_in, H, device=device)   # layer 1
    w2 <- torch$randn(H, D_out, device=device)  # layer 2
    
    learning_rate = 1e-6
    # loop
    for (t in 1:iterations) {
      # Forward pass: compute predicted y
      h <- x$mm(w1)
      h_relu <- h$clamp(min=0)
      y_pred <- h_relu$mm(w2)
    
      # Compute and print loss; loss is a scalar stored in a PyTorch Tensor
      # of shape (); we can get its value as a Python number with loss.item().
      loss <- (torch$sub(y_pred, y))$pow(2)$sum()
      # cat(t, "\t"); cat(loss$item(), "\n")
    
      # Backprop to compute gradients of w1 and w2 with respect to loss
      grad_y_pred <- torch$mul(torch$scalar_tensor(2.0), torch$sub(y_pred, y))
      grad_w2 <- h_relu$t()$mm(grad_y_pred)
      grad_h_relu <- grad_y_pred$mm(w2$t())
      grad_h <- grad_h_relu$clone()
      mask <- grad_h$lt(0)
      torch$masked_select(grad_h, mask)$fill_(0.0)
      grad_w1 <- x$t()$mm(grad_h)
       
      # Update weights using gradient descent
      w1 <- torch$sub(w1, torch$mul(learning_rate, grad_w1))
      w2 <- torch$sub(w2, torch$mul(learning_rate, grad_w2))
    }
    data.frame(y = y$flatten()$numpy(), 
                        y_pred = y_pred$flatten()$numpy(), iter = iterations)
}

12.4.4.3 Run it at 100 iterations

# retrieve the results and store them in a dataframe
df_100 <- train(iterations = 100)
datatable(df_100, options = list(pageLength = 10))
# plot
ggplot(df_100, aes(x = y_pred, y = y)) +
    geom_point()

12.4.4.4 250 iterations

Still there are differences between the value and the prediction. Let’s try with more iterations, like 250:

df_250 <- train(iterations = 200)
datatable(df_250, options = list(pageLength = 25))
# plot
ggplot(df_250, aes(x = y_pred, y = y)) +
    geom_point()

We see the formation of a line between the values and prediction, which means we are getting closer at finding the right algorithm, in this particular case, weights and bias.

12.4.4.5 500 iterations

Let’s try one more time with 500 iterations:

df_500 <- train(iterations = 500)
datatable(df_500, options = list(pageLength = 25))
ggplot(df_500, aes(x = y_pred, y = y)) +
    geom_point()

12.5 Complete code for neural network in rTorch

library(rTorch)
library(ggplot2)
library(tictoc)

tic()
device = torch$device('cpu')
# device = torch.device('cuda')  # Uncomment this to run on GPU
invisible(torch$manual_seed(0))

# Properties of tensors and neural network
N <- 64L; D_in <- 1000L; H <- 100L; D_out <- 10L

# Create random Tensors to hold inputs and outputs
x <- torch$randn(N, D_in, device=device)
y <- torch$randn(N, D_out, device=device)
# dimensions of both tensors

# initialize the weights
w1 <- torch$randn(D_in, H, device=device)   # layer 1
w2 <- torch$randn(H, D_out, device=device)  # layer 2

learning_rate = 1e-6
# loop
for (t in 1:500) {
  # Forward pass: compute predicted y, y_pred
  h <- x$mm(w1)              # matrix multiplication, x*w1
  h_relu <- h$clamp(min=0)   # make elements greater than zero
  y_pred <- h_relu$mm(w2)    # matrix multiplication, h_relu*w2

  # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor
  # of shape (); we can get its value as a Python number with loss.item().
  loss <- (torch$sub(y_pred, y))$pow(2)$sum()   # sum((y_pred-y)^2)
  # cat(t, "\t")
  # cat(loss$item(), "\n")

  # Backprop to compute gradients of w1 and w2 with respect to loss
  grad_y_pred <- torch$mul(torch$scalar_tensor(2.0), torch$sub(y_pred, y))
  grad_w2 <- h_relu$t()$mm(grad_y_pred)        # compute gradient of w2
  grad_h_relu <- grad_y_pred$mm(w2$t())
  grad_h <- grad_h_relu$clone()
  mask <- grad_h$lt(0)                         # filter values lower than zero 
  torch$masked_select(grad_h, mask)$fill_(0.0) # make them equal to zero
  grad_w1 <- x$t()$mm(grad_h)                  # compute gradient of w1
   
  # Update weights using gradient descent
  w1 <- torch$sub(w1, torch$mul(learning_rate, grad_w1))
  w2 <- torch$sub(w2, torch$mul(learning_rate, grad_w2))
}
# y vs predicted y
df<- data.frame(y = y$flatten()$numpy(), 
                    y_pred = y_pred$flatten()$numpy(), iter = 500)
datatable(df, options = list(pageLength = 25))
ggplot(df, aes(x = y_pred, y = y)) +
    geom_point()

toc()
#> 22.033 sec elapsed

12.6 Exercise

  1. Rewrite the code in rTorch but including and plotting the loss at each iteration

  2. On the neural network written in PyTorch, code, instead of printing a long table, print the table by pages that we could navigate using vertical and horizontal bars. Tip: read the PyThon data structure from R and plot it with ggplot2