Skip to content

Instantly share code, notes, and snippets.

@timspannzilliz
Forked from janakiramm/RAG_with_NIM.ipynb
Created September 1, 2024 02:05
Show Gist options
  • Select an option

  • Save timspannzilliz/d7f788b35a76362dca485971ecbff6fd to your computer and use it in GitHub Desktop.

Select an option

Save timspannzilliz/d7f788b35a76362dca485971ecbff6fd to your computer and use it in GitHub Desktop.

Revisions

  1. @janakiramm janakiramm created this gist Aug 22, 2024.
    348 changes: 348 additions & 0 deletions RAG_with_NIM.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,348 @@
    {
    "cells": [
    {
    "cell_type": "code",
    "execution_count": 1,
    "id": "9bbd56b4-079b-4658-9690-8db19c602dd5",
    "metadata": {},
    "outputs": [],
    "source": [
    "from pymilvus import MilvusClient\n",
    "from pymilvus import connections\n",
    "from openai import OpenAI\n",
    "from dotenv import load_dotenv\n",
    "import os\n",
    "import ast"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 2,
    "id": "68c521a0-e52d-48f6-b2b0-d9b78c010799",
    "metadata": {},
    "outputs": [
    {
    "data": {
    "text/plain": [
    "True"
    ]
    },
    "execution_count": 2,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "load_dotenv()"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 3,
    "id": "85b861a4-53a8-400b-ad7d-81279d4a660b",
    "metadata": {},
    "outputs": [],
    "source": [
    "LLM_URI=os.getenv(\"LLM_URI\")\n",
    "EMBED_URI=os.getenv(\"EMBED_URI\")\n",
    "VECTORDB_URI=os.getenv(\"VECTORDB_URI\")"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 4,
    "id": "a187cff2-91c5-4b60-819d-d5abb806bd95",
    "metadata": {},
    "outputs": [],
    "source": [
    "NIM_API_KEY=os.getenv(\"NIM_API_KEY\")\n",
    "ZILIZ_API_KEY=os.getenv(\"ZILIZ_API_KEY\")"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 5,
    "id": "9976b4b6-adec-44ee-a719-1e3b866ff509",
    "metadata": {},
    "outputs": [],
    "source": [
    "llm_client = OpenAI(\n",
    " api_key=NIM_API_KEY,\n",
    " base_url=LLM_URI\n",
    ")"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 6,
    "id": "01805361-4500-43ad-b043-5064ab1311f9",
    "metadata": {},
    "outputs": [],
    "source": [
    "embedding_client = OpenAI(\n",
    " api_key=NIM_API_KEY,\n",
    " base_url=EMBED_URI\n",
    ")"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 7,
    "id": "9423cf23-d14d-456d-8c94-50c905bc52a2",
    "metadata": {},
    "outputs": [],
    "source": [
    "vectordb_client = MilvusClient(\n",
    " uri=VECTORDB_URI,\n",
    " token=ZILIZ_API_KEY\n",
    ")"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 8,
    "id": "35a799db-667d-47b8-af07-b22205966765",
    "metadata": {},
    "outputs": [],
    "source": [
    "if vectordb_client.has_collection(collection_name=\"india_facts\"):\n",
    " vectordb_client.drop_collection(collection_name=\"india_facts\")\n",
    "\n",
    "vectordb_client.create_collection(\n",
    " collection_name=\"india_facts\",\n",
    " dimension=1024, \n",
    ")"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 9,
    "id": "9133c393-805c-47aa-af09-81b4368fae5e",
    "metadata": {},
    "outputs": [],
    "source": [
    "docs = [\n",
    " \"India is the seventh-largest country by land area in the world.\",\n",
    " \"The Indus Valley Civilization, one of the world's oldest, originated in India around 3300 BCE.\",\n",
    " \"The game of chess, originally called 'Chaturanga,' was invented in India during the Gupta Empire.\",\n",
    " \"India is home to the world's largest democracy, with over 900 million eligible voters.\",\n",
    " \"The Indian mathematician Aryabhata was the first to explain the concept of zero in the 5th century.\",\n",
    " \"India has the second-largest population in the world, with over 1.4 billion people.\",\n",
    " \"The Kumbh Mela, held every 12 years, is the largest religious gathering in the world, attracting millions of devotees.\",\n",
    " \"India is the birthplace of four major world religions: Hinduism, Buddhism, Jainism, and Sikhism.\",\n",
    " \"The Indian Space Research Organisation (ISRO) successfully sent a spacecraft to Mars on its first attempt in 2014.\",\n",
    " \"India's Varanasi is considered one of the world's oldest continuously inhabited cities, with a history dating back over 3,000 years.\"\n",
    "]"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 10,
    "id": "7ccfb283-15e6-43f7-8f1e-8ed61d0c2f28",
    "metadata": {},
    "outputs": [],
    "source": [
    "def embed(docs):\n",
    " response = embedding_client.embeddings.create(\n",
    " input=docs,\n",
    " model=\"nvidia/nv-embedqa-e5-v5\",\n",
    " encoding_format=\"float\",\n",
    " extra_body={\"input_type\": \"query\", \"truncate\": \"NONE\"}\n",
    " )\n",
    " vectors = [embedding_data.embedding for embedding_data in response.data]\n",
    " return vectors"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 11,
    "id": "efd2f98a-7534-4eb8-a1bd-e932e8f756b2",
    "metadata": {},
    "outputs": [],
    "source": [
    "vectors=embed(docs)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 12,
    "id": "ab81f2e7-7240-414d-a35a-f1b5775e8bbd",
    "metadata": {},
    "outputs": [],
    "source": [
    "data = [\n",
    " {\"id\": i, \"vector\": vectors[i], \"text\": docs[i], \"subject\": \"history\"}\n",
    " for i in range(len(vectors))\n",
    "]"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 13,
    "id": "773869a6-d5e0-480a-9baf-b69191f0629f",
    "metadata": {},
    "outputs": [
    {
    "data": {
    "text/plain": [
    "{'insert_count': 10, 'ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 'cost': 0}"
    ]
    },
    "execution_count": 13,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "vectordb_client.insert(collection_name=\"india_facts\", data=data)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 14,
    "id": "7c7f9a08-0a09-4323-afff-406d9ce1a7f2",
    "metadata": {},
    "outputs": [],
    "source": [
    "query_vectors = embed([\"ISRO\"])\n",
    "\n",
    "res = vectordb_client.search(\n",
    " collection_name=\"india_facts\", \n",
    " data=query_vectors, \n",
    " limit=2, \n",
    " output_fields=[\"text\", \"subject\"],\n",
    ")\n",
    "\n",
    "#print(res)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 15,
    "id": "b37a0650-eb8a-4995-b8a9-d64260f2441b",
    "metadata": {},
    "outputs": [],
    "source": [
    "def retrieve(query):\n",
    " query_vectors = embed([query])\n",
    "\n",
    " search_results = vectordb_client.search(\n",
    " collection_name=\"india_facts\",\n",
    " data=query_vectors,\n",
    " output_fields=[\"text\", \"subject\"]\n",
    " )\n",
    "\n",
    " all_texts = []\n",
    " for item in search_results:\n",
    " try:\n",
    " evaluated_item = ast.literal_eval(item) if isinstance(item, str) else item\n",
    " except:\n",
    " evaluated_item = item\n",
    " \n",
    " if isinstance(evaluated_item, list):\n",
    " all_texts.extend(subitem['entity']['text'] for subitem in evaluated_item if isinstance(subitem, dict) and 'entity' in subitem and 'text' in subitem['entity'])\n",
    " elif isinstance(evaluated_item, dict) and 'entity' in evaluated_item and 'text' in evaluated_item['entity']:\n",
    " all_texts.append(evaluated_item['entity']['text'])\n",
    " \n",
    " return \" \".join(all_texts)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 16,
    "id": "c4d75523-9538-48b1-93d7-764fe4d321e0",
    "metadata": {},
    "outputs": [],
    "source": [
    "def generate(context, question):\n",
    " prompt = f'''\n",
    " Based on the context: {context}\n",
    " \n",
    " Please answer the question: {question}\n",
    " ''' \n",
    " system_prompt='''\n",
    " You are a helpful assistant that answers questions based on the given context.\\n\n",
    " Don't add anything to the response. \\n\n",
    " If you cannot find the answer within the context, say I do not know. \n",
    " '''\n",
    " completion = llm_client.chat.completions.create(\n",
    " model=\"meta/llama3-8b-instruct\",\n",
    " messages=[\n",
    " {\"role\": \"system\", \"content\": system_prompt},\n",
    " {\"role\": \"user\", \"content\": prompt}\n",
    " ],\n",
    " temperature=0,\n",
    " top_p=1,\n",
    " max_tokens=1024\n",
    " )\n",
    " return completion.choices[0].message.content"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 17,
    "id": "22cd1e12-5ac4-4620-8e1b-7dfe6a3e608c",
    "metadata": {},
    "outputs": [],
    "source": [
    "def chat(prompt):\n",
    " context=retrieve(prompt)\n",
    " response=generate(context,prompt)\n",
    " return response"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 22,
    "id": "a5a047ba-087a-4e15-aee3-e26c06dce782",
    "metadata": {},
    "outputs": [],
    "source": [
    "#prompt=\"What is ISRO?\"\n",
    "#prompt=\"What is chess originally called?\"\n",
    "#prompt=\"When did Indus Valley Civilization orginate?\"\n",
    "prompt=\"what are the four major world religions?\""
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 23,
    "id": "6a6cf3ab-77de-45ee-89d5-bfba76cadafb",
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "The four major world religions are: Hinduism, Buddhism, Jainism, and Sikhism.\n"
    ]
    }
    ],
    "source": [
    "res=chat(prompt)\n",
    "print(res)"
    ]
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": "Python 3 (ipykernel)",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.10.13"
    }
    },
    "nbformat": 4,
    "nbformat_minor": 5
    }