Skip to content

Can't convert trained AMP model to full precision #349

Closed
@jtiscione

Description

@jtiscione

I have a basic benchmark test where I train a CNN on MNIST data with and without AMP. The problem is that I can't get the f16 types out of the model or export it to a CPU. Calling float() on the model doesn't seem to do anything.

import time
import torch.optim as optim
import torch
import torch.nn as nn
import torch.utils.data
import torch.onnx
import torchvision.datasets as dsets
import torchvision.transforms as trans
from apex import amp

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.lin1 = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.lin1(out)
        return out

trainSet = dsets.MNIST(root='./data', train=True, transform=trans.ToTensor(), download=True)
trainLoader = torch.utils.data.DataLoader(dataset=trainSet, batch_size=100, shuffle=True)
trainData = iter(trainLoader)

testSet = dsets.MNIST(root='./data', train=False, transform=trans.ToTensor(), download=True)
testLoader = torch.utils.data.DataLoader(dataset=testSet, batch_size=100, shuffle=True)
testData = iter(testLoader)

def accuracy(testLoader, model):
    correct, total = 0, 0
    with torch.no_grad():
        for data in testLoader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return (correct / total)

NUM_EPOCHS = 3
device = torch.device('cuda:0')
criterion = nn.CrossEntropyLoss()
model = Net()
model.__init__()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), 0.01)
start = time.time()
for epoch in range(NUM_EPOCHS):
    print('Epoch {}'.format(epoch))
    for i, (images, labels) in enumerate(trainLoader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
print('f32: Accuracy: {0:.4f}'.format(accuracy(testLoader, model)))
print('f32: Training time: {0:.2f}'.format(time.time() - start))
torch.save(model.state_dict(), './f32.pth')
dummy_input = torch.randn(1, 1, 28, 28).to(device)
torch.onnx.export(model, dummy_input, './f32.onnx', verbose=True)

print('*****************************************************************************')

model = Net()
model.__init__()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), 0.01)
amp_model, amp_optimizer = amp.initialize(model, optimizer, opt_level="O1")
start = time.time()
for epoch in range(NUM_EPOCHS):
    print('Epoch {}'.format(epoch))
    for i, (images, labels) in enumerate(trainLoader):
        images = images.to(device)
        labels = labels.to(device)
        amp_optimizer.zero_grad()
        outputs = amp_model(images)
        loss = criterion(outputs, labels)
        # loss.backward()
        with amp.scale_loss(loss, amp_optimizer) as scaled_loss:
            scaled_loss.backward()
        amp_optimizer.step()
print('AMP: Accuracy: {0:.4f}'.format(accuracy(testLoader, amp_model)))
print('AMP: Training time: {0:.2f}'.format(time.time() - start))

amp_model = amp_model.float()  # This line doesn't do anything

torch.save(amp_model.state_dict(), './amp.pth')
dummy_input = torch.randn(1, 1, 28, 28).to(device)
torch.onnx.export(amp_model, dummy_input, './amp.onnx', verbose=True)

This code prints the following:

Epoch 0
Epoch 1
Epoch 2
f32: Accuracy: 0.9803
f32: Training time: 13.65
graph(%0 : Float(1, 1, 28, 28),
      %layer1.0.weight : Float(16, 1, 5, 5),
      %layer1.0.bias : Float(16),
      %layer1.1.weight : Float(16),
      %layer1.1.bias : Float(16),
      %layer1.1.running_mean : Float(16),
      %layer1.1.running_var : Float(16),
      %layer1.1.num_batches_tracked : Long(),
      %layer2.0.weight : Float(32, 16, 5, 5),
      %layer2.0.bias : Float(32),
      %layer2.1.weight : Float(32),
      %layer2.1.bias : Float(32),
      %layer2.1.running_mean : Float(32),
      %layer2.1.running_var : Float(32),
      %layer2.1.num_batches_tracked : Long(),
      %lin1.weight : Float(10, 1568),
      %lin1.bias : Float(10)):
  %17 : Float(1, 16, 28, 28) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%0, %layer1.0.weight, %layer1.0.bias), scope: Net/Sequential[layer1]/Conv2d[0]
  %18 : Float(1, 16, 28, 28) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%17, %layer1.1.weight, %layer1.1.bias, %layer1.1.running_mean, %layer1.1.running_var), scope: Net/Sequential[layer1]/BatchNorm2d[1]
  %19 : Float(1, 16, 28, 28) = onnx::Relu(%18), scope: Net/Sequential[layer1]/ReLU[2]
  %20 : Float(1, 16, 14, 14) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%19), scope: Net/Sequential[layer1]/MaxPool2d[3]
  %21 : Float(1, 32, 14, 14) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%20, %layer2.0.weight, %layer2.0.bias), scope: Net/Sequential[layer2]/Conv2d[0]
  %22 : Float(1, 32, 14, 14) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%21, %layer2.1.weight, %layer2.1.bias, %layer2.1.running_mean, %layer2.1.running_var), scope: Net/Sequential[layer2]/BatchNorm2d[1]
  %23 : Float(1, 32, 14, 14) = onnx::Relu(%22), scope: Net/Sequential[layer2]/ReLU[2]
  %24 : Float(1, 32, 7, 7) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%23), scope: Net/Sequential[layer2]/MaxPool2d[3]
  %25 : Long() = onnx::Constant[value={0}](), scope: Net
  %26 : Tensor = onnx::Shape(%24), scope: Net
  %27 : Long() = onnx::Gather[axis=0](%26, %25), scope: Net
  %28 : Long() = onnx::Constant[value={-1}](), scope: Net
  %29 : Tensor = onnx::Unsqueeze[axes=[0]](%27)
  %30 : Tensor = onnx::Unsqueeze[axes=[0]](%28)
  %31 : Tensor = onnx::Concat[axis=0](%29, %30)
  %32 : Float(1, 1568) = onnx::Reshape(%24, %31), scope: Net
  %33 : Float(1, 10) = onnx::Gemm[alpha=1, beta=1, transB=1](%32, %lin1.weight, %lin1.bias), scope: Net/Linear[lin1]
  return (%33)

*****************************************************************************
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Epoch 0
Epoch 1
Epoch 2
AMP: Accuracy: 0.9788
AMP: Training time: 19.55
graph(%x.3 : Float(1, 1, 28, 28),
      %layer1.0.weight : Float(16, 1, 5, 5),
      %layer1.0.bias : Float(16),
      %layer1.1.weight : Float(16),
      %layer1.1.bias : Float(16),
      %layer1.1.running_mean : Float(16),
      %layer1.1.running_var : Float(16),
      %layer1.1.num_batches_tracked : Long(),
      %layer2.0.weight : Float(32, 16, 5, 5),
      %layer2.0.bias : Float(32),
      %layer2.1.weight : Float(32),
      %layer2.1.bias : Float(32),
      %layer2.1.running_mean : Float(32),
      %layer2.1.running_var : Float(32),
      %layer2.1.num_batches_tracked : Long(),
      %lin1.weight : Float(10, 1568),
      %lin1.bias : Float(10)):
  %17 : Half(16, 1, 5, 5) = onnx::Cast[to=10](%layer1.0.weight), scope: Net/Sequential[layer1]/Conv2d[0]
  %18 : Half(16) = onnx::Cast[to=10](%layer1.0.bias), scope: Net/Sequential[layer1]/Conv2d[0]
  %19 : Half(1, 1, 28, 28) = onnx::Cast[to=10](%x.3), scope: Net/Sequential[layer1]/Conv2d[0]
  %20 : Half(1, 16, 28, 28) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%19, %17, %18), scope: Net/Sequential[layer1]/Conv2d[0]
  %21 : Half(1, 16, 28, 28) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%20, %layer1.1.weight, %layer1.1.bias, %layer1.1.running_mean, %layer1.1.running_var), scope: Net/Sequential[layer1]/BatchNorm2d[1]
  %22 : Half(1, 16, 28, 28) = onnx::Relu(%21), scope: Net/Sequential[layer1]/ReLU[2]
  %23 : Half(1, 16, 14, 14) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%22), scope: Net/Sequential[layer1]/MaxPool2d[3]
  %24 : Half(32, 16, 5, 5) = onnx::Cast[to=10](%layer2.0.weight), scope: Net/Sequential[layer2]/Conv2d[0]
  %25 : Half(32) = onnx::Cast[to=10](%layer2.0.bias), scope: Net/Sequential[layer2]/Conv2d[0]
  %26 : Half(1, 32, 14, 14) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%23, %24, %25), scope: Net/Sequential[layer2]/Conv2d[0]
  %27 : Half(1, 32, 14, 14) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%26, %layer2.1.weight, %layer2.1.bias, %layer2.1.running_mean, %layer2.1.running_var), scope: Net/Sequential[layer2]/BatchNorm2d[1]
  %28 : Half(1, 32, 14, 14) = onnx::Relu(%27), scope: Net/Sequential[layer2]/ReLU[2]
  %29 : Half(1, 32, 7, 7) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: Net/Sequential[layer2]/MaxPool2d[3]
  %30 : Long() = onnx::Constant[value={0}](), scope: Net
  %31 : Tensor = onnx::Shape(%29), scope: Net
  %32 : Long() = onnx::Gather[axis=0](%31, %30), scope: Net
  %33 : Long() = onnx::Constant[value={-1}](), scope: Net
  %34 : Tensor = onnx::Unsqueeze[axes=[0]](%32)
  %35 : Tensor = onnx::Unsqueeze[axes=[0]](%33)
  %36 : Tensor = onnx::Concat[axis=0](%34, %35)
  %37 : Half(1, 1568) = onnx::Reshape(%29, %36), scope: Net
  %38 : Half(10, 1568) = onnx::Cast[to=10](%lin1.weight), scope: Net/Linear[lin1]
  %39 : Half(10) = onnx::Cast[to=10](%lin1.bias), scope: Net/Linear[lin1]
  %40 : Half(1, 10) = onnx::Gemm[alpha=1, beta=1, transB=1](%37, %38, %39), scope: Net/Linear[lin1]
  return (%40)

Using mixed precision, the training time per batch went up 43%, but I can still increase the batch size now so that's OK. (This is on an RTX-2060 with 6GB.) What's more concerning is that I seem to be stuck in f16-land.
The expected behavior for model.float() is to convert all parameters and buffers to Float, but it's still riddled with these Half types making it useless in any environment with no f16 support. How do I get them out of there if float() doesn't do anything?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions