Training

for (epoch in 1:num_epochs) {
    i <-  0
    for (obj in iter_train_loader) {
        
        images <- obj[[1]]   # tensor torch.Size([64, 3, 28, 28])
        labels <- obj[[2]]   # tensor torch.Size([64]), labels from 0 to 9
        # cat(i, "\t"); print(images$shape)

        # Reshape images to (batch_size, input_size)
        images <- images$reshape(-1L, 28L*28L)
        # images <- torch$as_tensor(images$reshape(-1L, 28L*28L), dtype=torch$double)

        # Forward pass
        outputs <- model(images)
        loss <- criterion(outputs, labels)

        # Backward and optimize
        optimizer$zero_grad()
        loss$backward()
        optimizer$step()

        if ((i+1) %% 100 == 0) {
            cat(sprintf('Epoch [%d/%d], Step [%d/%d], Loss: %f \n',
                epoch+1, num_epochs, i+1, total_step, loss$item()))
        }
        i <-  i + 1
    }
}  
#> Epoch [2/5], Step [100/600], Loss: 2.208986 
#> Epoch [2/5], Step [200/600], Loss: 2.166300 
#> Epoch [2/5], Step [300/600], Loss: 2.048386 
#> Epoch [2/5], Step [400/600], Loss: 1.968113 
#> Epoch [2/5], Step [500/600], Loss: 1.847626 
#> Epoch [2/5], Step [600/600], Loss: 1.836973 
#> Epoch [3/5], Step [100/600], Loss: 1.765670 
#> Epoch [3/5], Step [200/600], Loss: 1.768158 
#> Epoch [3/5], Step [300/600], Loss: 1.657210 
#> Epoch [3/5], Step [400/600], Loss: 1.579294 
#> Epoch [3/5], Step [500/600], Loss: 1.480846 
#> Epoch [3/5], Step [600/600], Loss: 1.498360 
#> Epoch [4/5], Step [100/600], Loss: 1.469377 
#> Epoch [4/5], Step [200/600], Loss: 1.503922 
#> Epoch [4/5], Step [300/600], Loss: 1.401340 
#> Epoch [4/5], Step [400/600], Loss: 1.333385 
#> Epoch [4/5], Step [500/600], Loss: 1.244200 
#> Epoch [4/5], Step [600/600], Loss: 1.270890 
#> Epoch [5/5], Step [100/600], Loss: 1.271359 
#> Epoch [5/5], Step [200/600], Loss: 1.324649 
#> Epoch [5/5], Step [300/600], Loss: 1.227103 
#> Epoch [5/5], Step [400/600], Loss: 1.169355 
#> Epoch [5/5], Step [500/600], Loss: 1.085202 
#> Epoch [5/5], Step [600/600], Loss: 1.112856 
#> Epoch [6/5], Step [100/600], Loss: 1.133963 
#> Epoch [6/5], Step [200/600], Loss: 1.197483 
#> Epoch [6/5], Step [300/600], Loss: 1.103267 
#> Epoch [6/5], Step [400/600], Loss: 1.053915 
#> Epoch [6/5], Step [500/600], Loss: 0.972801 
#> Epoch [6/5], Step [600/600], Loss: 0.998075

Save the model

# Save the model checkpoint
torch$save(model$state_dict(), 'model.ckpt')