Created
April 15, 2026 03:15
-
-
Save ddh0/5e30297d684282b6af1561d1938a650d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # chat_template_txt.py | |
| # Python 3.12.3 | |
| import sys | |
| import argparse | |
| import datetime | |
| from typing import TextIO, Optional, Literal | |
| from transformers import AutoTokenizer | |
| # ...logger code... | |
| class SupportsWriteAndFlush(TextIO): | |
| """A file, stream, or buffer that supports writing and flushing""" | |
| class ANSI: | |
| """ANSI codes for terminal emulators""" | |
| # | |
| # Standard colors | |
| # | |
| FG_BLACK = '\x1b[30m' | |
| FG_RED = '\x1b[31m' | |
| FG_GREEN = '\x1b[32m' | |
| FG_YELLOW = '\x1b[33m' | |
| FG_BLUE = '\x1b[34m' | |
| FG_MAGENTA = '\x1b[35m' | |
| FG_CYAN = '\x1b[36m' | |
| FG_WHITE = '\x1b[37m' | |
| BG_BLACK = '\x1b[40m' | |
| BG_RED = '\x1b[41m' | |
| BG_GREEN = '\x1b[42m' | |
| BG_YELLOW = '\x1b[43m' | |
| BG_BLUE = '\x1b[44m' | |
| BG_MAGENTA = '\x1b[45m' | |
| BG_CYAN = '\x1b[46m' | |
| BG_WHITE = '\x1b[47m' | |
| # | |
| # Bright colors | |
| # | |
| FG_BRIGHT_BLACK = '\x1b[90m' | |
| FG_BRIGHT_RED = '\x1b[91m' | |
| FG_BRIGHT_GREEN = '\x1b[92m' | |
| FG_BRIGHT_YELLOW = '\x1b[93m' | |
| FG_BRIGHT_BLUE = '\x1b[94m' | |
| FG_BRIGHT_MAGENTA = '\x1b[95m' | |
| FG_BRIGHT_CYAN = '\x1b[96m' | |
| FG_BRIGHT_WHITE = '\x1b[97m' | |
| BG_BRIGHT_BLACK = '\x1b[100m' | |
| BG_BRIGHT_RED = '\x1b[101m' | |
| BG_BRIGHT_GREEN = '\x1b[102m' | |
| BG_BRIGHT_YELLOW = '\x1b[103m' | |
| BG_BRIGHT_BLUE = '\x1b[104m' | |
| BG_BRIGHT_MAGENTA = '\x1b[105m' | |
| BG_BRIGHT_CYAN = '\x1b[106m' | |
| BG_BRIGHT_WHITE = '\x1b[107m' | |
| # | |
| # Text modes | |
| # | |
| MODE_RESET_ALL = '\x1b[0m' | |
| MODE_BOLD_SET = '\x1b[1m' | |
| MODE_DIM_SET = '\x1b[2m' | |
| MODE_ITALIC_SET = '\x1b[3m' | |
| MODE_UNDERLINE_SET = '\x1b[4m' | |
| MODE_BLINKING_SET = '\x1b[5m' | |
| MODE_REVERSE_SET = '\x1b[7m' | |
| MODE_HIDDEN_SET = '\x1b[8m' | |
| MODE_STRIKETHROUGH_SET = '\x1b[9m' | |
| MODE_BOLD_RESET = '\x1b[22m' | |
| MODE_DIM_RESET = '\x1b[22m' | |
| MODE_ITALIC_RESET = '\x1b[23m' | |
| MODE_UNDERLINE_RESET = '\x1b[24m' | |
| MODE_BLINKING_RESET = '\x1b[25m' | |
| MODE_REVERSE_RESET = '\x1b[27m' | |
| MODE_HIDDEN_RESET = '\x1b[28m' | |
| MODE_STRIKETHROUGH_RESET = '\x1b[29m' | |
| # | |
| # Special | |
| # | |
| BELL = '\a' | |
| TERMINAL_RESET = '\x1bc' | |
| SCROLLBACK_CLEAR = '\x1b[3J' | |
| CLEAR = TERMINAL_RESET + SCROLLBACK_CLEAR + MODE_RESET_ALL | |
| def timestamp() -> str: | |
| return datetime.datetime.now().strftime("%Y-%m-%d %a %H:%M:%S.%f")[:-3] | |
| class Logger: | |
| def __init__( | |
| self, | |
| stdout: Optional[SupportsWriteAndFlush] = None, | |
| stderr: Optional[SupportsWriteAndFlush] = None | |
| ): | |
| self.stdout = stdout if stdout is not None else sys.stdout | |
| self.stderr = stderr if stderr is not None else sys.stderr | |
| @staticmethod | |
| def _get_prefix(lvl: Literal[0,1,2,3]) -> str: | |
| if lvl == 0: | |
| lvltxt = f"{ANSI.FG_BRIGHT_WHITE}DEBUG" | |
| elif lvl == 1: | |
| lvltxt = f"{ANSI.FG_BRIGHT_GREEN}INFO" | |
| elif lvl == 2: | |
| lvltxt = f"{ANSI.FG_BRIGHT_YELLOW}WARNING" | |
| elif lvl == 3: | |
| lvltxt = f"{ANSI.FG_BRIGHT_RED}ERROR" | |
| else: | |
| raise ValueError | |
| return ( | |
| f"{ANSI.MODE_RESET_ALL}{ANSI.MODE_BOLD_SET}{ANSI.FG_BRIGHT_BLACK}" | |
| f"[{timestamp()}]{ANSI.MODE_RESET_ALL}{ANSI.MODE_BOLD_SET} {lvltxt}" | |
| f"{ANSI.MODE_RESET_ALL}{ANSI.MODE_BOLD_SET}{ANSI.FG_BRIGHT_BLACK}:" | |
| f"{ANSI.MODE_RESET_ALL}" | |
| ) | |
| def debug(self, msg: str) -> None: | |
| print(f"{self._get_prefix(0)} {msg}", file=self.stdout) | |
| def info(self, msg: str) -> None: | |
| print(f"{self._get_prefix(1)} {msg}", file=self.stdout) | |
| def warn(self, msg: str) -> None: | |
| print(f"{self._get_prefix(2)} {msg}", file=self.stderr) | |
| def error(self, msg: str) -> None: | |
| print(f"{self._get_prefix(3)} {msg}", file=self.stderr) | |
| def __call__(self, msg: str) -> None: | |
| self.info(msg) | |
| # actual code starts here | |
| def main(): | |
| parser = argparse.ArgumentParser(description="format raw text into pseudo-chat turns (e.g. for imatrix calibration)") | |
| parser.add_argument("--model", required=True, help="path to tokenizer or HF repo") | |
| parser.add_argument("--jinja", help="path to optional .jinja template file, to override the tokenizer's chat template") | |
| parser.add_argument("--input", required=True, help="path to plaintext file") | |
| parser.add_argument("--output", required=True, help="path to save the formatted output file") | |
| parser.add_argument("--chunk-size", type=int, default=512, help="number of tokens used per message") | |
| args = parser.parse_args() | |
| log = Logger() | |
| log.info(f"loading tokenizer from: {args.model}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) | |
| if args.jinja: | |
| log.info(f"overwriting template with jinja from: {args.jinja}") | |
| with open(args.jinja, "r", encoding="utf-8") as f: | |
| tokenizer.chat_template = f.read() | |
| log.info(">> chat template contents <<") | |
| print(tokenizer.chat_template) | |
| log.info(">> end of chat template contents <<") | |
| example_chat = [ | |
| {"role": "system", "content": "example system prompt"}, | |
| {"role": "user", "content": "example user turn 1"}, | |
| {"role": "assistant", "content": "example assistant turn 1"}, | |
| {"role": "user", "content": "example user turn 2"}, | |
| {"role": "assistant", "content": "example assistant turn 2"}, | |
| {"role": "user", "content": "example user turn 3"}, | |
| ] | |
| log.info(">> rendered chat template example <<") | |
| print(tokenizer.apply_chat_template(example_chat, add_generation_prompt=False, tokenize=False), end="") | |
| log.info(">> end of rendered chat template example <<") | |
| log.info(f"tokenizing text corpus: {args.input}") | |
| with open(args.input, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| all_tokens = tokenizer.encode(text, add_special_tokens=False) | |
| total_tokens = len(all_tokens) | |
| log.info(f"total tokens in corpus: {total_tokens}") | |
| roles = ["system"] | |
| while len(roles) < (total_tokens // args.chunk_size) + 1: | |
| roles.append("user" if len(roles) % 2 != 0 else "assistant") | |
| messages = [] | |
| current_token_idx = 0 | |
| for role in roles: | |
| if current_token_idx >= total_tokens: | |
| break | |
| end_idx = min(current_token_idx + args.chunk_size, total_tokens) | |
| chunk_tokens = all_tokens[current_token_idx:end_idx] | |
| chunk_text = tokenizer.decode(chunk_tokens, clean_up_tokenization_spaces=False) | |
| messages.append({"role": role, "content": chunk_text}) | |
| current_token_idx = end_idx | |
| log.info(f"applying chat template over {len(messages)} turns") | |
| formatted_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=False | |
| ) | |
| if tokenizer.bos_token and formatted_text.startswith(tokenizer.bos_token): | |
| # llama.cpp adds BOS token as needed; avoid double BOS tokens | |
| formatted_text = formatted_text[len(tokenizer.bos_token):] | |
| with open(args.output, "w", encoding="utf-8") as f: | |
| f.write(formatted_text) | |
| log.info(f"successfully wrote formatted corpus to: {args.output}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment