Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save davidathompson/837e394de112f6accbf876d456ac9fce to your computer and use it in GitHub Desktop.

Select an option

Save davidathompson/837e394de112f6accbf876d456ac9fce to your computer and use it in GitHub Desktop.

Revisions

  1. @virattt virattt revised this gist Jan 21, 2024. 1 changed file with 12 additions and 1 deletion.
    13 changes: 12 additions & 1 deletion rag-reranking-gpt-colbert.ipynb
    Original file line number Diff line number Diff line change
    @@ -1,5 +1,15 @@
    {
    "cells": [
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "view-in-github",
    "colab_type": "text"
    },
    "source": [
    "<a href=\"https://colab.research.google.com/gist/virattt/b140fb4bf549b6125d53aa153dc53be6/rag-reranking-gpt-colbert.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
    },
    {
    "cell_type": "markdown",
    "source": [
    @@ -326,7 +336,8 @@
    },
    "orig_nbformat": 4,
    "colab": {
    "provenance": []
    "provenance": [],
    "include_colab_link": true
    }
    },
    "nbformat": 4,
  2. @virattt virattt created this gist Jan 21, 2024.
    334 changes: 334 additions & 0 deletions rag-reranking-gpt-colbert.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,334 @@
    {
    "cells": [
    {
    "cell_type": "markdown",
    "source": [
    "# Install dependencies"
    ],
    "metadata": {
    "id": "S2mGQxA958dW"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "pip install --quiet openai"
    ],
    "metadata": {
    "id": "2bY0NapN_z98"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "id": "lEQQJHH9gufm"
    },
    "outputs": [],
    "source": [
    "!pip install --quiet chromadb"
    ]
    },
    {
    "cell_type": "code",
    "source": [
    "!pip install --quiet langchain"
    ],
    "metadata": {
    "id": "ygccK6lm54VT"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "!pip install --quiet tiktoken"
    ],
    "metadata": {
    "id": "K5KyVC5O7Elw"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "!pip install --quiet pypdf"
    ],
    "metadata": {
    "id": "_o1MOUo07GBO"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "# Set your OpenAI API key\n",
    "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()"
    ],
    "metadata": {
    "id": "tavToGb_MJrc"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "# Download and prepare SEC filing"
    ],
    "metadata": {
    "id": "sz639zFf6JoK"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "from langchain.document_loaders import PyPDFLoader\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "\n",
    "# Load $ABNB's financial report. This may take 1-2 minutes since the PDF is large\n",
    "sec_filing_pdf = \"https://d18rn0p25nwr6d.cloudfront.net/CIK-0001559720/8a9ebed0-815a-469a-87eb-1767d21d8cec.pdf\"\n",
    "\n",
    "# Create your PDF loader\n",
    "loader = PyPDFLoader(sec_filing_pdf)\n",
    "\n",
    "# Load the PDF document\n",
    "documents = loader.load()\n",
    "\n",
    "# Chunk the financial report\n",
    "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=0)\n",
    "docs = text_splitter.split_documents(documents)"
    ],
    "metadata": {
    "id": "rIO5t-j7611h"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "# Load the SEC filing into vector store"
    ],
    "metadata": {
    "id": "iaYSqxiMLUGb"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "from langchain_community.vectorstores import Chroma\n",
    "from langchain.embeddings.openai import OpenAIEmbeddings\n",
    "\n",
    "# Load the document into Chroma\n",
    "embedding_function = OpenAIEmbeddings()\n",
    "db = Chroma.from_documents(docs, embedding_function)"
    ],
    "metadata": {
    "id": "QVZevdc-Md4N"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "# Query the vector store"
    ],
    "metadata": {
    "id": "m8HqBNyYrDHb"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "query = \"What are the specific factors contributing to Airbnb's increased operational expenses in the last fiscal year?\"\n",
    "docs = db.similarity_search(query)"
    ],
    "metadata": {
    "id": "3qZTrAtXLPl1"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "# Re-rank the results using GPT-4"
    ],
    "metadata": {
    "id": "0UMU-ogKM6w8"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "from openai import OpenAI\n",
    "import time\n",
    "import json\n",
    "\n",
    "start = time.time()\n",
    "client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n",
    "response = client.chat.completions.create(\n",
    " model='gpt-4-1106-preview',\n",
    " response_format={\"type\": \"json_object\"},\n",
    " temperature=0,\n",
    " messages=[\n",
    " {\"role\": \"system\", \"content\": \"You are an expert relevance ranker. Given a list of documents and a query, your job is to determine how relevant each document is for answering the query. Your output is JSON, which is a list of documents. Each document has two fields, content and score. relevance_score is from 0.0 to 100.0. Higher relevance means higher score.\"},\n",
    " {\"role\": \"user\", \"content\": f\"Query: {query} Docs: {docs}\"}\n",
    " ]\n",
    " )\n",
    "\n",
    "print(f\"Took {time.time() - start} seconds to re-rank documents with GPT-4.\")"
    ],
    "metadata": {
    "id": "Z83h16UuMlMt"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "# Sort the scores by highest to lowest and print\n",
    "scores = json.loads(response.choices[0].message.content)[\"documents\"]\n",
    "sorted_data = sorted(scores, key=lambda x: x['score'], reverse=True)\n",
    "print(json.dumps(sorted_data, indent=2))"
    ],
    "metadata": {
    "id": "8VZMWffzm0-i"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "source": [
    "# Re-rank the results using ColBERT"
    ],
    "metadata": {
    "id": "wnViL4XQg1FE"
    }
    },
    {
    "cell_type": "code",
    "source": [
    "!pip install --quiet transformers torch"
    ],
    "metadata": {
    "id": "fXPdarpEiN65"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "# Load the tokenizer and the model\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"colbert-ir/colbertv2.0\")\n",
    "model = AutoModel.from_pretrained(\"colbert-ir/colbertv2.0\")"
    ],
    "metadata": {
    "id": "4jE_MXfelFyv"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "import torch\n",
    "import time\n",
    "\n",
    "start = time.time()\n",
    "scores = []\n",
    "\n",
    "# Function to compute MaxSim\n",
    "def maxsim(query_embedding, document_embedding):\n",
    " # Expand dimensions for broadcasting\n",
    " # Query: [batch_size, query_length, embedding_size] -> [batch_size, query_length, 1, embedding_size]\n",
    " # Document: [batch_size, doc_length, embedding_size] -> [batch_size, 1, doc_length, embedding_size]\n",
    " expanded_query = query_embedding.unsqueeze(2)\n",
    " expanded_doc = document_embedding.unsqueeze(1)\n",
    "\n",
    " # Compute cosine similarity across the embedding dimension\n",
    " sim_matrix = torch.nn.functional.cosine_similarity(expanded_query, expanded_doc, dim=-1)\n",
    "\n",
    " # Take the maximum similarity for each query token (across all document tokens)\n",
    " # sim_matrix shape: [batch_size, query_length, doc_length]\n",
    " max_sim_scores, _ = torch.max(sim_matrix, dim=2)\n",
    "\n",
    " # Average these maximum scores across all query tokens\n",
    " avg_max_sim = torch.mean(max_sim_scores, dim=1)\n",
    " return avg_max_sim\n",
    "\n",
    "# Get score for each document\n",
    "for document in docs:\n",
    " document_encoding = tokenizer(document.page_content, return_tensors='pt', truncation=True, max_length=512)\n",
    " document_embedding = model(**document_encoding).last_hidden_state\n",
    "\n",
    " # Calculate MaxSim score\n",
    " score = maxsim(query_embedding.unsqueeze(0), document_embedding)\n",
    " scores.append({\n",
    " \"score\": score.item(),\n",
    " \"document\": document.page_content,\n",
    " })\n",
    "\n",
    "print(f\"Took {time.time() - start} seconds to re-rank documents with ColBERT.\")"
    ],
    "metadata": {
    "id": "u84ePIKtjrtg"
    },
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "source": [
    "# Sort the scores by highest to lowest and print\n",
    "sorted_data = sorted(scores, key=lambda x: x['score'], reverse=True)\n",
    "print(json.dumps(sorted_data, indent=2))"
    ],
    "metadata": {
    "id": "jBotW5sxjT6W"
    },
    "execution_count": null,
    "outputs": []
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": "base",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.10.12"
    },
    "orig_nbformat": 4,
    "colab": {
    "provenance": []
    }
    },
    "nbformat": 4,
    "nbformat_minor": 0
    }