Running Llama 3.2 in Rust

15 minute read

Published:

The complete implementation is available on GitHub.

Introduction:

With the rapid growth of AI and natural language processing, efficient language model deployment has become a key focus for developers. Models like Llama 3.2, known for their performance and flexibility, open doors for sophisticated applications in text generation, chatbots, summarization, and more. But, running these models at high speeds—especially on GPUs—requires a language that balances performance with control.

In this post, we’ll explore how to set up and run Llama 3.2 in Rust, a language gaining popularity for its system-level access, memory safety, and concurrency features. Using the llama.cpp library as our backend, we’ll implement a flexible language model interface with Rust.

Setup

To get started with Llama 3.2 in Rust, we’ll first clone the project repository and set up the dependencies using Cargo, Rust’s package manager. Since this project depends on the llama.cpp backend, we need to clone it recursively to ensure all submodules are included. Follow these steps to prepare the environment:

Clone the Repository

First, clone the repository recursively to pull in all necessary submodules:

git clone --recursive https://github.com/ramintoosi/llama-rust
cd llama-rust

Note: The --recursive flag is essential for including the llama.cpp bindings, which provides the backend support for model inference.

Cargo Dependencies

The project’s Cargo.toml file includes the following dependencies:

[dependencies]
llama-cpp-2 = { path = "llama-cpp-rs/llama-cpp-2", features = ["cuda"] }
hf-hub = "0.3.2"
clap = { version = "4.5.19", features = ["derive"] }
anyhow = "1.0.89"
encoding_rs = "0.8.34"
log = "0.4.22"

[features]
cuda = ["llama-cpp-2/cuda"]
  • llama-cpp-2: Connects the project to the llama.cpp Rust binding, with an optional cuda feature for GPU support.

Arguments Parsing

The Args module, built with the clap library, efficiently manages the command-line arguments required for model selection, configuration, and customization. Here’s a breakdown of how each component works, focusing on the design and functionality of the different parameters.

use anyhow::{anyhow, Context};
use std::path::PathBuf;
use clap::{Parser, Subcommand};
use hf_hub::api::sync::ApiBuilder;
use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
use std::str::FromStr;

#[derive(Subcommand, Debug, Clone)]
pub enum Model {
    /// Use an already downloaded model
    #[clap(name = "local")]
    Local {
        /// The path to the model. e.g. `../hub/models--TheBloke--Llama-2-7B-Chat-GGUF/blobs/08a5566d61d7cb6b420c3e4387a39e0078e1f2fe5f055f3a03887385304d4bfa`
        /// or `./llama-3.2-1b-instruct-q8_0.gguf`
        path: PathBuf,
    },
    /// Download a model from huggingface (or use a cached version)
    #[clap(name = "hf-model")]
    HuggingFace {
        /// the repo containing the model. e.g. `TheBloke/Llama-2-7B-Chat-GGUF`
        repo: String,
        /// the model name. e.g. `llama-2-7b-chat.Q4_K_M.gguf`
        model: String,
    },
}

impl Model {
    /// Convert the model to a path - may download from huggingface
    pub fn get_or_load(self) -> anyhow::Result<PathBuf> {
        match self {
            Model::Local { path } => Ok(path),
            Model::HuggingFace { model, repo } => ApiBuilder::new()
                .with_progress(true)
                .build()
                .with_context(|| "unable to create huggingface api")?
                .model(repo)
                .get(&model)
                .with_context(|| "unable to download model"),
        }
    }
}


#[derive(clap::ValueEnum, Clone, Debug)]
pub enum Mode {
    Chat,
    Completion,
}

#[derive(Parser, Debug, Clone)]
pub struct Args {
    /// The path to the model
    #[command(subcommand)]
    pub model: Model,

    /// The mode of the code: completion or chat
    #[clap(value_enum, short = 'm', long, default_value = "chat")]
    pub mode: Mode,

    // /// The prompt to use - valid only if the mode is `completion`
    // #[clap(short = 'p', long, required_if_eq("mode", "completion"))]
    // prompt: Option<String>,

    /// set the length of the prompt + output in tokens
    #[clap(long, default_value_t = 512)]
    pub max_token: u32,

    /// override some parameters of the model
    #[clap(short = 'o', value_parser = parse_key_val)]
    pub key_value_overrides: Vec<(String, ParamOverrideValue)>,

    /// how many layers to keep on the gpu - zero is cpu mode
    #[clap(
        short = 'g',
        long,
        help = "how many layers to keep on the gpu - zero is cpu mode (default: 0)"
    )]
    pub n_gpu_layers: u32,

    /// set the seed for the RNG
    #[clap(short = 's', long, default_value_t=561371)]
    pub seed: u32,

    /// number of threads to use during generation
    #[clap(
        long,
        help = "number of threads to use during generation (default: use all available threads)"
    )]
    pub threads: Option<i32>,
    #[clap(
        long,
        help = "number of threads to use during batch and prompt processing (default: use all available threads)"
    )]
    pub threads_batch: Option<i32>,

    // /// size of the prompt context
    // #[clap(
    //     short = 'c',
    //     long,
    //     help = "size of the prompt context (default: loaded from the model)"
    // )]
    // pub ctx_size: Option<NonZeroU32>,
    
    /// show the token/s speed at the end of each turn
    #[clap(short = 'v', long, action)]
    pub verbose: bool,

}

/// Parse a single key-value pair
fn parse_key_val(s: &str) -> anyhow::Result<(String, ParamOverrideValue)> {
    let pos = s
        .find('=')
        .ok_or_else(|| anyhow!("invalid KEY=value: no `=` found in `{}`", s))?;
    let key = s[..pos].parse()?;
    let value: String = s[pos + 1..].parse()?;
    let value = i64::from_str(&value)
        .map(ParamOverrideValue::Int)
        .or_else(|_| f64::from_str(&value).map(ParamOverrideValue::Float))
        .or_else(|_| bool::from_str(&value).map(ParamOverrideValue::Bool))
        .map_err(|_| anyhow!("must be one of i64, f64, or bool"))?;

    Ok((key, value))
}

Model Enum for Model Selection

The Model enum defines two ways to specify a language model:

  • Local Model (local): This variant lets users provide a path to a locally downloaded model file.

  • Hugging Face Model (hf-model): This variant enables automatic downloading of a model from Hugging Face’s repositories. It takes in the repo and model names and uses hf_hub to fetch the specified model, caching it for future use.

The get_or_load method in Model abstracts the logic for loading models, using Hugging Face’s API if needed.

Mode Enum for Operation Mode

The Mode enum controls the inference mode:

  • Chat: Enables a conversational interaction with the model, allowing for an ongoing dialogue.
  • Completion: Generates text completions based on an initial prompt.

The mode argument defaults to Chat, but users can specify either mode by passing -m completion or -m chat.

The Args Struct

The Args struct organizes all of the command-line arguments into a clean and accessible structure. Let’s look at each parameter:

  • model: A subcommand that lets the user specify a model, either by path or by Hugging Face repository, as described in the Model enum.

  • mode: Specifies the operational mode, either chat or completion.

  • max_token: Sets the maximum length for prompt and output tokens. The default value is 512, allowing some control over the model’s generation length.

  • key_value_overrides: Provides flexibility by allowing the user to pass key-value overrides for specific model parameters.

  • n_gpu_layers: Specifies the number of layers to run on the GPU, with a default of 0, which uses the CPU.

  • seed: Sets a seed for the random number generator, allowing for reproducible output across runs. This defaults to 561371.

  • threads and threads_batch: These arguments allow the user to fine-tune performance by specifying the number of threads used for generation and batch processing, respectively. By default, it uses all available threads.

  • verbose: If set, this flag displays token processing speed after each generation, useful for performance monitoring.

Llama 3.2 Inference

The LLM struct is the central component for handling Llama model inference, encapsulating essential components like the model, backend, and configuration parameters.

use std::ffi::CString;
use std::io::Write;
use std::num::NonZeroU32;
use std::pin::pin;
use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{AddBos, LlamaModel, Special};
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use std::sync::Arc;
use std::time::Duration;
use llama_cpp_2::ggml_time_us;
use super::args_handler::{Args, Mode};

pub struct LLM {
    pub model: Arc<LlamaModel>,
    pub backend: LlamaBackend,
    pub ctx_params: LlamaContextParams,
    mode: Mode,
    history: String,
    max_token: u32,
    verbose: bool,
}

new Method

pub fn new(args: Args) -> Self{
        // init LLM
        let backend = LlamaBackend::init()
            .expect("Could not initialize Llama backend");

        // offload all layers to the gpu
        let model_params = {
            if args.n_gpu_layers > 0 {
                LlamaModelParams::default().with_n_gpu_layers(args.n_gpu_layers)
            } else {
                LlamaModelParams::default()
            }
        };

        let mut model_params = pin!(model_params);

        for (k, v) in &args.key_value_overrides {
            let k = CString::new(k.as_bytes()).expect(format!("invalid key: {k}").as_str());
            model_params.as_mut().append_kv_override(k.as_c_str(), *v);
        }

        let model_path = args.model.clone()
            .get_or_load()
            .expect("failed to get model from args");

        

        // Load the model and wrap it in an Arc for shared ownership
        let model = Arc::new(LlamaModel::load_from_file(&backend, model_path, &model_params)
                .expect("failed to load model"
        ));

        // initialize the context
        
        let mut ctx_params = LlamaContextParams::default()
            .with_n_ctx(NonZeroU32::new(args.max_token))
            .with_seed(args.seed);
        if let Some(threads) = args.threads {
            ctx_params = ctx_params.with_n_threads(threads);
        }
        if let Some(threads_batch) = args.threads_batch.or(args.threads) {
            ctx_params = ctx_params.with_n_threads_batch(threads_batch);
        }
        
        Self {
            model,
            backend,
            ctx_params,
            mode: args.mode,
            history: String::new(),
            max_token: args.max_token,
            verbose: args.verbose,
        }

        
    }

The new method in the LLM struct initializes the core components required for inference, including backend setup, model loading, and context configuration. It takes in an Args instance, which holds the user-provided configuration parameters, and uses these to fine-tune the model and backend settings. Here’s a step-by-step breakdown of each part:

Initializing the Backend

let backend = LlamaBackend::init()
    .expect("Could not initialize Llama backend");

Configuring Model Parameters

let model_params = {
    if args.n_gpu_layers > 0 {
        LlamaModelParams::default().with_n_gpu_layers(args.n_gpu_layers)
    } else {
        LlamaModelParams::default()
    }
};

The model_params configuration is determined by whether the user has specified GPU layers with n_gpu_layers. This parameter allows the model to use GPU acceleration for a set number of layers, which can significantly improve inference speed. If no GPU layers are specified (i.e., n_gpu_layers is 0), it defaults to using the CPU.

Applying Key-Value Overrides

for (k, v) in &args.key_value_overrides {
    let k = CString::new(k.as_bytes()).expect(format!("invalid key: {k}").as_str());
    model_params.as_mut().append_kv_override(k.as_c_str(), *v);
}

This section iterates over any key-value overrides specified by the user, allowing them to customize specific model parameters.

Loading the Model

let model_path = args.model.clone()
    .get_or_load()
    .expect("failed to get model from args");

let model = Arc::new(LlamaModel::load_from_file(&backend, model_path, &model_params)
    .expect("failed to load model"));

The model_path is determined by calling get_or_load on the args.model field, which either retrieves the local model path or downloads it from Hugging Face. The model is then loaded into an Arc<LlamaModel> instance, allowing shared ownership of the model.

Setting Up the Context Parameters

let mut ctx_params = LlamaContextParams::default()
    .with_n_ctx(NonZeroU32::new(args.max_token))
    .with_seed(args.seed);
if let Some(threads) = args.threads {
    ctx_params = ctx_params.with_n_threads(threads);
}
if let Some(threads_batch) = args.threads_batch.or(args.threads) {
    ctx_params = ctx_params.with_n_threads_batch(threads_batch);
}

Returning the LLM Instance

Finally, the configured fields are assembled into a new LLM instance:

Self {
    model,
    backend,
    ctx_params,
    mode: args.mode,
    history: String::new(),
    max_token: args.max_token,
    verbose: args.verbose,
}

The history is initialized as an empty string, which is useful for conversational interactions in Chat mode. The max_token and verbose flags are set based on the user input.

Generation

The generate_chat function in this implementation is designed to generate responses in either a chat or text completion mode. It handles various tasks essential for text generation, including formatting inputs, managing tokenization and decoding, and outputting the generated text. Let’s go through each component to see how it works.

Formatting the Input Based on Mode

The function starts by formatting the input based on whether the LLM instance is in Chat or Completion mode.

let input_to_model = match self.mode {
    Mode::Chat => {
        let input_formatted = format!(
            "<|start_header_id|>user<|end_header_id|>{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
            prompt
        );
        self.history.push_str(input_formatted.as_str());
        flag_chat = true;
        &self.history
    }
    Mode::Completion => {
        &format!("<|begin_of_text|>{}", prompt)
    }
};
  • Chat Mode: It appends the prompt to a conversation history that includes user and assistant markers. This structured approach keeps track of the conversation context and appends new responses to self.history.
  • Completion Mode: It prepares a straightforward prompt without conversation markers.

Tokenizing the Input

The next step tokenizes the formatted prompt, converting it into a list of tokens that the model can process.

let tokens_list = ctx.model
    .str_to_token(&input_to_model, AddBos::Never)
    .expect(format!("failed to tokenize {}", input_to_model).as_str());

Here, AddBos::Never ensures that no beginning-of-sequence token is added, as the prompt format is controlled separately.

Validating Token Length

To prevent the model from running out of memory, the function checks that the token count doesn’t exceed the max_token limit:

if tokens_list.len() >= max_token as usize {
    panic!("the prompt is too long, it has more tokens than max_token ({max_token})")
}

Preparing the LlamaBatch for Decoding

The function then sets up a LlamaBatch, which is used to manage the tokens being processed.

let mut batch = LlamaBatch::new(n_cxt as usize, 1);

Each token in tokens_list is added to this batch, with a flag indicating the last token in the sequence for decoding.

for (i, token) in tokens_list.into_iter().enumerate() {
    let is_last = i == last_index;
    batch.add(token, i as i32, &[0], is_last).unwrap();
}

Clearing the KV Cache for Completion Mode

To prevent older tokens from influencing new ones in completion mode, the function clears the KV cache:

ctx.clear_kv_cache();

Decoding Loop

The main loop continues decoding tokens until reaching the max_token limit. During each iteration, it:

  1. Samples a New Token: The model generates a list of candidate tokens, from which the most probable token is chosen.

     let candidates = ctx.candidates();
     let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
     let new_token_id = ctx.sample_token_greedy(candidates_p);
    
  2. Checks for End of Stream: If the token is an end-of-generation (EOG) marker, the loop breaks.

     if ctx.model.is_eog_token(new_token_id) {
         break;
     }
    
  3. Decodes and Outputs the Token: The selected token ID is converted into a string, appended to the output, and flushed to standard output.

     let output_bytes = ctx.model.token_to_bytes(new_token_id, Special::Tokenize).unwrap();
     let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
    
     print!("{output_string}");
     llm_output.push_str(output_string.as_str());
     std::io::stdout().flush().unwrap();
    
  4. Prepares for Next Iteration: The LlamaBatch is cleared, and the new token is added to continue decoding.

Finalizing the Output

After decoding, if in chat mode, the output is appended to self.history, preserving context for future exchanges:

if flag_chat {
    self.history.push_str(llm_output_formatted.as_str());
}

Logging Decoding Speed (Optional)

If verbose mode is enabled, the function calculates the speed of token decoding and outputs it.

if self.verbose {
    eprintln!(
        "\n[decoded {} tokens in {:.2} s, speed {:.2} t/s]",
        n_decode,
        duration.as_secs_f32(),
        n_decode as f32 / duration.as_secs_f32()
    );
}

This part is helpful for profiling performance, particularly when optimizing for response time.

The main function

The main function is the entry point for this Rust-based Llama chatbot application. It initializes the model, creates a context for inference, and then enters a loop to handle continuous user input. Let’s break down its components to understand how it operates.

Parsing Command-Line Arguments

let args: Args = Args::parse();

This line uses the clap library to parse command-line arguments based on the Args struct defined earlier.

Initializing the LLM Instance

let mut rllm: LLM = LLM::new(args);

Here, we create an instance of LLM using the parsed arguments. The LLM::new function does the heavy lifting of loading the model, setting parameters, and initializing the model backend.

Creating the Model Context

let binding = rllm.model.clone();
let mut ctx = binding
    .new_context(&rllm.backend, rllm.ctx_params.clone())
    .expect("failed to create context");

The context is necessary for token-based decoding and inference with Llama.

The context creation here can feel a bit detached from the LLM struct itself. Since the context setup is specific to a particular input session, it could be made an internal component of LLM, initialized within the struct. But I had lifetime issues!

Main Interaction Loop

The function then enters a loop where it continuously takes user input and generates responses until the user exits by submitting an empty input.

let mut input = String::new();
print!("Assistant: How can I help you today?\n");
loop {
    input.clear();
    println!("\nYou: ");
    std::io::stdout().flush().unwrap();
    std::io::stdin().read_line(&mut input).unwrap();
    let input = input.trim();

    if input.is_empty() {
        break;
    }

    rllm.generate_chat(&mut ctx, &input);
}
  • Clearing Input: Before each prompt, input.clear() ensures any leftover data from the previous iteration is removed.
  • Prompting the User: The program prints You: as a prompt to the user to type their input. flush() ensures this prompt displays immediately.
  • Reading and Trimming Input: The user’s input is read from stdin and trimmed to remove extra whitespace.
  • Exiting on Empty Input: If the user submits an empty line, the loop breaks, ending the program.
  • Generating Response: The generate_chat function of LLM is called with the context and user input to generate and display a response from the model.

For future refinements, moving context handling inside the LLM struct could streamline the design, making it more self-contained.

GitHub Repository

The complete implementation is available on GitHub.

Reference

[1] llama-cpp-rs

[2] ChatGPT