Last active
September 25, 2024 12:43
-
-
Save segeljakt/25e8f60af5cf2997827a28d88aa1202e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use anyhow::{Error, Result}; | |
| use candle::{DType, Device, Tensor}; | |
| use candle_nn::VarBuilder; | |
| use candle_transformers::generation::{LogitsProcessor, Sampling}; | |
| use candle_transformers::models::llama::Cache; | |
| use candle_transformers::models::llama::Llama; | |
| use candle_transformers::models::llama::LlamaConfig; | |
| use candle_transformers::models::llama::LlamaEosToks; | |
| use hf_hub::{api::sync::Api, Repo, RepoType}; | |
| use tokenizers::Tokenizer; | |
| const EOS_TOKEN: &str = "</s>"; | |
| const REPEAT_PENALTY: f32 = 1.1; | |
| const REPEAT_LAST_N: usize = 128; | |
| const SEED: u64 = 299792458; | |
| const SAMPLE_LEN: usize = 10000; | |
| const ADD_SPECIAL_TOKENS: bool = true; | |
| const SKIP_SPECIAL_TOKENS: bool = true; | |
| const USE_KV_CACHE: bool = true; | |
| const USE_FLASH_ATTENTION: bool = false; | |
| pub struct Chat { | |
| model: Llama, | |
| logits_processor: LogitsProcessor, | |
| cache: Cache, | |
| tokenizer: Tokenizer, | |
| device: Device, | |
| eos_token_id: Option<LlamaEosToks>, | |
| tokens: Vec<u32>, | |
| index: usize, | |
| } | |
| impl Chat { | |
| pub fn new() -> Result<Self> { | |
| let device = Device::new_metal(0)?; | |
| let dtype = DType::F16; | |
| let api = Api::new()?; | |
| let api = api.repo(Repo::with_revision( | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), | |
| RepoType::Model, | |
| "main".to_string(), | |
| )); | |
| let tokenizer_filename = api.get("tokenizer.json")?; | |
| let config_filename = api.get("config.json")?; | |
| let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; | |
| let config = config.into_config(USE_FLASH_ATTENTION); | |
| let filenames = vec![api.get("model.safetensors")?]; | |
| let cache = Cache::new(USE_KV_CACHE, dtype, &config, &device)?; | |
| let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; | |
| let model = Llama::load(vb, &config)?; | |
| let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(Error::msg)?; | |
| let eos_token_id = config | |
| .eos_token_id | |
| .or_else(|| tokenizer.token_to_id(EOS_TOKEN).map(LlamaEosToks::Single)); | |
| let logits_processor = LogitsProcessor::from_sampling(SEED, Sampling::ArgMax); | |
| Ok(Self { | |
| model, | |
| tokenizer, | |
| logits_processor, | |
| eos_token_id, | |
| cache, | |
| device, | |
| tokens: Vec::new(), | |
| index: 0, | |
| }) | |
| } | |
| pub fn run(&mut self, prompt: &str) -> Result<String> { | |
| self.tokens.extend( | |
| self.tokenizer | |
| .encode(prompt, ADD_SPECIAL_TOKENS) | |
| .map_err(Error::msg)? | |
| .get_ids(), | |
| ); | |
| for _ in 0..SAMPLE_LEN { | |
| let tokens_slice = &self.tokens[self.index..]; | |
| let input = Tensor::new(tokens_slice, &self.device)?.unsqueeze(0)?; | |
| let logits = self | |
| .model | |
| .forward(&input, self.index, &mut self.cache)? | |
| .squeeze(0)?; | |
| let logits = candle_transformers::utils::apply_repeat_penalty( | |
| &logits, | |
| REPEAT_PENALTY, | |
| &self.tokens[self.tokens.len().saturating_sub(REPEAT_LAST_N)..], | |
| )?; | |
| self.index += tokens_slice.len(); | |
| let next_token = self.logits_processor.sample(&logits)?; | |
| self.tokens.push(next_token); | |
| if self.is_eos_token(next_token) { | |
| break; | |
| } | |
| } | |
| let output = self | |
| .tokenizer | |
| .decode(&self.tokens, SKIP_SPECIAL_TOKENS) | |
| .map_err(Error::msg)?; | |
| Ok(output) | |
| } | |
| fn is_eos_token(&self, token: u32) -> bool { | |
| matches!(self.eos_token_id, Some(LlamaEosToks::Single(id)) if token == id) | |
| || matches!(self.eos_token_id, Some(LlamaEosToks::Multiple(ref ids)) if ids.contains(&token)) | |
| } | |
| } | |
| fn main() { | |
| let mut ctx = Chat::new().unwrap(); | |
| println!("{}", ctx.run("Hello my name is").unwrap()); | |
| println!("{}", ctx.run("Today").unwrap()); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment