SGD and its convergence

4.1. SGD and its convergence#

The SGD method aims to tackle (4.1) by mimicking the gradient descent method applied on the stochastic optimization problem. In particular, observe that if we have \(\xi \sim {\cal D}\), then

\[\nabla f( \vw ) \approx \nabla \ell( \vw, \xi ) \]

Particularly, when conditioned on \(\vw\), the expected value of the right hand side equals the left hand side. The above insight motivates the SGD method as follows:

Algorithm 4.1 (Stochastic Gradient (SGD) Method)

  • Input: initial point \(\vw^0\), stepsize sequence \(\{ \eta_t \}_{t \geq 0}\), max. no. of iterations \(T\).

  • For \(t=0,1,2,..., T-1\),

\[\begin{split}\begin{split} & \text{draw sample}~\xi_{t} \sim {\cal D} \\ & \text{update}~\vw^{t+1} = \vw^t - \eta_t \nabla \ell( \vw^t, \xi_{t} ). \end{split}\end{split}\]
  • Output: last iterate \(\vw^T\), or the solution sequence \(\{ \vw^t \}_{t=1}^T\).

Unlike the deterministic gradient method, the sequence \(\{ \vw^t \}_{t=1}^T\) generated by the SGD method is a random process itself (in fact, a Markov chain). As such, there are several subtleties in analyzing the convergence properties of the solution sequence. Here, the common forms of convergence are almost-sure convergence, convergence in expectation, etc. We will demonstrate the convergence in expectation of SGD under the simplest setting, i.e., with smooth (but possibly non-convex) objective function. Interested readers may refer to (ref.) for further discussions.

In addition to Assumption 2.1 on the expected objective function \(f(\vw)\), we need the following condition on the stochastic gradient:

Assumption 4.1 (Stochastic Oracle)

There exists \(\sigma \geq 0\) such that for any \(\vw \in \mathbb{R}^d\), we have

\[\begin{split}\begin{split} & \EE_{ \xi \sim {\cal D} } [ \nabla \ell( \vw, \xi ) ] = \nabla f( \vw ) \\ & \EE_{ \xi \sim {\cal D} } [ \| \nabla \ell( \vw, \xi ) - \nabla f( \vw ) \|^2 ] \leq \sigma^2 \end{split}\end{split}\]

In other words, the stochastic oracle used in the SGD method is both unbiased and has bounded variance.

We observe the following convergence in expecatation result:

Theorem 4.1 (Convergence of SGD Method (Smooth Case))

Under Assumption 2.1, Assumption 4.1 and suppose that the step size satisfies \(\eta_t \leq 1/L\). For any \(T \geq 1\), the following holds:

\[\min_{ t=0, ..., T-1 } \EE[ \| \nabla f( \vw^t ) \|^2 ] \leq \frac{ 2 \left( f (\vw^0) - f(\vw^T) + \frac{ \sigma^2 L }{2} \sum_{t=0}^{T-1} \eta_t^2 \right) }{ \sum_{t=0}^{T-1} \eta_t }\]

From the results in Theorem 4.1, we notice that if we set \(\eta_t = 1 / \sqrt{T}\) for all \(t=0,...,T-1\), then for sufficiently large \(T\), one has

\[\min_{ t=0, ..., T-1 } \EE[ \| \nabla f( \vw^t ) \|^2 ] = {\cal O} \left( \frac{1 + \log T}{ \sqrt{T} } \right),\]

showing the convergence (in expectation) for the SGD method.

4.1.1. A simple example of running SGD in PyTorch#

In the following code blocks, we use SGD to train the Cifar10 dataset on ResNet. Torch has a module torch.optim which can help you load the SGD optimizer directly. You can change torch.optim.SGD to torch.optim.Adagrad, torch.optim.RMSprop or torch.optim.AdamW to see the effect of different optimizers.

# Simple Pytorch Training script for Image Classification (Resnet18 model on CIFAR10 dataset) using SGD

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.models import resnet18

import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# typical augmentation for CIFAR
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
nclasses = 10 # CIFAR10 has 10 classes

batch_size = 128
epochs = 10

train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8) 
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=8) 

model = resnet18(num_classes=nclasses)
model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, epochs, T_mult=1)


def train(epoch, train_loader, model, optimizer, scheduler):
    nbatches = len(train_loader) 
    running_loss = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
    
        data, target = data.to(device), target.to(device)
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()

        output = model(data)
        loss = F.cross_entropy(output, target)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 3)
        optimizer.step()

        scheduler.step((epoch-1) + batch_idx/nbatches )   
        running_loss += loss.item() * data.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch: {epoch} \tLoss: {epoch_loss}')
    return epoch_loss
    
    
def test(epoch, test_loader, model):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
     
        data, target = data.to(device), target.to(device)
        data, target = Variable(data), Variable(target)
        output = model(data)
        test_loss += F.cross_entropy(output, target, reduction='sum').item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\n({}) - Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        epoch, test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, 100. * correct / len(test_loader.dataset)
     
train_losses = []
test_losses = []
test_accs = []
for epoch in range(1, epochs + 1):     
    
    train_loss = train(epoch, train_loader, model, optimizer, scheduler)
    
    if epoch % 1 == 0:
        test_loss, test_acc = test(epoch, test_loader, model)
    
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    test_accs.append(test_acc)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
100%|██████████| 170498071/170498071 [00:07<00:00, 22756360.38it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch: 1 	Loss: 1.9615433359909058

(1) - Test set: Average loss: 1.5024, Accuracy: 4596/10000 (45.96%)

Epoch: 2 	Loss: 1.40224799369812

(2) - Test set: Average loss: 1.2430, Accuracy: 5525/10000 (55.25%)

Epoch: 3 	Loss: 1.1913091995429992

(3) - Test set: Average loss: 1.1020, Accuracy: 6149/10000 (61.49%)

Epoch: 4 	Loss: 1.044463742198944

(4) - Test set: Average loss: 1.0641, Accuracy: 6293/10000 (62.93%)

Epoch: 5 	Loss: 0.936025900554657

(5) - Test set: Average loss: 0.9054, Accuracy: 6854/10000 (68.54%)

Epoch: 6 	Loss: 0.8328752424430848

(6) - Test set: Average loss: 0.7783, Accuracy: 7290/10000 (72.90%)

Epoch: 7 	Loss: 0.7448515315628051

(7) - Test set: Average loss: 0.6759, Accuracy: 7628/10000 (76.28%)

Epoch: 8 	Loss: 0.6620122916030884

(8) - Test set: Average loss: 0.6234, Accuracy: 7860/10000 (78.60%)

Epoch: 9 	Loss: 0.5950784549331665

(9) - Test set: Average loss: 0.5929, Accuracy: 7958/10000 (79.58%)

Epoch: 10 	Loss: 0.5526298828697205

(10) - Test set: Average loss: 0.5829, Accuracy: 8012/10000 (80.12%)
# Plot the results
epochs_range = range(1, epochs + 1)

plt.figure(figsize=(10,4))

# Plot both losses on the same plot
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_losses, label='Train Loss')
plt.plot(epochs_range, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training & Test Loss')
plt.grid(True)
plt.legend()

# Plot test accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs_range, test_accs, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Test Accuracy')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()
../../_images/021706a7b8437bcaacd2de7486aa0460ba667819e1c79a09d588e25d02ae69e9.png