Running Llama 3.2 in Rust
Published:
The complete implementation is available on GitHub.
Introduction:
With the rapid growth of AI and natural language processing, efficient language model deployment has become a key focus for developers. Models like Llama 3.2, known for their performance and flexibility, open doors for sophisticated applications in text generation, chatbots, summarization, and more. But, running these models at high speeds—especially on GPUs—requires a language that balances performance with control.
In this post, we’ll explore how to set up and run Llama 3.2 in Rust, a language gaining popularity for its system-level access, memory safety, and concurrency features. Using the llama.cpp
library as our backend, we’ll implement a flexible language model interface with Rust.
Setup
To get started with Llama 3.2 in Rust, we’ll first clone the project repository and set up the dependencies using Cargo, Rust’s package manager. Since this project depends on the llama.cpp
backend, we need to clone it recursively to ensure all submodules are included. Follow these steps to prepare the environment:
Clone the Repository
First, clone the repository recursively to pull in all necessary submodules:
git clone --recursive https://github.com/ramintoosi/llama-rust
cd llama-rust
Note: The
--recursive
flag is essential for including thellama.cpp
bindings, which provides the backend support for model inference.
Cargo Dependencies
The project’s Cargo.toml
file includes the following dependencies:
[dependencies]
llama-cpp-2 = { path = "llama-cpp-rs/llama-cpp-2", features = ["cuda"] }
hf-hub = "0.3.2"
clap = { version = "4.5.19", features = ["derive"] }
anyhow = "1.0.89"
encoding_rs = "0.8.34"
log = "0.4.22"
[features]
cuda = ["llama-cpp-2/cuda"]
llama-cpp-2
: Connects the project to thellama.cpp
Rust binding, with an optionalcuda
feature for GPU support.
Arguments Parsing
The Args
module, built with the clap
library, efficiently manages the command-line arguments required for model selection, configuration, and customization. Here’s a breakdown of how each component works, focusing on the design and functionality of the different parameters.
use anyhow::{anyhow, Context};
use std::path::PathBuf;
use clap::{Parser, Subcommand};
use hf_hub::api::sync::ApiBuilder;
use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
use std::str::FromStr;
#[derive(Subcommand, Debug, Clone)]
pub enum Model {
/// Use an already downloaded model
#[clap(name = "local")]
Local {
/// The path to the model. e.g. `../hub/models--TheBloke--Llama-2-7B-Chat-GGUF/blobs/08a5566d61d7cb6b420c3e4387a39e0078e1f2fe5f055f3a03887385304d4bfa`
/// or `./llama-3.2-1b-instruct-q8_0.gguf`
path: PathBuf,
},
/// Download a model from huggingface (or use a cached version)
#[clap(name = "hf-model")]
HuggingFace {
/// the repo containing the model. e.g. `TheBloke/Llama-2-7B-Chat-GGUF`
repo: String,
/// the model name. e.g. `llama-2-7b-chat.Q4_K_M.gguf`
model: String,
},
}
impl Model {
/// Convert the model to a path - may download from huggingface
pub fn get_or_load(self) -> anyhow::Result<PathBuf> {
match self {
Model::Local { path } => Ok(path),
Model::HuggingFace { model, repo } => ApiBuilder::new()
.with_progress(true)
.build()
.with_context(|| "unable to create huggingface api")?
.model(repo)
.get(&model)
.with_context(|| "unable to download model"),
}
}
}
#[derive(clap::ValueEnum, Clone, Debug)]
pub enum Mode {
Chat,
Completion,
}
#[derive(Parser, Debug, Clone)]
pub struct Args {
/// The path to the model
#[command(subcommand)]
pub model: Model,
/// The mode of the code: completion or chat
#[clap(value_enum, short = 'm', long, default_value = "chat")]
pub mode: Mode,
// /// The prompt to use - valid only if the mode is `completion`
// #[clap(short = 'p', long, required_if_eq("mode", "completion"))]
// prompt: Option<String>,
/// set the length of the prompt + output in tokens
#[clap(long, default_value_t = 512)]
pub max_token: u32,
/// override some parameters of the model
#[clap(short = 'o', value_parser = parse_key_val)]
pub key_value_overrides: Vec<(String, ParamOverrideValue)>,
/// how many layers to keep on the gpu - zero is cpu mode
#[clap(
short = 'g',
long,
help = "how many layers to keep on the gpu - zero is cpu mode (default: 0)"
)]
pub n_gpu_layers: u32,
/// set the seed for the RNG
#[clap(short = 's', long, default_value_t=561371)]
pub seed: u32,
/// number of threads to use during generation
#[clap(
long,
help = "number of threads to use during generation (default: use all available threads)"
)]
pub threads: Option<i32>,
#[clap(
long,
help = "number of threads to use during batch and prompt processing (default: use all available threads)"
)]
pub threads_batch: Option<i32>,
// /// size of the prompt context
// #[clap(
// short = 'c',
// long,
// help = "size of the prompt context (default: loaded from the model)"
// )]
// pub ctx_size: Option<NonZeroU32>,
/// show the token/s speed at the end of each turn
#[clap(short = 'v', long, action)]
pub verbose: bool,
}
/// Parse a single key-value pair
fn parse_key_val(s: &str) -> anyhow::Result<(String, ParamOverrideValue)> {
let pos = s
.find('=')
.ok_or_else(|| anyhow!("invalid KEY=value: no `=` found in `{}`", s))?;
let key = s[..pos].parse()?;
let value: String = s[pos + 1..].parse()?;
let value = i64::from_str(&value)
.map(ParamOverrideValue::Int)
.or_else(|_| f64::from_str(&value).map(ParamOverrideValue::Float))
.or_else(|_| bool::from_str(&value).map(ParamOverrideValue::Bool))
.map_err(|_| anyhow!("must be one of i64, f64, or bool"))?;
Ok((key, value))
}
Model
Enum for Model Selection
The Model
enum defines two ways to specify a language model:
Local Model (
local
): This variant lets users provide a path to a locally downloaded model file.Hugging Face Model (
hf-model
): This variant enables automatic downloading of a model from Hugging Face’s repositories. It takes in therepo
andmodel
names and useshf_hub
to fetch the specified model, caching it for future use.
The get_or_load
method in Model
abstracts the logic for loading models, using Hugging Face’s API if needed.
Mode
Enum for Operation Mode
The Mode
enum controls the inference mode:
- Chat: Enables a conversational interaction with the model, allowing for an ongoing dialogue.
- Completion: Generates text completions based on an initial prompt.
The mode
argument defaults to Chat
, but users can specify either mode by passing -m completion
or -m chat
.
The Args
Struct
The Args
struct organizes all of the command-line arguments into a clean and accessible structure. Let’s look at each parameter:
model
: A subcommand that lets the user specify a model, either by path or by Hugging Face repository, as described in theModel
enum.mode
: Specifies the operational mode, eitherchat
orcompletion
.max_token
: Sets the maximum length for prompt and output tokens. The default value is 512, allowing some control over the model’s generation length.key_value_overrides
: Provides flexibility by allowing the user to pass key-value overrides for specific model parameters.n_gpu_layers
: Specifies the number of layers to run on the GPU, with a default of0
, which uses the CPU.seed
: Sets a seed for the random number generator, allowing for reproducible output across runs. This defaults to561371
.threads
andthreads_batch
: These arguments allow the user to fine-tune performance by specifying the number of threads used for generation and batch processing, respectively. By default, it uses all available threads.verbose
: If set, this flag displays token processing speed after each generation, useful for performance monitoring.
Llama 3.2 Inference
The LLM struct is the central component for handling Llama model inference, encapsulating essential components like the model, backend, and configuration parameters.
use std::ffi::CString;
use std::io::Write;
use std::num::NonZeroU32;
use std::pin::pin;
use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{AddBos, LlamaModel, Special};
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use std::sync::Arc;
use std::time::Duration;
use llama_cpp_2::ggml_time_us;
use super::args_handler::{Args, Mode};
pub struct LLM {
pub model: Arc<LlamaModel>,
pub backend: LlamaBackend,
pub ctx_params: LlamaContextParams,
mode: Mode,
history: String,
max_token: u32,
verbose: bool,
}
new
Method
pub fn new(args: Args) -> Self{
// init LLM
let backend = LlamaBackend::init()
.expect("Could not initialize Llama backend");
// offload all layers to the gpu
let model_params = {
if args.n_gpu_layers > 0 {
LlamaModelParams::default().with_n_gpu_layers(args.n_gpu_layers)
} else {
LlamaModelParams::default()
}
};
let mut model_params = pin!(model_params);
for (k, v) in &args.key_value_overrides {
let k = CString::new(k.as_bytes()).expect(format!("invalid key: {k}").as_str());
model_params.as_mut().append_kv_override(k.as_c_str(), *v);
}
let model_path = args.model.clone()
.get_or_load()
.expect("failed to get model from args");
// Load the model and wrap it in an Arc for shared ownership
let model = Arc::new(LlamaModel::load_from_file(&backend, model_path, &model_params)
.expect("failed to load model"
));
// initialize the context
let mut ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(args.max_token))
.with_seed(args.seed);
if let Some(threads) = args.threads {
ctx_params = ctx_params.with_n_threads(threads);
}
if let Some(threads_batch) = args.threads_batch.or(args.threads) {
ctx_params = ctx_params.with_n_threads_batch(threads_batch);
}
Self {
model,
backend,
ctx_params,
mode: args.mode,
history: String::new(),
max_token: args.max_token,
verbose: args.verbose,
}
}
The new
method in the LLM
struct initializes the core components required for inference, including backend setup, model loading, and context configuration. It takes in an Args
instance, which holds the user-provided configuration parameters, and uses these to fine-tune the model and backend settings. Here’s a step-by-step breakdown of each part:
Initializing the Backend
let backend = LlamaBackend::init()
.expect("Could not initialize Llama backend");
Configuring Model Parameters
let model_params = {
if args.n_gpu_layers > 0 {
LlamaModelParams::default().with_n_gpu_layers(args.n_gpu_layers)
} else {
LlamaModelParams::default()
}
};
The model_params
configuration is determined by whether the user has specified GPU layers with n_gpu_layers
. This parameter allows the model to use GPU acceleration for a set number of layers, which can significantly improve inference speed. If no GPU layers are specified (i.e., n_gpu_layers
is 0), it defaults to using the CPU.
Applying Key-Value Overrides
for (k, v) in &args.key_value_overrides {
let k = CString::new(k.as_bytes()).expect(format!("invalid key: {k}").as_str());
model_params.as_mut().append_kv_override(k.as_c_str(), *v);
}
This section iterates over any key-value overrides specified by the user, allowing them to customize specific model parameters.
Loading the Model
let model_path = args.model.clone()
.get_or_load()
.expect("failed to get model from args");
let model = Arc::new(LlamaModel::load_from_file(&backend, model_path, &model_params)
.expect("failed to load model"));
The model_path
is determined by calling get_or_load
on the args.model
field, which either retrieves the local model path or downloads it from Hugging Face. The model is then loaded into an Arc<LlamaModel>
instance, allowing shared ownership of the model.
Setting Up the Context Parameters
let mut ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(args.max_token))
.with_seed(args.seed);
if let Some(threads) = args.threads {
ctx_params = ctx_params.with_n_threads(threads);
}
if let Some(threads_batch) = args.threads_batch.or(args.threads) {
ctx_params = ctx_params.with_n_threads_batch(threads_batch);
}
Returning the LLM
Instance
Finally, the configured fields are assembled into a new LLM
instance:
Self {
model,
backend,
ctx_params,
mode: args.mode,
history: String::new(),
max_token: args.max_token,
verbose: args.verbose,
}
The history
is initialized as an empty string, which is useful for conversational interactions in Chat
mode. The max_token
and verbose
flags are set based on the user input.
Generation
The generate_chat
function in this implementation is designed to generate responses in either a chat or text completion mode. It handles various tasks essential for text generation, including formatting inputs, managing tokenization and decoding, and outputting the generated text. Let’s go through each component to see how it works.
Formatting the Input Based on Mode
The function starts by formatting the input based on whether the LLM
instance is in Chat
or Completion
mode.
let input_to_model = match self.mode {
Mode::Chat => {
let input_formatted = format!(
"<|start_header_id|>user<|end_header_id|>{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
prompt
);
self.history.push_str(input_formatted.as_str());
flag_chat = true;
&self.history
}
Mode::Completion => {
&format!("<|begin_of_text|>{}", prompt)
}
};
- Chat Mode: It appends the prompt to a conversation history that includes user and assistant markers. This structured approach keeps track of the conversation context and appends new responses to
self.history
. - Completion Mode: It prepares a straightforward prompt without conversation markers.
Tokenizing the Input
The next step tokenizes the formatted prompt, converting it into a list of tokens that the model can process.
let tokens_list = ctx.model
.str_to_token(&input_to_model, AddBos::Never)
.expect(format!("failed to tokenize {}", input_to_model).as_str());
Here, AddBos::Never
ensures that no beginning-of-sequence token is added, as the prompt format is controlled separately.
Validating Token Length
To prevent the model from running out of memory, the function checks that the token count doesn’t exceed the max_token
limit:
if tokens_list.len() >= max_token as usize {
panic!("the prompt is too long, it has more tokens than max_token ({max_token})")
}
Preparing the LlamaBatch
for Decoding
The function then sets up a LlamaBatch
, which is used to manage the tokens being processed.
let mut batch = LlamaBatch::new(n_cxt as usize, 1);
Each token in tokens_list
is added to this batch, with a flag indicating the last token in the sequence for decoding.
for (i, token) in tokens_list.into_iter().enumerate() {
let is_last = i == last_index;
batch.add(token, i as i32, &[0], is_last).unwrap();
}
Clearing the KV Cache for Completion Mode
To prevent older tokens from influencing new ones in completion mode, the function clears the KV cache:
ctx.clear_kv_cache();
Decoding Loop
The main loop continues decoding tokens until reaching the max_token
limit. During each iteration, it:
Samples a New Token: The model generates a list of candidate tokens, from which the most probable token is chosen.
let candidates = ctx.candidates(); let candidates_p = LlamaTokenDataArray::from_iter(candidates, false); let new_token_id = ctx.sample_token_greedy(candidates_p);
Checks for End of Stream: If the token is an end-of-generation (EOG) marker, the loop breaks.
if ctx.model.is_eog_token(new_token_id) { break; }
Decodes and Outputs the Token: The selected token ID is converted into a string, appended to the output, and flushed to standard output.
let output_bytes = ctx.model.token_to_bytes(new_token_id, Special::Tokenize).unwrap(); let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false); print!("{output_string}"); llm_output.push_str(output_string.as_str()); std::io::stdout().flush().unwrap();
Prepares for Next Iteration: The
LlamaBatch
is cleared, and the new token is added to continue decoding.
Finalizing the Output
After decoding, if in chat mode, the output is appended to self.history
, preserving context for future exchanges:
if flag_chat {
self.history.push_str(llm_output_formatted.as_str());
}
Logging Decoding Speed (Optional)
If verbose
mode is enabled, the function calculates the speed of token decoding and outputs it.
if self.verbose {
eprintln!(
"\n[decoded {} tokens in {:.2} s, speed {:.2} t/s]",
n_decode,
duration.as_secs_f32(),
n_decode as f32 / duration.as_secs_f32()
);
}
This part is helpful for profiling performance, particularly when optimizing for response time.
The main function
The main
function is the entry point for this Rust-based Llama chatbot application. It initializes the model, creates a context for inference, and then enters a loop to handle continuous user input. Let’s break down its components to understand how it operates.
Parsing Command-Line Arguments
let args: Args = Args::parse();
This line uses the clap
library to parse command-line arguments based on the Args
struct defined earlier.
Initializing the LLM Instance
let mut rllm: LLM = LLM::new(args);
Here, we create an instance of LLM
using the parsed arguments. The LLM::new
function does the heavy lifting of loading the model, setting parameters, and initializing the model backend.
Creating the Model Context
let binding = rllm.model.clone();
let mut ctx = binding
.new_context(&rllm.backend, rllm.ctx_params.clone())
.expect("failed to create context");
The context is necessary for token-based decoding and inference with Llama.
The context creation here can feel a bit detached from the LLM
struct itself. Since the context setup is specific to a particular input session, it could be made an internal component of LLM
, initialized within the struct. But I had lifetime issues!
Main Interaction Loop
The function then enters a loop where it continuously takes user input and generates responses until the user exits by submitting an empty input.
let mut input = String::new();
print!("Assistant: How can I help you today?\n");
loop {
input.clear();
println!("\nYou: ");
std::io::stdout().flush().unwrap();
std::io::stdin().read_line(&mut input).unwrap();
let input = input.trim();
if input.is_empty() {
break;
}
rllm.generate_chat(&mut ctx, &input);
}
- Clearing Input: Before each prompt,
input.clear()
ensures any leftover data from the previous iteration is removed. - Prompting the User: The program prints
You:
as a prompt to the user to type their input.flush()
ensures this prompt displays immediately. - Reading and Trimming Input: The user’s input is read from
stdin
and trimmed to remove extra whitespace. - Exiting on Empty Input: If the user submits an empty line, the loop breaks, ending the program.
- Generating Response: The
generate_chat
function ofLLM
is called with the context and user input to generate and display a response from the model.
For future refinements, moving context handling inside the LLM
struct could streamline the design, making it more self-contained.
GitHub Repository
The complete implementation is available on GitHub.
Reference
[1] llama-cpp-rs
[2] ChatGPT