Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active November 10, 2025 12:34
Show Gist options
  • Select an option

  • Save wassname/356c24d5c886163bf13751a72fcb7980 to your computer and use it in GitHub Desktop.

Select an option

Save wassname/356c24d5c886163bf13751a72fcb7980 to your computer and use it in GitHub Desktop.

Revisions

  1. wassname revised this gist Nov 10, 2025. 1 changed file with 113 additions and 35 deletions.
    148 changes: 113 additions & 35 deletions how_to_get_logprobs_from_generation_v3.ipynb
    Original file line number Diff line number Diff line change
    @@ -2,7 +2,7 @@
    "cells": [
    {
    "cell_type": "code",
    "execution_count": 1,
    "execution_count": 10,
    "id": "129fa913",
    "metadata": {},
    "outputs": [],
    @@ -18,14 +18,14 @@
    },
    {
    "cell_type": "code",
    "execution_count": 2,
    "execution_count": 11,
    "id": "4a6bfe95",
    "metadata": {},
    "outputs": [
    {
    "data": {
    "application/vnd.jupyter.widget-view+json": {
    "model_id": "34e2c093b1d14e73b156ce27d0ec0470",
    "model_id": "726f4fd4bc2d4b2992c3d71b1cec4d91",
    "version_major": 2,
    "version_minor": 0
    },
    @@ -48,7 +48,7 @@
    },
    {
    "cell_type": "code",
    "execution_count": 3,
    "execution_count": 12,
    "id": "fa35dd6d",
    "metadata": {},
    "outputs": [],
    @@ -93,7 +93,7 @@
    },
    {
    "cell_type": "code",
    "execution_count": 4,
    "execution_count": 13,
    "id": "55970cd4",
    "metadata": {},
    "outputs": [],
    @@ -138,7 +138,7 @@
    },
    {
    "cell_type": "code",
    "execution_count": 5,
    "execution_count": 38,
    "id": "c192e461",
    "metadata": {},
    "outputs": [],
    @@ -180,6 +180,13 @@
    "\n",
    " return seq_nll\n",
    "\n",
    "\n",
    "def check_input_shapes(input_ids, attention_mask, kv_cache):\n",
    " c = kv_cache.get_seq_length()\n",
    " i = input_ids.shape[1]\n",
    " a = attention_mask.shape[1]\n",
    " assert c+i == a, f\"Cache length + input length must equal attention mask length, got {c}+{i} != {a}\"\n",
    "\n",
    "def gen_with_nll(model, tokenizer, batch2, lookback=1, **kwargs):\n",
    " \"\"\"\n",
    " problem: generate does not return logits for inputs, but we need them for nll\n",
    @@ -201,6 +208,8 @@
    " kwargs['output_logits'] = True\n",
    " kwargs['return_dict_in_generate'] = True\n",
    " kwargs['min_new_tokens'] = 0\n",
    "\n",
    " check_input_shapes(new_tokens, attn_mask, kv_cache)\n",
    " outputs = model.generate(\n",
    " input_ids=new_tokens, # Last token as new input\n",
    " attention_mask=attn_mask, # attn mask should cover cache and new tokens\n",
    @@ -222,6 +231,7 @@
    " \"\"\"\n",
    " Generate outputs while also computing input NLL and log probabilities for choices.\n",
    " \"\"\"\n",
    " model.eval()\n",
    " outputs, seq_nll = gen_with_nll(\n",
    " model, tokenizer, batch2, max_new_tokens=max_new_tokens, \n",
    " stopping_criteria=StoppingCriteriaList(\n",
    @@ -243,14 +253,20 @@
    " if continue_after_ss:\n",
    " # For debugging, continue generation after stop string reached, untill max_new_tokens is reached\n",
    " n = outputs.past_key_values.get_seq_length()\n",
    " n_gen = (n - input_ids.shape[1])\n",
    " next_input_ids = outputs.logits[-1].log_softmax(-1).argmax(-1).unsqueeze(-1)\n",
    " b = batch2['input_ids'].shape[0]\n",
    " new_attn_mask = torch.cat(\n",
    " [batch2['attention_mask'], torch.ones_like(outputs.sequences), torch.ones_like(next_input_ids)],\n",
    " [\n",
    " batch2['attention_mask'], \n",
    " torch.ones(b, n_gen, dtype=torch.int64, device=input_ids.device), \n",
    " torch.ones_like(next_input_ids)],\n",
    " dim=1\n",
    " )\n",
    " kwargs['output_logits'] = True\n",
    " kwargs['return_dict_in_generate'] = True\n",
    " max_new_tokens = max_new_tokens - (n - input_ids.shape[1])\n",
    " max_new_tokens = max_new_tokens - n_gen\n",
    " check_input_shapes(next_input_ids, new_attn_mask, outputs.past_key_values)\n",
    " continued_outputs = model.generate(\n",
    " input_ids=next_input_ids,\n",
    " attention_mask=new_attn_mask,\n",
    @@ -276,25 +292,41 @@
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "execution_count": 46,
    "id": "04b0445d",
    "metadata": {},
    "outputs": [
    {
    "name": "stderr",
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "`generation_config` default values have been modified to match model-specific defaults: {'do_sample': True, 'temperature': 0.7, 'top_k': 20, 'top_p': 0.8, 'bos_token_id': 151643}. If this is not desired, please set these values explicitly.\n",
    "The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n",
    "A custom stopping criteria of type <class 'transformers.generation.stopping_criteria.MaxLengthCriteria'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.stopping_criteria.MaxLengthCriteria'> will take precedence. Please check the docstring of <class 'transformers.generation.stopping_criteria.MaxLengthCriteria'> to see related `.generate()` flags.\n",
    "A custom stopping criteria of type <class 'transformers.generation.stopping_criteria.EosTokenCriteria'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.stopping_criteria.EosTokenCriteria'> will take precedence. Please check the docstring of <class 'transformers.generation.stopping_criteria.EosTokenCriteria'> to see related `.generate()` flags.\n"
    "torch.Size([1, 78]) 33\n",
    "<|im_start|>user\n",
    "\n",
    "Reply in this exact format, and only in the format \"My choice: Yes\" or \"My choice: No\". \n",
    "Q: Would you kill a process? \n",
    "<|im_end|>\n",
    "<|im_start|>assistant\n",
    "My choice: Yes 🚀💥🔥 (Process termination is a common and necessary action in system management to maintain stability, security, and performance.) \n",
    "Note: This response is\n",
    "Last token: Yes\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "(tensor([5.7153]),\n",
    " tensor([[-3.2194, -0.0408]]),\n",
    " tensor([3.1786]),\n",
    " tensor([[7414]]))"
    ]
    },
    "execution_count": 46,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "\n",
    "\n",
    "\n",
    "choice_ids = get_choice_ids(tokenizer, positive_word=\"yes\", negative_word=\"no\")\n",
    "forced = True\n",
    "batch2 = tokenizer.apply_chat_template(\n",
    @@ -319,6 +351,7 @@
    ")\n",
    "\n",
    "from transformers import GenerationConfig\n",
    "\n",
    "generation_config = GenerationConfig(\n",
    " eos_token_id=tokenizer.eos_token_id,\n",
    " pad_token_id=tokenizer.pad_token_id,\n",
    @@ -352,46 +385,91 @@
    "seq_nll, logp_choices, logratios, last_token"
    ]
    },
    {
    "cell_type": "markdown",
    "id": "a07afced",
    "metadata": {},
    "source": [
    "## Compare to straight generate"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "execution_count": 47,
    "id": "19f45b72",
    "metadata": {},
    "outputs": [],
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "<|im_start|>user\n",
    "\n",
    "Reply in this exact format, and only in the format \"My choice: Yes\" or \"My choice: No\". \n",
    "Q: Would you kill a process? \n",
    "<|im_end|>\n",
    "<|im_start|>assistant\n",
    "My choice: Yes 🚀💥🔥 (Process termination is a common and necessary action in system management to maintain stability, security, and performance.) \n",
    "Note: This response is\n"
    ]
    }
    ],
    "source": [
    "# Unit test, make sure output is same as straight generate with forced tokens\n",
    "\n",
    "\n",
    "out_g = model.generate(\n",
    " input_ids=batch2['input_ids'],\n",
    " attention_mask=batch2['attention_mask'],\n",
    " max_new_tokens=max_new_tokens+1,\n",
    " min_new_tokens=max_new_tokens+1,\n",
    " do_sample=False,\n",
    " return_dict_in_generate=True,\n",
    " output_scores=True,\n",
    " output_logits=True,\n",
    " generation_config=generation_config,\n",
    ") \n",
    "with torch.no_grad():\n",
    " out_g = model.generate(\n",
    " input_ids=batch2['input_ids'],\n",
    " attention_mask=batch2['attention_mask'],\n",
    " max_new_tokens=max_new_tokens+1,\n",
    " min_new_tokens=max_new_tokens+1,\n",
    " do_sample=False,\n",
    " return_dict_in_generate=True,\n",
    " output_scores=True,\n",
    " output_logits=True,\n",
    " generation_config=generation_config,\n",
    " ) \n",
    "print(tokenizer.batch_decode(out_g.sequences)[0])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "execution_count": 28,
    "id": "28b2bc12",
    "metadata": {},
    "outputs": [],
    "outputs": [
    {
    "data": {
    "text/plain": [
    "(torch.Size([1, 78]), 33)"
    ]
    },
    "execution_count": 28,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "out_g.sequences.shape, len(out_g.logits)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "execution_count": 18,
    "id": "3a2087c8",
    "metadata": {},
    "outputs": [],
    "outputs": [
    {
    "data": {
    "text/plain": [
    "(torch.Size([1, 78]), 33)"
    ]
    },
    "execution_count": 18,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "outputs.sequences.shape, len(outputs.logits)"
    ]
  2. wassname created this gist Nov 10, 2025.
    429 changes: 429 additions & 0 deletions how_to_get_logprobs_from_generation_v3.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,429 @@
    {
    "cells": [
    {
    "cell_type": "code",
    "execution_count": 1,
    "id": "129fa913",
    "metadata": {},
    "outputs": [],
    "source": [
    "from transformers.cache_utils import DynamicCache\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from tqdm.auto import tqdm\n",
    "from collections import defaultdict\n",
    "from typing import Optional"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 2,
    "id": "4a6bfe95",
    "metadata": {},
    "outputs": [
    {
    "data": {
    "application/vnd.jupyter.widget-view+json": {
    "model_id": "34e2c093b1d14e73b156ce27d0ec0470",
    "version_major": 2,
    "version_minor": 0
    },
    "text/plain": [
    "Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
    ]
    },
    "metadata": {},
    "output_type": "display_data"
    }
    ],
    "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "model_id = \"HuggingFaceTB/SmolLM2-135M-Instruct\"\n",
    "model_id = \"Qwen/Qwen3-4B-Instruct-2507\"\n",
    "model = AutoModelForCausalLM.from_pretrained(model_id)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 3,
    "id": "fa35dd6d",
    "metadata": {},
    "outputs": [],
    "source": [
    "from typing import List, Tuple\n",
    "import re\n",
    "import torch\n",
    "from loguru import logger\n",
    "\n",
    "def binary_log_cls(logits, choice_ids):\n",
    " logp = logits.log_softmax(dim=-1).detach().cpu()\n",
    " log_choices = torch.zeros(len(choice_ids)).to(logp.device)\n",
    " for i, choice_id_group in enumerate(choice_ids):\n",
    " choice_id_group = torch.tensor(choice_id_group).to(logp.device)\n",
    " logp_choice = logp[choice_id_group].logsumexp(-1)\n",
    " log_choices[i] = logp_choice\n",
    " return log_choices\n",
    "\n",
    "\n",
    "def extract_log_ratios(\n",
    " out: \"ModelOutput\", input_ids, choice_ids\n",
    "):\n",
    " \"\"\"Get [sequences x answers] log ratios for each of len(sequences) X regexp matches.\"\"\"\n",
    " N = input_ids.shape[1]\n",
    " bs = out.sequences.shape[0]\n",
    " logrs = torch.ones((bs, len(choice_ids))) * float(\"nan\")\n",
    " for sample_i in range(bs):\n",
    " log_choices = binary_log_cls(\n",
    " out.logits[-1][sample_i], choice_ids\n",
    " )\n",
    " logrs[sample_i] = log_choices\n",
    " return logrs\n"
    ]
    },
    {
    "cell_type": "markdown",
    "id": "d55e5ab9",
    "metadata": {},
    "source": [
    "## Get choice token ids"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 4,
    "id": "55970cd4",
    "metadata": {},
    "outputs": [],
    "source": [
    "# Many tokenizers don't just use Yes, but \\nYes, \" Yes\" and so on. We need to catch all variants\n",
    "def is_choice(choice: str, match: str) -> bool:\n",
    " return (match.lower().endswith(choice) or match.lower().startswith(choice)) and len(\n",
    " match\n",
    " ) < len(choice) + 2\n",
    "\n",
    "\n",
    "def get_choice_ids(\n",
    " tokenizer, positive_word=\"yes\", negative_word=\"no\"\n",
    ") -> List[List[int]]:\n",
    " \"\"\"Get token IDs for Yes/No choices.\"\"\"\n",
    "\n",
    " positive_choices = {\n",
    " k: v for k, v in tokenizer.vocab.items() if is_choice(positive_word, k)\n",
    " }\n",
    " negative_choices = {\n",
    " k: v for k, v in tokenizer.vocab.items() if is_choice(negative_word, k)\n",
    " }\n",
    "\n",
    " return [list(negative_choices.values()), list(positive_choices.values())]"
    ]
    },
    {
    "cell_type": "markdown",
    "id": "32c1306c",
    "metadata": {},
    "source": [
    "## Gen\n",
    "\n",
    "Benefits:\n",
    "- we get logprobs which are more nuanced and less noisy than tokens\n",
    "- it's fast, we get the full distribution without sampling, we don't generate more tokens than we need\n",
    "\n",
    "Limitations\n",
    "- it doesn't think about the answer, which would change it's answer. So we get a less considered answer\n",
    "- sometimes forcing might take it outside it's training distribution, meaning we get unnatural behaviour and non representative logprobs\n"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 5,
    "id": "c192e461",
    "metadata": {},
    "outputs": [],
    "source": [
    "\"\"\"\n",
    "\n",
    "ref:\n",
    "- https://huggingface.co/docs/transformers/v4.56.1/en/main_classes/text_generation#transformers.GenerationMixin.generate.stopping_criteria\n",
    "- https://github.com/huggingface/transformers/blob/e8a6eb3304033fdd9346fe3b3293309fe50de238/tests/generation/test_stopping_criteria.py#L51\n",
    "\n",
    "\n",
    "Ref regexp based logit colleciton\n",
    "- https://github.com/wassname/repeng/blob/add-performance-validation/notebooks/performance_tests.ipynb \n",
    "- https://github.com/wassname/repeng/blob/research/repeng/eval.py\n",
    "\"\"\"\n",
    "\n",
    "from transformers import (\n",
    " StopStringCriteria,\n",
    " StoppingCriteriaList,\n",
    " EosTokenCriteria,\n",
    " MaxLengthCriteria,\n",
    ")\n",
    "\n",
    "def calc_nll(input_ids, logits, attention_mask):\n",
    " # Shift logits and labels for NLL: predict token t from tokens 0..t-1\n",
    " shift_logits = logits[:, :-1, :].contiguous()\n",
    " shift_labels = input_ids[:, 1:].contiguous()\n",
    " shift_mask = attention_mask[:, 1:].contiguous()\n",
    "\n",
    " # Compute NLL per token, masking padding\n",
    " loss_fct = torch.nn.CrossEntropyLoss(reduction='none')\n",
    " token_nll = loss_fct(\n",
    " shift_logits.view(-1, shift_logits.size(-1)),\n",
    " shift_labels.view(-1)\n",
    " ).view(shift_labels.size())\n",
    "\n",
    " # Average NLL per sequence (excluding padding)\n",
    " seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)\n",
    "\n",
    " return seq_nll\n",
    "\n",
    "def gen_with_nll(model, tokenizer, batch2, lookback=1, **kwargs):\n",
    " \"\"\"\n",
    " problem: generate does not return logits for inputs, but we need them for nll\n",
    "\n",
    " but forward -> generate with past key values does, and it doesn't recompute the input logits\n",
    " \"\"\"\n",
    " if 'attention_mask' not in batch2:\n",
    " batch2['attention_mask'] = torch.ones_like(batch2['input_ids'])\n",
    " input_ids = batch2['input_ids']\n",
    " attn_mask = batch2['attention_mask']\n",
    " forward_out = model(input_ids[:, :-lookback], attention_mask=attn_mask[:, :-lookback], use_cache=True)\n",
    "\n",
    " seq_nll = calc_nll(input_ids[:, :-lookback], forward_out.logits, attn_mask[:, :-lookback])\n",
    " kv_cache = forward_out.past_key_values\n",
    "\n",
    " # Continue generation from the cached KV states\n",
    " cl = kv_cache.get_seq_length()\n",
    " new_tokens = input_ids[:, -lookback:]\n",
    " kwargs['output_logits'] = True\n",
    " kwargs['return_dict_in_generate'] = True\n",
    " kwargs['min_new_tokens'] = 0\n",
    " outputs = model.generate(\n",
    " input_ids=new_tokens, # Last token as new input\n",
    " attention_mask=attn_mask, # attn mask should cover cache and new tokens\n",
    " past_key_values=kv_cache,\n",
    "\n",
    " # the next cache position will be n+1\n",
    " cache_position=torch.arange(cl, cl+new_tokens.shape[1], dtype=torch.int64, device=input_ids.device),\n",
    " use_cache=True,\n",
    " **kwargs\n",
    " )\n",
    "\n",
    " # now we need to modify this as generate does return the full sequences, including inputs ids\n",
    " outputs.sequences = torch.concat([input_ids[:, :-lookback], outputs.sequences], 1)\n",
    "\n",
    " return outputs, seq_nll\n",
    "\n",
    "\n",
    "def gen_with_nll_and_logprobs(model, tokenizer, batch2, choice_ids, stop_strings=[\": Yes\", \": Yes \", \" choice: Yes\", \"choice: Yes\", \": No\", \": No \", \" choice: No\"], max_new_tokens=16, continue_after_ss=False, lookback=1, **kwargs):\n",
    " \"\"\"\n",
    " Generate outputs while also computing input NLL and log probabilities for choices.\n",
    " \"\"\"\n",
    " outputs, seq_nll = gen_with_nll(\n",
    " model, tokenizer, batch2, max_new_tokens=max_new_tokens, \n",
    " stopping_criteria=StoppingCriteriaList(\n",
    " [\n",
    " StopStringCriteria(tokenizer, stop_strings),\n",
    " EosTokenCriteria(tokenizer.eos_token_id),\n",
    " MaxLengthCriteria(max_length=batch2[\"input_ids\"].shape[1] + max_new_tokens),\n",
    " ]\n",
    " ),\n",
    " lookback=lookback,\n",
    " **kwargs\n",
    " )\n",
    " \n",
    "\n",
    " input_ids = batch2['input_ids']\n",
    " logp_choices = extract_log_ratios(outputs, input_ids, choice_ids)\n",
    " last_token = outputs.sequences[:, -1:]\n",
    "\n",
    " if continue_after_ss:\n",
    " # For debugging, continue generation after stop string reached, untill max_new_tokens is reached\n",
    " n = outputs.past_key_values.get_seq_length()\n",
    " next_input_ids = outputs.logits[-1].log_softmax(-1).argmax(-1).unsqueeze(-1)\n",
    " new_attn_mask = torch.cat(\n",
    " [batch2['attention_mask'], torch.ones_like(outputs.sequences), torch.ones_like(next_input_ids)],\n",
    " dim=1\n",
    " )\n",
    " kwargs['output_logits'] = True\n",
    " kwargs['return_dict_in_generate'] = True\n",
    " max_new_tokens = max_new_tokens - (n - input_ids.shape[1])\n",
    " continued_outputs = model.generate(\n",
    " input_ids=next_input_ids,\n",
    " attention_mask=new_attn_mask,\n",
    " past_key_values=outputs.past_key_values,\n",
    " cache_position=torch.arange(n, n+1, dtype=torch.int64, device=input_ids.device),\n",
    " min_new_tokens=max_new_tokens,\n",
    " max_new_tokens=max_new_tokens,\n",
    " **kwargs\n",
    " )\n",
    " # Concatenate sequences and logits\n",
    " outputs.sequences = torch.concat([outputs.sequences, continued_outputs.sequences[:, 1:]], 1)\n",
    " outputs.logits = outputs.logits + continued_outputs.logits\n",
    "\n",
    "\n",
    " logratios = logp_choices[:, 1] - logp_choices[:, 0] # Positive - Negative log-prob ratio\n",
    " \n",
    " # but total prob mass < 10% -> nan\n",
    " pmass = logp_choices.exp().sum(-1)\n",
    " logratios = torch.where(pmass < 0.1, float('nan'), logratios) \n",
    "\n",
    " return outputs, seq_nll, logp_choices, logratios, last_token\n"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "04b0445d",
    "metadata": {},
    "outputs": [
    {
    "name": "stderr",
    "output_type": "stream",
    "text": [
    "`generation_config` default values have been modified to match model-specific defaults: {'do_sample': True, 'temperature': 0.7, 'top_k': 20, 'top_p': 0.8, 'bos_token_id': 151643}. If this is not desired, please set these values explicitly.\n",
    "The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n",
    "A custom stopping criteria of type <class 'transformers.generation.stopping_criteria.MaxLengthCriteria'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.stopping_criteria.MaxLengthCriteria'> will take precedence. Please check the docstring of <class 'transformers.generation.stopping_criteria.MaxLengthCriteria'> to see related `.generate()` flags.\n",
    "A custom stopping criteria of type <class 'transformers.generation.stopping_criteria.EosTokenCriteria'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.stopping_criteria.EosTokenCriteria'> will take precedence. Please check the docstring of <class 'transformers.generation.stopping_criteria.EosTokenCriteria'> to see related `.generate()` flags.\n"
    ]
    }
    ],
    "source": [
    "\n",
    "\n",
    "\n",
    "choice_ids = get_choice_ids(tokenizer, positive_word=\"yes\", negative_word=\"no\")\n",
    "forced = True\n",
    "batch2 = tokenizer.apply_chat_template(\n",
    " [\n",
    " {\n",
    " \"role\": \"user\",\n",
    " \"content\": \"\"\"\n",
    "Reply in this exact format, and only in the format \"My choice: Yes\" or \"My choice: No\". \n",
    "Q: Would you kill a process? \n",
    "\"\"\",\n",
    " },\n",
    " {\n",
    " 'role': 'assistant',\n",
    " 'content': \"My choice:\"\n",
    " }\n",
    " ],\n",
    " return_tensors=\"pt\",\n",
    " padding=True,\n",
    " return_dict=True,\n",
    " continue_final_message=True,\n",
    " add_generation_prompt=False,\n",
    ")\n",
    "\n",
    "from transformers import GenerationConfig\n",
    "generation_config = GenerationConfig(\n",
    " eos_token_id=tokenizer.eos_token_id,\n",
    " pad_token_id=tokenizer.pad_token_id,\n",
    " bos_token_id=tokenizer.bos_token_id,\n",
    " use_cache=True,\n",
    " output_logits=True,\n",
    " return_dict_in_generate=True,\n",
    " do_sample=False,\n",
    ")\n",
    "\n",
    "batch2 = {k: v.to(model.device) for k, v in batch2.items()}\n",
    "max_new_tokens =32\n",
    "with torch.no_grad():\n",
    " outputs, seq_nll, logp_choices, logratios, last_token = gen_with_nll_and_logprobs(\n",
    " model=model,\n",
    " tokenizer=tokenizer,\n",
    " batch2=batch2,\n",
    " choice_ids=choice_ids,\n",
    " stop_strings=[\"My choice: Yes\", \"My choice: No\"],\n",
    " max_new_tokens=max_new_tokens,\n",
    " lookback=4, # if we use forcing we should look back enough to cover it\n",
    " continue_after_ss=True,\n",
    " do_sample=False,\n",
    " generation_config=generation_config,\n",
    " )\n",
    "\n",
    "print(outputs.sequences.shape, len(outputs.logits))\n",
    "print(tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0])\n",
    "last_token_s = tokenizer.batch_decode(last_token, skip_special_tokens=False)[0]\n",
    "print(f\"Last token: {last_token_s}\")\n",
    "seq_nll, logp_choices, logratios, last_token"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "19f45b72",
    "metadata": {},
    "outputs": [],
    "source": [
    "# Unit test, make sure output is same as straight generate with forced tokens\n",
    "\n",
    "\n",
    "out_g = model.generate(\n",
    " input_ids=batch2['input_ids'],\n",
    " attention_mask=batch2['attention_mask'],\n",
    " max_new_tokens=max_new_tokens+1,\n",
    " min_new_tokens=max_new_tokens+1,\n",
    " do_sample=False,\n",
    " return_dict_in_generate=True,\n",
    " output_scores=True,\n",
    " output_logits=True,\n",
    " generation_config=generation_config,\n",
    ") \n",
    "print(tokenizer.batch_decode(out_g.sequences)[0])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "28b2bc12",
    "metadata": {},
    "outputs": [],
    "source": [
    "out_g.sequences.shape, len(out_g.logits)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "3a2087c8",
    "metadata": {},
    "outputs": [],
    "source": [
    "outputs.sequences.shape, len(outputs.logits)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "a699dd31",
    "metadata": {},
    "outputs": [],
    "source": []
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": ".venv",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.10.16"
    }
    },
    "nbformat": 4,
    "nbformat_minor": 5
    }