
构建高效文本到sql管道的5个关键步骤:可靠性提升与故障排除指南 Pt.2
本博客是创建可靠的文本到SQL系统的第二部分。在第一部分中,我展示了如何从零开始构建一个具有重试和反射功能的SQL代理。您不需要查看第一部分就可以跟随这里的内容——我希望即使您使用的是不同的文本到SQL工具,仍然可以学到一些有用的东西。
为什么文本到SQL查询会失败?
在我开始展示如何迭代改进你的SQL代理之前,了解查询为什么会失败是很重要的。以下是最常见的原因。
-
无效解析:当你要求大型语言模型(LLM)给你代码、Markdown或任何特定类型的输出时,你需要正确处理它。对模型来说,它生成的一切都是文本标记。然而,它被训练使用某些标记,比如在结构化代码之前添加
python或
sql,以便你知道它正在生成那种输出。然而,有时你的解析器可能无法很好地处理它,或者模型可能由于上下文限制而没有正确包含这些标记,这可能导致错误。 -
不正确的SQL语法:当你要求模型为特定数据库系统(如Postgres)提供SQL代码时,但实际上你使用的是不同的系统,这种情况经常发生。不同版本之间也可能存在问题。LLM在大量互联网数据上进行训练,因此它们对流行的SQL系统(如Postgres、MySQL和Oracle)更为熟悉。在你的提示中提到你使用的SQL数据库类型是非常有帮助的!
-
未提供/检索相关上下文:LLM非常出色,但它们不知道你的数据库的细节!它们不知道你有哪些表、列或行值。如果你不提供正确的信息,它们可能会编造一些像表名这样的东西,这可能导致你的查询失败。这实际上是文本到SQL查询不工作的最常见原因,当你处理更大的数据库时,这个问题会变得更加棘手!
-
LLM无法理解你:这与缺失上下文的前一个问题有关,但我想单独强调这一点。正确的上下文应该包括更多细节——就像你给新分析师第一天所需的所有信息。这包括所有缩写的含义、公司偏好的分析方式,以及关于数据存储和处理的任何“注意事项”或重要信息。
现在已经解释了查询失败的原因,我们来看看如何减轻这些问题!
修复这些问题
以下是您应该如何处理每个问题:
- 确定问题
- 分析根本原因
- 开发解决方案
- 实施解决方案
- 审查结果
示例代码
def example_function():
print("This is an example function.")
表格示例
列 1 | 列 2 |
---|---|
行 1 | 数据 1 |
行 2 | 数据 2 |
请记住仔细遵循这些步骤以确保流程顺利。
错误解析
这是一个“暴力破解”问题,您需要做的就是覆盖所有情况。只需几个 if 条件即可覆盖您可能获得的所有错误输出。检查当结尾 ``` 不存在或 SQL 位于大量文本中间时会发生什么。
错误的语法
您可以采取以下步骤来减轻这个问题:
-
在提示中明确添加您使用的版本 / 数据库管理系统(DBMS)。
-
在上一篇文章中,我使用了一个错误修复代理。如果您在文本到SQL解决方案中使用类似的东西,可以添加常见的SQL版本翻译。例如,日期解析在Postgres等中的不同之处。
相关上下文未检索到
要使文本到SQL工具在大型数据库中良好工作,您需要一个“检索器”。这是一个需要填充您想要运行的SQL查询示例的系统。根据我的经验,语言模型生成准确查询的最大原因是缺乏足够的训练示例。以下是您可以准备系统以生成更可靠结果的方法。
在这篇博客中,我使用Qdrant作为检索器,但您可以使用任何适合您的检索器。我还使用了gretelai的开源文本到SQL数据集,这非常好,因为它已经包含了各种SQL查询和示例。如果您正在为自己的数据库构建某些内容,只需确保向您的检索器添加类似的示例:
- 问题/SQL对
- 表列的元数据
- 用例/领域特定信息。
在检索器的第一次尝试中,您可能会有一些遗漏,但您可以创建反馈系统,如后面所述。
from qdrant_client import QdrantClient, models
from qdrant_client.models import Distance, VectorParams, PointStruct
from openai import OpenAI
import numpy as np
def get_embedding(text):
"""获取OpenAI文本的嵌入"""
response = openai_client.embeddings.create(
model="text-embedding-ada-002",
input=text
)
return response.data[0].embedding
qdrant = QdrantClient(":memory:")
openai_client = OpenAI()
sql_pairs_collection = "sql_question_pairs"
metadata_collection = "domain_metadata"
db_info_collection = "db_info"
collections = [
(sql_pairs_collection, "用于存储SQL查询和问题对"),
(metadata_collection, "用于存储领域/业务知识"),
(db_info_collection, "用于存储数据库架构和表信息")
]
for collection_name, description in collections:
if not qdrant.collection_exists(collection_name):
qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=1536, distance=Distance.COSINE, on_disk=True)
)
用于构建此数据集的示例来自不同领域的许多示例。
为了公正评估,我将从这个较大的数据集中创建两个示例数据集。训练集将包含每个领域的10个查询,而测试集将包含5个(不包括训练集中的那些)。
train_df = df.groupby('domain').apply(lambda x: x.sample(n=min(len(x), 10))).reset_index(drop=True)
remaining_df = df[~df.index.isin(train_df.index)]
test_df = remaining_df.groupby('domain').apply(lambda x: x.sample(n=min(len(x), 5))).reset_index(drop=True)
print(f"原始数据集大小: {len(df)}")
print(f"训练集大小: {len(train_df)}")
print(f"测试集大小: {len(test_df)}")
print("\n每个领域的训练样本:")
print(train_df['domain'].value_counts())
print("\n每个领域的测试样本:")
print(test_df['domain'].value_counts())
下一步是将训练集“更新插入”到Qdrant集合中!
for _, row in train_df.iterrows():
domain = row['domain']
sql = row['sql']
tables = re.findall(r'(?:FROM|JOIN)\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql, re.IGNORECASE)
if domain not in domain_tables:
domain_tables[domain] = set()
domain_tables[domain].update(tables)
for idx, row in train_df.iterrows():
prompt_embedding = get_embedding(row["sql_prompt"])
domain_point = models.PointStruct(
id=idx,
vector=prompt_embedding,
payload={
"domain": row["domain"],
"domain_description": domain_tables[row['domain']]
}
)
qdrant.upsert(
collection_name="domain_metadata",
points=[domain_point]
)
db_point = models.PointStruct(
id=idx,
vector=prompt_embedding,
payload={
"sql_context": row["sql_context"]
}
)
qdrant.upsert(
collection_name="db_info",
points=[db_point]
)
sql_point = models.PointStruct(
id=idx,
vector=prompt_embedding,
payload={
"sql_prompt": row["sql_prompt"],
"sql": row["sql"]
}
)
qdrant.upsert(
collection_name="sql_question_pairs",
points=[sql_point]
)
print("训练数据成功更新插入到集合中")
您可以使用任何文本到SQL系统与此检索器,只要您传递检索到的上下文。因此
def search_similar_examples(query_text, top_k=3):
"""
在所有集合中使用嵌入相似性搜索相似示例
参数:
query_text (str): 要搜索的查询文本
top_k (int): 每个集合返回的结果数量
返回:
包含每个集合结果的字典及其有效载荷和分数
"""
query_embedding = get_embedding(query_text)
collections = ["domain_metadata", "db_info", "sql_question_pairs"]
all_results = {}
for collection in collections:
search_result = qdrant.search(
collection_name=collection,
query_vector=query_embedding,
limit=top_k
)
results = []
for scored_point in search_result:
results.append({
'payload': scored_point.payload,
'score': scored_point.score
})
all_results[collection] = results
return all_results
sql_system =
context = search_similar_examples(query)
sql_system.query(query=query, relevant_context=str(context), db_engine=conn)
现在我们有一个准备好的检索器,我们应该衡量它生成查询的效果。
使用测试集我们得到这些指标!
作者提供的图像 — 显示测试集上的初始正确与错误
正如预期的那样,系统在测试集上的表现不佳。这意味着在上下文之外询问系统任何内容都不会很好地工作。
在构建全面解决方案之前,让我们看看单个查询以理解问题。
作者提供的图像 — 显示系统错误的查询
作者提供的图像 — 显示检索到的上下文
查看上述示例,有几点变得清晰,系统在表名和列名之间感到困惑。检索到的上下文显示有两个表Satellites
和SatelliteInfo
。总体而言,查询逻辑是100%正确的,但它将两个语义上相似的表名搞错了。
作者提供的图像 — 系统错误的另一个示例查询
检索到的第二个示例的上下文
再一次,从逻辑上讲查询是正确的,但系统使用了developers
表与smart_contracts
连接,而正确的查询直接查询smart_contracts
表。
作者提供的图像 — 第三个错误查询的示例
查询三的检索上下文。
这一次,系统的错误在于它不知道程序表中的location变量已经
解决方案管道
解决方案流程图
解决方案实际上是“简单”的,系统失败是因为缺乏正确SQL查询的上下文。您可以使用LM程序生成更好的上下文,以供检索器使用。虽然人类可以检查或观察整个过程,并作为编辑来增强系统。
以下是如何使用DSPy创建LM程序来实现这一点。
class sql_example_generator(dspy.Signature):
"""
A synthetic SQL example generator that takes a user query, correct SQL, incorrect SQL and context
to generate additional similar examples to improve the system's context.
The generator creates a new SQL/query pairs that maintain similar patterns but vary in complexity
and specific details.
You can use the retrieved context to see how the database is structured.
"""
user_query = dspy.InputField(desc="Original user query that describes what kind of SQL is needed")
correct_sql = dspy.InputField(desc="The correct SQL query for the user's request")
incorrect_sql = dspy.InputField(desc="An incorrect SQL query attempt")
retrieved_context = dspy.InputField(desc="The context information about tables and schema")
new_query = dspy.OutputField(desc="create one new query that maintains similar patterns but vary in complexity and specific details")
new_sql = dspy.OutputField(desc="create one new SQL that matches the new query")
sql_context_gen = dspy.ChainOfThought(sql_example_generator)
这里是一些生成的合成示例
将此添加到上下文中,让我们重新测量在测试集上的准确性。
注意:是的,从技术上讲,这是在测试集上训练,但与传统的机器学习不同,您的数据收集是有限的,在这里您可以合成创建无限相似的示例。
作者提供的图像 — 增强后准确性
经过一次迭代或轮次,正确查询的数量现在为43%。使用与之前相同的过程,您可以通过在失败的查询上增强合成示例。经过几轮的运行,结果如下。
作者提供的图像 — 图表显示查询的测试集和训练集的准确性,随着您在错误的SQL(测试)查询上创建合成示例,测试集的准确性提高,但过拟合使训练集的准确性下降
根据该数据集,最佳的增强步骤/轮次约为3。这可能会根据您的数据库在文本到SQL方面的“挑战性”而有所不同。