
使用 pydanticai 和 postgresql 创建一个 rag ai 代理:开发者的全面分步指南
使用 PydanticAI 创建 RAG 应用程序
在本文中,我将逐步向您展示如何使用 PydanticAI 创建 RAG (Retrieval Augmented Generation) 应用程序。与手动实现 RAG 相比,代码更简单、更清晰。
先决条件:
PydanticAI RAG Agent
在开始之前,您需要以下内容:
- Python 3.9+
- 在线文档存储,至少包含一个 PDF 文件,或者您可以使用此 URL 包含一个虚构宇宙示例 PDF 文件:https://skolo-ai-agent.ams3.cdn.digitaloceanspaces.com/pydantic/the_seven_realms.pdf
- OpenAI API 密钥
- 一个 IDE
数据库设置
我们需要设置数据库,获取连接字符串,并使用以下模式创建一个干净的表:
DB_SCHEMA = """
CREATE EXTENSION IF NOT EXISTS vector;
CREATE TABLE IF NOT EXISTS text_chunks (
id serial PRIMARY KEY,
chunk text NOT NULL,
embedding vector(1536) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_text_chunks_embedding ON text_chunks USING hnsw (embedding vector_l2_ops);
"""
数据库的其余代码如下所示:
from __future__ import annotations as _annotations
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import List
import pydantic_core
import asyncpg
import httpx
import fitz
import json
from pydantic import BaseModel
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIModel
from openai import AsyncOpenAI
DB_DSN = "database-dsn-goes-here"
OPENAI_API_KEY = "sk-proj-your-api-key-goes-here"
@asynccontextmanager
async def database_connect(create_db: bool = False):
"""Manage database connection pool."""
pool = await asyncpg.create_pool(DB_DSN)
try:
if create_db:
async with pool.acquire() as conn:
await conn.execute(DB_SCHEMA)
yield pool
finally:
await pool.close()
class Chunk(BaseModel):
chunk: str
async def split_text_into_chunks(text: str, max_words: int = 400, overlap: float = 0.2) -> List[Chunk]:
"""Split long text into smaller chunks based on word count with overlap."""
words = text.split()
chunks = []
step_size = int(max_words * (1 - overlap))
for start in range(0, len(words), step_size):
end = start + max_words
chunk_words = words[start:end]
if chunk_words:
chunks.append(Chunk(chunk=" ".join(chunk_words)))
return chunks
async def insert_chunks(pool: asyncpg.Pool, chunks: List[Chunk], openai_client: AsyncOpenAI):
"""Insert text chunks into the database with embeddings."""
for chunk in chunks:
embedding_response = await openai_client.embeddings.create(
input=chunk.chunk,
model="text-embedding-3-small"
)
# Extract embedding data and convert to JSON format
assert len(embedding_response.data) == 1, f"Expected 1 embedding, got {len(embedding_response.data)}"
embedding_data = json.dumps(embedding_response.data[0].embedding)
# Insert into the database
await pool.execute(
'INSERT INTO text_chunks (chunk, embedding) VALUES ($1, $2)',
chunk.chunk,
embedding_data
)
async def download_pdf(url: str) -> bytes:
"""Download PDF from a given URL."""
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
return response.content
def extract_text_from_pdf(pdf_content: bytes) -> str:
"""Extract text from PDF content."""
document = fitz.open(stream=pdf_content, filetype="pdf")
text = ""
for page_num in range(document.page_count):
page = document.load_page(page_num)
text += page.get_text()
return text
async def add_pdf_to_db(url: str):
"""Download PDF, extract text, and add to the embeddings database."""
openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
pdf_content = await download_pdf(url)
text = extract_text_from_pdf(pdf_content)
async with database_connect(create_db=True) as pool:
chunks = await split_text_into_chunks(text)
await insert_chunks(pool, chunks, openai_client)
async def update_db_with_pdf(url: str):
"""Download PDF, extract text, and update the embeddings database."""
openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
pdf_content = await download_pdf(url)
text = extract_text_from_pdf(pdf_content)
async with database_connect() as pool:
chunks = await split_text_into_chunks(text)
await insert_chunks(pool, chunks, openai_client)
async def execute_url_pdf(url: str):
"""
Check if the database table exists, and call the appropriate function
to handle the PDF URL.
"""
async with database_connect() as pool:
table_exists = await pool.fetchval("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'text_chunks'
)
""")
if table_exists:
# If the table exists, update the database
print("Table exists. Updating database with PDF content.")
await update_db_with_pdf(url)
else:
# If the table does not exist, add the PDF and create the table
print("Table does not exist. Adding PDF and creating the table.")
await add_pdf_to_db(url)
以上代码将执行以下操作:
- 使用您提供的 DB_DSN 字符串连接到数据库,如果表尚不存在,则根据模式创建一个新表。
- 获取 PDF 文档并从文档中提取文本。
- 将文档分成 20% 重叠的块。
- 使用 OpenAI 嵌入模型获取块并创建嵌入。
- 将创建的嵌入与块一起保存到数据库中。
此代码还允许您根据需要将新的 PDF 文档添加到同一表中,因此您可以上传多个文档。
💡 提示:此代码目前仅允许上传 PDF 文件。您可以通过添加函数来扩展它,以提取其他文档类型的内容,例如:Word 文档、Excel 电子表格、Powerpoint 等。
🚀 通过使用更复杂的块方法来保留上下文并优化 RAG 输出,改进代码。
PydanticAI RAG Agent 代码
添加以下代码以创建一个带有检索工具的 PydanticAI agent。RAG 仅作为检索工具添加到 agent 中,这意味着您可以将此工具与其他许多工具一起添加,例如我们在上一篇文章中探讨的 AI Agent CRUD 工具。
@dataclass
class Deps:
pool: asyncpg.Pool
openai: AsyncOpenAI
### 初始化 agent
model = OpenAIModel("gpt-4o", api_key=OPENAI_API_KEY)
rag_agent = Agent(model, deps_type=Deps)
@rag_agent.tool
async def retrieve(context: RunContext[Deps], search_query: str) -> str:
"""根据搜索查询检索文档部分。
参数:
context: 调用上下文。
search_query: 搜索查询。
"""
print("正在检索..............")
embedding = await context.deps.openai.embeddings.create(
input=search_query,
model='text-embedding-3-small',
)
assert (
len(embedding.data) == 1
), f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}'
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
rows = await context.deps.pool.fetch(
'SELECT chunk FROM text_chunks ORDER BY embedding <-> $1 LIMIT 5',
embedding_json,
)
from_db = '\n\n'.join(
f'# Chunk:\n{row["chunk"]}\n'
for row in rows
)
return from_db
async def run_agent(question: str):
"""运行 agent 并执行基于 RAG 的问答的入口点。"""
## 设置 agent 和依赖项
async with database_connect() as pool:
openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
async with database_connect(False) as pool:
deps = Deps(openai=openai_client, pool=pool)
base_instruction = f"使用 'retrieve' 工具获取信息以帮助您回答这个问题:{question}"
answer = await rag_agent.run(base_instruction, deps=deps)
return answer.data
我们的 agent 称为 rag_agent
,我们使用一个包含数据库连接和 OpenAI 客户端的 Dependency 类对其进行初始化。此依赖项将允许 rag_agent
连接到数据库并使用数据库中的信息回答问题,只需使用该工具即可。
这很棒,因为 AI agent 将决定是否应该调用该工具,并且工作将在后台进行,我们将获得完整的响应,而无需执行多个 AI 调用来首先过滤数据库,然后回答问题。
前端应用程序:Streamlit
可以通过前端 streamlit 应用程序访问 agent。该应用程序的代码如下所示:
import streamlit as st
import asyncio
from aiagent import execute_url_pdf, run_agent
### Streamlit 页面配置
st.set_page_config(
page_title="AI 助手 📚🤖",
page_icon="📚",
layout="wide"
)
### 标题
st.title("AI 助手 📚🤖")
st.write("与您的基于 PDF 的 AI 助手互动。使用以下选项上传 PDF 或提问。")
### 带有两列的布局
col1, col2 = st.columns(2)
### 第 1 列:通过 URL 上传 PDF
with col1:
st.subheader("📄 上传 PDF")
pdf_url = st.text_input("输入 PDF 文档的 URL:", placeholder="https://example.com/document.pdf")
if st.button("📥 将 PDF 添加到数据库"):
if pdf_url:
with st.spinner("正在处理 PDF 并更新数据库..."):
try:
asyncio.run(execute_url_pdf(pdf_url))
st.success("PDF 已成功处理并添加到数据库!")
except Exception as e:
st.error(f"处理 PDF 时出错:{e}")
else:
st.warning("请输入有效的 URL。")
### 第 2 列:提问
with col2:
st.subheader("❓ 提问")
question = st.text_input("输入您的问题:", placeholder="全栈开发人员的职责是什么?")
if st.button("🔍 获取答案"):
if question:
with st.spinner("思考中..."):
try:
answer = asyncio.run(run_agent(question))
st.success("这是答案:")
st.write(answer)
except Exception as e:
st.error(f"获取答案时出错:{e}")
else:
st.warning("请输入有效的问题。")
### 页脚
st.markdown("---")
st.write("✨ 由 [Skolo Online](https://skolo.online) 和 Pydantic AI 提供支持")
从前端,您可以:
- 上传 PDF 文档
- 向 PydanticAI Agent 提问有关您上传的 PDF 的问题
前端应该如下所示:
Streamlit PydanticAI RAG 应用程序