Boosting Model Efficiency with Quantization and Pruning in PyTorch
Published:
The complete implementation is available on GitHub.
Introduction:
In the world of deep learning, deploying models on resource-constrained devices such as mobile phones, IoT devices, or edge computing platforms presents unique challenges. To ensure efficient performance, it’s crucial to reduce model size and enhance inference speed without significantly compromising accuracy. Two popular techniques to achieve this are quantization and pruning.
Quantization refers to reducing the precision of the numbers used to represent a model’s weights and activations, typically from 32-bit floating-point numbers to lower precision formats like 8-bit integers. This results in faster computation and lower memory usage. On the other hand, pruning reduces the number of parameters in a model by removing redundant or less important connections, leading to a more compact and faster model.
This project showcases a practical implementation of quantization and pruning on a ResNet model using PyTorch. By combining these techniques, we can demonstrate how to create lightweight, efficient models suited for deployment, all while maintaining high accuracy. Whether you are a machine learning practitioner aiming to optimize your models or just exploring model compression techniques, this project serves as an educational guide to help you get started with quantization and pruning in PyTorch.
Setup
- Clone the repository:
git clone https://github.com/ramintoosi/resnet-quantization-pruning.git cd resnet-quantization-pruning
- Install dependencies:
pip install torch torchvision torchaudio
or
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
Training a Simple Base Model
Before diving into quantization and pruning, the first step is to train a simple base model without applying any optimization techniques. In this project, we use a standard ResNet architecture to train on the CIFAR-10 dataset. This base model will serve as the foundation for comparing the effects of quantization and pruning later on.
The training script, train_simple_model.py
, initializes the ResNet model and trains it using typical components such as an Adam optimizer, a cross-entropy loss function, and a learning rate scheduler. By running this script, you will obtain a well-trained model that can later be optimized using various compression techniques.
The training process involves the following key steps:
- Model: A ResNet model with 10 output classes (suitable for CIFAR-10) is instantiated.
- Device: The model is trained on a GPU if available, or falls back to CPU.
- Optimizer: The Adam optimizer is used to adjust the model parameters.
- Learning Rate Scheduler: A
ReduceLROnPlateau
scheduler is applied to reduce the learning rate when the validation loss stops improving.
To train the model from scratch, simply run the following command:
python train_model_simple.py
This will begin the training process and save the best model to weights/simple_best_model.pth
. The dataset loading and training loop details are abstracted within the train()
function and the load_data()
function, which you can refer to in my previous post on weak supervision.
After training this base model, we’ll be ready to explore how quantization and pruning can be applied to improve model efficiency.
Model Validation
Once the model is trained, it’s important to evaluate its performance on a validation dataset to measure accuracy, loss, and inference speed. The validation process, implemented in the validate()
function, ensures that the model generalizes well to unseen data. Here is the validation script.
"""
This module validates and calculates the accuracy of a model on MNIST validation data.
"""
import time
import torch
from tqdm import tqdm
from data import load_data
def validate(model, device, n_total=2000):
"""
Validate the model on the validation data.
:param n_total: the number of images to validate.
:param device: cuda or cpu.
:param model: Model to validate.
:return: Tuple of accuracy, loss, and average inference time (ms).
"""
dataloaders = load_data(batch_size=1, num_workers=0)
model.eval()
model.to(device)
correct = 0
total = 0
running_loss = 0.0
start_time = time.time()
i_data = 0
with torch.no_grad():
for data in tqdm(dataloaders['val'], total=n_total, desc='Validating model', unit=' image'):
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
running_loss += torch.nn.CrossEntropyLoss()(outputs, labels).item()
i_data += 1
if i_data > n_total:
break
elapsed_time = time.time() - start_time
accuracy = correct / total
loss = running_loss / total
avg_inference_time = elapsed_time / total
return accuracy, loss, avg_inference_time * 1000
Post-Training Quantization in PyTorch
Post-training quantization (PTQ) is a popular technique for reducing the size and improving the efficiency of deep learning models. It allows us to convert a pre-trained floating-point model into a quantized version without requiring retraining. In PyTorch, this is achieved using several quantization strategies, such as dynamic quantization and static quantization. Below is a detailed explanation of the functions used in this module to perform quantization.
"""
This module implements post-training quantization of a PyTorch model.
"""
import copy
import torch
from torch.ao import quantization as quan
import torch.ao.quantization.quantize_fx as quantize_fx
from data import load_data
from tqdm import tqdm
# Dynamic Quantization: This method quantizes only the activations during inference, while weights
# are quantized beforehand. This means that the quantization overhead occurs during the forward pass, but
# since it happens on-the-fly, there's no additional pre-processing step needed.
# Therefore, inference time remains unaffected.
# This is used for situations where the model execution time is dominated by loading weights
# from memory rather than computing the matrix multiplications.
def quantize_dynamic(model_fp32: torch.nn.Module, dtype=torch.qint8):
"""
Quantize a PyTorch model using dynamic quantization.
:param model_fp32: model to quantize
:param dtype: target dtype for quantized weights
:return: quantized model
"""
# create a quantized model instance
model_quantized = torch.ao.quantization.quantize_dynamic(
model_fp32, # the original model
{torch.nn.Linear}, # a set of layers to dynamically quantize
dtype=dtype) # the target dtype for quantized weights
return model_quantized
# The ModelWrapper class is a custom PyTorch module designed to facilitate the quantization process.
# It wraps an existing model and adds QuantStub and DeQuantStub modules to handle the conversion of tensors
# between floating point and quantized formats. The forward method specifies where these conversions occur
# during the forward pass of the model. This setup is essential for static quantization, where the model needs
# to be prepared and calibrated before being converted to a quantized version.
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
# QuantStub converts tensors from floating point to quantized
self.quant = torch.ao.quantization.QuantStub()
self.model = model
# DeQuantStub converts tensors from quantized to floating point
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
# manually specify where tensors will be converted from floating
# point to quantized in the quantized model
x = self.quant(x)
x = self.model(x)
# manually specify where tensors will be converted from quantized
# to floating point in the quantized model
x = self.dequant(x)
return x
def quantize_static(model: torch.nn.Module):
"""
Quantize a PyTorch model using static quantization.
:param model: model to quantize
:return: quantized model
"""
# wrap the model with the ModelWrapper to include the quant and dequant stubs
model_fp32 = ModelWrapper(model)
# model must be set to eval mode for static quantization logic to work
model_fp32.eval()
# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
qconfig = quan.get_default_qconfig('x86')
model_fp32.qconfig = qconfig
# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32)
# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
# instead of an empty dataset.
dataloader = load_data()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_fp32_prepared.to(device)
with torch.no_grad():
for data, _ in tqdm(dataloader['train'], desc='Calibrating model', unit=' batch'):
model_fp32_prepared(data.to(device))
model_fp32_prepared.cpu()
# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
return model_int8
def quantize_static_fx(model_fp: torch.nn.Module):
"""
Quantize a PyTorch model using static quantization with FX graph mode.
:param model_fp: model to quantize
:return: quantized model
"""
model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = quan.get_default_qconfig_mapping("x86")
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, (torch.rand((1, 3, 224, 224)),))
# calibrate
dataloader = load_data()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_prepared.to(device)
with torch.no_grad():
for data, _ in tqdm(dataloader['train'], desc='Calibrating model FX', unit=' batch'):
model_prepared(data.to(device))
model_prepared.cpu()
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
return model_quantized
1. Dynamic Quantization
Dynamic quantization applies quantization only to weights and quantizes activations dynamically during inference. This is useful when the execution time of a model is dominated by memory-bound operations like loading weights. According to the PyTorch website:
Dynamic Quantization: This method quantizes only the activations during inference, while weights are quantized beforehand. This means that the quantization overhead occurs during the forward pass, but since it happens on-the-fly, there’s no additional pre-processing step needed. Therefore, inference time remains unaffected. This is used for situations where the model execution time is dominated by loading weights from memory rather than computing the matrix multiplications.
Function: quantize_dynamic
- Purpose: This function converts a pre-trained model into a dynamically quantized model by reducing the precision of weights (typically from FP32 to int8).
- How it works:
- It takes the floating-point model (
model_fp32
) and the target data type (dtype
, usuallytorch.qint8
). - It only quantizes specific layers, such as
torch.nn.Linear
, where matrix multiplications dominate computation.
model_quantized = torch.ao.quantization.quantize_dynamic( model_fp32, # original model {torch.nn.Linear}, # layers to quantize dtype=dtype # target dtype for quantized weights )
- It takes the floating-point model (
2. Static Quantization
Static quantization quantizes both weights and activations before inference. It requires calibration to estimate the dynamic range of activations by running representative data through the model.
Function: quantize_static
- Purpose: This function implements static quantization by converting the model into a fully quantized version with both weights and activations quantized.
- Steps:
Model Wrapper: The model is wrapped using
ModelWrapper
, which introducesQuantStub
andDeQuantStub
. These stubs manage the conversion between floating-point and quantized tensors during the forward pass.model_fp32 = ModelWrapper(model)
Quantization Configuration: The
qconfig
specifies how the model will be quantized. For example,'x86'
is used for server inference, whileqnnpack
is suited for mobile devices.qconfig = quan.get_default_qconfig('x86') model_fp32.qconfig = qconfig
Prepare for Quantization: The model is prepared for quantization by inserting observers that track the ranges of activations.
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32)
Calibration: The prepared model is calibrated by passing a subset of the training data through it to estimate the ranges of activations. This step is crucial for determining scaling factors for quantization.
for data, _ in tqdm(dataloader['train'], desc='Calibrating model', unit=' batch'): model_fp32_prepared(data.to(device))
Convert to Quantized Model: After calibration, the model is converted to a fully quantized version. This step replaces the floating-point operations with quantized versions.
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
3. FX Graph Mode Quantization
The FX Graph Mode provides a more flexible approach to quantization by allowing users to modify the model’s computational graph.
Function: quantize_static_fx
- Purpose: This function applies static quantization using PyTorch’s FX Graph Mode API. It allows fine-tuning of the quantization process by manipulating the model’s computational graph.
- Steps:
Model Preparation: The model is deep-copied and prepared for quantization.
qconfig_mapping
specifies how the layers are quantized, whileprepare_fx()
converts the model into a graph form suitable for quantization.model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, (torch.rand((1, 3, 224, 224)),))
Calibration: Like static quantization, the model is calibrated by running a sample of training data to estimate the quantization parameters.
for data, _ in tqdm(dataloader['train'], desc='Calibrating model FX', unit=' batch'): model_prepared(data.to(device))
Conversion: After calibration, the model is quantized by replacing floating-point operators with their quantized counterparts using
convert_fx()
.model_quantized = quantize_fx.convert_fx(model_prepared)
Summary
- Dynamic Quantization: Quantizes only weights, with activations quantized on-the-fly during inference. It’s suitable when memory operations dominate.
- Static Quantization: Quantizes both weights and activations but requires calibration with representative data. It’s ideal for a more compact, efficient model at inference time.
- FX Graph Mode Quantization: Offers a flexible way to manipulate the model’s computational graph and quantize it with better control.
Each quantization method has its own use case depending on the deployment environment and model performance goals.
Post-Training Quantization Script
This section describes the implementation of post-training quantization in PyTorch, where different quantization techniques are applied to an already trained ResNet model. The primary goal is to compare the accuracy, loss, and inference time of the original and quantized models.
The script quantizes the ResNet model using three different PTQ techniques: dynamic quantization, static quantization, and FX static quantization, and then validates the models’ performance.
"""
This module quantizes a PyTorch model using post-training quantization.
Let's save the quantized models and compare the results with the original model.
"""
from os.path import isfile
import torch
from model.resnet import get_model
from quantization.post_training import quantize_dynamic, quantize_static, quantize_static_fx
from validation import validate
model = get_model(num_classes=10)
checkpoint = "weights/original_model.pt"
if isfile(checkpoint):
model.load_state_dict(torch.load(checkpoint))
else:
model.load_state_dict(torch.load("weights/simple_best_model.pt")["model_state_dict"])
torch.save(model.state_dict(), "weights/original_model.pt")
model_quantized = quantize_dynamic(model, dtype=torch.qint8)
checkpoint_quantized = "weights/quantized_dynamic_model.pt"
if not isfile(checkpoint_quantized):
torch.save(model_quantized.state_dict(), checkpoint_quantized)
checkpoint_quantized_static = "weights/quantized_static_model.pt"
if isfile(checkpoint_quantized_static):
model_quantized_static = torch.jit.load(checkpoint_quantized_static)
else:
model_quantized_static = quantize_static(model)
traced = torch.jit.trace(model_quantized_static, torch.rand((1, 3, 224, 224)))
torch.jit.save(traced, checkpoint_quantized_static)
checkpoint_quantized_static_fx = "weights/quantized_static_fx_model.pt"
if isfile(checkpoint_quantized_static_fx):
model_quantized_static_fx = torch.jit.load(checkpoint_quantized_static_fx)
else:
model_quantized_static_fx = quantize_static_fx(model)
traced = torch.jit.trace(model_quantized_static_fx, torch.rand((1, 3, 224, 224)))
torch.jit.save(traced, checkpoint_quantized_static_fx)
# device
device = torch.device("cpu")
# validate models
accuracy, loss, inference_time = validate(model, device)
accuracy_quantized, loss_quantized, inference_time_quantized = validate(model_quantized, device)
accuracy_quantized_static, loss_quantized_static, inference_time_quantized_static = (
validate(model_quantized_static, device))
accuracy_quantized_static_fx, loss_quantized_static_fx, inference_time_quantized_static_fx = (
validate(model_quantized_static_fx, device))
# print the results
print(f"Original model accuracy: {accuracy:.4f}, loss: {loss:.4f}, inference time: {inference_time:.2f}ms")
print(f"Quantized dynamic model accuracy: {accuracy_quantized:.2f}, loss: {loss_quantized:.2f}, "
f"inference time: {inference_time_quantized:.2f}ms")
print(f"Quantized static model accuracy: {accuracy_quantized_static:.2f}, loss: {loss_quantized_static:.2f}, "
f"inference time: {inference_time_quantized_static:.2f}ms")
print(f"Quantized static model with FX accuracy: {accuracy_quantized_static_fx:.2f}, "
f"loss: {loss_quantized_static_fx:.2f}, inference time: {inference_time_quantized_static_fx:.2f}ms")
Step-by-Step Explanation of the Code
- Loading the Pre-Trained Model:
- First, we load a pre-trained ResNet model that has been trained on the CIFAR-10 dataset.
model = get_model(num_classes=10) checkpoint = "weights/original_model.pt" if isfile(checkpoint): model.load_state_dict(torch.load(checkpoint)) else: model.load_state_dict(torch.load("weights/simple_best_model.pt")["model_state_dict"]) torch.save(model.state_dict(), "weights/original_model.pt")
- First, we load a pre-trained ResNet model that has been trained on the CIFAR-10 dataset.
- Dynamic Quantization:
model_quantized = quantize_dynamic(model, dtype=torch.qint8) checkpoint_quantized = "weights/quantized_dynamic_model.pt" if not isfile(checkpoint_quantized): torch.save(model_quantized.state_dict(), checkpoint_quantized)
- Static Quantization:
checkpoint_quantized_static = "weights/quantized_static_model.pt" if isfile(checkpoint_quantized_static): model_quantized_static = torch.jit.load(checkpoint_quantized_static) else: model_quantized_static = quantize_static(model) traced = torch.jit.trace(model_quantized_static, torch.rand((1, 3, 224, 224))) torch.jit.save(traced, checkpoint_quantized_static)
- Static Quantization with FX (Graph Mode):
checkpoint_quantized_static_fx = "weights/quantized_static_fx_model.pt" if isfile(checkpoint_quantized_static_fx): model_quantized_static_fx = torch.jit.load(checkpoint_quantized_static_fx) else: model_quantized_static_fx = quantize_static_fx(model) traced = torch.jit.trace(model_quantized_static_fx, torch.rand((1, 3, 224, 224))) torch.jit.save(traced, checkpoint_quantized_static_fx)
- Model Validation:
- The original and quantized models are validated on the CIFAR-10 validation dataset using the
validate()
function. - This function calculates the accuracy, loss, and average inference time for each model.
accuracy, loss, inference_time = validate(model, device) ...
- The original and quantized models are validated on the CIFAR-10 validation dataset using the
- Results Comparison:
- Finally, the script prints out the results of the original and quantized models, including accuracy, loss, and inference time.
print(f"Original model accuracy: {accuracy:.4f}, loss: {loss:.4f}, inference time: {inference_time:.2f}ms") ...
- Finally, the script prints out the results of the original and quantized models, including accuracy, loss, and inference time.
Model Pruning Function
This section covers the model pruning function, which is used to reduce the number of parameters in a neural network by making the model weights sparse. In this particular implementation, L1 unstructured pruning is applied to both Conv2d
and Linear
layers of the model.
Pruning helps reduce the computational cost and memory footprint of a model, which can be useful for deploying models on resource-constrained environments. Here’s a breakdown of how the pruning function works.
import torch
import torch.nn.utils.prune as prune
def make_sparse(model_to_prune, rate=0.5):
"""
This function prunes the model by making the weights sparse.
:param model_to_prune: model to prune
:param rate: the percentage of weights to prune
"""
for name, module in model_to_prune.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=rate)
prune.remove(module, 'weight')
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=rate)
prune.remove(module, 'weight')
Code Explanation
- Iterating through the Model’s Layers:
- The function uses
named_modules()
to loop through all the layers of the model and check their types. - If the layer is a
Conv2d
orLinear
layer, pruning is applied to its weights. - L1 Unstructured Pruning: The function applies L1-norm pruning to remove a specified percentage of weights with the smallest absolute values.
- The function uses
for name, module in model_to_prune.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=rate)
- Pruning the Weights:
prune.remove()
is called afterward to finalize the pruning process and remove the pruning mask from the model. This makes the sparsity permanent, converting zeroed-out weights into actual zeros in the model.
prune.remove(module, 'weight')
- Applying the Same Pruning to Linear Layers:
- The same pruning process is applied to
Linear
layers to make them sparse as well.
- The same pruning process is applied to
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=rate)
prune.remove(module, 'weight')
Pruning and Quantization Script Explanation
This script combines pruning and post-training static quantization on a ResNet model to optimize its size and inference performance. Pruning reduces the number of parameters by making the weights sparse, and quantization further compresses the model by reducing the precision of the weights and activations.
"""
This module quantizes a PyTorch model using post-training quantization and pruning.
"""
import copy
from os.path import isfile
import torch
from validation import validate
from model.resnet import get_model
from quantization.post_training import quantize_static_fx
from prune import make_sparse
model = get_model(num_classes=10)
checkpoint = "weights/original_model.pt"
model.load_state_dict(torch.load(checkpoint))
model_orig = copy.deepcopy(model)
checkpoint_quantized_prune = "weights/quantized_prune_model.pt"
if isfile(checkpoint_quantized_prune):
model_quantized_prune = torch.jit.load(checkpoint_quantized_prune)
else:
make_sparse(model)
model_quantized_prune = quantize_static_fx(model)
traced = torch.jit.trace(model_quantized_prune, torch.rand((1, 3, 224, 224)))
torch.jit.save(traced, checkpoint_quantized_prune)
# validate models
device = torch.device("cpu")
accuracy, loss, inference_time = validate(model_orig, device, n_total=100)
accuracy_quantized, loss_quantized, inference_time_quantized = validate(model_quantized_prune, device, n_total=100)
# print the results
print(f"Original model accuracy: {accuracy:.4f}, loss: {loss:.4f}, inference time: {inference_time:.2f}ms")
print(f"Quantized static model accuracy: {accuracy_quantized:.2f}, loss: {loss_quantized:.2f}, "
f"inference time: {inference_time_quantized:.2f}ms")
Here’s a breakdown of each part of the script:
Loading the Original Model
model = get_model(num_classes=10)
checkpoint = "weights/original_model.pt"
model.load_state_dict(torch.load(checkpoint))
model_orig = copy.deepcopy(model)
- A ResNet model for 10 classes (e.g., CIFAR-10) is loaded using the
get_model()
function. - The model’s weights are loaded from a saved checkpoint (
original_model.pt
), ensuring it starts from a pre-trained state. model_orig
is a deep copy of the original model, used later for comparison purposes.
Pruning and Quantization
checkpoint_quantized_prune = "weights/quantized_prune_model.pt"
if isfile(checkpoint_quantized_prune):
model_quantized_prune = torch.jit.load(checkpoint_quantized_prune)
else:
make_sparse(model)
model_quantized_prune = quantize_static_fx(model)
traced = torch.jit.trace(model_quantized_prune, torch.rand((1, 3, 224, 224)))
torch.jit.save(traced, checkpoint_quantized_prune)
- The script checks if a quantized and pruned model (
quantized_prune_model.pt
) already exists. - If it exists, the pruned and quantized model is loaded using TorchScript (
torch.jit.load()
). - If not, the following steps occur:
- Pruning:
make_sparse(model)
prunes the weights of the model, making the weights sparse by removing a portion of them. - Quantization:
quantize_static_fx(model)
applies static quantization using FX graph mode, further optimizing the model. - Saving: The model is traced using TorchScript (
torch.jit.trace
) for optimization and portability, then saved for future use.
- Pruning:
Model Validation
# validate models
device = torch.device("cpu")
accuracy_quantized, loss_quantized, inference_time_quantized = validate(model_quantized_prune, device)
Quantization Aware Training (QAT) Script Explanation
This script demonstrates how to perform Quantization Aware Training (QAT) on a ResNet model using PyTorch. QAT simulates quantization during training, allowing the model to learn to adjust to the quantized weights and activations. This often results in better accuracy when converting the model to a quantized version compared to post-training quantization.
Here’s a breakdown of the QAT function and the corresponding QAT training script:
QAT Function
from torch import nn
from torch.ao.quantization import get_default_qat_qconfig_mapping
import torch.ao.quantization.quantize_fx as quantize_fx
import copy
def prepare_model_qat(model_fp: nn.Module, example_inputs):
"""
Prepare a model for quantization-aware training (QAT).
:param model_fp: The floating point model to prepare.
:param example_inputs: Example inputs for the model, used during preparation.
:return: A model ready for QAT.
"""
model_to_quantize = copy.deepcopy(model_fp) # Deep copy the model to avoid modifying the original one.
# Get the default QAT configuration for the target platform (x86 in this case).
qconfig_mapping = get_default_qat_qconfig_mapping("x86")
model_to_quantize.train() # Set the model to training mode.
# Prepare the model for QAT by adding necessary observers and quantization logic.
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
return model_prepared
Deep Copy of the Model: The function creates a copy of the original floating point model (
model_fp
) to avoid modifying the original model directly.QAT Config: It retrieves the default QAT configuration for the target platform (
x86
), specifying how the model’s layers will be quantized during training.Training Mode: The model is set to training mode since QAT requires forward and backward passes to simulate quantization during training.
QAT Preparation: The function uses
prepare_qat_fx()
from thetorch.ao.quantization
module to prepare the model for quantization-aware training. This adds observers and prepares the model to handle quantized operations during training.
QAT Training Script
"""
this module is used to train the model with QAT
"""
import os
import torch
import torch.ao.quantization.quantize_fx as quantize_fx
from data import load_data
from model.resnet import get_model
from train import train
from quantization.qat import prepare_model_qat
from validation import validate
def train_model_qat(resume=True):
"""
Train a simple model without quantization and pruning.
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
dataloaders = load_data(batch_size=128, num_workers=0)
model = get_model(num_classes=10)
model.load_state_dict(torch.load("weights/original_model.pt", weights_only=True))
model_prepared = prepare_model_qat(model, example_inputs = next(iter(dataloaders['train']))[0])
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=5)
train(model_prepared, dataloaders, optimizer, criterion, scheduler,
device, "qat", 1, resume=resume)
if __name__ == '__main__':
# train_model_qat(resume=False)
checkpoint = 'weights/qat_fx_model.pt'
if os.path.isfile(checkpoint):
model_quantized = torch.jit.load(checkpoint)
else:
dataloaders = load_data(batch_size=128, num_workers=0)
model = get_model(num_classes=10)
model_prepared = prepare_model_qat(model, example_inputs=next(iter(dataloaders['train']))[0])
model_prepared.load_state_dict(torch.load("weights/qat_best_model.pt", weights_only=True)["model_state_dict"])
model_prepared.eval()
model_quantized = quantize_fx.convert_fx(model_prepared)
traced = torch.jit.trace(model_quantized, torch.rand((1, 3, 224, 224)))
torch.jit.save(traced, checkpoint)
# validation
device = torch.device("cpu")
accuracy, loss, inference_time = validate(model_quantized, device)
print(f"Quantized model (QAT) accuracy: {accuracy:.2f}, loss: {loss:.2f}, inference time: {inference_time:.2f}ms")
Model Training with QAT
def train_model_qat(resume=True):
"""
Train a ResNet model with Quantization Aware Training (QAT).
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
dataloaders = load_data(batch_size=128, num_workers=0)
model = get_model(num_classes=10)
model.load_state_dict(torch.load("weights/original_model.pt", weights_only=True))
model_prepared = prepare_model_qat(model, example_inputs=next(iter(dataloaders['train']))[0])
criterion = torch.nn.CrossEntropyLoss() # Loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam optimizer
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=5) # Learning rate scheduler
train(model_prepared, dataloaders, optimizer, criterion, scheduler,
device, "qat", 1, resume=resume)
- The model is loaded with the original pre-trained weights.
- The model is prepared for QAT using the
prepare_model_qat()
function with example inputs from the training data. - The script sets up the loss function, optimizer, and learning rate scheduler to adjust the learning rate when the model’s performance plateaus.
Loading and Quantizing the Model
if __name__ == '__main__':
checkpoint = 'weights/qat_fx_model.pt'
if os.path.isfile(checkpoint):
model_quantized = torch.jit.load(checkpoint) # Load the pre-quantized model
else:
dataloaders = load_data(batch_size=128, num_workers=0)
model = get_model(num_classes=10)
model_prepared = prepare_model_qat(model, example_inputs=next(iter(dataloaders['train']))[0])
model_prepared.load_state_dict(torch.load("weights/qat_best_model.pt", weights_only=True)["model_state_dict"])
model_prepared.eval() # Set the model to evaluation mode before quantizing.
model_quantized = quantize_fx.convert_fx(model_prepared) # Convert the model to quantized format.
traced = torch.jit.trace(model_quantized, torch.rand((1, 3, 224, 224))) # Trace the model with TorchScript.
torch.jit.save(traced, checkpoint) # Save the quantized model.
- If a pre-quantized model checkpoint exists, it is loaded directly.
- Otherwise, the script prepares the model for QAT and loads the trained QAT model’s weights.
- It then converts the QAT-trained model to a fully quantized format using
quantize_fx.convert_fx()
. - The quantized model is traced with TorchScript and saved to be loaded later.
Model Validation
# validation
device = torch.device("cpu")
accuracy, loss, inference_time = validate(model_quantized, device)
print(f"Quantized model (QAT) accuracy: {accuracy:.2f}, loss: {loss:.2f}, inference time: {inference_time:.2f}ms")
Pruning and QAT
"""
this module is used to train the model with QAT
"""
import os
import torch
import torch.ao.quantization.quantize_fx as quantize_fx
from data import load_data
from model.resnet import get_model
from train import train
from quantization.qat import prepare_model_qat
from validation import validate
def train_model_qat(resume=True):
"""
Train a simple model without quantization and pruning.
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
dataloaders = load_data(batch_size=128, num_workers=0)
model = get_model(num_classes=10)
model.load_state_dict(torch.load("weights/original_model.pt", weights_only=True))
model_prepared = prepare_model_qat(model, example_inputs = next(iter(dataloaders['train']))[0])
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=5)
train(model_prepared, dataloaders, optimizer, criterion, scheduler,
device, "qat", 1, resume=resume)
if __name__ == '__main__':
# train_model_qat(resume=False)
checkpoint = 'weights/qat_fx_model.pt'
if os.path.isfile(checkpoint):
model_quantized = torch.jit.load(checkpoint)
else:
dataloaders = load_data(batch_size=128, num_workers=0)
model = get_model(num_classes=10)
model_prepared = prepare_model_qat(model, example_inputs=next(iter(dataloaders['train']))[0])
model_prepared.load_state_dict(torch.load("weights/qat_best_model.pt", weights_only=True)["model_state_dict"])
model_prepared.eval()
model_quantized = quantize_fx.convert_fx(model_prepared)
traced = torch.jit.trace(model_quantized, torch.rand((1, 3, 224, 224)))
torch.jit.save(traced, checkpoint)
# validation
device = torch.device("cpu")
accuracy, loss, inference_time = validate(model_quantized, device)
print(f"Quantized model (QAT) accuracy: {accuracy:.2f}, loss: {loss:.2f}, inference time: {inference_time:.2f}ms")
This script performs both pruning and Quantization Aware Training (QAT) on a ResNet model using PyTorch. By combining pruning, which reduces the number of weights, and QAT, which trains the model to adjust for quantization, we achieve both model compression and faster inference while maintaining high accuracy.
Here’s a breakdown of the Pruning and QAT function and the corresponding training script:
Code Breakdown
Model Training with Pruning and QAT
def train_model_qat_prune(resume=True):
"""
Train a ResNet model with Quantization Aware Training (QAT) and pruning.
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
dataloaders = load_data(batch_size=128, num_workers=0)
model = get_model(num_classes=10)
model.load_state_dict(torch.load("weights/original_model.pt", weights_only=True))
# Apply pruning to make the model sparse
make_sparse(model)
# Prepare the pruned model for Quantization Aware Training (QAT)
model_prepared = prepare_model_qat(model, example_inputs=next(iter(dataloaders['train']))[0])
criterion = torch.nn.CrossEntropyLoss() # Define loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) # Adam optimizer with a smaller learning rate
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=5) # Learning rate scheduler
# Train the pruned and QAT-prepared model
train(model_prepared, dataloaders, optimizer, criterion, scheduler,
device, "qat_prune", 20, resume=resume)
- Pruning:
- The
make_sparse(model)
function prunes the ResNet model by setting a percentage of the model’s weights to zero, based on the pruning rate. This function is applied to the model before QAT.
- The
- QAT Preparation:
- The pruned model is passed through the
prepare_model_qat()
function, which prepares it for Quantization Aware Training by inserting quantization-aware operations into the model. This prepares the model to handle quantized weights during training.
- The pruned model is passed through the
- Training Setup:
- The script defines the training setup using
CrossEntropyLoss
, the Adam optimizer, and a learning rate scheduler. The learning rate is reduced when the model’s performance plateaus.
- The script defines the training setup using
- Training:
- The training loop is executed for 20 epochs (
"qat_prune", 20
), adjusting the weights of the pruned and quantization-aware model.
- The training loop is executed for 20 epochs (
Model Loading and Quantization
if __name__ == '__main__':
train_model_qat_prune(resume=False)
checkpoint = 'weights/qat_prune_model.pt'
if os.path.isfile(checkpoint):
model_quantized = torch.jit.load(checkpoint) # Load the pre-quantized model if it exists
else:
dataloaders = load_data(batch_size=128, num_workers=0)
model = get_model(num_classes=10)
# Apply pruning
make_sparse(model)
# Prepare the pruned model for QAT
model_prepared = prepare_model_qat(model, example_inputs=next(iter(dataloaders['train']))[0])
model_prepared.load_state_dict(torch.load("weights/qat_prune_best_model.pt", weights_only=True)["model_state_dict"])
# Convert the QAT-prepared model to a quantized format
model_prepared.eval()
model_quantized = quantize_fx.convert_fx(model_prepared)
# Trace and save the quantized model
traced = torch.jit.trace(model_quantized, torch.rand((1, 3, 224, 224)))
torch.jit.save(traced, checkpoint)
Loading Pre-Trained Model:
- Pruning and QAT:
- If no checkpoint exists, the model is pruned, and then it is prepared for Quantization Aware Training with
prepare_model_qat()
. - After loading the pre-trained weights (
weights/qat_prune_best_model.pt
), the model is converted to its fully quantized form usingquantize_fx.convert_fx()
.
- If no checkpoint exists, the model is pruned, and then it is prepared for Quantization Aware Training with
- Model Tracing and Saving:
Model Validation
# Validation step
device = torch.device("cpu")
accuracy, loss, inference_time = validate(model_quantized, device)
print(f"Quantized model (QAT) accuracy: {accuracy:.2f}, loss: {loss:.2f}, inference time: {inference_time:.2f}ms")
Results Summary
These results show the comparison of various model optimization techniques applied to a ResNet model trained on CIFAR-10. The experiments were conducted using an NVIDIA RTX 2080 Ti GPU, and the results focus on three metrics:
Model Type | Accuracy | Loss | Inference Time |
---|---|---|---|
Original model | 0.95 | 0.27 | 54.28ms |
PTQ dynamic model | 0.96 | 0.27 | 53.95ms |
PTQ static model | 0.95 | 0.28 | 22.96ms |
PTQ static model with FX | 0.95 | 0.28 | 21.37ms |
Pruned 50% + PTQ static | 0.93 | 0.19 | 20.02ms |
QAT | 0.95 | 0.27 | 19.87ms |
Pruned 50% + QAT | 0.95 | 0.25 | 20.61ms |
Key Observations
- Accuracy:
- The original model achieves 95% accuracy, and both PTQ (Post-Training Quantization) and QAT (Quantization Aware Training) retain similar levels of accuracy (95%-96%), demonstrating minimal accuracy loss.
- Pruned 50% + PTQ static slightly reduces accuracy to 93%, but applying QAT along with pruning helps recover accuracy to 95%.
- Inference Time:
- PTQ static models achieve a significant reduction in inference time, cutting it by over 50% (from 54.28ms to ~21ms).
- Pruning and QAT further reduce the inference time, with QAT alone achieving the fastest inference at 19.87ms.
- Pruned 50% + PTQ/QAT models maintain a low inference time of around 20ms, which is a substantial improvement over the original model.
- Loss:
- The original model, PTQ, and QAT maintain similar loss values (~0.27-0.28).
- Pruned models show a lower loss (0.19-0.25), suggesting that pruning can help reduce overfitting and potentially improve generalization.
Conclusion
This experiment demonstrates the effectiveness of post-training quantization (PTQ), quantization-aware training (QAT), and pruning techniques in optimizing the performance of a deep learning model. By applying these techniques to a ResNet model trained on CIFAR-10, we observed substantial improvements in inference time and model efficiency while maintaining accuracy levels close to the original, unoptimized model.
Both static and dynamic quantization techniques, particularly static quantization with FX, significantly reduced inference time without sacrificing accuracy. Similarly, pruning, especially when combined with PTQ or QAT, further enhanced model efficiency while maintaining strong performance metrics.
QAT emerged as a standout technique for producing highly efficient models, yielding minimal accuracy loss and achieving the fastest inference time of 19.87ms. The pruned models also performed well, showing that combining QAT or PTQ with pruning is an effective strategy for improving both model speed and size.
Takeaways
Quantization is highly effective in optimizing model performance:
QAT achieves the best balance between speed and accuracy:
- Pruning enhances the benefits of quantization:
- Pruning alone can reduce overfitting, as seen in the lower loss values.
Static quantization offers superior speedup compared to dynamic quantization:
- Pruned models show a slight drop in accuracy but perform well overall:
Final Thought:
For real-world applications requiring model optimization, QAT combined with pruning offers the best balance of speed, accuracy, and efficiency. When ease of implementation is a concern, PTQ static is a strong alternative that still provides significant benefits in terms of inference time.
GitHub Repository
The complete implementation is available on GitHub.
Reference
[2] PyTorch Pruning
[3] ChatGPT