Skip to content

Instantly share code, notes, and snippets.

@segeljakt
Last active September 25, 2024 12:43
Show Gist options
  • Select an option

  • Save segeljakt/25e8f60af5cf2997827a28d88aa1202e to your computer and use it in GitHub Desktop.

Select an option

Save segeljakt/25e8f60af5cf2997827a28d88aa1202e to your computer and use it in GitHub Desktop.
use anyhow::{Error, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::Cache;
use candle_transformers::models::llama::Llama;
use candle_transformers::models::llama::LlamaConfig;
use candle_transformers::models::llama::LlamaEosToks;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
const EOS_TOKEN: &str = "</s>";
const REPEAT_PENALTY: f32 = 1.1;
const REPEAT_LAST_N: usize = 128;
const SEED: u64 = 299792458;
const SAMPLE_LEN: usize = 10000;
const ADD_SPECIAL_TOKENS: bool = true;
const SKIP_SPECIAL_TOKENS: bool = true;
const USE_KV_CACHE: bool = true;
const USE_FLASH_ATTENTION: bool = false;
pub struct Chat {
model: Llama,
logits_processor: LogitsProcessor,
cache: Cache,
tokenizer: Tokenizer,
device: Device,
eos_token_id: Option<LlamaEosToks>,
tokens: Vec<u32>,
index: usize,
}
impl Chat {
pub fn new() -> Result<Self> {
let device = Device::new_metal(0)?;
let dtype = DType::F16;
let api = Api::new()?;
let api = api.repo(Repo::with_revision(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
RepoType::Model,
"main".to_string(),
));
let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(USE_FLASH_ATTENTION);
let filenames = vec![api.get("model.safetensors")?];
let cache = Cache::new(USE_KV_CACHE, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Llama::load(vb, &config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(Error::msg)?;
let eos_token_id = config
.eos_token_id
.or_else(|| tokenizer.token_to_id(EOS_TOKEN).map(LlamaEosToks::Single));
let logits_processor = LogitsProcessor::from_sampling(SEED, Sampling::ArgMax);
Ok(Self {
model,
tokenizer,
logits_processor,
eos_token_id,
cache,
device,
tokens: Vec::new(),
index: 0,
})
}
pub fn run(&mut self, prompt: &str) -> Result<String> {
self.tokens.extend(
self.tokenizer
.encode(prompt, ADD_SPECIAL_TOKENS)
.map_err(Error::msg)?
.get_ids(),
);
for _ in 0..SAMPLE_LEN {
let tokens_slice = &self.tokens[self.index..];
let input = Tensor::new(tokens_slice, &self.device)?.unsqueeze(0)?;
let logits = self
.model
.forward(&input, self.index, &mut self.cache)?
.squeeze(0)?;
let logits = candle_transformers::utils::apply_repeat_penalty(
&logits,
REPEAT_PENALTY,
&self.tokens[self.tokens.len().saturating_sub(REPEAT_LAST_N)..],
)?;
self.index += tokens_slice.len();
let next_token = self.logits_processor.sample(&logits)?;
self.tokens.push(next_token);
if self.is_eos_token(next_token) {
break;
}
}
let output = self
.tokenizer
.decode(&self.tokens, SKIP_SPECIAL_TOKENS)
.map_err(Error::msg)?;
Ok(output)
}
fn is_eos_token(&self, token: u32) -> bool {
matches!(self.eos_token_id, Some(LlamaEosToks::Single(id)) if token == id)
|| matches!(self.eos_token_id, Some(LlamaEosToks::Multiple(ref ids)) if ids.contains(&token))
}
}
fn main() {
let mut ctx = Chat::new().unwrap();
println!("{}", ctx.run("Hello my name is").unwrap());
println!("{}", ctx.run("Today").unwrap());
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment