Skip to content

Instantly share code, notes, and snippets.

@TomMarius
Last active September 1, 2025 18:57
Show Gist options
  • Select an option

  • Save TomMarius/0ad2be6db90a40f926b936b216c4652f to your computer and use it in GitHub Desktop.

Select an option

Save TomMarius/0ad2be6db90a40f926b936b216c4652f to your computer and use it in GitHub Desktop.
Minimal Candle Rust text inference example with quantized Qwen3
use candle::{Device, Tensor, quantized::gguf_file};
use candle_transformers::{
generation::{LogitsProcessor, Sampling},
models::quantized_qwen3::ModelWeights as Qwen3,
};
use std::{
fs::File,
io::{self, Write},
};
use tokenizers::Tokenizer;
const PROMPT: &str = "Write a Rust function to calculate the factorial of a given number.";
const MODEL_PATH: &str = "/ai/models/qwen-unsloth-q4.gguf";
const TOKENIZER_PATH: &str = "/ai/models/qwen-tokenizer.json";
fn main() {
let device = Device::new_cuda(0).expect("CUDA required");
let mut model_file = File::open(MODEL_PATH).unwrap();
let model_gguf = gguf_file::Content::read(&mut model_file).unwrap();
let mut model_weights = Qwen3::from_gguf(model_gguf, &mut model_file, &device).unwrap();
let tokenizer = Tokenizer::from_file(TOKENIZER_PATH).unwrap();
let mut tokens = tokenizer
.encode(
format!("<|im_start|>user\n{PROMPT}<|im_end|>\n<|im_start|>assistant\n"),
true,
)
.unwrap()
.get_ids()
.to_vec();
let logits = model_weights
.forward(
&Tensor::new(tokens.as_slice(), &device)
.unwrap()
.unsqueeze(0)
.unwrap(),
0,
)
.unwrap()
.squeeze(0)
.unwrap();
let mut sampler = LogitsProcessor::from_sampling(
0,
Sampling::TopKThenTopP {
temperature: 0.6,
p: 0.8,
k: 20,
},
);
let eos_token_id: u32 = *tokenizer.get_vocab(true).get("<|im_end|>").unwrap();
let mut next_token_id = sampler.sample(&logits).unwrap();
while next_token_id != eos_token_id {
tokens.push(next_token_id);
print!("{}", tokenizer.decode(&[next_token_id], true).unwrap());
io::stdout().flush().unwrap();
let logits = model_weights
.forward(
&Tensor::new(&[next_token_id], &device)
.unwrap()
.unsqueeze(0)
.unwrap(),
tokens.len() - 1,
)
.unwrap()
.squeeze(0)
.unwrap();
next_token_id = sampler.sample(&logits).unwrap();
}
println!();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment