Let’s train a CNN in Rust with libtorch

18 minute read

Published:

Introduction:

As a newcomer to Rust, I recently took on the challenge of training a CNN to classify images of the sea versus the jungle using the tch-rust crate, which employs Libtorch. Rust’s promise of memory safety and performance intrigued me, and I was eager to explore how it could handle deep learning tasks. In this post, I’ll walk you through my journey of setting up and training a simple CNN model. Whether you’re a seasoned Rustacean or just starting, I hope this post will inspire you to experiment with deep learning in Rust.

Note: I’m not an experienced Rust developer.

Dataset

For the dataset, I used a collection I prepared in a previous blog post. You can use any dataset that is structured as follows:

data/
  ├── train/
  │   ├── class1/
  │   │   ├── img1.jpg
  │   │   ├── img2.jpg
  │   ├── class2/
  │   │   ├── img1.jpg
  │   │   ├── img2.jpg
  ├── val/
      ├── class1/
      │   ├── img1.jpg
      │   ├── img2.jpg
      ├── class2/
          ├── img1.jpg
          ├── img2.jpg

This structure makes it easy to load images for training and validation, with separate folders for each class.

libtorch

We use the tch-rust crate, which provides “Rust bindings for the C++ API of PyTorch.” This crate requires Libtorch. You can download pre-built Libtorch files from PyTorch or build them manually. Then you need to set two env variables.

export LIBTORCH=[path_to_libroch]
export LD_LIBRARY_PATH=[path_to_libtorch]:$LD_LIBRARY_PATH

path_to_libtorch is the Libtorch folder. If you downloaded the pre-built files, use that path. If you built it manually, it might look something like /path/to/pytorch/build/lib.linux-x86_64-cpython-310/torch/.

Build libtorch manually

Install typing-extensions and pyyaml. Then,

git clone -b v2.3.0 --recurse-submodule https://github.com/pytorch/pytorch.git --depth 1
cd pytorch
USE_CUDA=ON BUILD_SHARED_LIBS=ON python setup.py build

Choose pytorch version based on the tch-rust crate.

Code structure

Here’s a brief overview of the project file structure:

├── src
│   ├── train
│   │   └── utils.rs     # Utility functions for training
│   ├── data.rs          # Data loading and preprocessing
│   ├── inference.rs     # Functions for model inference
│   ├── main.rs          # Main program entry point
│   ├── model.rs         # Definition of the CNN model
│   └── train.rs         # Training loop and logic
  • main.rs: The main entry point of the application, where execution begins.
  • data.rs: Handles data loading and preprocessing, organizing datasets for training and validation.
  • model.rs: Defines the architecture of the CNN model.
  • train.rs: Implements the training loop and logic, coordinating the model training process.
  • train/utils.rs: Contains helper functions used during training.
  • inference.rs: Includes functions for running inference on new images.

Data loader

The data loader module is designed to handle the image dataset, similar to ImageFolderDataset in PyTorch. It assumes that images are divided into training and validation sets, with the images of each class in its own folder.

use kdam::tqdm;
use rand::seq::SliceRandom;
use rand::thread_rng;
use std::collections::HashMap;
use std::{fs::read_dir, path::Path};
use tch::{vision, Tensor};

pub struct Dataset {
    root: String,
    image_path: Vec<(i64, String)>,
    class_to_idx: HashMap<String, i64>,
    total_size: usize,
}

impl Dataset {
    /// This function walks through the root folder and gathers images and creates a Dataset
    pub fn new<T: AsRef<Path>>(root: T) -> Dataset {
        let root = root.as_ref();

        let mut image_path: Vec<(i64, String)> = Vec::new();
        let mut class_to_idx: HashMap<String, i64> = HashMap::new();

        Self::get_images_and_classes(
            &root,
            &mut image_path,
            &mut class_to_idx,
        );

        Dataset {
            root: root.to_str().unwrap().to_string(),
            total_size: image_path.len(),
            image_path,
            class_to_idx,
        }
    }

    /// In the input folder finds the classes and images
    fn get_images_and_classes(
        dir: &Path,
        image_path: &mut Vec<(i64, String)>,
        class_to_idx: &mut HashMap<String, i64>,
    ) {
        for (class_id, root_class) in read_dir(&dir).unwrap().enumerate() {
            let root_class = root_class.unwrap().path().clone();
            if root_class.is_dir() {
                Self::get_images_in_folder(&root_class, image_path, class_id as i64);
                let class_name_str = root_class
                    .file_name()
                    .unwrap()
                    .to_str()
                    .unwrap()
                    .to_string();
                class_to_idx.insert(class_name_str.clone(), class_id as i64);
            }
        }
    }

    /// find images with specific extensions "jpg", "png", "jpeg"
    fn get_images_in_folder(
        dir: &Path,
        image_path: &mut Vec<(i64, String)>,
        class_idx: i64,
    ) {
        let valid_ext = vec!["jpg", "png", "jpeg"];
        for file_path in tqdm!(read_dir(&dir).unwrap()) {
            let file_path = &file_path.unwrap().path().clone();
            if file_path.is_file()
                & valid_ext.contains(
                &file_path
                    .extension()
                    .unwrap()
                    .to_str()
                    .unwrap()
                    .to_lowercase()
                    .as_str(),
            )
            {
                image_path.push((class_idx, file_path.to_str().unwrap().to_string()));
            }
        }
    }

    /// A simple print function for our Dataset
    pub fn print(&self) {
        println!("DATASET ({})", self.root);
        println!("Classes: {:?}", self.class_to_idx);
        println!("Size: {}", self.total_size);
        println!("sample of data\n{:?}", &self.image_path[1..3]);
    }

    /// load the image into a tensor and return (image, label)
    fn get_item(&self, idx: usize) -> (Tensor, i64) {
        let image =vision::imagenet::load_image_and_resize224(&self.image_path[idx].1).unwrap();
        (image, self.image_path[idx].0.clone())
    }
}

/// A struct for our data loader
pub struct DataLoader {
    dataset: Dataset,
    batch_size: i64,
    batch_index: i64,
    shuffle: bool,
}

impl DataLoader {
    pub fn new(dataset: Dataset, batch_size: i64, shuffle: bool) -> DataLoader {
        // let mut rng = thread_rng();
        // dataset.ImagePath.shuffle(rng);
        DataLoader {
            dataset,
            batch_size,
            batch_index: 0,
            shuffle,
        }
    }

    fn shuffle_dataset(&mut self) {
        let mut rng = thread_rng();
        self.dataset.image_path.shuffle(&mut rng)
    }

    /// total number of images in the dataset
    pub fn len(&self) -> usize {
        self.dataset.total_size
    }

    /// number of batches based on the dataset size and batch size
    pub fn len_batch(&self) -> usize {
        (self.dataset.total_size / self.batch_size as usize) + 1
    }
}

/// implement iterator for our Dataloader to get batches of images and labels
impl Iterator for DataLoader {
    type Item = (Tensor, Tensor);

    fn next(&mut self) -> Option<Self::Item> {
        let start = (self.batch_index * self.batch_size) as usize;
        let mut end = ((self.batch_index + 1) * self.batch_size) as usize;
        if start >= self.dataset.total_size {
            self.batch_index = 0;
            return None;
        }
        if end > self.dataset.total_size {
            end = self.dataset.total_size;
        }
        if (self.batch_index == 0) & self.shuffle {
            self.shuffle_dataset();
        }
        let mut images: Vec<Tensor> = vec![]; // for preload change this to Vec<&Tensor>
        let mut labels: Vec<Tensor> = vec![];
        for i in start..end {
            let (image_t, label) = self.dataset.get_item(i);
            images.push(image_t);
            labels.push(Tensor::from(label))
        }
        self.batch_index += 1;
        Some((
            Tensor::f_stack(&images, 0).unwrap(),
            Tensor::f_stack(&labels, 0).unwrap(),
        ))
    }
}

Dataset: Manages the dataset structure, storing image paths and class indices. It reads the directory to gather images and organizes them by class.

pub struct Dataset {
    root: String,
    image_path: Vec<(i64, String)>,
    class_to_idx: HashMap<String, i64>,
    total_size: usize,
}

Key Functions

  • new: Initializes the dataset by reading image paths and classes from the root directory.
  • get_images_and_classes: Recursively finds class directories and images within them.
  • get_images_in_folder: Filters images based on valid extensions (jpg, png, jpeg).
  • print: Outputs dataset information, including classes and sample data.
  • get_item: Loads an image as a tensor and returns it along with its label.

DataLoader: Handles batching of the dataset for training. It can shuffle data and iterates through batches of images and labels.

pub struct DataLoader {
    dataset: Dataset,
    batch_size: i64,
    batch_index: i64,
    shuffle: bool,
}

Key Functions

  • new: Initializes the data loader with the specified batch size and shuffle option.
  • shuffle_dataset: Randomly shuffles the dataset.
  • len: Returns the total number of images.
  • len_batch: Calculates the total number of batches.

Iterator Implementation

The DataLoader implements the Iterator trait to provide batches of images and labels for training:

  • next: Retrieves the next batch, stacking images and labels into tensors. If all batches are processed, it resets and optionally shuffles the dataset for the next epoch.

This module organizes and processes image data, facilitating smooth training with Rust and Libtorch.

Model Definition

Here’s the definition of our CNN model using the tch crate. The model consists of several convolutional layers followed by fully connected layers.

use tch::{nn, nn::Module};

pub fn net(vs: &nn::Path, n_class: i64) -> impl Module {
    nn::seq()
        .add(nn::conv2d(vs, 3, 16, 16, Default::default()))
        .add_fn(|xs| xs.max_pool2d_default(4))
        .add_fn(|xs| xs.relu())
        .add(nn::conv2d(vs, 16, 64, 4, Default::default()))
        .add_fn(|xs| xs.max_pool2d_default(2))
        .add_fn(|xs| xs.relu())
        .add(nn::conv2d(vs, 64, 128, 4, Default::default()))
        .add_fn(|xs| xs.relu())
        .add_fn(|xs| xs.flat_view())
        .add(nn::linear(vs, 56448, 1024, Default::default()))
        .add_fn(|xs| xs.relu())
        .add(nn::linear(vs, 1024, n_class, Default::default()))
}

Model Architecture

  • Convolutional Layers: Three convolutional layers with increasing channel sizes (16, 64, 128) to capture spatial features.
    • Each is followed by a ReLU activation and pooling layer to reduce dimensionality and introduce non-linearity.
  • Fully Connected Layers: Two fully connected layers that transform the features into the desired number of classes.
    • The first linear layer reduces the feature map size from 56448 to 1024.
    • The final layer outputs predictions for each class.

Learning Rate Scheduler

The scheduler dynamically adjusts the learning rate during training based on model performance. This helps optimize convergence.

use num_traits::Float;
use tch::nn::Optimizer;

pub struct Scheduler<'a> {
    pub opt: &'a mut Optimizer,
    patience: i64,
    factor: f64,
    lr: f64,
    step: i64,
    last_val: f64,
}

impl Scheduler<'_> {
    pub fn new(opt: &mut Optimizer, mut patience: i64, lr: f64, mut factor: f64) -> Scheduler {
        if patience < 0 { patience = 5; }
        if factor < 0.0 { factor = 0.95; }

        Scheduler {
            opt,
            patience,
            factor,
            lr,
            step: 0,
            last_val: f64::infinity(),
        }
    }

    /// Adjusts learning rate based on validation performance.
    pub fn step(&mut self, value: f64) {
        if value < self.last_val {
            self.last_val = value;
            self.step = 0;
        } else {
            self.step += 1;
            if self.step == self.patience {
                self.step = 0;
                self.lr *= self.factor;
                self.opt.set_lr(self.lr);
            }
        }
    }

    pub fn get_lr(&self) -> f64 {
        self.lr
    }
}

Scheduler Functionality

  • Initialization: Takes an optimizer, patience, initial learning rate, and a decay factor. Default values are set if negative inputs are given.
  • step Method: Monitors validation loss to decide when to reduce the learning rate. If the validation loss improves, it resets the step counter; otherwise, it increments the counter. After reaching the patience threshold, the learning rate is decreased by the factor.
  • get_lr Method: Returns the current learning rate.

This scheduler helps fine-tune the training process by adapting the learning rate, ensuring efficient convergence.

Model Training Step-by-Step

  1. Set Up Directories and Device:
    • Check if the save_dir exists; create it if not.
    • Determine if CUDA is available and select the appropriate device for training.
if !Path::new(save_dir).is_dir() {
        create_dir(save_dir).unwrap();
    }

    let device = Device::cuda_if_available();
    println!("The device is {:?}", device);
  1. Initialize Model and Optimizer:
    • Create the model using model::net with two output classes.
    • Set up an Adam optimizer with a learning rate of 1e-3 and a weight decay of 1e-4.
let vs = nn::VarStore::new(device);
    let net = model::net(&vs.root(), 2);
    let lr = 1e-3;

    let mut opt = nn::Adam::default().build(&vs, lr).unwrap();
    opt.set_weight_decay(1e-4);
  1. Learning Rate Scheduler:
    • Initialize a scheduler to adjust the learning rate dynamically based on validation performance, with patience set to 5 epochs and a decay factor of 0.5.
let mut scheduler = utils::Scheduler::new(
        &mut opt, 5, lr, 0.5
    );
  1. Progress Bars:
    • Set up progress bars for tracking training and validation progress, as well as the overall epoch.
let total_batch_train = dataloader_train.len_batch();
    let total_batch_val = dataloader_val.len_batch();
    
    let mut pbar = tqdm!(
        total = total_batch_train,
        position = 1,
        desc = format!("{:<8}", "Train"),
        force_refresh = true,
        ncols = 100
    );
    
    let mut pbar2 = tqdm!(
        total = total_batch_val,
        position = 2,
        desc = format!("{:<8}", "Val"),
        force_refresh = true,
        ncols = 100
    );

    let n_epochs = 30;

    let mut pbar_e = tqdm!(
        total = n_epochs,
        position = 0,
        desc = format!("{:<8}", "Epoch"),
        ncols = 100,
        force_refresh=true
    );
  1. Training Loop:
    • Run for a fixed number of epochs (30).
    • For each epoch:
    • Training Phase:
      • Iterate over the training data loader.
      • For each batch:
        • Zero the gradients.
        • Forward pass: compute predictions.
        • Calculate accuracy and cross-entropy loss.
        • Backward pass: compute gradients and update weights.
        • Update the progress bar with the current loss and accuracy.
    • Validation Phase:
      • Iterate over the validation data loader.
      • For each batch:
        • Forward pass: compute predictions.
        • Calculate accuracy and cross-entropy loss.
        • Update the progress bar with the current loss and accuracy.
    • Adjust Learning Rate:
      • After each epoch, the scheduler checks the validation loss to decide if the learning rate should be decreased.
    • Save Best Model:
      • If the current epoch’s validation loss is the lowest seen so far, save the model.
let mut best_acc = 0.0;
let mut best_loss: f64 = f64::infinity();

println!("\n\n Start Training \n\n");
for e in 1..n_epochs {
    pbar_e.set_postfix(
        format!("lr = {:<.7}", scheduler.get_lr())
    );
    let _ = pbar_e.update_to(e);
    let mut epoch_acc_train = 0.0;
    let mut epoch_loss_train = 0.0;
    let mut running_samples = 0;
    for (i, (images, labels)) in (&mut dataloader_train).enumerate() {
        scheduler.opt.zero_grad();
        let out = net
            .forward(&images.to_device(device))
            .to_device(Device::Cpu);
        let acc = out.accuracy_for_logits(&labels);
        let loss = out.cross_entropy_for_logits(&labels);
        epoch_acc_train += f64::try_from(acc).unwrap() * (out.size()[0] as f64);
        epoch_loss_train += f64::try_from(&loss).unwrap() * (out.size()[0] as f64);
        scheduler.opt.backward_step(&loss);
        running_samples += out.size()[0];
        pbar.set_postfix(format!(
            "loss={:<7.4} - accuracy={:<7.4}",
            epoch_loss_train / (running_samples as f64),
            epoch_acc_train / (running_samples as f64) * 100.0
        ));
        let _ = pbar.update_to(i + 1);
    }

    let mut epoch_acc_val = 0.0;
    let mut epoch_loss_val = 0.0;

    running_samples = 0;
    for (i, (images, labels)) in (&mut dataloader_val).enumerate() {
        let out = net
            .forward(&images.to_device(device))
            .to_device(Device::Cpu);
        let loss = out.cross_entropy_for_logits(&labels);
        let acc = out.accuracy_for_logits(&labels);
        epoch_acc_val += f64::try_from(acc).unwrap() * (out.size()[0] as f64);
        epoch_loss_val += f64::try_from(&loss).unwrap() * (out.size()[0] as f64);

        running_samples += out.size()[0];
        pbar2.set_postfix(format!(
            "loss={:<7.4} - accuracy={:<7.4}",
            epoch_loss_val / (running_samples as f64),
            epoch_acc_val / (running_samples as f64) * 100.0
        ));
        let _ = pbar2.update_to(i + 1);
    }
    epoch_acc_val /= dataloader_val.len() as f64;
    epoch_loss_val /= dataloader_val.len() as f64;

    scheduler.step(epoch_loss_val);

    if epoch_loss_val < best_loss {
        best_loss = epoch_loss_val;
        best_acc = epoch_acc_val;
        vs.save(Path::new(save_dir).join("best_model.ot")).unwrap()
    }
}

Key Points

  • Device Management: Utilizes GPU if available for faster computations.
  • Model and Optimizer: Defined using the tch crate, with appropriate configurations.
  • Dynamic Learning Rate: Adjusts based on validation performance to optimize training.
  • Progress Tracking: Uses kdam crate to visually track training and validation progress.
  • Performance Metrics: Tracks accuracy and loss during both training and validation phases.

Complete training code

mod utils;

use crate::data::DataLoader;
use crate::model;
use kdam::{tqdm, BarExt};
use num_traits::float::Float;
use tch::{
    nn,
    nn::{Module, OptimizerConfig},
    Device,
};
use std::path::Path;
use std::fs::create_dir;

/// This function trains the model with train and val data loaders
pub fn train_model(
    mut dataloader_train: DataLoader,
    mut dataloader_val: DataLoader,
    save_dir: &str,
) {

    if !Path::new(save_dir).is_dir() {
        create_dir(save_dir).unwrap();
    }

    let device = Device::cuda_if_available();
    println!("The device is {:?}", device);
    let vs = nn::VarStore::new(device);
    let net = model::net(&vs.root(), 2);
    let lr = 1e-3;

    let mut opt = nn::Adam::default().build(&vs, lr).unwrap();
    opt.set_weight_decay(1e-4);

    let mut scheduler = utils::Scheduler::new(
        &mut opt, 5, lr, 0.5
    );

    let total_batch_train = dataloader_train.len_batch();
    let total_batch_val = dataloader_val.len_batch();

    let mut pbar = tqdm!(
        total = total_batch_train,
        position = 1,
        desc = format!("{:<8}", "Train"),
        force_refresh = true,
        ncols = 100
    );

    let mut pbar2 = tqdm!(
        total = total_batch_val,
        position = 2,
        desc = format!("{:<8}", "Val"),
        force_refresh = true,
        ncols = 100
    );

    let n_epochs = 30;

    let mut pbar_e = tqdm!(
        total = n_epochs,
        position = 0,
        desc = format!("{:<8}", "Epoch"),
        ncols = 100,
        force_refresh=true
    );

    let mut best_acc = 0.0;
    let mut best_loss: f64 = f64::infinity();

    println!("\n\n Start Training \n\n");
    for e in 1..n_epochs {
        pbar_e.set_postfix(
            format!("lr = {:<.7}", scheduler.get_lr())
        );
        let _ = pbar_e.update_to(e);
        let mut epoch_acc_train = 0.0;
        let mut epoch_loss_train = 0.0;
        let mut running_samples = 0;
        for (i, (images, labels)) in (&mut dataloader_train).enumerate() {
            scheduler.opt.zero_grad();
            let out = net
                .forward(&images.to_device(device))
                .to_device(Device::Cpu);
            let acc = out.accuracy_for_logits(&labels);
            let loss = out.cross_entropy_for_logits(&labels);
            epoch_acc_train += f64::try_from(acc).unwrap() * (out.size()[0] as f64);
            epoch_loss_train += f64::try_from(&loss).unwrap() * (out.size()[0] as f64);
            scheduler.opt.backward_step(&loss);
            running_samples += out.size()[0];
            pbar.set_postfix(format!(
                "loss={:<7.4} - accuracy={:<7.4}",
                epoch_loss_train / (running_samples as f64),
                epoch_acc_train / (running_samples as f64) * 100.0
            ));
            let _ = pbar.update_to(i + 1);
        }

        let mut epoch_acc_val = 0.0;
        let mut epoch_loss_val = 0.0;

        running_samples = 0;
        for (i, (images, labels)) in (&mut dataloader_val).enumerate() {
            let out = net
                .forward(&images.to_device(device))
                .to_device(Device::Cpu);
            let loss = out.cross_entropy_for_logits(&labels);
            let acc = out.accuracy_for_logits(&labels);
            epoch_acc_val += f64::try_from(acc).unwrap() * (out.size()[0] as f64);
            epoch_loss_val += f64::try_from(&loss).unwrap() * (out.size()[0] as f64);

            running_samples += out.size()[0];
            pbar2.set_postfix(format!(
                "loss={:<7.4} - accuracy={:<7.4}",
                epoch_loss_val / (running_samples as f64),
                epoch_acc_val / (running_samples as f64) * 100.0
            ));
            let _ = pbar2.update_to(i + 1);
        }
        epoch_acc_val /= dataloader_val.len() as f64;
        epoch_loss_val /= dataloader_val.len() as f64;

        scheduler.step(epoch_loss_val);

        if epoch_loss_val < best_loss {
            best_loss = epoch_loss_val;
            best_acc = epoch_acc_val;
            vs.save(Path::new(save_dir).join("best_model.ot")).unwrap()
        }
    }
    println!("\n\n\n");
    println!(
        "Best validation loss = {best_loss:.4}, accuracy={:.4}",
        best_acc * 100.0
    );
}

Inference

The inference function predicts the class of an input image using a trained CNN model. It first checks for CUDA availability and sets the device. The model is initialized and loaded with pre-trained weights. The input image is preprocessed, resized to 224x224 pixels, and passed through the model. The function then identifies the predicted class by selecting the index of the maximum output logit, returning this as the predicted class label. This process enables quick classification of new images.

use tch::{Device, nn, vision, nn::Module};
use crate::model;

pub fn inference(image_path: &str) -> i64{
    let device = Device::cuda_if_available();
    let mut vs = nn::VarStore::new(device);
    let net = model::net(&vs.root(), 2);
    vs.load("weights/best_model.ot").unwrap();
    let image = vision::imagenet::load_image_and_resize224(image_path).unwrap().unsqueeze(0);
    let out = net.forward(&image.to_device(device));
    let prediction = out.argmax(1, false);

    i64::try_from(prediction).unwrap()
}

Main function: A Complete Workflow

The main function orchestrates the training and inference processes of our CNN. The datasets for training and validation are loaded from specified directories, and their details are printed, including class mappings and dataset sizes. Data loaders are then initialized for batch processing during training and validation phases. The CNN model is trained using the training dataset, with validation performed using the separate validation dataset to monitor performance. After training, the best-performing model weights are saved. Inference is demonstrated by predicting the class of a sample image from the validation set using the trained model, and the predicted class label is printed as output.

mod data;
mod model;
mod train;
mod inference;


fn main() {
    std::env::set_var("CUDA_LAUNCH_BLOCKING", "1");
    std::env::set_var("TORCH_SHOW_WARNINGS", "0");
    println!("Loading Dataset / train");
    let dataset_train = data::Dataset::new("data/sea_vs_jungle/train");
    println!("Loading Dataset / val");
    let dataset_val = data::Dataset::new("data/sea_vs_jungle/val");
    dataset_train.print();
    dataset_val.print();

    let dataloader_train = data::DataLoader::new(dataset_train, 32, true);
    let dataloader_val = data::DataLoader::new(dataset_val, 32, false);
    println!("{}", dataloader_val.len_batch());
    train::train_model(dataloader_train, dataloader_val, "weights");

    let prediction = inference::inference("data/sea_vs_jungle/val/sea/001c31c29de8a9cd.jpg");

    println!("Prediction is {prediction}")

}

Conclusion

At the end of this project, while Rust proved to be a powerful tool for deep learning tasks with LibTorch bindings using the tch-rs crate, I found that it might be more straightforward to train a model in Python and then utilize it in Rust. In the future, I plan to explore this approach further and will write a blog post on how to train a model in Python, extract it using TorchScript, and seamlessly integrate it into Rust applications.

Code Repository

The complete Rust implementation is available on GitHub.

Reference

[1] Rust Book

[2] tch-rust

[3] ChatGPT