Efficient Similarity Search with FAISS and SQLite in Python

Page content

Summary

This is another component in SmartAnswer and enhanced LLM interface.

In this blog post, we introduce a wrapper class, FaissDB, which integrates FAISS with SQLite or any database to manage document embeddings and enable efficient similarity search. This approach combines FAISS’s vector search capabilities with the storage and querying power of a database, making it ideal for applications such as Retrieval-Augmented Generation (RAG) and recommendation systems.

It builds up this tool PaperSearch.

I also did an introductory post on Faiss.

Advantages of Using FAISS

FAISS provides several advantages when dealing with large-scale similarity searches:

  • High-Speed Search: FAISS is optimized for fast nearest-neighbor search, enabling efficient retrieval of similar vectors even in massive datasets.
  • Scalability: Designed for large-scale applications, FAISS efficiently handles millions of embeddings and performs searches in milliseconds.
  • Multiple Distance Metrics: Supports various similarity measures, including L2 distance and inner product, making it versatile for different use cases.
  • Efficient Indexing Techniques: FAISS offers different indexing strategies, such as IVF (Inverted File Index) and HNSW (Hierarchical Navigable Small World), allowing for further speed optimizations.
  • GPU Acceleration: Supports GPU-based computation, significantly speeding up large-scale searches and embedding comparisons.
  • Seamless Integration: FAISS integrates well with popular ML and AI frameworks, making it easy to use in AI-powered applications.

Overview of the FaissDB Class

The FaissDB class is designed to handle document embeddings, perform similarity searches, and manage the FAISS index efficiently. Below are its main functionalities:

  • Building an Index: Converts documents into embeddings and stores them in a FAISS index.
  • Searching for Similar Documents: Retrieves the most similar documents based on a given query.
  • Storing and Retrieving Embeddings: Saves embeddings in an SQLite database for persistence.
  • Loading Documents: Fetches and cleans text data from the database for processing.

Let’s explore each functionality in detail.

Implementation Details

1. Setting Up the Database Connection

The class initializes an SQLite connection and loads configurations from appConfig.

class FaissDB:
    def __init__(self, config=appConfig):
        self.db = Database()
        self.conn = sqlite3.connect(config["DB_NAME"])
        self.cursor = self.conn.cursor()
        self.config = config
        self.embedding_model = config["EMBEDDING_MODEL"]
        self.document_index = None

2. Generating and Storing Embeddings

FAISS relies on embeddings to perform similarity searches. We generate embeddings using Ollama’s model and store them in the database.

def get_embeddings(self, documents) -> np.ndarray:
    embeddings = []
    for order, doc in enumerate(documents, start=1):
        response = ollama.embeddings(model=self.embedding_model, prompt=doc[1])
        embedding_vector = np.array(response["embedding"], dtype="float32")
        self.cursor.execute(
            """INSERT INTO document_embeddings (faiss_index_id, document_page_split_id, embedding)
            VALUES (?, ?, ?)""",
            (order, doc[0], embedding_vector.tobytes()),
        )
        embeddings.append(response["embedding"])
    self.conn.commit()
    return np.array(embeddings, dtype="float32")

3. Building the FAISS Index

Once embeddings are generated, they are stored in a FAISS index for fast retrieval.

def build_index(self, embeddings: np.ndarray) -> faiss.IndexFlatIP:
    dimension = embeddings.shape[1]
    self.document_index = faiss.IndexFlatL2(dimension)
    self.document_index.add(embeddings)
    return self.document_index

The search method retrieves the most relevant documents based on the query’s embedding.

def search(self, text, k=appConfig["top_k"]):
    if self.document_index is None:
        raise ValueError("Index not set. Please set the index before searching.")
    query_embedding = self.get_embedding(text)
    distances, indices = self.document_index.search(query_embedding, k)
    results = [SearchResult(idx, distances[0][i], self.get_doc(idx)) for i, idx in enumerate(indices[0])]
    return results

5. Retrieving Documents from the Database

The retrieved indices are used to fetch the corresponding documents from SQLite.

def get_doc(faiss_index, config=appConfig):
    conn = sqlite3.connect(config["DB_NAME"])
    cursor = conn.cursor()
    cursor.execute(
        f"""SELECT text FROM document_page_split
            WHERE id IN (
                SELECT document_page_split_id FROM document_embeddings
                WHERE faiss_index_id = {faiss_index})"""
    )
    document = cursor.fetchone()
    return document[0] if document else ""

6. Saving and Loading the Index

To avoid recomputing the FAISS index every time, we save and reload it when needed.

@staticmethod
def save(index, filename):
    faiss.write_index(index, filename)

@staticmethod
def load(filename):
    return faiss.read_index(filename)

Complete code

The current class looks like


import sqlite3
import re
import numpy as np
import faiss
import ollama
import logging

from config import appConfig
from database import Database

logger = logging.getLogger(__name__)


class SearchResult:
    def __init__(self, index, score, doc):
        self.index = index
        self.score = score
        self.doc = doc

    def __repr__(self):
        return f"Score: {self.score}, Index: {self.index}, Doc: {self.doc[:15]} ..."

    def __str__(self):
        return self.doc

    @staticmethod
    def from_tuple(t):
        return SearchResult(t[0], t[1], t[2])

    @staticmethod
    def from_tuples(tuples):
        return [SearchResult.from_tuple(t) for t in tuples]

class FaissDB:
    def __init__(self, config=appConfig):
        self.db = Database()
        self.conn = sqlite3.connect(config["DB_NAME"])
        self.cursor = self.conn.cursor()
        self.config = config
        self.embedding_model = config["EMBEDDING_MODEL"]
        self.document_index = None

    def set_index(self, document_index):
        self.document_index = document_index

    def search(self, text, k = appConfig["top_k"]):
        if self.document_index is None:
            raise ValueError("Index not set. Please set the index before searching.")
        logger.info(f"Searching for similar items to: {text}")
        query_embedding = self.get_embedding(text)
        # Search for the top-k most similar embeddings
        distances, indices = self.document_index.search(query_embedding, k)

        results = []
        for (i, idx) in enumerate(indices[0]):
            doc = self.get_doc(idx)
            self.cursor.execute(
                """INSERT INTO search_results (query, distance, faiss_index, document_text, document_index, document_page_id) 
                VALUES (?, ?, ?, ? ,?, ?)""",
                (text, distances[0][i], idx, doc, idx, 0),
            )
            results.append(SearchResult(idx, distances[0][i], doc))
        self.conn.commit()
        return results


    def build_index(self, embeddings: np.ndarray) -> faiss.IndexFlatIP:
        dimension = embeddings.shape[1]  # Number of features in each embedding
        logger.info(f"Building index with model:{self.embedding_model} with dimension: {dimension}")
        # Create a FAISS index
        self.document_index = faiss.IndexFlatL2(dimension)  # L2 distance for similarity search
        # Add embeddings to the index
        self.document_index.add(embeddings)
        return self.document_index
    
    def as_retriever(self, query_text, top_k=appConfig["top_k"]):
        return self.search(query_text, top_k)
    

    def get_embeddings(self, documents) -> np.ndarray:
        embeddings = []
        for order, doc in enumerate(documents, start=1):
            logger.info(f"Generating embedding for document id: {doc[0]}")
            response = ollama.embeddings(model=self.embedding_model, prompt=doc[1])
            embeddings.append(response["embedding"])
            embedding_vector = np.array(response["embedding"], dtype="float32")
            self.cursor.execute(
                """INSERT INTO document_embeddings (faiss_index_id, document_page_split_id, embedding) 
                VALUES (?, ?, ?)""",
                (order, doc[0], embedding_vector.tobytes()),
            )
        self.conn.commit()
        # Convert embeddings to a NumPy array
        embeddings = np.array(embeddings).astype("float32")
        return embeddings

    @staticmethod
    def get_embedding(text, model=appConfig["EMBEDDING_MODEL"]) -> np.ndarray:
        logger.info(f"Generating embedding for text: {text}")
        response = ollama.embeddings(model=model, prompt=text)
        embeddings = []
        embeddings.append(response["embedding"])
        embeddings = np.array(embeddings).astype("float32")
        return embeddings


    @staticmethod
    def get_doc(faiss_index, config=appConfig):
        logger.info("Loading documents from database for faiss_index {}".format(faiss_index))
        conn = sqlite3.connect(config["DB_NAME"])
        cursor = conn.cursor()
        cursor.execute(
            f"""SELECT text from document_page_split 
                WHERE id IN ( 
                SELECT document_page_split_id 
                FROM document_embeddings 
                WHERE faiss_index_id = {faiss_index})"""
            )
        document = cursor.fetchone()
        if document is None:
            logger.warn(f"Document not found for faiss_index {faiss_index}")
            return ""
        return document[0]


    @staticmethod
    def save(index, filename):
        logger.info("Saving index to file")
        faiss.write_index(index, filename)
        logger.info("Index saved successfully")

    @staticmethod
    def load(filename):
        logger.info("Loading index from file")
        loaded_index = faiss.read_index(filename)
        logger.info("Index loaded successfully")
        return loaded_index

    @staticmethod
    def load_documents_from_db(config, max_documents: int = 1000) -> list:
        """Load documents from the downloads database table."""
        conn = sqlite3.connect(config["DB_NAME"])
        cursor = conn.cursor()
        cursor.execute(
            f"SELECT id, text from document_page_split LIMIT {max_documents}"
        )
        documents = cursor.fetchall()
        conn.close()
        return [(doc[0], FaissDB.clean_text(doc[1])) for doc in documents]

    @staticmethod
    def clean_text(md_text):
        """Remove Markdown syntax and extract clean text."""
        md_text = re.sub(r"\[.*?\]\(.*?\)", "", md_text)  # Remove links
        md_text = re.sub(r"#{1,6}\s*", "", md_text)  # Remove headers
        md_text = re.sub(
            r"(```.*?```|`.*?`)", "", md_text, flags=re.DOTALL
        )  # Remove code blocks
        md_text = re.sub(
            r"\*{1,2}|\_{1,2}", "", md_text
        )  # Remove bold/italic formatting
        md_text = re.sub(r">\s*", "", md_text)  # Remove block quotes
        md_text = re.sub(r"[-+*]\s+", "", md_text)  # Remove bullet points
        md_text = re.sub(r"\d+\.\s+", "", md_text)  # Remove numbered lists
        return md_text.strip()

Using the class


db = Database()
vector_store = FaissDB()
docs = vector_store.load_documents_from_db(appConfig, 30)
embeddings = vector_store.get_embeddings(docs)
index = vector_store.build_index(embeddings)
vector_store.set_index(index)

# Query embedding (generate an embedding for the query)
query_text = "RAG Database"
query_embedding = vector_store.get_embedding(query_text)

# Search for the top-k most similar embeddings
k = 5  # Number of nearest neighbors to retrieve
distances, indices = index.search(query_embedding, k)

# Print results
print("Indices of similar documents:", indices)
print("Distances to similar documents:", distances)
print("Similar documents:")
for i, idx in enumerate(indices[0]):
    print(f"Document {idx}: {docs[idx][1]}")
    doc = FaissDB.get_doc(idx)
    print("---------------------")
    print(doc)
    print("---------------------")
    print("Similarity score:", 1 - distances[0][i])