Created
October 25, 2024 11:37
-
-
Save smrati/4ffc00c87ed9d2fb8560071168a665b4 to your computer and use it in GitHub Desktop.
ChromaDB + Glove Embeddings : Create a text search engine
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import chromadb | |
| from loguru import logger | |
| class ChromaHelper: | |
| def __init__(self, host, port): | |
| self.host = host | |
| self.port = port | |
| self.chroma_client = None | |
| def setup_client(self): | |
| logger.info("Chroma client setup") | |
| self.chroma_client = chromadb.HttpClient(host=self.host, port=self.port) | |
| def get_collection(self, collection_name): | |
| collection = self.chroma_client.get_or_create_collection(name=collection_name, metadata={"hnsw:space": "cosine"}) | |
| return collection | |
| def feed_document(self, collection_object, input_data): | |
| feed_data = { | |
| "documents": [x["document"] for x in input_data], | |
| "metadatas": [x["metadata"] for x in input_data], | |
| "embeddings": [x["embedding"] for x in input_data], | |
| "ids": [x["id"] for x in input_data] | |
| } | |
| collection_object.add(documents = feed_data["documents"], embeddings=feed_data["embeddings"], | |
| metadatas=feed_data["metadatas"], ids=feed_data["ids"]) | |
| logger.info(collection_object.count()) | |
| def fetch_document(self, collection_object, query_embedding, n_results = 10): | |
| resp = collection_object.query( | |
| query_embeddings = [query_embedding], | |
| n_results = n_results | |
| ) | |
| return resp | |
| def cleanup_collection(self, collection_name): | |
| logger.info("Deleting everything from collection....") | |
| self.chroma_client.delete_collection(collection_name) | |
| logger.info("Everything deleted from collection") | |
| if __name__ == "__main__": | |
| ch = ChromaHelper("localhost", 8001) | |
| ch.setup_client() | |
| facts_collection = ch.get_collection("facts") | |
| dummy_data = [ | |
| { | |
| "document": "This is a fine day", | |
| "metadata": {"chapter": 3}, | |
| "embedding": [1.1, 2.2,3.3], | |
| "id": "abc" | |
| }, | |
| { | |
| "document": "Global warming is causing extreme client events", | |
| "metadata": {"chapter": 7}, | |
| "embedding": [0.5, 3.2, 2.17], | |
| "id": "def" | |
| } | |
| ] | |
| ch.feed_document(facts_collection, dummy_data) | |
| resp = ch.fetch_document(facts_collection, [1.1, 2.0, 3.1]) | |
| print(resp) | |
| ch.cleanup_collection("facts") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from loguru import logger | |
| import numpy as np | |
| import re | |
| class GloVeEmbedding: | |
| def __init__(self, glove_path): | |
| self.glove_path = glove_path | |
| self.embeddings_index = None | |
| def load_glove_embeddings(self): | |
| """ | |
| Load glove txt file path and create a python dictionary | |
| where key will be word(token) and value will be embedding for | |
| that word | |
| """ | |
| self.embeddings_index = {} | |
| with open(self.glove_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| values = line.split() | |
| word = values[0] | |
| coefs = np.asarray(values[1:], dtype='float32') | |
| self.embeddings_index[word] = coefs | |
| logger.info(f'Loaded {len(self.embeddings_index)} word vectors.') | |
| def compute_embedding(self, input_text, embedding_dim): | |
| """ | |
| A text is provided and its 100 dimension embedding | |
| is returned. we are going to use glove.6B.100d | |
| for generating embeddings. | |
| """ | |
| text = input_text.lower() # convert to lowercase | |
| text = re.sub(r'\W+', ' ', text).strip() # remove non-alphanumeric characters | |
| words = text.split() | |
| word_embeddings = [] | |
| for word in words: | |
| if word in self.embeddings_index: | |
| word_embeddings.append(self.embeddings_index[word]) | |
| if not word_embeddings: | |
| # Return a zero vector if no words from the text match GloVe vocab | |
| return np.zeros(embedding_dim) | |
| # Average the word vectors to get a sentence/paragraph vector | |
| return np.mean(word_embeddings, axis=0) | |
| if __name__ == "__main__": | |
| ge = GloVeEmbedding('/home/chick/Desktop/coding/nlp_learn/nlp_in_action_2/glove.6B.100d.txt') | |
| ge.load_glove_embeddings() | |
| embedding_vector = ge.compute_embedding("This is a pretty fine day", 100) | |
| print(embedding_vector) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import re | |
| import uuid | |
| from helpers.chromaConnect import ChromaHelper | |
| from helpers.embeddings import GloVeEmbedding | |
| from loguru import logger | |
| def preprocess_text(text): | |
| text = text.lower() | |
| text = re.sub(r'\W+', ' ', text) # remove non-alphanumeric characters | |
| text = text.strip() | |
| return text | |
| def read_txt_line_by_line(file_path): | |
| """ | |
| Read txt file line by line | |
| """ | |
| lines = [] | |
| # Read the text file | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| lines.append(preprocess_text(line)) | |
| return lines | |
| if __name__ == "__main__": | |
| lines = read_txt_line_by_line("/home/chick/Desktop/coding/nlp_learn/nlp_in_action_2/placeholder.txt") | |
| ge = GloVeEmbedding('/home/chick/Desktop/coding/nlp_learn/nlp_in_action_2/glove.6B.100d.txt') | |
| ge.load_glove_embeddings() | |
| ch = ChromaHelper("localhost", 8001) | |
| ch.setup_client() | |
| facts_collection = ch.get_collection("facts") | |
| for line in lines: | |
| # since we are using glove.6b.100d vocabulary, so embedding dimension must be 100 only | |
| embedding_vector = ge.compute_embedding(line, 100) | |
| payload = [ | |
| { | |
| "document": line, | |
| "metadata": {"data_type": "demo"}, | |
| "embedding": embedding_vector, | |
| "id": str(uuid.uuid4()) | |
| } | |
| ] | |
| logger.info(payload) | |
| ch.feed_document(facts_collection, payload) | |
| # query chromaDB to verify we fed data properly in vector database | |
| # query_text = "How presidential election take place?" | |
| # query_text = "Which country has biggest army?" | |
| query_text = "Which country has highest income group population?" | |
| query_embedding = ge.compute_embedding(preprocess_text(query_text), 100) | |
| resp = ch.fetch_document(facts_collection, query_embedding, 1) | |
| print("----------------Matching Document---------------------") | |
| print(resp) | |
| # delete collection (cleanup operation) | |
| ch.cleanup_collection("facts") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Global warming is the long-term increase in Earth's average temperature caused by human activities like burning fossil fuels. This increase in temperature is causing significant changes to our planet, including rising sea levels, more extreme weather events, and melting glaciers. These changes are already having a negative impact on people, animals, and plants around the world. | |
| The USSR, or the Union of Soviet Socialist Republics, was a communist country that existed from 1922 to 1991. It was the largest country in the world, spanning across Eurasia from Eastern Europe to the Pacific Ocean. The USSR was a superpower, competing with the United States during the Cold War. However, economic problems and political unrest led to its eventual collapse in 1991, resulting in the formation of 15 independent countries. | |
| The US presidential election is a complex process involving several stages. First, candidates from each major political party compete in primary elections to win delegates who will vote for them at the party's national convention. The candidate who wins the most delegates at the convention becomes the party's nominee. The general election then takes place between the nominees of the two major parties. Voters cast ballots for their preferred candidate, and the candidate who wins the most electoral votes becomes the President of the United States. | |
| The country with the largest standing army is China. Its military, known as the People's Liberation Army (PLA), is one of the most powerful in the world and is composed of millions of active personnel. China's large army is a reflection of its status as a major global power and its need to defend its vast territory and interests. | |
| Luxembourg currently holds the title of the country with the highest per capita income. This is largely due to its strong financial sector, which is a major contributor to its economy. Luxembourg's low corporate tax rates and favorable business environment have attracted numerous multinational corporations, boosting its economic growth. Additionally, the country's substantial natural resources, such as iron ore and steel, have also played a role in its economic success. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment