Closed
Description
I have a basic benchmark test where I train a CNN on MNIST data with and without AMP. The problem is that I can't get the f16 types out of the model or export it to a CPU. Calling float()
on the model doesn't seem to do anything.
import time
import torch.optim as optim
import torch
import torch.nn as nn
import torch.utils.data
import torch.onnx
import torchvision.datasets as dsets
import torchvision.transforms as trans
from apex import amp
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.lin1 = nn.Linear(32 * 7 * 7, 10)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.view(out.size(0), -1)
out = self.lin1(out)
return out
trainSet = dsets.MNIST(root='./data', train=True, transform=trans.ToTensor(), download=True)
trainLoader = torch.utils.data.DataLoader(dataset=trainSet, batch_size=100, shuffle=True)
trainData = iter(trainLoader)
testSet = dsets.MNIST(root='./data', train=False, transform=trans.ToTensor(), download=True)
testLoader = torch.utils.data.DataLoader(dataset=testSet, batch_size=100, shuffle=True)
testData = iter(testLoader)
def accuracy(testLoader, model):
correct, total = 0, 0
with torch.no_grad():
for data in testLoader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return (correct / total)
NUM_EPOCHS = 3
device = torch.device('cuda:0')
criterion = nn.CrossEntropyLoss()
model = Net()
model.__init__()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), 0.01)
start = time.time()
for epoch in range(NUM_EPOCHS):
print('Epoch {}'.format(epoch))
for i, (images, labels) in enumerate(trainLoader):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print('f32: Accuracy: {0:.4f}'.format(accuracy(testLoader, model)))
print('f32: Training time: {0:.2f}'.format(time.time() - start))
torch.save(model.state_dict(), './f32.pth')
dummy_input = torch.randn(1, 1, 28, 28).to(device)
torch.onnx.export(model, dummy_input, './f32.onnx', verbose=True)
print('*****************************************************************************')
model = Net()
model.__init__()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), 0.01)
amp_model, amp_optimizer = amp.initialize(model, optimizer, opt_level="O1")
start = time.time()
for epoch in range(NUM_EPOCHS):
print('Epoch {}'.format(epoch))
for i, (images, labels) in enumerate(trainLoader):
images = images.to(device)
labels = labels.to(device)
amp_optimizer.zero_grad()
outputs = amp_model(images)
loss = criterion(outputs, labels)
# loss.backward()
with amp.scale_loss(loss, amp_optimizer) as scaled_loss:
scaled_loss.backward()
amp_optimizer.step()
print('AMP: Accuracy: {0:.4f}'.format(accuracy(testLoader, amp_model)))
print('AMP: Training time: {0:.2f}'.format(time.time() - start))
amp_model = amp_model.float() # This line doesn't do anything
torch.save(amp_model.state_dict(), './amp.pth')
dummy_input = torch.randn(1, 1, 28, 28).to(device)
torch.onnx.export(amp_model, dummy_input, './amp.onnx', verbose=True)
This code prints the following:
Epoch 0
Epoch 1
Epoch 2
f32: Accuracy: 0.9803
f32: Training time: 13.65
graph(%0 : Float(1, 1, 28, 28),
%layer1.0.weight : Float(16, 1, 5, 5),
%layer1.0.bias : Float(16),
%layer1.1.weight : Float(16),
%layer1.1.bias : Float(16),
%layer1.1.running_mean : Float(16),
%layer1.1.running_var : Float(16),
%layer1.1.num_batches_tracked : Long(),
%layer2.0.weight : Float(32, 16, 5, 5),
%layer2.0.bias : Float(32),
%layer2.1.weight : Float(32),
%layer2.1.bias : Float(32),
%layer2.1.running_mean : Float(32),
%layer2.1.running_var : Float(32),
%layer2.1.num_batches_tracked : Long(),
%lin1.weight : Float(10, 1568),
%lin1.bias : Float(10)):
%17 : Float(1, 16, 28, 28) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%0, %layer1.0.weight, %layer1.0.bias), scope: Net/Sequential[layer1]/Conv2d[0]
%18 : Float(1, 16, 28, 28) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%17, %layer1.1.weight, %layer1.1.bias, %layer1.1.running_mean, %layer1.1.running_var), scope: Net/Sequential[layer1]/BatchNorm2d[1]
%19 : Float(1, 16, 28, 28) = onnx::Relu(%18), scope: Net/Sequential[layer1]/ReLU[2]
%20 : Float(1, 16, 14, 14) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%19), scope: Net/Sequential[layer1]/MaxPool2d[3]
%21 : Float(1, 32, 14, 14) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%20, %layer2.0.weight, %layer2.0.bias), scope: Net/Sequential[layer2]/Conv2d[0]
%22 : Float(1, 32, 14, 14) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%21, %layer2.1.weight, %layer2.1.bias, %layer2.1.running_mean, %layer2.1.running_var), scope: Net/Sequential[layer2]/BatchNorm2d[1]
%23 : Float(1, 32, 14, 14) = onnx::Relu(%22), scope: Net/Sequential[layer2]/ReLU[2]
%24 : Float(1, 32, 7, 7) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%23), scope: Net/Sequential[layer2]/MaxPool2d[3]
%25 : Long() = onnx::Constant[value={0}](), scope: Net
%26 : Tensor = onnx::Shape(%24), scope: Net
%27 : Long() = onnx::Gather[axis=0](%26, %25), scope: Net
%28 : Long() = onnx::Constant[value={-1}](), scope: Net
%29 : Tensor = onnx::Unsqueeze[axes=[0]](%27)
%30 : Tensor = onnx::Unsqueeze[axes=[0]](%28)
%31 : Tensor = onnx::Concat[axis=0](%29, %30)
%32 : Float(1, 1568) = onnx::Reshape(%24, %31), scope: Net
%33 : Float(1, 10) = onnx::Gemm[alpha=1, beta=1, transB=1](%32, %lin1.weight, %lin1.bias), scope: Net/Linear[lin1]
return (%33)
*****************************************************************************
Selected optimization level O1: Insert automatic casts around Pytorch functions and Tensor methods.
Defaults for this optimization level are:
enabled : True
opt_level : O1
cast_model_type : None
patch_torch_functions : True
keep_batchnorm_fp32 : None
master_weights : None
loss_scale : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled : True
opt_level : O1
cast_model_type : None
patch_torch_functions : True
keep_batchnorm_fp32 : None
master_weights : None
loss_scale : dynamic
Epoch 0
Epoch 1
Epoch 2
AMP: Accuracy: 0.9788
AMP: Training time: 19.55
graph(%x.3 : Float(1, 1, 28, 28),
%layer1.0.weight : Float(16, 1, 5, 5),
%layer1.0.bias : Float(16),
%layer1.1.weight : Float(16),
%layer1.1.bias : Float(16),
%layer1.1.running_mean : Float(16),
%layer1.1.running_var : Float(16),
%layer1.1.num_batches_tracked : Long(),
%layer2.0.weight : Float(32, 16, 5, 5),
%layer2.0.bias : Float(32),
%layer2.1.weight : Float(32),
%layer2.1.bias : Float(32),
%layer2.1.running_mean : Float(32),
%layer2.1.running_var : Float(32),
%layer2.1.num_batches_tracked : Long(),
%lin1.weight : Float(10, 1568),
%lin1.bias : Float(10)):
%17 : Half(16, 1, 5, 5) = onnx::Cast[to=10](%layer1.0.weight), scope: Net/Sequential[layer1]/Conv2d[0]
%18 : Half(16) = onnx::Cast[to=10](%layer1.0.bias), scope: Net/Sequential[layer1]/Conv2d[0]
%19 : Half(1, 1, 28, 28) = onnx::Cast[to=10](%x.3), scope: Net/Sequential[layer1]/Conv2d[0]
%20 : Half(1, 16, 28, 28) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%19, %17, %18), scope: Net/Sequential[layer1]/Conv2d[0]
%21 : Half(1, 16, 28, 28) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%20, %layer1.1.weight, %layer1.1.bias, %layer1.1.running_mean, %layer1.1.running_var), scope: Net/Sequential[layer1]/BatchNorm2d[1]
%22 : Half(1, 16, 28, 28) = onnx::Relu(%21), scope: Net/Sequential[layer1]/ReLU[2]
%23 : Half(1, 16, 14, 14) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%22), scope: Net/Sequential[layer1]/MaxPool2d[3]
%24 : Half(32, 16, 5, 5) = onnx::Cast[to=10](%layer2.0.weight), scope: Net/Sequential[layer2]/Conv2d[0]
%25 : Half(32) = onnx::Cast[to=10](%layer2.0.bias), scope: Net/Sequential[layer2]/Conv2d[0]
%26 : Half(1, 32, 14, 14) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%23, %24, %25), scope: Net/Sequential[layer2]/Conv2d[0]
%27 : Half(1, 32, 14, 14) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%26, %layer2.1.weight, %layer2.1.bias, %layer2.1.running_mean, %layer2.1.running_var), scope: Net/Sequential[layer2]/BatchNorm2d[1]
%28 : Half(1, 32, 14, 14) = onnx::Relu(%27), scope: Net/Sequential[layer2]/ReLU[2]
%29 : Half(1, 32, 7, 7) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: Net/Sequential[layer2]/MaxPool2d[3]
%30 : Long() = onnx::Constant[value={0}](), scope: Net
%31 : Tensor = onnx::Shape(%29), scope: Net
%32 : Long() = onnx::Gather[axis=0](%31, %30), scope: Net
%33 : Long() = onnx::Constant[value={-1}](), scope: Net
%34 : Tensor = onnx::Unsqueeze[axes=[0]](%32)
%35 : Tensor = onnx::Unsqueeze[axes=[0]](%33)
%36 : Tensor = onnx::Concat[axis=0](%34, %35)
%37 : Half(1, 1568) = onnx::Reshape(%29, %36), scope: Net
%38 : Half(10, 1568) = onnx::Cast[to=10](%lin1.weight), scope: Net/Linear[lin1]
%39 : Half(10) = onnx::Cast[to=10](%lin1.bias), scope: Net/Linear[lin1]
%40 : Half(1, 10) = onnx::Gemm[alpha=1, beta=1, transB=1](%37, %38, %39), scope: Net/Linear[lin1]
return (%40)
Using mixed precision, the training time per batch went up 43%, but I can still increase the batch size now so that's OK. (This is on an RTX-2060 with 6GB.) What's more concerning is that I seem to be stuck in f16-land.
The expected behavior for model.float()
is to convert all parameters and buffers to Float, but it's still riddled with these Half types making it useless in any environment with no f16 support. How do I get them out of there if float()
doesn't do anything?
Metadata
Metadata
Assignees
Labels
No labels