Last active
July 26, 2025 02:08
-
-
Save wassname/04f0c50a68054f0323f62b0da418daec to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import copy | |
| from transformers import DynamicCache, PreTrainedModel, PreTrainedTokenizerBase | |
| from datasets import Dataset | |
| # how to eval, I couldlook at perplexity on chosen vs rejected in the context of prompt | |
| def get_output_ppx(output, input): | |
| loss_fn = CrossEntropyLoss(reduction="none") | |
| shift_logits = output.logits[:, :-1].contiguous() | |
| shift_labels = input.input_ids[:, 1:].contiguous() | |
| loss = loss_fn(shift_logits.transpose(1, 2), shift_labels) | |
| shift_masks = input.attention_mask[:, 1:].contiguous() # target_masks[:, 1:].contiguous() * | |
| nll = (loss * shift_masks).sum().item() | |
| count = shift_masks.sum().item() | |
| return np.exp(nll / count) | |
| # I could get the logprobs of each yep | |
| @torch.no_grad() | |
| def eval_pref_ds_ppx(model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, ds_pref: Dataset, batch_size: int=2, max_new_tokens: int=128): | |
| """ | |
| Evaluate on a preference dataset. | |
| The relative perplexity of the chosen and rejected completions of a prompt. | |
| """ | |
| results = [] | |
| for batch in tqdm(ds_pref.batch(batch_size), unit="batch"): | |
| # first we cache the prompt | |
| kv_cache = DynamicCache() | |
| inputs1 = tokenizer(batch['prompt'], return_tensors="pt", padding=True, truncation=True, max_length=max_new_tokens//2, return_token_type_ids=False, return_attention_mask=True) | |
| model.forward(**inputs1, past_key_values=kv_cache) | |
| # then we evaluate the perplexity of the accepted and rejected completion | |
| res = {} | |
| for p in ['rejected', 'chosen']: | |
| input = tokenizer(batch[p], return_tensors="pt", padding=True, truncation=True, max_length=max_new_tokens//2, return_token_type_ids=False, return_attention_mask=True) | |
| # we need to update the attention mask to match the kv_cache | |
| input['attention_mask'] = torch.cat([inputs1['attention_mask'], input['attention_mask']], dim=1) | |
| kv_cache2 = copy.deepcopy(kv_cache) | |
| output = model.forward(**input, past_key_values=kv_cache2) | |
| ppx = get_output_ppx(output, input) | |
| res[p] = ppx | |
| results.append(res) | |
| df_results = pd.DataFrame(results) | |
| return df_results | |
| ds_pref = load_dataset("wassname/genies_preferences", name="illegal_dont_help", split="train") | |
| ds_pref = ds_pref.select(range(0, 100)) | |
| df_results = eval_pref_ds_ppx(model, tokenizer, ds_pref, batch_size, max_new_tokens) | |
| df_results.describe() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment