Skip to content

Instantly share code, notes, and snippets.

@smrati
Created October 25, 2024 11:37
Show Gist options
  • Select an option

  • Save smrati/4ffc00c87ed9d2fb8560071168a665b4 to your computer and use it in GitHub Desktop.

Select an option

Save smrati/4ffc00c87ed9d2fb8560071168a665b4 to your computer and use it in GitHub Desktop.
ChromaDB + Glove Embeddings : Create a text search engine
import chromadb
from loguru import logger
class ChromaHelper:
def __init__(self, host, port):
self.host = host
self.port = port
self.chroma_client = None
def setup_client(self):
logger.info("Chroma client setup")
self.chroma_client = chromadb.HttpClient(host=self.host, port=self.port)
def get_collection(self, collection_name):
collection = self.chroma_client.get_or_create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})
return collection
def feed_document(self, collection_object, input_data):
feed_data = {
"documents": [x["document"] for x in input_data],
"metadatas": [x["metadata"] for x in input_data],
"embeddings": [x["embedding"] for x in input_data],
"ids": [x["id"] for x in input_data]
}
collection_object.add(documents = feed_data["documents"], embeddings=feed_data["embeddings"],
metadatas=feed_data["metadatas"], ids=feed_data["ids"])
logger.info(collection_object.count())
def fetch_document(self, collection_object, query_embedding, n_results = 10):
resp = collection_object.query(
query_embeddings = [query_embedding],
n_results = n_results
)
return resp
def cleanup_collection(self, collection_name):
logger.info("Deleting everything from collection....")
self.chroma_client.delete_collection(collection_name)
logger.info("Everything deleted from collection")
if __name__ == "__main__":
ch = ChromaHelper("localhost", 8001)
ch.setup_client()
facts_collection = ch.get_collection("facts")
dummy_data = [
{
"document": "This is a fine day",
"metadata": {"chapter": 3},
"embedding": [1.1, 2.2,3.3],
"id": "abc"
},
{
"document": "Global warming is causing extreme client events",
"metadata": {"chapter": 7},
"embedding": [0.5, 3.2, 2.17],
"id": "def"
}
]
ch.feed_document(facts_collection, dummy_data)
resp = ch.fetch_document(facts_collection, [1.1, 2.0, 3.1])
print(resp)
ch.cleanup_collection("facts")
from loguru import logger
import numpy as np
import re
class GloVeEmbedding:
def __init__(self, glove_path):
self.glove_path = glove_path
self.embeddings_index = None
def load_glove_embeddings(self):
"""
Load glove txt file path and create a python dictionary
where key will be word(token) and value will be embedding for
that word
"""
self.embeddings_index = {}
with open(self.glove_path, 'r', encoding='utf-8') as f:
for line in f:
values = line.split()
word = values[0]
coefs = np.asarray(values[1:], dtype='float32')
self.embeddings_index[word] = coefs
logger.info(f'Loaded {len(self.embeddings_index)} word vectors.')
def compute_embedding(self, input_text, embedding_dim):
"""
A text is provided and its 100 dimension embedding
is returned. we are going to use glove.6B.100d
for generating embeddings.
"""
text = input_text.lower() # convert to lowercase
text = re.sub(r'\W+', ' ', text).strip() # remove non-alphanumeric characters
words = text.split()
word_embeddings = []
for word in words:
if word in self.embeddings_index:
word_embeddings.append(self.embeddings_index[word])
if not word_embeddings:
# Return a zero vector if no words from the text match GloVe vocab
return np.zeros(embedding_dim)
# Average the word vectors to get a sentence/paragraph vector
return np.mean(word_embeddings, axis=0)
if __name__ == "__main__":
ge = GloVeEmbedding('/home/chick/Desktop/coding/nlp_learn/nlp_in_action_2/glove.6B.100d.txt')
ge.load_glove_embeddings()
embedding_vector = ge.compute_embedding("This is a pretty fine day", 100)
print(embedding_vector)
import re
import uuid
from helpers.chromaConnect import ChromaHelper
from helpers.embeddings import GloVeEmbedding
from loguru import logger
def preprocess_text(text):
text = text.lower()
text = re.sub(r'\W+', ' ', text) # remove non-alphanumeric characters
text = text.strip()
return text
def read_txt_line_by_line(file_path):
"""
Read txt file line by line
"""
lines = []
# Read the text file
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
lines.append(preprocess_text(line))
return lines
if __name__ == "__main__":
lines = read_txt_line_by_line("/home/chick/Desktop/coding/nlp_learn/nlp_in_action_2/placeholder.txt")
ge = GloVeEmbedding('/home/chick/Desktop/coding/nlp_learn/nlp_in_action_2/glove.6B.100d.txt')
ge.load_glove_embeddings()
ch = ChromaHelper("localhost", 8001)
ch.setup_client()
facts_collection = ch.get_collection("facts")
for line in lines:
# since we are using glove.6b.100d vocabulary, so embedding dimension must be 100 only
embedding_vector = ge.compute_embedding(line, 100)
payload = [
{
"document": line,
"metadata": {"data_type": "demo"},
"embedding": embedding_vector,
"id": str(uuid.uuid4())
}
]
logger.info(payload)
ch.feed_document(facts_collection, payload)
# query chromaDB to verify we fed data properly in vector database
# query_text = "How presidential election take place?"
# query_text = "Which country has biggest army?"
query_text = "Which country has highest income group population?"
query_embedding = ge.compute_embedding(preprocess_text(query_text), 100)
resp = ch.fetch_document(facts_collection, query_embedding, 1)
print("----------------Matching Document---------------------")
print(resp)
# delete collection (cleanup operation)
ch.cleanup_collection("facts")
Global warming is the long-term increase in Earth's average temperature caused by human activities like burning fossil fuels. This increase in temperature is causing significant changes to our planet, including rising sea levels, more extreme weather events, and melting glaciers. These changes are already having a negative impact on people, animals, and plants around the world.
The USSR, or the Union of Soviet Socialist Republics, was a communist country that existed from 1922 to 1991. It was the largest country in the world, spanning across Eurasia from Eastern Europe to the Pacific Ocean. The USSR was a superpower, competing with the United States during the Cold War. However, economic problems and political unrest led to its eventual collapse in 1991, resulting in the formation of 15 independent countries.
The US presidential election is a complex process involving several stages. First, candidates from each major political party compete in primary elections to win delegates who will vote for them at the party's national convention. The candidate who wins the most delegates at the convention becomes the party's nominee. The general election then takes place between the nominees of the two major parties. Voters cast ballots for their preferred candidate, and the candidate who wins the most electoral votes becomes the President of the United States.
The country with the largest standing army is China. Its military, known as the People's Liberation Army (PLA), is one of the most powerful in the world and is composed of millions of active personnel. China's large army is a reflection of its status as a major global power and its need to defend its vast territory and interests.
Luxembourg currently holds the title of the country with the highest per capita income. This is largely due to its strong financial sector, which is a major contributor to its economy. Luxembourg's low corporate tax rates and favorable business environment have attracted numerous multinational corporations, boosting its economic growth. Additionally, the country's substantial natural resources, such as iron ore and steel, have also played a role in its economic success.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment