Chapter 7 Creating PyTorch classes
Last update: Thu Oct 22 16:46:28 2020 -0500 (54a46ea04)
7.1 Build a PyTorch model class
PyTorch classes cannot not directly be instantiated from R
. Yet. We need an intermediate step to create a class. For this, we use reticulate
functions like py_run_string()
that will read the class implementation in Python
code, and then assign it to an R object.
7.1.1 Example 1: One layer NN
py_run_string("import torch")
main = py_run_string(
"
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer = torch.nn.Linear(1, 1)
def forward(self, x):
x = self.layer(x)
return x
")
# build a Linear Rgression model
net <- main$Net()
The R object net
now contains all the object in the PyTorch class Net
.
7.1.2 Example 2: Logistic Regression
main <- py_run_string(
"
import torch.nn as nn
class LogisticRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(LogisticRegressionModel, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
out = self.linear(x)
return out
")
# build a Logistic Rgression model
LogisticRegressionModel <- main$LogisticRegressionModel
The R object LogisticRegressionModel
now contains all the objects in the PyTorch class LogisticRegressionModel
.