Pydantic AI 函数工具全面详解
1. 函数工具基础概念
1.1 什么是函数工具
函数工具是 Pydantic AI 中的关键机制,它为大型语言模型提供了与现实世界互动的能力。简而言之,函数工具就像是为 AI 模型安装的“手和脚”,使模型不仅能思考和回答问题,还能执行具体操作、获取外部数据、进行计算等实际任务。
核心价值:
- 扩展模型能力:超越纯文本生成的局限,使模型能够执行具体操作
- 访问实时数据:获取模型训练时未知的最新信息
- 执行确定性操作:将不确定的 AI 推理与确定性的程序逻辑结合
- 集成现有系统:连接数据库、API、文件系统等现有基础设施
1.2 函数工具与相关技术对比
函数工具 vs. RAG(检索增强生成):
- RAG 主要用于向量搜索和信息检索
- 函数工具更为通用,可以执行任意操作
- 两者可以结合使用,RAG 负责信息查找,函数工具负责执行操作
函数工具 vs. 结构化输出:
- 结构化输出定义模型的最终响应格式
- 函数工具定义模型在生成响应过程中可以调用的操作
- 一个模型可以同时使用多个工具,其中一些用于中间操作,一些用于最终输出
2. 工具注册机制详解
2.1 装饰器注册方式
2.1.1 @agent.tool 装饰器
@agent.tool
这是主要的装饰器,适用于需要访问代理运行上下文(RunContext)的工具。RunContext 提供了对依赖项、模型信息和其他运行时数据的访问。
from pydantic_ai import Agent, RunContext
import asyncio
import random
# 创建代理实例,指定依赖项类型为字符串(用户ID)
agent = Agent(
'qwen-plus', # 使用阿里云通义千问模型
deps_type=str, # 依赖项类型为用户ID字符串
system_prompt=(
"你是一个智能客服助手,可以查询用户信息和订单状态。"
"请根据用户请求调用适当的工具来获取信息。"
),
)
@agent.tool
def get_user_profile(ctx: RunContext[str]) -> dict:
"""
获取用户个人信息
通过运行上下文中的用户ID来查询用户资料
"""
user_id = ctx.deps # 从上下文中获取依赖项(用户ID)
# 模拟数据库查询 - 实际应用中这里可能是真实的数据库操作
user_profiles = {
"user_001": {"name": "张三", "level": "VIP", "join_date": "2023-01-15"},
"user_002": {"name": "李四", "level": "普通", "join_date": "2023-03-20"},
}
if user_id in user_profiles:
return user_profiles[user_id]
else:
return {"error": "用户不存在"}
@agent.tool
def get_order_status(ctx: RunContext[str], order_id: str) -> dict:
"""
查询订单状态
Args:
order_id: 订单编号,格式为 ORDER_XXXX
"""
user_id = ctx.deps
# 模拟订单数据查询
orders = {
"ORDER_1001": {"status": "已发货", "product": "智能手机", "amount": 2999},
"ORDER_1002": {"status": "处理中", "product": "笔记本电脑", "amount": 5999},
}
if order_id in orders:
order_info = orders[order_id]
return {
"order_id": order_id,
"status": order_info["status"],
"product": order_info["product"],
"amount": order_info["amount"],
"user_id": user_id
}
else:
return {"error": "订单不存在"}
async def main():
"""
演示工具调用的完整流程
"""
# 使用用户ID "user_001" 作为依赖项
result = await agent.run(
"请查询我的个人信息和订单ORDER_1001的状态",
deps="user_001" # 传递用户ID作为依赖项
)
print("=== 模型响应 ===")
print(result.output)
print("\n=== 工具调用详情 ===")
for i, message in enumerate(result.all_messages()):
print(f"消息 {i}: {message}")
if __name__ == "__main__":
asyncio.run(main())
代码详细解释:
- 代理初始化:创建 Agent 时指定了
,这意味着每个工具都可以通过deps_type=str
访问字符串类型的依赖项。RunContext - 工具函数定义:
不需要额外参数,直接从上下文中获取用户ID;get_user_profile
需要get_order_status
参数,由模型在调用时提供。order_id - 上下文访问:通过
访问在ctx.deps
中传递的依赖项数据。agent.run() - 执行流程:模型分析用户请求,决定需要调用哪些工具,然后依次执行并整合结果。
2.1.2 @agent.tool_plain 装饰器
@agent.tool_plain
用于不需要访问运行上下文的简单工具。这些工具更加纯粹,只依赖于传入的参数。
from pydantic_ai import Agent
import asyncio
import requests
from datetime import datetime, timedelta
agent = Agent('qwen-plus')
@agent.tool_plain
def get_weather(city: str, date: str = None) -> dict:
"""
获取城市天气信息
Args:
city: 城市名称,如"北京"、"上海"
date: 日期字符串,格式YYYY-MM-DD,默认为今天
"""
if date is None:
date = datetime.now().strftime("%Y-%m-%d")
# 模拟天气API调用 - 实际应用中这里调用真实天气API
weather_data = {
"北京": {"temperature": "22°C", "condition": "晴", "humidity": "45%"},
"上海": {"temperature": "25°C", "condition": "多云", "humidity": "65%"},
"广州": {"temperature": "28°C", "condition": "雨", "humidity": "80%"},
}
if city in weather_data:
return {
"city": city,
"date": date,
"weather": weather_data[city]
}
else:
return {"error": f"未找到城市 {city} 的天气信息"}
@agent.tool_plain
def calculate_expression(expression: str) -> dict:
"""
计算数学表达式
Args:
expression: 数学表达式,如 "2 + 3 * 4"
"""
try:
# 安全地计算数学表达式
result = eval(expression) # 注意:实际生产环境需要更安全的方式
return {
"expression": expression,
"result": result,
"type": type(result).__name__
}
except Exception as e:
return {"error": f"计算失败: {str(e)}"}
@agent.tool_plain
def get_current_time(timezone: str = "UTC") -> dict:
"""
获取当前时间
Args:
timezone: 时区,如"UTC", "Asia/Shanghai"
"""
from datetime import datetime
import pytz # 需要安装pytz库
try:
tz = pytz.timezone(timezone)
current_time = datetime.now(tz)
return {
"timezone": timezone,
"current_time": current_time.strftime("%Y-%m-%d %H:%M:%S"),
"iso_format": current_time.isoformat()
}
except Exception as e:
return {"error": f"时区错误: {str(e)}"}
async def main():
"""
演示多个工具的组合使用
"""
result = await agent.run(
"请告诉我北京今天的天气,并计算(15 + 25) * 2的结果,还有上海当前时间"
)
print("=== 综合查询结果 ===")
print(result.output)
# 显示详细的工具调用序列
print("\n=== 执行轨迹 ===")
for i, msg in enumerate(result.all_messages()):
print(f"步骤 {i}: {type(msg).__name__}")
if hasattr(msg, 'parts'):
for part in msg.parts:
part_type = type(part).__name__
if hasattr(part, 'content'):
print(f" - {part_type}: {part.content}")
if __name__ == "__main__":
asyncio.run(main())
代码详细解释:
- 工具独立性:
装饰的工具不依赖运行上下文,更加纯粹。@agent.tool_plain - 参数处理:每个工具都明确定义了参数和返回类型,模型会根据这些信息正确调用。
- 错误处理:工具内部包含完整的错误处理逻辑,返回结构化的错误信息。
- 多工具协作:模型可以在一个对话中调用多个不同的工具来满足复杂需求。
2.2 通过 Agent 构造函数注册
除了装饰器,还可以在创建 Agent 时通过
tools 参数注册工具,这种方式更适合工具复用和模块化设计。
from pydantic_ai import Agent, RunContext, Tool
import asyncio
import json
# 独立的工具函数 - 可以在多个代理间共享
def currency_converter(amount: float, from_currency: str, to_currency: str) -> dict:
"""
货币兑换计算
Args:
amount: 金额数量
from_currency: 源货币代码,如"USD", "CNY"
to_currency: 目标货币代码,如"EUR", "JPY"
"""
# 模拟汇率数据
exchange_rates = {
"USD": {"CNY": 7.2, "EUR": 0.92, "JPY": 150},
"CNY": {"USD": 0.14, "EUR": 0.13, "JPY": 21},
"EUR": {"USD": 1.09, "CNY": 7.8, "JPY": 163},
}
if from_currency in exchange_rates and to_currency in exchange_rates[from_currency]:
rate = exchange_rates[from_currency][to_currency]
converted = amount * rate
return {
"original": f"{amount} {from_currency}",
"converted": f"{converted:.2f} {to_currency}",
"exchange_rate": rate
}
else:
return {"error": "不支持的货币兑换"}
def get_stock_price(symbol: str) -> dict:
"""
获取股票价格
Args:
symbol: 股票代码,如"AAPL", "TSLA"
"""
# 模拟股票数据
stock_prices = {
"AAPL": {"price": 185.32, "change": "+1.25", "change_percent": "+0.68%"},
"TSLA": {"price": 245.18, "change": "-3.42", "change_percent": "-1.38%"},
"GOOGL": {"price": 138.45, "change": "+0.85", "change_percent": "+0.62%"},
}
if symbol.upper() in stock_prices:
return stock_prices[symbol.upper()]
else:
return {"error": f"未找到股票 {symbol} 的信息"}
def create_user_session(ctx: RunContext[str]) -> dict:
"""
创建用户会话
需要访问运行上下文来获取用户信息
"""
user_id = ctx.deps
return {
"session_id": f"SESS_{user_id}_{hash(user_id)}",
"user_id": user_id,
"created_at": "2024-01-01T10:00:00Z",
"status": "active"
}
# 方式1:直接传递函数列表
# Agent会自动检测哪些函数需要上下文
agent_simple = Agent(
'qwen-plus',
deps_type=str,
tools=[currency_converter, get_stock_price, create_user_session],
system_prompt="金融信息查询助手,可以处理货币兑换和股票查询。"
)
# 方式2:使用Tool类进行精细控制
# 明确指定每个工具是否需要上下文
agent_advanced = Agent(
'qwen-plus',
deps_type=str,
tools=[
Tool(currency_converter, takes_ctx=False),
Tool(get_stock_price, takes_ctx=False),
Tool(create_user_session, takes_ctx=True),
],
system_prompt="高级金融助手,提供货币兑换和股票查询服务。"
)
async def main():
"""
比较两种注册方式的效果
"""
print("=== 简单注册方式 ===")
result1 = await agent_simple.run(
"将100美元换成人民币,并查看苹果股票价格",
deps="user_123"
)
print(result1.output)
print("\n=== 高级注册方式 ===")
result2 = await agent_advanced.run(
"创建用户会话并查询特斯拉股票",
deps="user_456"
)
print(result2.output)
if __name__ == "__main__":
asyncio.run(main())
代码详细解释:
- 工具函数独立性:工具函数被定义为独立的函数,不依赖于特定的代理实例。
- 两种注册方式对比:简单方式——直接传递函数列表,Agent 自动检测上下文需求;高级方式——使用
类明确指定每个工具的配置。Tool - 工具复用:相同的工具函数可以在多个代理实例间共享。
- 明确性:高级方式虽然代码稍多,但意图更加明确,便于维护。
3. 工具模式与参数系统
3.1 自动模式生成机制
Pydantic AI 会自动从函数签名和文档字符串中提取信息,生成完整的工具模式(Schema)。这个模式告诉模型如何调用工具,包括参数类型、描述、约束条件等。
from pydantic_ai import Agent
import asyncio
from typing import List, Optional
from pydantic import Field
agent = Agent('qwen-plus')
@agent.tool_plain(
docstring_format='google', # 指定文档字符串格式
require_parameter_descriptions=True # 要求必须提供参数描述
)
def book_hotel(
city: str,
check_in_date: str,
check_out_date: str,
guests: int = Field(1, description="入住人数", ge=1, le=10),
room_type: str = Field("standard", description="房间类型: standard, deluxe, suite"),
amenities: Optional[List[str]] = Field(None, description="需要的设施")
) -> dict:
"""
预订酒店房间
根据用户需求在指定城市预订酒店,支持多种房型和设施选择。
Args:
city: 城市名称,必须是支持服务的城市
check_in_date: 入住日期,格式YYYY-MM-DD
check_out_date: 退房日期,格式YYYY-MM-DD
guests: 入住人数,范围1-10人
room_type: 房间类型,可选标准间、豪华间、套房
amenities: 额外设施需求,如wifi、parking、breakfast等
Returns:
包含预订详情的字典,包括预订ID和价格信息
Raises:
ModelRetry: 当参数无效或预订失败时抛出重试异常
"""
# 参数验证逻辑
if check_in_date >= check_out_date:
raise ModelRetry("退房日期必须晚于入住日期,请检查日期输入")
supported_cities = ["北京", "上海", "广州", "深圳", "杭州"]
if city not in supported_cities:
raise ModelRetry(
f"暂不支持城市 {city},请选择以下城市: {', '.join(supported_cities)}"
)
# 模拟预订逻辑
import random
booking_id = f"HTL{random.randint(10000, 99999)}"
base_prices = {"standard": 300, "deluxe": 500, "suite": 800}
base_price = base_prices.get(room_type, 300)
total_price = base_price * (guests if guests <= 2 else guests - 1)
return {
"booking_id": booking_id,
"city": city,
"check_in": check_in_date,
"check_out": check_out_date,
"guests": guests,
"room_type": room_type,
"amenities": amenities or [],
"total_price": f"?{total_price}",
"status": "confirmed"
}
@agent.tool_plain
def search_flights(
departure: str,
arrival: str,
date: str,
passengers: int = 1,
class_type: str = "economy"
) -> List[dict]:
"""
搜索航班信息
Args:
departure: 出发城市机场代码
arrival: 到达城市机场代码
date: 出发日期,格式YYYY-MM-DD
passengers: 乘客人数
class_type: 舱位类型,economy/business/first
Returns:
可用航班列表,包含时间和价格信息
"""
# 模拟航班搜索
flights = [
{
"flight_no": "CA1234",
"departure_time": "08:00",
"arrival_time": "11:00",
"price": 1200,
"airline": "中国国际航空"
},
{
"flight_no": "MU5678",
"departure_time": "14:30",
"arrival_time": "17:30",
"price": 980,
"airline": "中国东方航空"
}
]
return [
{
**flight,
"class_type": class_type,
"passengers": passengers,
"total_price": flight["price"] * passengers
}
for flight in flights
]
async def main():
"""
演示自动生成的工具模式如何工作
"""
# 测试复杂参数的工具调用
result = await agent.run(
"我想在北京预订酒店,3月20日入住,3月25日退房,2个人,要豪华间,需要wifi和早餐"
)
print("=== 酒店预订结果 ===")
print(result.output)
# 航班搜索测试
result2 = await agent.run(
"搜索3月15日从北京到上海的航班,2个乘客,经济舱"
)
print("\n=== 航班搜索结果 ===")
print(result2.output)
if __name__ == "__main__":
asyncio.run(main())
代码详细解释:
- 文档字符串解析:系统会解析 Google 风格的文档字符串,提取参数描述。
- 类型注解:Python 类型注解被转换为 JSON Schema 类型约束。
- 参数验证:Pydantic 的
类提供额外的验证规则(如范围约束)。Field - 错误处理:使用
异常让模型知道需要重新尝试。ModelRetry - 复杂返回类型:支持返回字典、列表等复杂数据结构。
3.2 使用 Pydantic 模型作为参数
对于复杂的工具参数,可以使用 Pydantic 模型来定义数据结构,这样既能获得更好的类型安全,也能生成更清晰的工具模式。
from pydantic import BaseModel, Field, validator
from typing import List, Optional
from datetime import datetime
from pydantic_ai import Agent, ModelRetry
import asyncio
class Address(BaseModel):
"""地址信息"""
street: str = Field(description="街道地址")
city: str = Field(description="城市")
state: str = Field(description="省/州")
postal_code: str = Field(description="邮政编码")
country: str = Field(default="中国", description="国家")
class CustomerInfo(BaseModel):
"""客户信息"""
name: str = Field(description="客户姓名")
email: str = Field(description="邮箱地址")
phone: str = Field(description="手机号码")
address: Address = Field(description="联系地址")
@validator('email')
def validate_email(cls, v):
if '@' not in v:
raise ValueError('邮箱格式无效')
return v
@validator('phone')
def validate_phone(cls, v):
if not v.replace('+', '').replace(' ', '').isdigit():
raise ValueError('手机号码格式无效')
return v
class OrderItem(BaseModel):
"""订单项"""
product_id: str = Field(description="产品ID")
product_name: str = Field(description="产品名称")
quantity: int = Field(description="数量", ge=1)
unit_price: float = Field(description="单价", gt=0)
class CreateOrderRequest(BaseModel):
"""创建订单请求"""
customer: CustomerInfo = Field(description="客户信息")
items: List[OrderItem] = Field(description="订单项目列表")
shipping_method: str = Field(
default="standard",
description="配送方式: standard/express/overnight"
)
notes: Optional[str] = Field(None, description="订单备注")
agent = Agent('qwen-plus')
@agent.tool_plain
def create_order(request: CreateOrderRequest) -> dict:
"""
创建新订单
接收完整的订单信息,验证后创建订单并返回订单详情。
Args:
request: 包含客户信息、订单项目和配送方式的完整请求
Returns:
创建的订单详情,包括订单ID和总金额
"""
try:
# 计算总金额
total_amount = sum(item.quantity * item.unit_price for item in request.items)
# 生成订单ID
import random
order_id = f"ORD{datetime.now().strftime('%Y%m%d')}{random.randint(1000, 9999)}"
# 模拟订单创建
return {
"order_id": order_id,
"customer_name": request.customer.name,
"customer_email": request.customer.email,
"items": [
{
"product_name": item.product_name,
"quantity": item.quantity,
"unit_price": item.unit_price,
"subtotal": item.quantity * item.unit_price
}
for item in request.items
],
"total_amount": total_amount,
"shipping_method": request.shipping_method,
"status": "created",
"created_at": datetime.now().isoformat()
}
except Exception as e:
raise ModelRetry(f"订单创建失败: {str(e)}")
@agent.tool_plain
def update_order_status(order_id: str, status: str, notes: Optional[str] = None) -> dict:
"""
更新订单状态
Args:
order_id: 订单ID
status: 新状态: processing/shipped/delivered/cancelled
notes: 状态更新备注
"""
valid_statuses = ["processing", "shipped", "delivered", "cancelled"]
if status not in valid_statuses:
raise ModelRetry(f"无效状态: {status},可选: {', '.join(valid_statuses)}")
return {
"order_id": order_id,
"old_status": "created", # 模拟原有状态
"new_status": status,
"updated_at": datetime.now().isoformat(),
"notes": notes
}
async def main():
"""
演示复杂参数工具的使用
"""
result = await agent.run(
"创建新订单:"
"客户张三,邮箱zhangsan@example.com,电话13800138000,"
"地址:北京市海淀区中关村大街100号,北京,北京,100080;"
"订购产品:笔记本电脑2台单价5000元,鼠标1个单价100元;"
"使用快递配送,备注尽快发货"
)
print("=== 订单创建结果 ===")
print(result.output)
# 显示工具调用的详细信息
print("\n=== 执行详情 ===")
for i, message in enumerate(result.all_messages()):
if hasattr(message, 'parts'):
print(f"消息 {i}:")
for part in message.parts:
if hasattr(part, 'tool_name'):
print(f" 工具调用: {part.tool_name}")
elif hasattr(part, 'content'):
content = part.content
if len(str(content)) > 100:
content = str(content)[:100] + "..."
print(f" 内容: {content}")
if __name__ == "__main__":
asyncio.run(main())
代码详细解释:
- Pydantic 模型定义:使用
定义复杂的数据结构。BaseModel - 字段验证:通过
设置字段描述和约束条件。Field - 自定义验证器:使用
装饰器实现自定义验证逻辑。@validator - 嵌套模型:模型可以嵌套其他模型,构建复杂的数据结构。
- 自动模式生成:Pydantic AI 会自动将这些模型转换为 JSON Schema。
4. 高级工具功能
4.1 多模态工具输出
Pydantic AI 支持工具返回多种类型的数据,包括文本、图像、文档等,这对于多模态模型特别有用。
from pydantic_ai import Agent, ImageUrl, DocumentUrl, ToolReturn, BinaryContent
from pydantic_ai.messages import TextContent
import asyncio
import base64
# 使用多模态模型
agent = Agent('qwen-vl-plus')
@agent.tool_plain
def generate_chart(data_type: str) -> ToolReturn:
"""
生成数据图表
Args:
data_type: 数据类型: sales/users/performance
"""
# 模拟图表生成 - 实际应用中这里会调用图表生成库
chart_data = {
"sales": {
"title": "月度销售趋势",
"data": [120, 150, 180, 200, 240, 300, 280],
"labels": ["1月", "2月", "3月", "4月", "5月", "6月", "7月"]
},
"users": {
"title": "用户增长统计",
"data": [1000, 1500, 2200, 3000, 4000, 5200, 6500],
"labels": ["Q1", "Q2", "Q3", "Q4", "Q5", "Q6", "Q7"]
}
}
if data_type not in chart_data:
return ToolReturn(
return_value={"error": f"未知数据类型: {data_type}"},
content=[f"无法生成 {data_type} 类型的图表"]
)
data = chart_data[data_type]
# 模拟生成图表图像数据(base64编码)
# 实际应用中这里会使用matplotlib、plotly等库生成真实图表
fake_image_data = base64.b64encode(f"chart_{data_type}".encode()).decode()
return ToolReturn(
return_value={
"chart_type": data_type,
"title": data["title"],
"data_points": len(data["data"]),
"max_value": max(data["data"])
},
content=[
f"已生成 {data['title']} 图表",
"图表数据摘要:",
TextContent(text=str(data)),
"生成的图表:",
BinaryContent(data=fake_image_data, media_type="image/png"),
"请分析图表趋势并提供见解。"
],
metadata={
"chart_type": data_type,
"generated_at": "2024-01-01T10:00:00Z",
"data_source": "internal_system"
}
)
@agent.tool_plain
def get_product_catalog() -> ToolReturn:
"""
获取产品目录,包含产品图片和描述
"""
products = [
{
"id": "P001",
"name": "智能手机",
"price": 2999,
"image": "https://example.com/images/phone.jpg",
"description": "最新款智能手机,高性能处理器"
},
{
"id": "P002",
"name": "笔记本电脑",
"price": 5999,
"image": "https://example.com/images/laptop.jpg",
"description": "轻薄便携笔记本电脑,超长续航"
}
]
return ToolReturn(
return_value={"products": products, "total_count": len(products)},
content=[
"产品目录如下:",
*[
f"产品: {p['name']} - ?{p['price']}\n"
f"描述: {p['description']}\n"
f"图片: {p['image']}"
for p in products
],
"请基于产品信息回答用户问题。"
]
)
@agent.tool_plain
def analyze_document(document_url: str) -> ToolReturn:
"""
分析文档内容
Args:
document_url: 文档的URL地址
"""
# 模拟文档分析
document_content = {
"title": "季度报告",
"author": "财务部",
"pages": 25,
"key_points": [
"收入同比增长15%",
"利润率提升至20%",
"新客户增长30%"
],
"summary": "本季度业绩表现良好,各项指标均超预期"
}
return ToolReturn(
return_value=document_content,
content=[
f"已分析文档: {document_url}",
"文档摘要:",
TextContent(text=document_content["summary"]),
"关键要点:",
*[f"- {point}" for point in document_content["key_points"]],
"请基于文档内容回答用户问题。"
],
metadata={
"document_url": document_url,
"analysis_method": "ai_analysis",
"confidence_score": 0.95
}
)
async def main():
"""
演示多模态工具的高级功能
"""
print("=== 测试图表生成 ===")
result1 = await agent.run("生成销售数据图表并分析趋势")
print("响应:", result1.output)
print("\n=== 测试产品目录 ===")
result2 = await agent.run("显示产品目录并推荐适合的产品")
print("响应:", result2.output)
print("\n=== 测试文档分析 ===")
result3 = await agent.run("分析https://example.com/report.pdf这个文档")
print("响应:", result3.output)
# 显示工具返回的元数据
if result1.tool_results:
print("\n=== 工具元数据 ===")
for tool_result in result1.tool_results:
print(f"工具: {tool_result.tool_name}")
print(f"元数据: {tool_result.metadata}")
if __name__ == "__main__":
asyncio.run(main())
代码详细解释:
- ToolReturn 类:提供对工具返回值的精细控制。
- 多部分内容:可以同时返回文本、图像、文档等多种类型的内容。
- 元数据分离:
包含不发送给模型但应用程序需要的数据。metadata - 结构化返回值:支持返回复杂的数据结构。
return_value
包含工具的主要执行结果
丰富的上下文
:content
为模型提供分析和回答所需的额外信息
4.2 动态工具控制
动态工具功能允许根据运行时上下文启用、禁用或调整工具行为。
from pydantic_ai import Agent, RunContext, ToolDefinition, Tool
from typing import Optional, List
import asyncio
class UserRole:
"""用户角色定义"""
ADMIN = "admin"
USER = "user"
GUEST = "guest"
SUPPORT = "support"
agent = Agent('qwen-plus', deps_type=dict) # 依赖项为用户上下文字典
# 工具准备函数 - 基于用户角色过滤工具
async def prepare_tools_by_role(
ctx: RunContext[dict],
tool_defs: List[ToolDefinition]
) -> Optional[List[ToolDefinition]]:
"""
根据用户角色过滤可用工具
Args:
ctx: 运行上下文,包含用户信息
tool_defs: 所有可用的工具定义列表
"""
user_context = ctx.deps
user_role = user_context.get('role', UserRole.GUEST)
# 定义角色权限
role_permissions = {
UserRole.ADMIN: ['*'], # 管理员可以使用所有工具
UserRole.SUPPORT: ['get_user_info', 'view_tickets', 'add_comment'],
UserRole.USER: ['get_own_info', 'create_ticket', 'view_own_tickets'],
UserRole.GUEST: ['get_public_info'] # 访客只能使用公开工具
}
allowed_tools = role_permissions.get(user_role, [])
if '*' in allowed_tools:
return tool_defs # 管理员可以使用所有工具
# 过滤工具
filtered_tools = []
for tool_def in tool_defs:
if tool_def.name in allowed_tools:
filtered_tools.append(tool_def)
return filtered_tools
# 设置代理级别的工具准备函数
agent.prepare_tools = prepare_tools_by_role
# 定义各种工具
@agent.tool_plain
def get_public_info() -> dict:
"""获取公开信息"""
return {
"service_status": "正常运行",
"announcements": ["系统维护通知", "新功能上线"],
"contact_info": "support@example.com"
}
@agent.tool
def get_own_info(ctx: RunContext[dict]) -> dict:
"""获取自己的信息"""
user_context = ctx.deps
return {
"user_id": user_context.get('user_id'),
"name": user_context.get('name', '未知用户'),
"join_date": "2023-01-01",
"points": 1500
}
@agent.tool
def create_ticket(ctx: RunContext[dict], title: str, description: str) -> dict:
"""创建支持工单"""
user_context = ctx.deps
import random
ticket_id = f"TICKET{random.randint(10000, 99999)}"
return {
"ticket_id": ticket_id,
"title": title,
"description": description,
"created_by": user_context.get('user_id'),
"status": "open",
"created_at": "2024-01-01T10:00:00Z"
}
@agent.tool
def view_own_tickets(ctx: RunContext[dict]) -> List[dict]:
"""查看自己的工单"""
user_context = ctx.deps
user_id = user_context.get('user_id')
# 模拟工单数据
return [
{
"ticket_id": "TICKET1001",
"title": "登录问题",
"status": "resolved",
"created_at": "2024-01-01T09:00:00Z"
},
{
"ticket_id": "TICKET1002",
"title": "支付问题",
"status": "in_progress",
"created_at": "2024-01-01T10:00:00Z"
}
]
@agent.tool_plain
def get_user_info(user_id: str) -> dict:
"""获取用户信息(支持人员专用)"""
# 模拟用户数据库
users = {
"user_001": {"name": "张三", "role": "user", "status": "active"},
"user_002": {"name": "李四", "role": "user", "status": "active"},
}
if user_id in users:
return users[user_id]
else:
return {"error": "用户不存在"}
@agent.tool_plain
def view_tickets(status: str = "all") -> List[dict]:
"""查看所有工单(支持人员专用)"""
# 模拟工单数据库
all_tickets = [
{"ticket_id": "TICKET1001", "user_id": "user_001", "status": "open"},
{"ticket_id": "TICKET1002", "user_id": "user_002", "status": "in_progress"},
{"ticket_id": "TICKET1003", "user_id": "user_001", "status": "resolved"},
]
if status != "all":
return [t for t in all_tickets if t['status'] == status]
return all_tickets
@agent.tool_plain
def add_comment(ticket_id: str, comment: str) -> dict:
"""添加工单评论(支持人员专用)"""
return {
"ticket_id": ticket_id,
"comment": comment,
"added_by": "support_agent",
"added_at": "2024-01-01T10:00:00Z"
}
@agent.tool_plain
def system_maintenance(action: str) -> dict:
"""系统维护操作(管理员专用)"""
if action == "backup":
return {"status": "success", "message": "系统备份完成"}
elif action == "restart":
return {"status": "success", "message": "系统重启完成"}
else:
return {"error": "未知维护操作"}
async def main():
"""
演示基于角色的动态工具访问控制
"""
# 测试不同角色的工具访问权限
test_cases = [
{
"role": UserRole.GUEST,
"user_id": None,
"name": "访客用户",
"queries": [
"查看系统状态",
"查看我的信息", # 应该被拒绝
"查看所有工单" # 应该被拒绝
]
},
{
"role": UserRole.USER,
"user_id": "user_001",
"name": "普通用户",
"queries": [
"查看我的信息",
"创建工单:无法登录",
"查看所有工单" # 应该被拒绝
]
},
{
"role": UserRole.SUPPORT,
"user_id": "support_001",
"name": "客服人员",
"queries": [
"查看用户user_001的信息",
"查看所有工单",
"系统备份" # 应该被拒绝
]
},
{
"role": UserRole.ADMIN,
"user_id": "admin_001",
"name": "系统管理员",
"queries": [
"查看用户user_001的信息",
"系统备份",
"查看所有工单"
]
}
]
for test_case in test_cases:
print(f"\n=== 测试角色: {test_case['role']} ===")
user_context = {
'role': test_case['role'],
'user_id': test_case['user_id'],
'name': test_case['name']
}
for query in test_case['queries']:
print(f"\n查询: {query}")
try:
result = await agent.run(query, deps=user_context)
print(f"响应: {result.output}")
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
asyncio.run(main())
代码详细解释:
角色基础权限系统
:定义不同用户角色可以访问的工具
动态工具过滤
:prepare_tools_by_role
函数根据用户角色过滤可用工具
上下文感知
:工具可以通过RunContext
访问用户信息
权限分级
:
访客:仅能访问公开信息
用户:可以管理自己的数据和创建工单
客服:可以查看所有用户信息和工单
管理员:无限制访问所有功能
安全控制
:防止未经授权访问敏感操作
5. 工具执行与错误处理
5.1 参数验证与自动重试
Pydantic AI 提供了强大的参数验证和自动重试机制,确保工具调用的可靠性。
from pydantic_ai import Agent, ModelRetry
from pydantic import ValidationError
import asyncio
from typing import List
from datetime import datetime
agent = Agent('qwen-plus')
@agent.tool_plain
def schedule_meeting(
title: str,
participants: List[str],
start_time: str,
duration_minutes: int = 60,
location: str = "线上会议",
recurring: bool = False
) -> dict:
"""
安排会议
Args:
title: 会议标题,不能为空
participants: 参与者邮箱列表,至少1人,最多50人
start_time: 开始时间,格式YYYY-MM-DD HH:MM,必须是将来的时间
duration_minutes: 会议时长(分钟),范围15-480
location: 会议地点
recurring: 是否为周期性会议
Returns:
会议安排结果,包含会议ID和详情
"""
# 参数验证
if not title or len(title.strip()) == 0:
raise ModelRetry("会议标题不能为空,请提供有效的会议标题")
if len(participants) == 0:
raise ModelRetry("必须指定至少一名参与者")
if len(participants) > 50:
raise ModelRetry("参与者人数不能超过50人")
# 验证邮箱格式
for email in participants:
if '@' not in email:
raise ModelRetry(f"参与者邮箱格式无效: {email}")
# 验证时间格式
try:
meeting_time = datetime.strptime(start_time, "%Y-%m-%d %H:%M")
if meeting_time <= datetime.now():
raise ModelRetry("会议时间必须是将来时间,请调整开始时间")
except ValueError:
raise ModelRetry("时间格式错误,请使用 YYYY-MM-DD HH:MM 格式")
if duration_minutes < 15 or duration_minutes > 480:
raise ModelRetry("会议时长必须在15-480分钟之间")
# 模拟会议安排
import random
meeting_id = f"MTG{random.randint(10000, 99999)}"
return {
"meeting_id": meeting_id,
"title": title,
"participants": participants,
"start_time": start_time,
"duration_minutes": duration_minutes,
"location": location,
"recurring": recurring,
"status": "scheduled",
"created_at": datetime.now().isoformat()
}
@agent.tool_plain
def make_reservation(
restaurant: str,
date: str,
time: str,
party_size: int,
customer_name: str,
special_requests: str = ""
) -> dict:
"""
餐厅预订
Args:
restaurant: 餐厅名称,必须是合作餐厅
date: 预订日期,格式YYYY-MM-DD
time: 预订时间,格式HH:MM
party_size: 用餐人数,1-20人
customer_name: 顾客姓名
special_requests: 特殊要求
Returns:
预订确认信息
"""
# 验证餐厅
supported_restaurants = ["和平饭店", "星光餐厅", "海港酒楼", "花园咖啡"]
if restaurant not in supported_restaurants:
raise ModelRetry(
f"不支持餐厅 '{restaurant}',可选餐厅: {', '.join(supported_restaurants)}"
)
# 验证日期
try:
reservation_date = datetime.strptime(date, "%Y-%m-%d")
if reservation_date < datetime.now().date():
raise ModelRetry("预订日期不能是过去日期")
except ValueError:
raise ModelRetry("日期格式错误,请使用 YYYY-MM-DD 格式")
# 验证时间
try:
datetime.strptime(time, "%H:%M")
except ValueError:
raise ModelRetry("时间格式错误,请使用 HH:MM 格式")
# 验证人数
if party_size < 1 or party_size > 20:
raise ModelRetry("用餐人数必须在1-20人之间")
if not customer_name or len(customer_name.strip()) == 0:
raise ModelRetry("必须提供顾客姓名")
# 模拟预订
import random
reservation_id = f"RES{random.randint(1000, 9999)}"
return {
"reservation_id": reservation_id,
"restaurant": restaurant,
"date": date,
"time": time,
"party_size": party_size,
"customer_name": customer_name,
"special_requests": special_requests,
"status": "confirmed",
"confirmation_code": f"CODE{random.randint(100000, 999999)}"
}
@agent.tool_plain
def calculate_loan_payment(
loan_amount: float,
annual_interest_rate: float,
loan_term_years: int,
payment_frequency: str = "monthly"
) -> dict:
"""
计算贷款还款计划
Args:
loan_amount: 贷款金额,必须大于0
annual_interest_rate: 年利率,范围0.1-50.0
loan_term_years: 贷款年限,1-30年
payment_frequency: 还款频率,monthly/quarterly/yearly
Returns:
还款计划详情
"""
# 参数验证
if loan_amount <= 0:
raise ModelRetry("贷款金额必须大于0")
if annual_interest_rate < 0.1 or annual_interest_rate > 50.0:
raise ModelRetry("年利率必须在0.1%到50%之间")
if loan_term_years < 1 or loan_term_years > 30:
raise ModelRetry("贷款年限必须在1-30年之间")
valid_frequencies = ["monthly", "quarterly", "yearly"]
if payment_frequency not in valid_frequencies:
raise ModelRetry(f"还款频率必须是: {', '.join(valid_frequencies)}")
# 计算还款计划
monthly_rate = annual_interest_rate / 100 / 12
total_payments = loan_term_years * 12
if monthly_rate == 0:
monthly_payment = loan_amount / total_payments
else:
monthly_payment = loan_amount * monthly_rate * (1 + monthly_rate) ** total_payments / ((1 + monthly_rate) ** total_payments - 1)
total_payment = monthly_payment * total_payments
total_interest = total_payment - loan_amount
return {
"loan_amount": loan_amount,
"annual_interest_rate": annual_interest_rate,
"loan_term_years": loan_term_years,
"payment_frequency": payment_frequency,
"monthly_payment": round(monthly_payment, 2),
"total_payment": round(total_payment, 2),
"total_interest": round(total_interest, 2),
"calculation_date": datetime.now().strftime("%Y-%m-%d")
}
async def main():
"""
演示工具的错误处理和自动重试机制
"""
test_cases = [
# 有效的会议安排
"安排明天下午2点的团队会议,参与者zhangsan@example.com, lisi@example.com,主题项目评审",
# 无效的会议安排 - 缺少参与者
"安排会议,主题测试会议,开始时间2024-03-20 10:00",
# 有效的餐厅预订
"在和平饭店预订3月20日晚上7点的座位,4个人,姓名李四",
# 无效的餐厅预订 - 不支持的餐厅
"在未知餐厅预订座位,3月20日晚上7点,2个人,姓名张三",
# 有效的贷款计算
"计算贷款:金额100000元,年利率5%,期限10年,按月还款",
# 无效的贷款计算 - 金额为0
"计算贷款:金额0元,年利率5%,期限10年"
]
for i, query in enumerate(test_cases, 1):
print(f"\n=== 测试用例 {i} ===")
print(f"查询: {query}")
try:
result = await agent.run(query)
print(f"结果: {result.output}")
# 显示重试次数(如果有)
if hasattr(result, 'retry_count') and result.retry_count > 0:
print(f"重试次数: {result.retry_count}")
except Exception as e:
print(f"执行失败: {e}")
print("-" * 50)
if __name__ == "__main__":
asyncio.run(main())
代码详细解释:
参数验证
:每个工具都包含完整的参数验证逻辑
ModelRetry 异常
:用于指示模型需要重新尝试调用工具
具体错误信息
:提供详细的错误描述,帮助模型理解问题所在
格式验证
:验证日期、时间、邮箱等特定格式
业务规则验证
:验证业务逻辑约束,如时间必须是未来的、人数限制等
自动重试
:系统会自动处理ModelRetry
异常并重新调用工具
6. 实际应用示例
6.1 完整的电子商务助手
下面是一个完整的电子商务助手示例,展示了如何结合多种工具功能:
from pydantic_ai import Agent, RunContext, ModelRetry
from pydantic import BaseModel, Field
from typing import List, Optional, Dict
import asyncio
from datetime import datetime, timedelta
class Product(BaseModel):
id: str
name: str
price: float
category: str
stock: int
description: str
class CartItem(BaseModel):
product_id: str
quantity: int
added_at: str
class UserContext(BaseModel):
user_id: str
name: str
membership_level: str = "standard"
preferences: Dict[str, str] = Field(default_factory=dict)
# 模拟数据库
products_db = {
"P001": Product(id="P001", name="智能手机", price=2999.0, category="electronics", stock=50, description="最新款智能手机"),
"P002": Product(id="P002", name="笔记本电脑", price=5999.0, category="electronics", stock=30, description="高性能笔记本电脑"),
"P003": Product(id="P003", name="无线耳机", price=599.0, category="electronics", stock=100, description="降噪无线耳机"),
"P004": Product(id="P004", name="运动鞋", price=399.0, category="clothing", stock=80, description="舒适运动鞋"),
"P005": Product(id="P005", name="编程书籍", price=89.0, category="books", stock=200, description="Python编程指南"),
}
user_carts: Dict[str, List[CartItem]] = {}
order_history: Dict[str, List[Dict]] = {}
agent = Agent('qwen-plus', deps_type=UserContext)
@agent.tool
def search_products(
ctx: RunContext[UserContext],
query: str = "",
category: str = "",
min_price: float = 0,
max_price: float = 100000,
in_stock: bool = True
) -> List[Product]:
"""
搜索产品
Args:
query: 搜索关键词
category: 产品类别
min_price: 最低价格
max_price: 最高价格
in_stock: 是否只显示有库存的商品
"""
results = []
for product in products_db.values():
# 关键词匹配
keyword_match = not query or (
query.lower() in product.name.lower() or
query.lower() in product.description.lower()
)
# 类别匹配
category_match = not category or product.category == category
# 价格匹配
price_match = min_price <= product.price <= max_price
# 库存匹配
stock_match = not in_stock or product.stock > 0
if keyword_match and category_match and price_match and stock_match:
results.append(product)
return results
@agent.tool
def get_product_details(ctx: RunContext[UserContext], product_id: str) -> Product:
"""
获取产品详情
"""
if product_id not in products_db:
raise ModelRetry(f"产品 {product_id} 不存在")
return products_db[product_id]
@agent.tool
def add_to_cart(ctx: RunContext[UserContext], product_id: str, quantity: int = 1) -> Dict:
"""
添加商品到购物车
"""
user_id = ctx.deps.user_id
if product_id not in products_db:
raise ModelRetry(f"产品 {product_id} 不存在")
product = products_db[product_id]
if product.stock < quantity:
raise ModelRetry(f"产品 {product.name} 库存不足,当前库存: {product.stock}")
# 初始化用户购物车
if user_id not in user_carts:
user_carts[user_id] = []
# 检查是否已存在相同商品
for item in user_carts[user_id]:
if item.product_id == product_id:
item.quantity += quantity
break
else:
user_carts[user_id].append(CartItem(
product_id=product_id,
quantity=quantity,
added_at=datetime.now().isoformat()
))
return {
"success": True,
"message": f"已添加 {quantity} 件 {product.name} 到购物车",
"cart_item_count": len(user_carts[user_id])
}
@agent.tool
def view_cart(ctx: RunContext[UserContext]) -> Dict:
"""
查看购物车
"""
user_id = ctx.deps.user_id
if user_id not in user_carts or not user_carts[user_id]:
return {"empty": True, "message": "购物车为空"}
cart_items = []
total_amount = 0
for item in user_carts[user_id]:
product = products_db[item.product_id]
subtotal = product.price * item.quantity
cart_items.append({
"product_id": item.product_id,
"product_name": product.name,
"quantity": item.quantity,
"unit_price": product.price,
"subtotal": subtotal
})
total_amount += subtotal
# 会员折扣
discount = 0
if ctx.deps.membership_level == "vip":
discount = total_amount * 0.1 # VIP 9折
elif ctx.deps.membership_level == "premium":
discount = total_amount * 0.15 # 高级会员 85折
final_amount = total_amount - discount
return {
"cart_items": cart_items,
"total_amount": total_amount,
"discount": discount,
"final_amount": final_amount,
"item_count": len(cart_items),
"membership_level": ctx.deps.membership_level
}
@agent.tool
def remove_from_cart(ctx: RunContext[UserContext], product_id: str, quantity: int = None) -> Dict:
"""
从购物车移除商品
"""
user_id = ctx.deps.user_id
if user_id not in user_carts:
return {"success": False, "message": "购物车为空"}
for i, item in enumerate(user_carts[user_id]):
if item.product_id == product_id:
if quantity is None or quantity >= item.quantity:
# 移除整个商品
removed_item = user_carts[user_id].pop(i)
product = products_db[product_id]
return {
"success": True,
"message": f"已从购物车移除 {product.name}",
"removed_quantity": removed_item.quantity
}
else:
# 减少数量
item.quantity -= quantity
product = products_db[product_id]
return {
"success": True,
"message": f"已从购物车移除 {quantity} 件 {product.name}",
"remaining_quantity": item.quantity
}
return {"success": False, "message": "商品不在购物车中"}
@agent.tool
def checkout(ctx: RunContext[UserContext], shipping_address: str, payment_method: str) -> Dict:
"""
结算订单
"""
user_id = ctx.deps.user_id
if user_id not in user_carts or not user_carts[user_id]:
raise ModelRetry("购物车为空,无法结算")
cart_result = view_cart(ctx)
# 验证库存
for item in user_carts[user_id]:
product = products_db[item.product_id]
if product.stock < item.quantity:
raise ModelRetry(f"商品 {product.name} 库存不足,当前库存: {product.stock}")
# 创建订单
import random
order_id = f"ORD{datetime.now().strftime('%Y%m%d')}{random.randint(1000, 9999)}"
order = {
"order_id": order_id,
"user_id": user_id,
"items": cart_result["cart_items"],
"total_amount": cart_result["total_amount"],
"discount": cart_result["discount"],
"final_amount": cart_result["final_amount"],
"shipping_address": shipping_address,
"payment_method": payment_method,
"status": "pending",
"created_at": datetime.now().isoformat()
}
# 保存订单
if user_id not in order_history:
order_history[user_id] = []
order_history[user_id].append(order)
# 更新库存
for item in user_carts[user_id]:
products_db[item.product_id].stock -= item.quantity
# 清空购物车
user_carts[user_id] = []
return {
"success": True,
"order_id": order_id,
"message": "订单创建成功",
"order_details": order
}
@agent.tool
def get_order_history(ctx: RunContext[UserContext], limit: int = 10) -> List[Dict]:
"""
获取订单历史
"""
user_id = ctx.deps.user_id
if user_id not in order_history:
return []
return order_history[user_id][-limit:]
@agent.tool
def track_order(ctx: RunContext[UserContext], order_id: str) -> Dict:
"""
跟踪订单状态
"""
user_id = ctx.deps.user_id
if user_id not in order_history:
return {"error": "没有订单历史"}
for order in order_history[user_id]:
if order["order_id"] == order_id:
return order
return {"error": "订单不存在"}
async def main():
"""
演示完整的电商购物流程
"""
# 创建用户上下文
user_context = UserContext(
user_id="user_001",
name="张三",
membership_level="vip",
preferences={"language": "zh-CN", "currency": "CNY"}
)
print("=== 电商助手演示 ===\n")
# 1. 搜索产品
print("1. 搜索电子产品...")
result1 = await agent.run("搜索电子产品", deps=user_context)
print(f"响应: {result1.output}\n")
# 2. 查看产品详情
print("2. 查看智能手机详情...")
result2 = await agent.run("查看产品P001的详情", deps=user_context)
print(f"响应: {result2.output}\n")
# 3. 添加到购物车
print("3. 添加商品到购物车...")
result3 = await agent.run("把智能手机添加到购物车,数量2", deps=user_context)
print(f"响应: {result3.output}\n")
# 4. 查看购物车
print("4. 查看购物车...")
result4 = await agent.run("显示我的购物车", deps=user_context)
print(f"响应: {result4.output}\n")
# 5. 结算订单
print("5. 结算订单...")
result5 = await agent.run(
"结算订单,收货地址北京市海淀区,支付方式支付宝",
deps=user_context
)
print(f"响应: {result5.output}\n")
# 6. 查看订单历史
print("6. 查看订单历史...")
result6 = await agent.run("显示我的订单历史", deps=user_context)
print(f"响应: {result6.output}\n")
# 显示完整的工具调用序列
print("\n=== 完整执行轨迹 ===")
for i, message in enumerate(result6.all_messages()):
print(f"\n步骤 {i}:")
if hasattr(message, 'parts'):
for part in message.parts:
part_type = type(part).__name__
if hasattr(part, 'tool_name'):
print(f" ????? 工具调用: {part.tool_name}")
elif hasattr(part, 'content'):
content = str(part.content)
if len(content) > 100:
content = content[:100] + "..."
print(f" ???? 内容: {content}")
if __name__ == "__main__":
asyncio.run(main())
代码详细解释:
完整电商流程
:涵盖搜索、详情查看、购物车管理、结算、订单跟踪全流程
用户上下文管理
:使用 Pydantic 模型管理用户信息和偏好
库存管理
:实时检查库存并在下单时更新
会员系统
:根据会员等级提供不同的折扣
错误处理
:全面的错误检查和用户友好的错误信息
数据持久化
:模拟数据库操作,管理产品、购物车、订单数据
业务流程
:完整的业务逻辑,包括库存验证、价格计算、订单创建等
总结
Pydantic AI 的函数工具系统提供了一个强大且灵活的框架,用于扩展大型语言模型的能力。通过本文的详细讲解和完整示例,我们可以看到:
核心优势:
类型安全
:基于 Pydantic 的自动类型验证和模式生成
灵活注册
:多种工具注册方式,支持模块化和代码复用
上下文感知
:工具可以访问运行时的依赖项和上下文信息
动态控制
:基于运行时条件动态启用、禁用或调整工具
多模态支持
:支持返回文本、图像、文档等多种数据类型
错误处理
:完善的验证和自动重试机制
最佳实践:
清晰的文档
:为每个工具提供完整的参数描述和使用说明
充分的验证
:在工具内部实现业务逻辑验证
适当的错误信息
:提供具体、可操作的错误信息帮助模型重试
模块化设计
:将相关工具组织在一起,便于维护和复用
安全控制
:基于用户角色和权限动态控制工具访问
适用场景:
数据查询
:访问数据库、API 或文件系统
计算服务
:执行复杂的数学计算或业务逻辑
系统集成
:与外部系统和服务进行交互
工作流自动化
:执行多步骤的业务流程
实时信息获取
:获取最新的天气、股票、新闻等信息
通过合理使用函数工具,可以构建出功能强大、安全可靠、用户体验良好的 AI 应用,真正实现人工智能与现有系统的无缝集成。

雷达卡


京公网安备 11010802022788号







