Unlocking Conversational AI’s Potential: A Deep Dive into Multi-modal RAG for large PDF
In the ever-evolving landscape of artificial intelligence, the quest to create more human-like conversational agents has been a central focus. While text-based interactions have seen significant advancements, the integration of multiple modalities such as text, images, and documents can further enrich the user experience and enhance the agent’s ability to understand and respond effectively. One promising approach in this domain is the utilization of Multi-modal Retrieval-Augmented Generative (RAG) models, which amalgamate the power of multi-modal representations with the capabilities of generative models.
This article contains a reimplementation and in depth explanation of Multi-Modal RAG from langchain’s cookbook.
Understanding Multi-modal RAG
The concept of Retrieval-Augmented Generative models, or RAG, has gained attention for its ability to combine the strengths of generative models like transformers with retrieval-based techniques. By augmenting generative models with retriever components, RAG models can effectively leverage pre-existing knowledge from large-scale text corpora to enhance their responses.
Now, extending this paradigm to a multi-modal setting, Multi-modal RAG integrates various modalities such as text, images, and documents into the retrieval and generation processes. This allows conversational agents to comprehend and generate responses based not only on textual input but also on accompanying visual or contextual information.
Lets dive into technicalities
Steps taken to achieve Multi-modal RAG based conversational agent:
- Extract text, tables, and images from PDF files using partitioning techniques and document structure analysis.
- Categorize extracted elements into text and tables based on their type.
- Generate summaries for text elements using an OpenAI model, optionally splitting long texts into manageable chunks.
- Encode images as base64 strings and summarize them using an OpenAI Vision model.
- Create a multi-vector retriever to index summaries and raw contents of text, tables, and images.
- Initialize a vector store using the Chroma vector store with OpenAI embeddings.
- Construct a multi-modal RAG chain for processing user questions with both textual and visual context.
- Retrieve relevant documents based on a user query using the multi-vector retriever.
- Invoke the multi-modal RAG chain to generate a response to the user query.
from langchain.text_splitter import CharacterTextSplitter
from unstructured.partition.pdf import partition_pdf
# Extract elements from PDF
def extract_pdf_elements(path, fname):
"""
Extract images, tables, and chunk text from a PDF file.
path: File path, which is used to dump images (.jpg)
fname: File name
"""
return partition_pdf(fname,
extract_images_in_pdf=True,
infer_table_structure=True,
chunking_strategy="title",
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000
)
# Categorize elements by type
def categorize_elements(raw_pdf_elements):
"""
Categorize extracted elements from a PDF into tables and texts.
raw_pdf_elements: List of unstructured.documents.elements
"""
tables = []
texts = []
for element in raw_pdf_elements:
if "unstructured.documents.elements.Table" in str(type(element)):
tables.append(str(element))
elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
texts.append(str(element))
return texts, tables
# File path
fpath = "/New-Creta-Brochure.pdf"
fname = "New-Creta-Brochure.pdf"
# Get elements
raw_pdf_elements = extract_pdf_elements(fpath, fname)
# Get text, tables
texts, tables = categorize_elements(raw_pdf_elements)
# Optional: Enforce a specific token size for texts
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
chunk_size=4000, chunk_overlap=0
)
joined_texts = " ".join(texts)
texts_4k_token = text_splitter.split_text(joined_texts)
- The
extract_pdf_elements
function utilizes thepartition_pdf
method from theunstructured.partition.pdf
module to extract images, tables, and chunked text from a PDF file. - The
categorize_elements
function categorizes the extracted elements into text and tables based on their type. - Optionally, the extracted text can be split into chunks of a specific token size using the
CharacterTextSplitter
class. This step involves setting up the text splitter with parameters such as chunk size and overlap and then splitting the joined texts into chunks.
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
import os
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# Generate summaries of text elements
def generate_text_summaries(texts, tables, summarize_texts=False):
"""
Summarize text elements
texts: List of str
tables: List of str
summarize_texts: Bool to summarize texts
"""
# Prompt
prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \
These summaries will be embedded and used to retrieve the raw text or table elements. \
Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} """
prompt = ChatPromptTemplate.from_template(prompt_text)
# Text summary chain
model = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-16k")
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
# Initialize empty summaries
text_summaries = []
table_summaries = []
# Apply to text if texts are provided and summarization is requested
if texts and summarize_texts:
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})
elif texts:
text_summaries = texts
# Apply to tables if tables are provided
if tables:
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})
return text_summaries, table_summaries
# Get text, table summaries
text_summaries, table_summaries = generate_text_summaries(
texts_4k_token, tables, summarize_texts=True
)
- The function
generate_text_summaries
takes lists of text and table elements as input along with a boolean flagsummarize_texts
, indicating whether to summarize text elements. It sets up a prompt template for summarization tasks. - The prompt template contains instructions for the assistant, guiding it to provide concise summaries optimized for retrieval.
- The code initializes a chat-based interaction with the OpenAI model, specifying parameters such as temperature and the model variant (“gpt-3.5-turbo-16k”). The summarization chain is constructed, which consists of the prompt template, the OpenAI model, and an output parser to handle the model’s response.
- If text elements are provided and summarization is requested, the summarization chain is applied to the text elements in batches. The
max_concurrency
parameter controls the maximum number of concurrent requests made to the OpenAI API. - If table elements are provided, the summarization chain is similarly applied to them.
- Finally, the
generate_text_summaries
function is called with the preprocessed text elements (texts_4k_token
), tables, and the flag to summarize text elements. This generates summaries for both text and table elements, utilizing the LangChain library and OpenAI's GPT-3.5 model.
Export OPENAI_API_KEY to environment variables or declare it in the code itself, the OPENAI API KEY can be found here.
import base64
import os
from langchain_core.messages import HumanMessage
def encode_image(image_path):
"""Getting the base64 string"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def image_summarize(img_base64, prompt):
"""Make image summary"""
chat = ChatOpenAI(model="gpt-4-vision-preview", max_tokens=200)
msg = chat.invoke(
[
HumanMessage(
content=[
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
},
]
)
]
)
return msg.content
def generate_img_summaries(path):
"""
Generate summaries and base64 encoded strings for images
path: Path to list of .jpg files extracted by Unstructured
"""
# Store base64 encoded images
img_base64_list = []
# Store image summaries
image_summaries = []
# Prompt
from tqdm import tqdm
import time
prompt = """You are an assistant tasked with summarizing images for retrieval. \
These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval."""
count = 0
for img_file in tqdm(sorted(os.listdir(path)), desc="Processing images"):
if img_file.endswith(".jpg"):
img_path = os.path.join(path, img_file)
try:
base64_image = encode_image(img_path)
img_base64_list.append(base64_image)
image_summaries.append(image_summarize(base64_image, prompt))
count += 1
except Exception as e:
print(f"Error processing image {img_file}: {e}")
print("Waiting for 60 seconds before continuing...")
time.sleep(60) # Wait for 60 seconds
return img_base64_list, image_summaries
# Image summaries
img_base64_list, image_summaries = generate_img_summaries("figures/")
encode_image
function reads an image file from a given path and encodes it into a base64 string. The encoded string is returned after decoding it into UTF-8 format.- The
image_summarize
function takes a base64-encoded image and a prompt as input. It initializes a chat session with the GPT-4 Vision model and constructs a message containing both the prompt and the image URL encoded in base64 format. - The
generate_img_summaries
function processes a directory containing JPEG images. It iterates through each image file, encoding it into base64 format, and generating a summary using theimage_summarize
function. - A prompt is defined within the
generate_img_summaries
function, instructing the assistant to provide concise summaries optimized for retrieval. tqdm library is used to display a progress bar during image processing. - Finally, the
generate_img_summaries
function is called with the directory path containing the images. This generates base64-encoded images and summaries for each image, utilizing the GPT-4 Vision model for image summarization.
import uuid
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
def create_multi_vector_retriever(
vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images
):
"""
Create retriever that indexes summaries, but returns raw images or texts
"""
# Initialize the storage layer
store = InMemoryStore()
id_key = "doc_id"
# Create the multi-vector retriever
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=store,
id_key=id_key,
)
# Helper function to add documents to the vectorstore and docstore
def add_documents(retriever, doc_summaries, doc_contents):
doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
summary_docs = [
Document(page_content=s, metadata={id_key: doc_ids[i]})
for i, s in enumerate(doc_summaries)
]
retriever.vectorstore.add_documents(summary_docs)
retriever.docstore.mset(list(zip(doc_ids, doc_contents)))
# Add texts, tables, and images
# Check that text_summaries is not empty before adding
if text_summaries:
add_documents(retriever, text_summaries, texts)
# Check that table_summaries is not empty before adding
if table_summaries:
add_documents(retriever, table_summaries, tables)
# Check that image_summaries is not empty before adding
if image_summaries:
add_documents(retriever, image_summaries, images)
return retriever
# The vectorstore to use to index the summaries
vectorstore = Chroma(
collection_name="rag-storage", embedding_function=OpenAIEmbeddings()
)
# Create retriever
retriever_multi_vector_img = create_multi_vector_retriever(
vectorstore,
text_summaries,
texts,
table_summaries,
tables,
image_summaries,
img_base64_list,
)
- The
create_multi_vector_retriever
function takes various summaries and corresponding raw contents as input, including text summaries, text contents, table summaries, table contents, image summaries, and image base64-encoded strings. - It initializes an in-memory document store (
store
) and sets up a multi-vector retriever (retriever
) with the specified vector store (vectorstore
). Theid_key
parameter determines the key used to identify documents within the retriever. - Inside the function, a helper function
add_documents
is defined to add documents to both the vector store and the document store. It generates unique UUIDs for each document, creates document objects with summaries as page content, and adds them to the vector store. It also adds raw document contents to the document store using the UUIDs as keys. - The code checks if each type of summary is non-empty before adding documents of that type to the retriever.
- We initialize a vector store (
vectorstore
) using the Chroma vector store with an OpenAI embedding function. - Using the
create_multi_vector_retriever
function, a multi-vector retriever (retriever_multi_vector_img
) is created by providing the vector store and summaries/contents for texts, tables, and images.
import io
import re
from IPython.display import HTML, display
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from PIL import Image
def looks_like_base64(sb):
"""Check if the string looks like base64"""
return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
def is_image_data(b64data):
"""
Check if the base64 data is an image by looking at the start of the data
"""
image_signatures = {
b"\xFF\xD8\xFF": "jpg",
b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
b"\x47\x49\x46\x38": "gif",
b"\x52\x49\x46\x46": "webp",
}
try:
header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes
for sig, format in image_signatures.items():
if header.startswith(sig):
return True
return False
except Exception:
return False
def resize_base64_image(base64_string, size=(128, 128)):
"""
Resize an image encoded as a Base64 string
"""
# Decode the Base64 string
img_data = base64.b64decode(base64_string)
img = Image.open(io.BytesIO(img_data))
# Resize the image
resized_img = img.resize(size, Image.LANCZOS)
# Save the resized image to a bytes buffer
buffered = io.BytesIO()
resized_img.save(buffered, format=img.format)
# Encode the resized image to Base64
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def split_image_text_types(docs):
"""
Split base64-encoded images and texts
"""
b64_images = []
texts = []
for doc in docs:
# Check if the document is of type Document and extract page_content if so
if isinstance(doc, Document):
doc = doc.page_content
if looks_like_base64(doc) and is_image_data(doc):
doc = resize_base64_image(doc, size=(1300, 600))
b64_images.append(doc)
else:
texts.append(doc)
return {"images": b64_images, "texts": texts}
def img_prompt_func(data_dict):
"""
Join the context into a single string
"""
formatted_texts = "\n".join(data_dict["context"]["texts"])
messages = []
# Adding image(s) to the messages if present
if data_dict["context"]["images"]:
for image in data_dict["context"]["images"]:
image_message = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
}
messages.append(image_message)
# Adding the text for analysis
text_message = {
"type": "text",
"text": (
"You are a car dealer and salesperson.\n"
"You will be given a mixed of text, tables, and image(s) usually of charts.\n"
"Use this information to provide quality advice related to the user question. \n"
f"User-provided question: {data_dict['question']}\n\n"
"Text and / or tables:\n"
f"{formatted_texts}"
),
}
messages.append(text_message)
return [HumanMessage(content=messages)]
def multi_modal_rag_chain(retriever):
"""
Multi-modal RAG chain
"""
# Multi-modal LLM
model = ChatOpenAI(temperature=0, model="gpt-4-vision-preview", max_tokens=1024)
# RAG pipeline
chain = (
{
"context": retriever | RunnableLambda(split_image_text_types),
"question": RunnablePassthrough(),
}
| RunnableLambda(img_prompt_func)
| model
| StrOutputParser()
)
return chain
# Create RAG chain
chain_multimodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)
- The
looks_like_base64
function checks if a given string resembles a base64-encoded string. - The
is_image_data
function checks if a base64-encoded string represents image data by examining its header. - The
resize_base64_image
function decodes a base64-encoded image string, resizes the image, and re-encodes it as a base64 string. - The
split_image_text_types
function splits a list of documents into image and text types based on their content. It checks each document for image-like base64 strings and resizes them accordingly. - The
img_prompt_func
function constructs a message containing user-provided questions, contextual text, and image(s) in a format suitable for input to the RAG model. It formats the context text and adds image messages if images are present. - The
multi_modal_rag_chain
function constructs a multi-modal RAG pipeline. It consists of retriever for fetching relevant context data, lambda functions for preprocessing, an image prompt function, the GPT-4 Vision model, and an output parser. - Finally, the
multi_modal_rag_chain
function is called with the retriever created earlier to create the RAG chain (chain_multimodal_rag
).
# Check retrieval
query = "Can you give me a brief description on the document."
docs = retriever_multi_vector_img.get_relevant_documents(query, limit=6)
len(docs)
A retrieval query is executed using the multi-vector retriever (retriever_multi_vector_img
). The query asks for a brief description of the document. The get_relevant_documents
method of the retriever is used to retrieve relevant documents based on the query, with a limit of 6 documents specified.
### Input
response = chain_multimodal_rag.invoke(query)
response = response.split('.')
# Print each line in a new line
for line in response:
print(line)
## Output
The document appears to be a brochure or specification sheet for a Hyundai vehicle, possibly a model lineup that includes different variants of a car
The first two images contain detailed tables listing various features and specifications for different trim levels of the vehicle
These tables compare the features available in each variant, such as engine options, interior features, safety equipment, wheels, lighting, and infotainment systems
The variants listed include "E," "EX," "S," "SX," and "SX(O)," with each subsequent variant typically offering more features than the previous
The document outlines the technical specifications for three engine types: a 1
5L MPi petrol engine, a 1
5L U2 CRDi diesel engine, and a 1
5L Turbo GDi petrol engine
It also provides information on dimensions, transmission types, suspension, brakes, and tire sizes
The third image shows a Hyundai vehicle driving on a scenic road with mountains in the background, which is likely part of the brochure's visual appeal to showcase the vehicle in an attractive setting
The last image is a small, low-resolution picture of a blue car, which seems to be part of the brochure but is not clear enough to provide specific details about the model or its features
Overall, the document is designed to inform potential buyers about the options and features available for a particular Hyundai car model, allowing them to compare different variants and make an informed decision based on their preferences and needs
This code snippet invokes the multi-modal RAG chain (chain_multimodal_rag
) to generate a response to a given query. The query is passed to the chain's invoke
method, which triggers the generation of a response based on the provided context and question.
Advantages of Multi-modal RAG for Conversational Agents:
- Enhanced Understanding: By incorporating multiple modalities, multi-modal RAG models can better grasp the nuances and context of user queries. Visual cues from images or additional information from documents can provide valuable context for generating more relevant and accurate responses.
- Richer Responses: The ability to draw upon diverse modalities enables conversational agents to generate responses that are more informative and engaging. Whether it’s providing visual explanations, referencing relevant documents, or incorporating multimedia content, multi-modal RAG can enrich the conversation experience.
- Improved User Interaction: Multi-modal interactions mimic real-world communication more closely, making the user experience more intuitive and natural. Users can communicate with the agent using a combination of text, images, or documents, enabling a more fluid and expressive interaction.
- Broader Knowledge Integration: Leveraging multi-modal representations allows conversational agents to tap into a wider range of knowledge sources. Instead of relying solely on textual data, agents can incorporate information from visual sources and documents, expanding their knowledge base and improving the quality of responses.
Challenges and Future Directions:
While multi-modal RAG holds great promise for advancing conversational AI, several challenges must be addressed to realize its full potential:
- Data Quality and Diversity: Ensuring the availability of high-quality, diverse multi-modal datasets is crucial for training robust models. Gathering and curating such datasets pose significant challenges, particularly in domains where multi-modal data is scarce.
- Model Complexity and Scalability: Integrating multiple modalities into a single model increases complexity and computational demands. Efficient architectures and training procedures must be developed to enable scalable deployment of multi-modal RAG models.
- Ethical and Privacy Concerns: As conversational agents become more adept at processing diverse data types, ensuring user privacy and handling sensitive information ethically become paramount. Robust mechanisms for data anonymization and consent management are essential.
- Evaluation Metrics and Benchmarks: Establishing appropriate evaluation metrics and benchmarks for multi-modal conversational AI systems is essential for assessing their performance accurately. Metrics should account for factors such as response relevance, coherence, and multi-modality integration.
Conclusion
Multi-modal RAG represents a significant step forward in the evolution of conversational AI, offering the potential to create more immersive and contextually rich interactions. By seamlessly integrating text, images, and documents, multi-modal RAG models can comprehend user queries more effectively and generate responses that are not only informative but also engaging. While challenges remain, continued research and development in this area hold the promise of unlocking new frontiers in human-like communication between machines and humans.
GitHub Repository
Explore the implementation and usage of Multi-modal RAG for conversational AI in my GitHub repository.
References:
LangChain Documentation:
Check out Langchain’s cookbook for more interesting use cases:
OpenAI: