
释放推理的力量:四阶段文本到 sql 代理实施!
DeepSeek R1 发布的一个酷炫成果
DeepSeek R1 发布的一个酷炫成果是 LLM 现在开始在响应中显示 Thinking <think>
令牌,类似于 ChatGPT-o1 和 o3-mimi。鼓励 LLM 更深入地思考有很多好处:
- 不再是黑箱答案!您可以实时看到 LLM 响应背后的推理。
- 用户可以洞察模型如何得出结论。
- 清晰地发现并修正提示错误。
- 透明度使 AI 决策感觉更加可靠。
- 当人类与 AI 共享推理时,协作变得轻而易举。
所以我们在这里,我构建了一个 RAG,它将类似的推理过程(CoT 响应)引入 LangGraph SQL 代理,并进行工具调用。这是一个 ReAct 代理(Reason + Act),它将 LangGraph 的 SQL toolkit 与基于图的执行相结合。以下是它的工作原理:
现在,让我们了解一下思维过程。
代理从一个系统提示开始,结构化它的思考:
我已经绘制出我们的 SQL 代理从接收到问题到返回最终查询的确切步骤:
Please translate the following text to Chinese. Requirements: 1. Keep code blocks (```) unchanged 2. Keep inline code (`) unchanged 3. Keep Markdown formatting 4. Use these technical term translations: DeepSeek R1 -> DeepSeek R1 LLM -> LLM ChatGPT-o1 -> ChatGPT-o1 ChatGPT-o3-mimi -> ChatGPT-o3-mimi RAG -> RAG CoT -> CoT LangGraph -> LangGraph SQL -> SQL ReAct -> ReAct SQL toolkit -> SQL toolkit SQLite -> SQLite TypedDict -> TypedDict Annotated -> Annotated Optional -> Optional SQLDatabaseToolkit -> SQLDatabaseToolkit SQLDatabase -> SQLDatabase create_engine -> create_engine ChatOpenAI -> ChatOpenAI ToolNode -> ToolNode StateGraph -> StateGraph MemorySaver -> MemorySaver API_KEY -> API_KEY chinook.db -> chinook.db SELECT -> SELECT LIMIT -> LIMIT ORDER BY -> ORDER BY JOIN -> JOIN GROUP BY -> GROUP BY SUM -> SUM UnitPrice -> UnitPrice Quantity -> Quantity Track -> Track InvoiceLine -> InvoiceLine TotalRevenue -> TotalRevenue error_check -> error_check final_check -> final_check reasoning -> reasoning analysis -> analysis query -> query
Text: ## Four-Phase Thinking Process
Reasoning Phase (<reasoning>
tag)
- Explains information needs
- Describes expected outcomes
- Identifies challenges
- Justifies approach
Analysis Phase (<analysis>
tag)
- Tables and joins needed
- Required columns
- Filters and conditions
- Ordering/grouping logic
Query Phase (<query>
tag)
- Constructs SQL following rules:
- SELECT statements only
- Proper syntax
- Default LIMIT 10
- Verified schema
Verification Phase (<error_check>
and <final_check>
tags)
- Validates reasoning
- Confirms approach
- Checks completeness
- Verifies output
Here’s a visualization of the process:
Here’s a full prompt template:
query_gen_system = """
I am an SQL expert who helps analyze database queries. I have access to tools for interacting with the database. When given a question, I'll think through it carefully and explain my reasoning in natural language.
Then I'll walk through my analysis process:
1. First, I'll understand what tables and data I need
2. Then, I'll verify the schema and relationships
3. Finally, I'll construct an appropriate SQL query
For each query, I'll think about:
- What tables are involved and how they connect
- Any special conditions or filters needed
- How to handle potential edge cases
- The most efficient way to get the results
<reasoning>
I will **always** include this section before writing a query. Here, I will:
- Explain what information I need and why
- Describe my expected outcome
- Identify potential challenges
- Justify my query structure
If this section is missing, I will rewrite my response to include it.
</reasoning>
<analysis>
Here I break down the key components needed for the query:
- Required tables and joins
- Important columns and calculations
- Any specific filters or conditions
- Proper ordering and grouping
</analysis>
<query>
The final SQL query
</query>
<error_check>
If there's an error, I'll explain:
- What went wrong
- Why it happened
- How to fix it
</error_check>
<final_check>
Before finalizing, I will verify:
- Did I include a clear reasoning section?
- Did I explain my approach before querying?
- Did I provide an analysis of the query structure?
- If any of these are missing, I will revise my response.
</final_check>
Important rules:
1. Only use SELECT statements, no modifications
2. Verify all schema assumptions
3. Use proper SQLite syntax
4. Limit results to 10 unless specified
5. Double-check all joins and conditions
6. Always include tool_analysis and tool_reasoning for each tool call
"""
The main part of our agent’s thinking process is complete — we’ve covered the flow and the detailed prompt that guides its reasoning. Now, let’s move to the next part: Building the LangGraph SQL Agent.
First, let’s look at the graph implementation:
query_gen_prompt = ChatPromptTemplate.from_messages([
("system", query_gen_system),
MessagesPlaceholder(variable_name="messages"),
])
query_gen_model = query_gen_prompt | ChatOpenAI(
model="gpt-4o-mini", temperature=0).bind_tools(tools=sql_db_toolkit_tools)
class State(TypedDict):
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
def query_gen_node(state: State):
return {"messages": [query_gen_model.invoke(state["messages"])]}
checkpointer = MemorySaver()
graph_builder.add_node("query_gen", query_gen_node)
query_gen_tools_node = ToolNode(tools=sql_db_toolkit_tools)
graph_builder.add_node("query_gen_tools", query_gen_tools_node)
graph_builder.add_conditional_edges(
"query_gen",
tools_condition,
{"tools": "query_gen_tools", END: END},
)
graph_builder.add_edge("query_gen_tools", "query_gen")
graph_builder.set_entry_point("query_gen")
graph = graph_builder.compile(checkpointer=checkpointer)
Now, here’s the crucial part — how we extract and process the thinking process from our agent’s responses:
- Extracts each thinking phase from reasoning tags we defined
- Formats the output in a readable way
- Captures the final SQL query when generated
- Shows the agent’s thought process in real-time
def extract_section(text: str, section: str) -> str:
pattern = f"<{section}>(.*?)</{section}>"
match = re.search(pattern, text, re.DOTALL)
return match.group(1).strip() if match else ""
def process_event(event: Dict[str, Any]) -> Optional[str]:
if 'query_gen' in event:
messages = event['query_gen']['messages']
for message in messages:
content = message.content if hasattr(message, 'content') else ""
reasoning = extract_section(content, "reasoning")
if reasoning:
print(format_section("", reasoning))
analysis = extract_section(content, "analysis")
if analysis:
print(format_section("", analysis))
error_check = extract_section(content, "error_check")
if error_check:
print(format_section("", error_check))
final_check = extract_section(content, "final_check")
if final_check:
print(format_section("", final_check))
if hasattr(message, 'tool_calls'):
for tool_call in message.tool_calls:
tool_name = tool_call['name']
if tool_name == 'sql_db_query':
return tool_call['args']['query']
query = extract_section(content, "query")
if query:
# Try to extract SQL between triple backticks
sql_match = re.search(
r'```sql\n(.*?)\n```', query, re.DOTALL)
if sql_match:
return format_section("", query)
return None
To use it, we simply stream the result from the graph.stream:
def run_query(query_text: str):
print(f"\nAnalyzing: {query_text}")
for event in graph.stream({"messages": [("user", query_text)]},
config={"configurable": {"thread_id": 12}}):
if sql := process_event(event):
print(f"\nGenerated SQL: {sql}")
return sql
Here’s the complete code to make this all work:
import os
from typing import Dict, Any
import re
from typing_extensions import TypedDict
from typing import Annotated, Optional
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
def _set_env(key: str):
if key not in os.environ:
os.environ['OPENAI_API_KEY'] = key
_set_env("API_KEY")
db_file = "chinook.db"
engine = create_engine(f"sqlite:///{db_file}")
db = SQLDatabase(engine=engine)
toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model="gpt-4o-mini"))
sql_db_toolkit_tools = toolkit.get_tools()
query_gen_system = """
I am an SQL expert who helps analyze database queries. I have access to tools for interacting with the database. When given a question, I'll think through it carefully and explain my reasoning in natural language.
Then I'll walk through my analysis process:
1. First, I'll understand what tables and data I need
2. Then, I'll verify the schema and relationships
3. Finally, I'll construct an appropriate SQL query
For each query, I'll think about:
- What tables are involved and how they connect
- Any special conditions or filters needed
- How to handle potential edge cases
- The most efficient way to get the results
<reasoning>
I will **always** include this section before writing a query. Here, I will:
- Explain what information I need and why
- Describe my expected outcome
- Identify potential challenges
- Justify my query structure
If this section is missing, I will rewrite my response to include it.
</reasoning>
<analysis>
Here I break down the key components needed for the query:
- Required tables and joins
- Important columns and calculations
- Any specific filters or conditions
- Proper ordering and grouping
</analysis>
<query>
The final SQL query
</query>
<error_check>
If there's an error, I'll explain:
- What went wrong
- Why it happened
- How to fix it
</error_check>
<final_check>
Before finalizing, I will verify:
- Did I include a clear reasoning section?
- Did I explain my approach before querying?
- Did I provide an analysis of the query structure?
- If any of these are missing, I will revise my response.
</final_check>
Important rules:
1. Only use SELECT statements, no modifications
2. Verify all schema assumptions
3. Use proper SQLite syntax
4. Limit results to 10 unless specified
5. Double-check all joins and conditions
6. Always include tool_analysis and tool_reason
```python
graph_builder.add_conditional_edges(
"query_gen",
tools_condition,
{"tools": "query_gen_tools", END: END},
)
graph_builder.add_edge("query_gen_tools", "query_gen")
graph_builder.set_entry_point("query_gen")
graph = graph_builder.compile(checkpointer=checkpointer)
def format_section(title: str, content: str) -> str:
if not content:
return ""
return f"\n{content}\n"
def extract_section(text: str, section: str) -> str:
pattern = f"<{section}>(.*?)</{section}>"
match = re.search(pattern, text, re.DOTALL)
return match.group(1).strip() if match else ""
def process_event(event: Dict[str, Any]) -> Optional[str]:
if 'query_gen' in event:
messages = event['query_gen']['messages']
for message in messages:
content = message.content if hasattr(message, 'content') else ""
reasoning = extract_section(content, "reasoning")
if reasoning:
print(format_section("", reasoning))
analysis = extract_section(content, "analysis")
if analysis:
print(format_section("", analysis))
error_check = extract_section(content, "error_check")
if error_check:
print(format_section("", error_check))
final_check = extract_section(content, "final_check")
if final_check:
print(format_section("", final_check))
if hasattr(message, 'tool_calls'):
for tool_call in message.tool_calls:
tool_name = tool_call['name']
if tool_name == 'sql_db_query':
return tool_call['args']['query']
query = extract_section(content, "query")
if query:
sql_match = re.search(
r'```sql\n(.*?)\n```', query, re.DOTALL)
if sql_match:
return format_section("", query)
return None
def run_query(query_text: str):
print(f"\n分析您的问题: {query_text}")
final_sql = None
for event in graph.stream({"messages": [("user", query_text)]},
config={"configurable": {"thread_id": 12}}):
sql = process_event(event)
if sql:
final_sql = sql
if final_sql:
print(
"\n根据我的分析,以下是将回答您问题的 SQL 查询:")
print(f"\n{final_sql}")
return final_sql
def interactive_sql():
print("\n欢迎使用 SQL 助手!输入 'exit' 退出。")
while True:
try:
query = input("\n您想知道什么? ")
if query.lower() in ['exit', 'quit']:
print("\n感谢您使用 SQL 助手!")
break
run_query(query)
except KeyboardInterrupt:
print("\n感谢您使用 SQL 助手!")
break
except Exception as e:
print(f"\n发生错误: {str(e)}")
print("请尝试使用不同的查询。")
if __name__ == "__main__":
interactive_sql()
让我们运行它并看看!这是代理在行动:
我已经使用几个模型(gpt4o,gpt4o-mini,Claude 3.5 Haiku)测试了这个实现,结果令人鼓舞。以下是一个示例思考输出:
按收入计算,前 5 名畅销曲目是什么?
分析您的问题: 按收入计算,前 5 名畅销曲目是什么?
为了确定按收入计算的前 5 名畅销曲目,我需要分析包含曲目及其销售信息的相关表。通常,这将涉及一个包含曲目详细信息的 “tracks” 表和一个记录销售交易的 “sales” 或 “orders” 表。
我预期的结果是按销售产生的总收入排序的前 5 名曲目列表。这里的挑战是确保我正确连接表格并聚合销售数据,以计算每首曲目的总收入。
我将构建查询以:
- 在曲目 ID 上将 “tracks” 表与 “sales” 表连接。
- 对每首曲目的收入进行求和。
- 按总收入降序排列结果。
- 将结果限制为前 5 名曲目。
我将首先检查数据库架构以确认表的名称及其关系。
- 所需表: “tracks” 和 “sales”(或等效名称)。
- 重要列: 曲目 ID、曲目名称和销售收入。
- 特定过滤器: 无需,但我将聚合销售数据。
- 正确排序: 按总收入降序排列,限制为 5 个结果。
现在,我将检查数据库中现有的表,以确认它们的名称和结构。
现在我已经确认了相关表及其结构,我可以继续构建 SQL 查询。“Track” 表包含每首曲目的信息,包括其 ID 和价格。“InvoiceLine” 表记录每笔销售,通过 TrackId 与 “Track” 表链接,并包括销售数量和单价。
要计算每首曲目的总收入,我将:
- 在 TrackId 上将 “Track” 表与 “InvoiceLine” 表连接。
- 将 UnitPrice 乘以 Quantity 以获取每笔销售的收入。
- 对每首曲目的收入进行求和。
- 按总收入降序排列结果。
- 将结果限制为前 5 名曲目。
这种方法将确保我准确捕捉按收入计算的畅销曲目。
- 所需表: “Track” 和 “InvoiceLine”。
- 重要列: TrackId、名称(来自 Track)、UnitPrice、数量(来自 InvoiceLine)。
- 特定过滤器: 无需,因为我想要所有曲目。
- 正确排序: 按总收入降序排列,限制为 5 个结果。
现在,我将根据此分析构建 SQL 查询。
- 我包含了一个清晰的推理部分,解释了查询的必要性。
- 我提供了查询结构的分析,详细说明了涉及的表和列。
- 我执行了查询并在没有错误的情况下收到了结果。
查询成功返回了按收入计算的前 5 名畅销曲目。以下是结果:
- The Woman King - $3.98
- The Fix - $3.98
- Walkabout - $3.98
- Hot Girl - $3.98
- Gay Witch Hunt - $3.98
所有曲目产生了相同的收入,这表明它们可能以相等的数量或相同的价格出售。
一切都井然有序,我已经验证了所有步骤。
根据我的分析,以下是 SQL 查询:
SELECT
t.TrackId,
t.Name,
SUM(il.UnitPrice * il.Quantity) AS TotalRevenue
FROM
Track t
JOIN
InvoiceLine il ON t.TrackId = il.TrackId
GROUP BY
t.TrackId, t.Name
ORDER BY
TotalRevenue DESC
LIMIT 5;
正如您所看到的,推理部分清楚地展示了所有思考步骤。输出展示了我们的代理如何思考,逐步展示其工作,而不是直接跳到答案。