Last active
September 1, 2025 18:57
-
-
Save TomMarius/0ad2be6db90a40f926b936b216c4652f to your computer and use it in GitHub Desktop.
Minimal Candle Rust text inference example with quantized Qwen3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use candle::{Device, Tensor, quantized::gguf_file}; | |
| use candle_transformers::{ | |
| generation::{LogitsProcessor, Sampling}, | |
| models::quantized_qwen3::ModelWeights as Qwen3, | |
| }; | |
| use std::{ | |
| fs::File, | |
| io::{self, Write}, | |
| }; | |
| use tokenizers::Tokenizer; | |
| const PROMPT: &str = "Write a Rust function to calculate the factorial of a given number."; | |
| const MODEL_PATH: &str = "/ai/models/qwen-unsloth-q4.gguf"; | |
| const TOKENIZER_PATH: &str = "/ai/models/qwen-tokenizer.json"; | |
| fn main() { | |
| let device = Device::new_cuda(0).expect("CUDA required"); | |
| let mut model_file = File::open(MODEL_PATH).unwrap(); | |
| let model_gguf = gguf_file::Content::read(&mut model_file).unwrap(); | |
| let mut model_weights = Qwen3::from_gguf(model_gguf, &mut model_file, &device).unwrap(); | |
| let tokenizer = Tokenizer::from_file(TOKENIZER_PATH).unwrap(); | |
| let mut tokens = tokenizer | |
| .encode( | |
| format!("<|im_start|>user\n{PROMPT}<|im_end|>\n<|im_start|>assistant\n"), | |
| true, | |
| ) | |
| .unwrap() | |
| .get_ids() | |
| .to_vec(); | |
| let logits = model_weights | |
| .forward( | |
| &Tensor::new(tokens.as_slice(), &device) | |
| .unwrap() | |
| .unsqueeze(0) | |
| .unwrap(), | |
| 0, | |
| ) | |
| .unwrap() | |
| .squeeze(0) | |
| .unwrap(); | |
| let mut sampler = LogitsProcessor::from_sampling( | |
| 0, | |
| Sampling::TopKThenTopP { | |
| temperature: 0.6, | |
| p: 0.8, | |
| k: 20, | |
| }, | |
| ); | |
| let eos_token_id: u32 = *tokenizer.get_vocab(true).get("<|im_end|>").unwrap(); | |
| let mut next_token_id = sampler.sample(&logits).unwrap(); | |
| while next_token_id != eos_token_id { | |
| tokens.push(next_token_id); | |
| print!("{}", tokenizer.decode(&[next_token_id], true).unwrap()); | |
| io::stdout().flush().unwrap(); | |
| let logits = model_weights | |
| .forward( | |
| &Tensor::new(&[next_token_id], &device) | |
| .unwrap() | |
| .unsqueeze(0) | |
| .unwrap(), | |
| tokens.len() - 1, | |
| ) | |
| .unwrap() | |
| .squeeze(0) | |
| .unwrap(); | |
| next_token_id = sampler.sample(&logits).unwrap(); | |
| } | |
| println!(); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment