from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache import torch from typing import Optional device = "cuda" # Copied from the gpt-fast repo def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization q = torch.empty_like(probs_sort).exponential_(1) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): logits = logits / max(temperature, 1e-5) if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) pivot = v.select(-1, -1).unsqueeze(-1) logits = torch.where(logits < pivot, -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) return probs def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): probs = logits_to_probs(logits[:, -1], temperature, top_k) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs def decode_one_tokens(model, cur_token, cache_position): logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache = True)[0] new_token = sample(logits,temperature=0.6, top_k=5)[0] return new_token model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16) model = model.to(device).eval() decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead",fullgraph=True) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") prompt = "My favourite condiment is" input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) batch_size, sequence_length = input_ids.shape max_cache_length = 2048 max_new_tokens = 100 model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length) generated_ids = torch.zeros((batch_size, max_new_tokens+sequence_length), dtype = torch.int, device=device) generated_ids[:,:sequence_length] = input_ids cache_position = torch.tensor([sequence_length], device=device) with torch.no_grad(): for i in range(100): if i == 0: # prefill uses vanilla model logits = model(input_ids, cache_position=torch.arange(sequence_length, device=device))[0] input_id = sample(logits, temperature=0.6, top_k=5)[0] generated_ids[:,sequence_length] = input_id[:,0] else: with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): input_id = decode_one_tokens(model, input_id.clone(), cache_position) generated_ids.index_copy_(1, cache_position, input_id) cache_position += 1 print(tokenizer.batch_decode(generated_ids.long())) [" My favourite condiment is ketchup. I know, I know, it's a bit cliche, but there's just something about the sweet and tangy flavour that I can't get enough of. I put it on everything from fries to scrambled eggs to grilled meats. And let's be real, it's the perfect accompaniment to a good old-fashioned burger and fries.\n\nBut ketchup isn't just delicious"]