让你的 AI 不再"凭空瞎编",而是真正能查数据库、分析数据、搜文档。
假设你有一个日志数据库,用户想查"昨天所有的错误日志"。传统做法是自己写 SQL,但如果你让 AI 来做,用户只需要说一句话,AI 就能生成正确的 SQL。
听起来很酷,但有个大问题:AI 生成的 SQL 可能是错的。怎么办?答案是:让数据库帮你验证。
AI 写 SQL 有两种结果:成功生成了 SQL,或者用户说的话太含糊没法生成。我们用联合类型来表示:
from pydantic import BaseModel, Field
from typing import Annotated
from annotated_types import MinLen
from typing import TypeAlias
class Success(BaseModel):
"""SQL 生成成功"""
sql_query: Annotated[str, MinLen(1)] # 至少 1 个字符
explanation: str = Field('', description='SQL 解释')
class InvalidRequest(BaseModel):
"""用户输入不够明确"""
error_message: str
# 联合类型:要么成功,要么失败
Response: TypeAlias = Success | InvalidRequest
AI 不是神仙,你得告诉它数据库长什么样,才能写出对的 SQL。这就用到了动态系统提示词:
from datetime import date
from pydantic_ai import Agent, format_as_xml
DB_SCHEMA = """
CREATE TABLE records (
created_at timestamptz,
level log_level,
message text,
attributes jsonb,
tags text[]
);
"""
# 几个示例,让 AI 学学怎么写
SQL_EXAMPLES = [
{
'request': '查看 foobar 为 false 的记录',
'response': "SELECT * FROM records WHERE attributes->>'foobar' = false",
},
{
'request': '查看昨天的记录',
'response': "SELECT * FROM records WHERE start_timestamp::date > CURRENT_TIMESTAMP - INTERVAL '1 day'",
},
]
agent = Agent(
'openai:gpt-4o',
output_type=Response, # type: ignore
deps_type=Deps,
)
@agent.system_prompt
async def system_prompt() -> str:
return f"""\
你的工作是根据用户请求生成 SQL 查询。
数据库结构:
{DB_SCHEMA}
今天的日期 = {date.today()}
参考示例:
{format_as_xml(SQL_EXAMPLES)}
"""
format_as_xml() 把示例转成 XML 格式,AI 更容易理解这是整个例子最精华的部分 -- 用数据库的 EXPLAIN 命令验证 AI 生成的 SQL:
from pydantic_ai import RunContext, ModelRetry
from dataclasses import dataclass
import asyncpg
@dataclass
class Deps:
conn: asyncpg.Connection
@agent.output_validator
async def validate_output(ctx: RunContext[Deps], output: Response) -> Response:
# 如果是无效请求,直接放行
if isinstance(output, InvalidRequest):
return output
# 清理 SQL 中多余的反斜杠
output.sql_query = output.sql_query.replace('\\', '')
# 必须是 SELECT 查询(安全考虑)
if not output.sql_query.upper().startswith('SELECT'):
raise ModelRetry('请生成 SELECT 查询')
# 关键!用 EXPLAIN 验证 SQL 是否合法
try:
await ctx.deps.conn.execute(f'EXPLAIN {output.sql_query}')
except asyncpg.exceptions.PostgresError as e:
raise ModelRetry(f'SQL 有误: {e}')
return output
EXPLAIN 是 PostgreSQL 的一个命令,它不会真的执行 SQL,而是分析 SQL 是否语法正确、能否执行。就像写完作文让老师先检查一下有没有错别字,再交给阅卷老师。
EXPLAIN 报错了,我们抛出 ModelRetry,AI 就会根据错误信息修正 SQL,再试一次。
output_validator 调用 EXPLAINSuccessModelRetry → AI 修正 → 再验证import asyncio
async def main():
# 连接数据库
conn = await asyncpg.connect('postgresql://...')
deps = Deps(conn=conn)
result = await agent.run(
'查看昨天 level 为 error 的日志', deps=deps
)
if isinstance(result.output, Success):
print(f"SQL: {result.output.sql_query}")
print(f"说明: {result.output.explanation}")
else:
print(f"无法生成: {result.output.error_message}")
asyncio.run(main())
你有一堆数据(比如一个 CSV 文件或 HuggingFace 上的数据集),想让 AI 帮你分析。比如:"这个电影评论数据集里有多少条负面评论?"
这个场景的难点在于:数据可能很大,不能直接扔给 AI(大模型有 token 限制)。解决方案是:让 AI 通过工具来操作数据,而不是直接看数据。
先来看一个非常巧妙的设计 -- Out[N] 引用系统:
from dataclasses import dataclass, field
import pandas as pd
from pydantic_ai import ModelRetry
@dataclass
class AnalystAgentDeps:
output: dict[str, pd.DataFrame] = field(
default_factory=dict[str, pd.DataFrame]
)
def store(self, value: pd.DataFrame) -> str:
"""存储 DataFrame,返回引用名(如 Out[1])"""
ref = f'Out[{len(self.output) + 1}]'
self.output[ref] = value
return ref
def get(self, ref: str) -> pd.DataFrame:
"""根据引用名获取 DataFrame"""
if ref not in self.output:
raise ModelRetry(
f'错误: {ref} 不是有效的引用,请检查之前的消息'
)
return self.output[ref]
Out[1]、Out[2]......后面的代码可以用这些标记引用之前的结果。这里的设计完全一样!
Out[1],然后说"对 Out[1] 执行这个 SQL",结果存为 Out[2],再说"显示 Out[2] 的内容"。
from pydantic_ai import Agent, RunContext
import datasets
import duckdb
analyst_agent = Agent(
'openai:gpt-4o',
deps_type=AnalystAgentDeps,
instructions='你是数据分析师,根据用户需求分析数据。',
)
# 工具一:从 HuggingFace 加载数据集
@analyst_agent.tool
def load_dataset(
ctx: RunContext[AnalystAgentDeps],
path: str,
split: str = 'train',
) -> str:
"""从 HuggingFace 加载数据集。
Args:
ctx: Pydantic AI agent RunContext
path: 数据集名称,格式为 `<用户名>/<数据集名>`
split: 加载的数据分割(默认: "train")
"""
builder = datasets.load_dataset_builder(path)
splits = builder.info.splits or {}
if split not in splits:
raise ModelRetry(
f'{split} 不存在,可用的分割: {",".join(splits.keys())}'
)
builder.download_and_prepare()
dataset = builder.as_dataset(split=split)
dataframe = dataset.to_pandas()
# 存储并获取引用
ref = ctx.deps.store(dataframe)
return f'数据集已加载为 `{ref}`,共 {len(dataframe)} 行'
# 工具二:用 DuckDB 执行 SQL 查询
@analyst_agent.tool
def run_duckdb(
ctx: RunContext[AnalystAgentDeps],
dataset: str,
sql: str,
) -> str:
"""对 DataFrame 执行 SQL 查询。
注意:SQL 中的表名必须使用 `dataset`。
Args:
ctx: Pydantic AI agent RunContext
dataset: DataFrame 的引用名(如 Out[1])
sql: 要执行的 SQL 查询
"""
data = ctx.deps.get(dataset)
result = duckdb.query_df(
df=data, virtual_table_name='dataset', sql_query=sql
)
ref = ctx.deps.store(result.df())
return f'查询完成,结果为 `{ref}`'
# 工具三:显示数据(最多 5 行)
@analyst_agent.tool
def display(
ctx: RunContext[AnalystAgentDeps],
name: str,
) -> str:
"""显示 DataFrame 的前 5 行。"""
dataset = ctx.deps.get(name)
return dataset.head().to_string()
当你问"数据集 rotten_tomatoes 里有多少负面评论?"时,AI 会自动串联工具:
load_dataset("cornell-movie-review-data/rotten_tomatoes")run_duckdb("Out[1]", "SELECT COUNT(*) FROM dataset WHERE label = 0")display("Out[2]")deps = AnalystAgentDeps()
result = analyst_agent.run_sync(
'统计 rotten_tomatoes 数据集中有多少负面评论',
deps=deps,
)
print(result.output)
# 你还可以查看所有中间结果
for ref, df in deps.output.items():
print(f"{ref}: {df.shape}")
想象你有一句话"机器学习很有趣"。电脑不认识文字,但如果我们能把这句话变成一串数字(比如 [0.12, -0.34, 0.56, ...]),而且意思相近的句子数字也相近,那电脑就能"理解"语义了。
这就是嵌入 -- 把文字变成向量(一串数字),保留语义信息。
"机器学习很有趣" → [0.12, -0.34, 0.56, 0.78, ...] ←┐
│ 距离近
"深度学习很酷" → [0.11, -0.32, 0.54, 0.80, ...] ←┘
"今天天气不错" → [0.89, 0.23, -0.67, 0.12, ...] ← 距离远
Pydantic AI 提供了统一的 Embedder 接口,支持多个提供商:
from pydantic_ai import Embedder
# OpenAI(最常用)
embedder = Embedder('openai:text-embedding-3-small')
# Google
embedder = Embedder('google-gla:gemini-embedding-001')
# Cohere
embedder = Embedder('cohere:embed-v4.0')
# VoyageAI(擅长代码、法律、金融领域)
embedder = Embedder('voyageai:voyage-3.5')
# 本地运行(不需要 API,数据不出本机)
embedder = Embedder('sentence-transformers:all-MiniLM-L6-v2')
async def main():
embedder = Embedder('openai:text-embedding-3-small')
# 方法一:嵌入查询(用于搜索时的查询语句)
result = await embedder.embed_query('什么是机器学习?')
print(f'向量维度: {len(result.embeddings[0])}')
# 向量维度: 1536
# 方法二:嵌入文档(用于建索引时的大批量文本)
docs = [
'机器学习是 AI 的子集',
'深度学习使用神经网络',
'Python 是编程语言',
]
result = await embedder.embed_documents(docs)
print(f'嵌入了 {len(result.embeddings)} 个文档')
# 嵌入了 3 个文档
高维向量更精确,但占用更多存储空间。你可以减小维度来节省空间:
from pydantic_ai.embeddings import EmbeddingSettings
embedder = Embedder(
'openai:text-embedding-3-small',
settings=EmbeddingSettings(dimensions=256), # 从 1536 降到 256
)
就像第一章学的自定义模型 URL 一样,嵌入也支持:
from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel
from pydantic_ai.providers.openai import OpenAIProvider
# 使用任何兼容 OpenAI 的嵌入 API
model = OpenAIEmbeddingModel(
'your-model-name',
provider=OpenAIProvider(
base_url='https://your-provider.com/v1',
api_key='your-api-key',
),
)
embedder = Embedder(model)
# 或者用内置的 Provider 快捷方式
embedder = Embedder('ollama:nomic-embed-text') # 本地 Ollama
写测试时不想花钱调 API?用 TestEmbeddingModel:
from pydantic_ai.embeddings import TestEmbeddingModel
async def test_my_rag():
embedder = Embedder('openai:text-embedding-3-small')
test_model = TestEmbeddingModel()
with embedder.override(model=test_model):
result = await embedder.embed_query('测试查询')
# TestEmbeddingModel 返回固定的向量
assert result.embeddings[0] == [1.0] * 8
RAG(Retrieval-Augmented Generation)中文叫检索增强生成,是目前最实用的 AI 应用模式之一。
核心思想很简单:先从你的文档库里找到相关内容,再让 AI 基于这些内容回答。
在搜索之前,需要先建立向量索引。就像图书馆的图书目录,有了目录才能快速找书。
用 pgvector 扩展让 PostgreSQL 支持向量存储:
-- 启用向量扩展
CREATE EXTENSION IF NOT EXISTS vector;
-- 创建文档表
CREATE TABLE IF NOT EXISTS doc_sections (
id serial PRIMARY KEY,
url text NOT NULL UNIQUE,
title text NOT NULL,
content text NOT NULL,
-- text-embedding-3-small 返回 1536 维向量
embedding vector(1536) NOT NULL
);
-- 创建 HNSW 索引,加速向量搜索
CREATE INDEX IF NOT EXISTS idx_doc_sections_embedding
ON doc_sections
USING hnsw (embedding vector_l2_ops);
from openai import AsyncOpenAI
import asyncpg
import pydantic_core
async def build_search_db():
"""下载文档 → 分块 → 嵌入 → 存入数据库"""
# 1. 下载文档(这里用的是 Logfire 的文档 JSON)
sections = download_and_parse_docs()
# 2. 对每个文档段落生成嵌入向量
openai = AsyncOpenAI()
pool = await asyncpg.create_pool('postgresql://...')
for section in sections:
# 生成嵌入
embedding = await openai.embeddings.create(
input=section.embedding_content(),
model='text-embedding-3-small',
)
embedding_json = pydantic_core.to_json(
embedding.data[0].embedding
).decode()
# 存入数据库
await pool.execute(
'INSERT INTO doc_sections (url, title, content, embedding) '
'VALUES ($1, $2, $3, $4)',
section.url, section.title,
section.content, embedding_json,
)
索引建好后,就可以用 Agent + 工具来实现 RAG 了:
from dataclasses import dataclass
from pydantic_ai import Agent, RunContext
import asyncpg
from openai import AsyncOpenAI
import pydantic_core
@dataclass
class Deps:
openai: AsyncOpenAI
pool: asyncpg.Pool
agent = Agent('openai:gpt-4o', deps_type=Deps)
@agent.tool
async def retrieve(
context: RunContext[Deps],
search_query: str,
) -> str:
"""根据查询检索相关文档段落。
Args:
context: 调用上下文
search_query: 搜索关键词
"""
# 1. 把搜索词变成向量
embedding = await context.deps.openai.embeddings.create(
input=search_query,
model='text-embedding-3-small',
)
embedding_json = pydantic_core.to_json(
embedding.data[0].embedding
).decode()
# 2. 在数据库中找最相似的 8 个段落
rows = await context.deps.pool.fetch(
'SELECT url, title, content '
'FROM doc_sections '
'ORDER BY embedding <-> $1 '
'LIMIT 8',
embedding_json,
)
# 3. 拼接成上下文返回给 AI
return '\n\n'.join(
f'# {row["title"]}\n'
f'文档链接: {row["url"]}\n\n'
f'{row["content"]}\n'
for row in rows
)
<-> 操作符: 这是 pgvector 的向量距离操作符。ORDER BY embedding <-> $1 意思是"按照和查询向量的距离排序",距离越小越相似。就像在地图上找离你最近的 8 家餐厅。
retrieve 工具async def main():
openai = AsyncOpenAI()
pool = await asyncpg.create_pool('postgresql://...')
deps = Deps(openai=openai, pool=pool)
answer = await agent.run(
'Logfire 怎么和 FastAPI 配合使用?',
deps=deps,
)
print(answer.output)
恭喜!你已经掌握了让 AI 操作真实数据的四大模式:
| 模式 | 适用场景 | 核心技术 | 验证方式 |
|---|---|---|---|
| SQL 生成 | 自然语言查数据库 | Union 输出 + format_as_xml | EXPLAIN 验证 |
| 数据分析 | 分析 DataFrame | 引用系统 + DuckDB | ModelRetry |
| 嵌入向量 | 语义搜索基础 | Embedder API | TestEmbeddingModel |
| RAG | 基于文档问答 | pgvector + retrieve 工具 | 向量距离排序 |
ModelRetry不用 HuggingFace,直接分析本地 CSV 文件:
import pandas as pd
from pydantic_ai import Agent, RunContext
from dataclasses import dataclass, field
@dataclass
class CsvDeps:
data: dict[str, pd.DataFrame] = field(default_factory=dict)
def store(self, name: str, df: pd.DataFrame) -> str:
self.data[name] = df
return f'已加载 {name},共 {len(df)} 行 {len(df.columns)} 列'
agent = Agent('openai:gpt-4o', deps_type=CsvDeps)
@agent.tool
def load_csv(ctx: RunContext[CsvDeps], file_path: str) -> str:
"""加载 CSV 文件"""
df = pd.read_csv(file_path)
return ctx.deps.store(file_path, df)
@agent.tool
def query(ctx: RunContext[CsvDeps], name: str, sql: str) -> str:
"""用 SQL 查询 DataFrame"""
import duckdb
df = ctx.deps.data[name]
result = duckdb.query_df(df, 'data', sql)
return result.df().head(10).to_string()
不需要 PostgreSQL 和 pgvector,用 Python 实现简易版:
import numpy as np
from pydantic_ai import Agent, Embedder, RunContext
from dataclasses import dataclass
@dataclass
class SimpleRAGDeps:
embedder: Embedder
documents: list[str] # 原始文档
embeddings: list[list[float]] # 文档嵌入
agent = Agent('openai:gpt-4o', deps_type=SimpleRAGDeps)
@agent.tool
async def search(
ctx: RunContext[SimpleRAGDeps],
query: str,
) -> str:
"""搜索相关文档"""
# 获取查询嵌入
result = await ctx.deps.embedder.embed_query(query)
query_vec = np.array(result.embeddings[0])
# 计算和每个文档的相似度
scores = []
for i, doc_emb in enumerate(ctx.deps.embeddings):
similarity = np.dot(query_vec, np.array(doc_emb))
scores.append((similarity, i))
# 返回最相似的 3 个文档
scores.sort(reverse=True)
top_docs = [ctx.deps.documents[i] for _, i in scores[:3]]
return '\n\n---\n\n'.join(top_docs)