Implementing K-Means in Rust: A 7x Speed Boost Over Python

10 minute read

Published:

Introduction:

As a machine learning enthusiast, I’ve always been fascinated by the speed and efficiency of algorithms. Recently, I decided to learn about Rust, a language known for its performance and safety guarantees, to implement the K-means clustering algorithm. Despite not being an experienced Rust developer, I was amazed by the results: my Rust implementation outperformed the popular scikit-learn package in Python by a factor of 7. In this blog post, I’ll share my journey of implementing K-means in Rust. Let’s dive in!

Data

This Python code generates 1 million samples of Gaussian noise clusters with different means in 10 dimensions and saves them to a CSV file named “data.csv”. There are four clusters, and each cluster has different means, and the clusters are separated by a distance of 2 along each dimension. Here’s the code:

import numpy as np

n_cluster = 4
number_of_samples = 1_000_000
dim = 10

n_sample_per_class = int(number_of_samples / n_cluster)
samples = np.zeros((n_sample_per_class*n_cluster , dim))

for i in range(n_cluster):
    samples[i*n_sample_per_class: (i+1)*n_sample_per_class] = np.random.randn(n_sample_per_class, dim) + (i+1) * 2
    
np.savetxt("data.csv", samples, delimiter=",", fmt='%.3f')

This code creates 250,000 samples for each cluster, totaling 1 million samples. Each sample has 10 dimensions, and the means of the clusters are 2, 4, 6, and 8 along each dimension, respectively.

Rust Implementation

In this section, I’ll walk you through the Rust implementation of the K-means algorithm step by step. However, since I’m not a seasoned Rust developer, I’ll keep the explanations high-level to avoid potential errors.

Setting Up Dependencies

This code snippet imports several standard library and external crates commonly used in Rust for various functionalities. It includes error handling, CSV file reading and writing, array operations, random number generation, and command-line argument parsing.

use std::{
    error::Error,
};
use std::ops::IndexMut;
use csv::{ReaderBuilder, Writer};
use ndarray::{Array1, Array2, ArrayView1, AssignElem};
use ndarray_csv::{Array2Reader};
use rand::seq::index::sample;
use rand::{Rng, thread_rng};
use clap::{Parser};

Defining Command-Line Arguments

This section defines the command-line arguments using the clap crate. It includes arguments for the path to the CSV file, the number of clusters, whether to use Kmeans++ for center initialization, the maximum number of iterations, the tolerance for center change, and the path to save the indices as a CSV file.

// define arguments using clap
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
    #[arg(short, long, help = "Path to the csv file")]
    data_path: String,

    #[arg(short, long, help = "Number of clusters")]
    num_cluster: usize,

    #[arg(short, long, help = "Use Kmeans++ to initialize centers")]
    kpp: bool,

    #[arg(short, long, help = "Maximum number of iterations", default_value = "1000")]
    max_iter: i32,

    #[arg(short, long, help = "Maximum center change tolerance", default_value = "1e-4")]
    tolerance: f32,

    #[arg(short, long, help = "Path to save indices as csv", default_value = "indices.csv")]
    output_path: String

}

Loading Data from CSV

This function load_data takes a file path as input and loads data from a CSV file into a 2D array of 32-bit floating-point numbers (Array2). It uses the csv crate to read the CSV file and deserialize its contents into the array.

fn load_data(file_path: &str) -> Result<Array2<f32>, Box<dyn Error>> {
    // this function loads data from the csv file
    let reader = ReaderBuilder::new().has_headers(false).from_path(file_path);
    let array_read: Array2<f32> = reader?.deserialize_array2_dynamic()?;
    Ok(array_read)
}

Initializing Cluster Centers

The random_centers function takes a reference to a 2D array data and the number of clusters n_cluster as input. It selects n_cluster random rows from data as initial cluster centers and returns them in a new 2D array. Similarly, the kmeans_pp function implements the KMeans++ algorithm for selecting initial cluster centers.

fn random_centers(data : &Array2<f32>, n_cluster: usize) -> Array2<f32> {
    // centers are randomly selected from the current set of samples
    let n_rows = data.nrows();
    let mut rng = thread_rng();
    let mut selected_rows =
        Array2::<f32>::zeros((n_cluster, data.ncols()));
    let indices: Array1<usize> = sample(&mut rng, n_rows, n_cluster).into_iter().collect();
    for (i, &index) in indices.iter().enumerate() {
        selected_rows.row_mut(i).assign(&data.row(index));
    }

    selected_rows
}

fn kmeans_pp(data : &Array2<f32>, n_cluster: usize) -> Array2<f32> {
    // centers are selected based on KMeans++ algorithm
    let n_rows = data.nrows();
    let mut chosen_points: Vec<usize> = vec!();
    let mut centers: Array2<f32> = Array2::zeros((n_cluster, data.ncols()));

    // first center
    chosen_points.push(thread_rng().gen_range(0..n_rows));
    centers.row_mut(1).assign(&data.row(chosen_points[0]));

    // other centers
    for i_center in 1..n_cluster {

        let mut max_dist: f32 = -1.0;
        let mut max_index: usize = 0;
        for (i_sample, sample) in data.rows().into_iter().enumerate() {
            if chosen_points.contains(&i_sample) {
                continue
            }
            let mut c_dist: f32 = 0.0;
            for i_prev_centers in 0..i_center {
                c_dist += euclidean_distance(&sample, &centers.row(i_prev_centers));
            }
            if c_dist > max_dist {
                max_dist = c_dist;
                max_index = i_sample
            }
        }

        chosen_points.push(max_index);
        centers.row_mut(i_center).assign(&data.row(max_index))
    }
    centers
}

Calculating Euclidean Distance

The euclidean_distance function calculates the Euclidean distance between two 1D arrays a and b represented as ArrayView1.

fn euclidean_distance(a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> f32 {
    // calculate Euclidean distance between two arrays
    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt()
}

Assigning Samples to Clusters

The assign_cluster_to_sample function assigns each sample in the input data (Array2) to the closest cluster center based on Euclidean distance. It takes three arguments: the input data, the cluster centers, and a mutable array of indices representing the assigned cluster for each sample.

fn assign_cluster_to_sample(data: &Array2<f32>, centers: &Array2<f32>, indices: &mut Array1<usize>) {
    // finds the closest cluster to each sample using centers
    for (index_sample, row) in data.rows().into_iter().enumerate(){
        let mut min_distance = f32::INFINITY;
        let mut min_index: usize = 0;

        for (index, center_row) in centers.rows().into_iter().enumerate(){
            let distance = euclidean_distance(&row, &center_row);
            if distance < min_distance{
                min_distance = distance;
                min_index = index;
            }
        }
        indices.index_mut(index_sample).assign_elem(min_index)
    }
}

Updating Cluster Centers

The update_centers function updates the cluster centers based on the samples assigned to each cluster. It takes the input data (data), the current cluster centers (centers), the indices representing the assigned cluster for each sample (indices), and the number of clusters (n_clusters) as input.

fn update_centers(data: &Array2<f32>,
                  centers: &mut Array2<f32>,
                  indices: &Array1<usize>, n_clusters: usize) -> f32 {
    // update centers as the average of the samples within the cluster
    let mut max_change = f32::INFINITY;
    for index in 0..n_clusters {
        let matched_indices: Vec<usize> = indices.iter()
            .enumerate()
            .filter(|&(_, &value)| value == index as usize)
            .map(|(i, _)| i)
            .collect();
        let mut c: Array1<f32> = Array1::zeros(data.ncols());
        for m_index in &matched_indices {
            c += &data.row(*m_index);
        }
        c /= matched_indices.len() as f32;
        let distance = euclidean_distance(&centers.row(index), &c.view());
        centers.row_mut(index as usize).assign(&c);
        if distance < max_change{
            max_change = distance;
        }
    }
    max_change
}

### The write_csv function writes the cluster assignments (indices) to a CSV file specified by output_path

fn write_csv(indices: &Array1<usize>, output_path: &str) -> Result<(), Box<dyn Error>>{
    // write the result into a csv file
    let mut writer = Writer::from_path(output_path)?;

    for &value in indices.iter() {
        writer.write_record(&[value.to_string()])?;
    }

    writer.flush()?;
    println!("Results wrote to {output_path}");
    Ok(())
}

K-Means Struct and Methods

The KMeans struct represents a K-means clustering model with configurable parameters. It includes fields for the number of clusters (num_cluster), whether to use KMeans++ for center initialization (kpp), the convergence tolerance (tolerance), the maximum number of iterations (max_iter), the current iteration number (iter), and the maximum change in cluster centers (max_change).

The fit method fits the K-means model to the input data (data). It initializes the cluster centers, assigns samples to clusters, and updates the centers iteratively until convergence or the maximum number of iterations is reached. It returns an array of cluster indices indicating the cluster assignment for each sample.

#[derive(Default)]
struct KMeans {
    num_cluster: usize,
    kpp: bool,
    tolerance: f32,
    max_iter: i32,
    iter: Option<i32>,
    max_change: Option<f32>

}

impl KMeans {
    fn fit(&mut self, data: &Array2<f32>) -> Array1<usize> {
        // initiate centers random or using kmeans++
        let mut centers = if self.kpp {
            kmeans_pp(&data, self.num_cluster)
        } else {
            random_centers(&data, self.num_cluster)
        };

        // main loop, assign indices and update centers
        let mut indices: Array1<usize> = Array1::zeros(data.nrows());
        let mut iter = 0;
        let mut max_change = f32::INFINITY;
        while  (max_change > self.tolerance) & (iter < self.max_iter) {
            iter += 1;
            assign_cluster_to_sample(&data, &centers, &mut indices);
            max_change = update_centers(&data, &mut centers, &indices, self.num_cluster);
        }
        let _ = self.iter.insert(iter);
        let _ = self.max_change.insert(max_change);
        indices
    }

    fn get_iter(&self) -> i32 {
        if self.iter.is_some() {
            self.iter.unwrap()
        }
        else {
            0
        }
    }

    fn get_max_change(&self) -> f32 {
        if self.max_change.is_some() {
            self.max_change.unwrap()
        } else {
            f32::INFINITY
        }
    }
    
}

Main Function

The main function serves as the entry point for the program. It parses command-line arguments using clap, loads data from a CSV file, initializes a KMeans struct with the specified parameters, fits the K-means model to the data, prints the number of iterations and maximum change in cluster centers, and writes the cluster assignments to a CSV file.

fn main() {

    // parse and get arguments
    let cli = Args::parse();
    let file_path = cli.data_path;

    // load data
    let data = load_data(&file_path).expect("Error reading csv");

    let mut kmeans = KMeans {
        num_cluster: cli.num_cluster,
        kpp: cli.kpp,
        tolerance: cli.tolerance,
        max_iter: cli.max_iter,
        ..Default::default()
    };

    let indices = kmeans.fit(&data);

    println!("Number of Iters: {:?} with max change: {:?}",
             kmeans.get_iter(), kmeans.get_max_change());

    // write to csv file
    let _ = write_csv(&indices, &cli.output_path);

}

Usage

>> target/release/kmeans -h

A simple implementation of KMeans algorithm.

Usage: kmeans [OPTIONS] --data-path <DATA_PATH> --num-cluster <NUM_CLUSTER>

Options:
  -d, --data-path <DATA_PATH>      Path to the csv file
  -n, --num-cluster <NUM_CLUSTER>  Number of clusters
  -k, --kpp                        Use Kmeans++ to initialize centers
  -m, --max-iter <MAX_ITER>        Maximum number of iterations [default: 1000]
  -t, --tolerance <TOLERANCE>      Maximum center change tolerance [default: 1e-4]
  -o, --output-path <OUTPUT_PATH>  Path to save indices as csv [default: indices.csv]
  -h, --help                       Print help
  -V, --version                    Print version

A Speed Showdown: Rust vs. Python for K-Means Clustering

This comparison pits the performance of Rust against Python for K-means clustering. The Python code uses numpy and scikit-learn to read data from a CSV file, perform K-means clustering, and save the cluster centers to another CSV file.

import time
import numpy as np
from sklearn.cluster import KMeans
start = time.time()
data = np.genfromtxt('data.csv', delimiter=',')
k = KMeans(n_clusters=4, init='k-means++', n_init=1, max_iter=100, tol=1e-5)
k.fit(data)
np.savetxt("res.csv", k.cluster_centers_, delimiter=",")
print(f'n_iter {k.n_iter_}')
print(time.time() - start)

Despite the fact that I’m not a Rust developer, the Rust code outperforms the Python code by a factor of 7, showcasing Rust’s speed and efficiency.

Code Repository

The complete Rust implementation of the K-means algorithm is available on GitHub.

Reference

[1] Rust Book

[2] ChatGPT