Skip to content

Instantly share code, notes, and snippets.

@ryan-blunden
Created March 18, 2026 07:30
Show Gist options
  • Select an option

  • Save ryan-blunden/92ba1e17102ffb853ac84fa63564b25e to your computer and use it in GitHub Desktop.

Select an option

Save ryan-blunden/92ba1e17102ffb853ac84fa63564b25e to your computer and use it in GitHub Desktop.
Simple vector search and ingestion
[project]
name = "simple_vector_search"
version = "0.1.0"
requires-python = ">=3.12,<3.13"
dependencies = [
"faiss-cpu",
"numpy",
"sentence-transformers",
"torch",
"huggingface_hub",
]
[dependency-groups]
dev = [
"pylint>=4.0.5",
]
"""Generate embeddings for one or more input files and save them as JSON."""
import argparse
import json
from pathlib import Path
from sentence_transformers import SentenceTransformer
BASE_DIR = Path(__file__).resolve().parent
DEFAULT_OUTPUT_PATH = BASE_DIR / "embeddings.json"
DEFAULT_MODEL = "google/embeddinggemma-300m"
DEFAULT_DEVICE = "mps"
DEFAULT_CHUNK_TOKENS = 256
DEFAULT_OVERLAP_TOKENS = 50
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Chunk a transcript and generate embeddings.json."
)
parser.add_argument(
"--input",
type=Path,
action="append",
required=True,
help="Markdown or text file to ingest. Pass multiple times for multiple files.",
)
parser.add_argument(
"--output",
type=Path,
default=DEFAULT_OUTPUT_PATH,
help="JSON file to append embeddings to.",
)
parser.add_argument(
"--model",
default=DEFAULT_MODEL,
help="SentenceTransformer model name.",
)
parser.add_argument(
"--device",
default=DEFAULT_DEVICE,
help="Torch device to use, e.g. mps or cpu.",
)
parser.add_argument(
"--chunk-tokens",
type=int,
default=DEFAULT_CHUNK_TOKENS,
help="Approximate token count per chunk.",
)
parser.add_argument(
"--overlap-tokens",
type=int,
default=DEFAULT_OVERLAP_TOKENS,
help="Approximate token overlap between chunks.",
)
parser.add_argument(
"--recreate",
action="store_true",
help="Replace the output file instead of appending to it.",
)
return parser.parse_args()
def read_text(path: Path) -> str:
"""Read a UTF-8 text file."""
return path.read_text(encoding="utf-8").strip()
def split_paragraphs(text: str) -> list[str]:
"""Split text into non-empty paragraphs."""
paragraphs = [part.strip() for part in text.split("\n\n")]
return [part for part in paragraphs if part]
def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> list[str]:
"""Create overlapping chunks using approximate token counts."""
if chunk_tokens <= 0:
raise ValueError("--chunk-tokens must be greater than 0")
if overlap_tokens < 0:
raise ValueError("--overlap-tokens must be 0 or greater")
if overlap_tokens >= chunk_tokens:
raise ValueError("--overlap-tokens must be smaller than --chunk-tokens")
paragraphs = split_paragraphs(text)
words: list[str] = []
for paragraph in paragraphs:
words.extend(paragraph.split())
words.append("\n\n")
if words and words[-1] == "\n\n":
words.pop()
chunks: list[str] = []
start = 0
step = chunk_tokens - overlap_tokens
while start < len(words):
end = min(start + chunk_tokens, len(words))
chunk_words = words[start:end]
chunk = " ".join(chunk_words).replace(" \n\n ", "\n\n").strip()
if chunk:
chunks.append(chunk)
if end >= len(words):
break
start += step
return chunks
def load_existing_embeddings(path: Path) -> list[dict]:
"""Load existing embedding records from disk."""
if not path.exists():
return []
with path.open(encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError(f"Expected {path} to contain a JSON list")
return data
def load_model(model_name: str, device: str) -> SentenceTransformer:
"""Load the embedding model."""
print(f"Loading model {model_name} on {device}...", flush=True)
model = SentenceTransformer(model_name, device=device)
print("Model ready", flush=True)
return model
def generate_embeddings(
model: SentenceTransformer, chunks: list[str], source_path: Path
) -> list[dict]:
"""Embed chunks and attach source metadata."""
print(f"Generating embeddings for {len(chunks)} chunk(s)...", flush=True)
vectors = model.encode(
chunks,
prompt_name="document",
normalize_embeddings=True,
show_progress_bar=True,
)
print("Embeddings generated", flush=True)
records = []
for index, (chunk, vector) in enumerate(zip(chunks, vectors), start=1):
records.append(
{
"text": chunk,
"vector": vector.tolist(),
"source": str(source_path),
"chunk_id": index,
}
)
return records
def save_embeddings(path: Path, records: list[dict], recreate: bool):
"""Write embedding records to the output file."""
existing = [] if recreate else load_existing_embeddings(path)
combined = [*existing, *records]
with path.open("w", encoding="utf-8") as f:
json.dump(combined, f)
action = "Recreated" if recreate else "Updated"
print(f"{action} {path} with {len(records)} new embedding(s).", flush=True)
print(f"Total records in file: {len(combined)}", flush=True)
def main():
"""Run the ingestion flow."""
args = parse_args()
output_path = args.output.resolve()
model = load_model(args.model, args.device)
records = []
for input_path in (path.resolve() for path in args.input):
if not input_path.exists():
raise FileNotFoundError(f"Could not find input file at {input_path}")
print(f"Reading transcript from {input_path}...", flush=True)
text = read_text(input_path)
chunks = chunk_text(text, args.chunk_tokens, args.overlap_tokens)
print(
f"Created {len(chunks)} chunk(s) with target size {args.chunk_tokens} "
f"and overlap {args.overlap_tokens}.",
flush=True,
)
records.extend(generate_embeddings(model, chunks, input_path))
save_embeddings(output_path, records, args.recreate)
if __name__ == "__main__":
main()
import json
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
def load_embeddings(path):
with open(path, encoding="utf-8") as f:
data = json.load(f)
documents = [d["text"] for d in data]
embeddings = np.array([d["vector"] for d in data], dtype="float32")
return documents, embeddings
def search(model, documents, index, top_k=5):
query = input("\nEnter a query: ")
query_embedding = model.encode(
[query], prompt_name="query", normalize_embeddings=True
).astype("float32")
_, ids = index.search(query_embedding, top_k)
return [documents[i] for i in ids[0]]
model = SentenceTransformer("google/embeddinggemma-300m", device="mps")
documents, embeddings = load_embeddings("embeddings.json")
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
while True:
results = search(model, documents, index)
print("\n\n---\n\n".join(results))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment