Created
March 18, 2026 07:30
-
-
Save ryan-blunden/92ba1e17102ffb853ac84fa63564b25e to your computer and use it in GitHub Desktop.
Simple vector search and ingestion
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| [project] | |
| name = "simple_vector_search" | |
| version = "0.1.0" | |
| requires-python = ">=3.12,<3.13" | |
| dependencies = [ | |
| "faiss-cpu", | |
| "numpy", | |
| "sentence-transformers", | |
| "torch", | |
| "huggingface_hub", | |
| ] | |
| [dependency-groups] | |
| dev = [ | |
| "pylint>=4.0.5", | |
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Generate embeddings for one or more input files and save them as JSON.""" | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from sentence_transformers import SentenceTransformer | |
| BASE_DIR = Path(__file__).resolve().parent | |
| DEFAULT_OUTPUT_PATH = BASE_DIR / "embeddings.json" | |
| DEFAULT_MODEL = "google/embeddinggemma-300m" | |
| DEFAULT_DEVICE = "mps" | |
| DEFAULT_CHUNK_TOKENS = 256 | |
| DEFAULT_OVERLAP_TOKENS = 50 | |
| def parse_args(): | |
| """Parse command-line arguments.""" | |
| parser = argparse.ArgumentParser( | |
| description="Chunk a transcript and generate embeddings.json." | |
| ) | |
| parser.add_argument( | |
| "--input", | |
| type=Path, | |
| action="append", | |
| required=True, | |
| help="Markdown or text file to ingest. Pass multiple times for multiple files.", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| default=DEFAULT_OUTPUT_PATH, | |
| help="JSON file to append embeddings to.", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| default=DEFAULT_MODEL, | |
| help="SentenceTransformer model name.", | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| default=DEFAULT_DEVICE, | |
| help="Torch device to use, e.g. mps or cpu.", | |
| ) | |
| parser.add_argument( | |
| "--chunk-tokens", | |
| type=int, | |
| default=DEFAULT_CHUNK_TOKENS, | |
| help="Approximate token count per chunk.", | |
| ) | |
| parser.add_argument( | |
| "--overlap-tokens", | |
| type=int, | |
| default=DEFAULT_OVERLAP_TOKENS, | |
| help="Approximate token overlap between chunks.", | |
| ) | |
| parser.add_argument( | |
| "--recreate", | |
| action="store_true", | |
| help="Replace the output file instead of appending to it.", | |
| ) | |
| return parser.parse_args() | |
| def read_text(path: Path) -> str: | |
| """Read a UTF-8 text file.""" | |
| return path.read_text(encoding="utf-8").strip() | |
| def split_paragraphs(text: str) -> list[str]: | |
| """Split text into non-empty paragraphs.""" | |
| paragraphs = [part.strip() for part in text.split("\n\n")] | |
| return [part for part in paragraphs if part] | |
| def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> list[str]: | |
| """Create overlapping chunks using approximate token counts.""" | |
| if chunk_tokens <= 0: | |
| raise ValueError("--chunk-tokens must be greater than 0") | |
| if overlap_tokens < 0: | |
| raise ValueError("--overlap-tokens must be 0 or greater") | |
| if overlap_tokens >= chunk_tokens: | |
| raise ValueError("--overlap-tokens must be smaller than --chunk-tokens") | |
| paragraphs = split_paragraphs(text) | |
| words: list[str] = [] | |
| for paragraph in paragraphs: | |
| words.extend(paragraph.split()) | |
| words.append("\n\n") | |
| if words and words[-1] == "\n\n": | |
| words.pop() | |
| chunks: list[str] = [] | |
| start = 0 | |
| step = chunk_tokens - overlap_tokens | |
| while start < len(words): | |
| end = min(start + chunk_tokens, len(words)) | |
| chunk_words = words[start:end] | |
| chunk = " ".join(chunk_words).replace(" \n\n ", "\n\n").strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| if end >= len(words): | |
| break | |
| start += step | |
| return chunks | |
| def load_existing_embeddings(path: Path) -> list[dict]: | |
| """Load existing embedding records from disk.""" | |
| if not path.exists(): | |
| return [] | |
| with path.open(encoding="utf-8") as f: | |
| data = json.load(f) | |
| if not isinstance(data, list): | |
| raise ValueError(f"Expected {path} to contain a JSON list") | |
| return data | |
| def load_model(model_name: str, device: str) -> SentenceTransformer: | |
| """Load the embedding model.""" | |
| print(f"Loading model {model_name} on {device}...", flush=True) | |
| model = SentenceTransformer(model_name, device=device) | |
| print("Model ready", flush=True) | |
| return model | |
| def generate_embeddings( | |
| model: SentenceTransformer, chunks: list[str], source_path: Path | |
| ) -> list[dict]: | |
| """Embed chunks and attach source metadata.""" | |
| print(f"Generating embeddings for {len(chunks)} chunk(s)...", flush=True) | |
| vectors = model.encode( | |
| chunks, | |
| prompt_name="document", | |
| normalize_embeddings=True, | |
| show_progress_bar=True, | |
| ) | |
| print("Embeddings generated", flush=True) | |
| records = [] | |
| for index, (chunk, vector) in enumerate(zip(chunks, vectors), start=1): | |
| records.append( | |
| { | |
| "text": chunk, | |
| "vector": vector.tolist(), | |
| "source": str(source_path), | |
| "chunk_id": index, | |
| } | |
| ) | |
| return records | |
| def save_embeddings(path: Path, records: list[dict], recreate: bool): | |
| """Write embedding records to the output file.""" | |
| existing = [] if recreate else load_existing_embeddings(path) | |
| combined = [*existing, *records] | |
| with path.open("w", encoding="utf-8") as f: | |
| json.dump(combined, f) | |
| action = "Recreated" if recreate else "Updated" | |
| print(f"{action} {path} with {len(records)} new embedding(s).", flush=True) | |
| print(f"Total records in file: {len(combined)}", flush=True) | |
| def main(): | |
| """Run the ingestion flow.""" | |
| args = parse_args() | |
| output_path = args.output.resolve() | |
| model = load_model(args.model, args.device) | |
| records = [] | |
| for input_path in (path.resolve() for path in args.input): | |
| if not input_path.exists(): | |
| raise FileNotFoundError(f"Could not find input file at {input_path}") | |
| print(f"Reading transcript from {input_path}...", flush=True) | |
| text = read_text(input_path) | |
| chunks = chunk_text(text, args.chunk_tokens, args.overlap_tokens) | |
| print( | |
| f"Created {len(chunks)} chunk(s) with target size {args.chunk_tokens} " | |
| f"and overlap {args.overlap_tokens}.", | |
| flush=True, | |
| ) | |
| records.extend(generate_embeddings(model, chunks, input_path)) | |
| save_embeddings(output_path, records, args.recreate) | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| def load_embeddings(path): | |
| with open(path, encoding="utf-8") as f: | |
| data = json.load(f) | |
| documents = [d["text"] for d in data] | |
| embeddings = np.array([d["vector"] for d in data], dtype="float32") | |
| return documents, embeddings | |
| def search(model, documents, index, top_k=5): | |
| query = input("\nEnter a query: ") | |
| query_embedding = model.encode( | |
| [query], prompt_name="query", normalize_embeddings=True | |
| ).astype("float32") | |
| _, ids = index.search(query_embedding, top_k) | |
| return [documents[i] for i in ids[0]] | |
| model = SentenceTransformer("google/embeddinggemma-300m", device="mps") | |
| documents, embeddings = load_embeddings("embeddings.json") | |
| index = faiss.IndexFlatIP(embeddings.shape[1]) | |
| index.add(embeddings) | |
| while True: | |
| results = search(model, documents, index) | |
| print("\n\n---\n\n".join(results)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment