楼主: dongbo0810
70 0

Pydantic AI:函数工具详解(三) [推广有奖]

  • 0关注
  • 0粉丝

等待验证会员

学前班

40%

还不是VIP/贵宾

-

威望
0
论坛币
0 个
通用积分
0
学术水平
0 点
热心指数
0 点
信用等级
0 点
经验
20 点
帖子
1
精华
0
在线时间
0 小时
注册时间
2018-4-18
最后登录
2018-4-18

楼主
dongbo0810 发表于 2025-11-18 15:19:14 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币

Pydantic AI 函数工具全面详解

1. 函数工具基础概念

1.1 什么是函数工具

函数工具是 Pydantic AI 中的关键机制,它为大型语言模型提供了与现实世界互动的能力。简而言之,函数工具就像是为 AI 模型安装的“手和脚”,使模型不仅能思考和回答问题,还能执行具体操作、获取外部数据、进行计算等实际任务。

核心价值:

  • 扩展模型能力:超越纯文本生成的局限,使模型能够执行具体操作
  • 访问实时数据:获取模型训练时未知的最新信息
  • 执行确定性操作:将不确定的 AI 推理与确定性的程序逻辑结合
  • 集成现有系统:连接数据库、API、文件系统等现有基础设施

1.2 函数工具与相关技术对比

函数工具 vs. RAG(检索增强生成):

  • RAG 主要用于向量搜索和信息检索
  • 函数工具更为通用,可以执行任意操作
  • 两者可以结合使用,RAG 负责信息查找,函数工具负责执行操作

函数工具 vs. 结构化输出:

  • 结构化输出定义模型的最终响应格式
  • 函数工具定义模型在生成响应过程中可以调用的操作
  • 一个模型可以同时使用多个工具,其中一些用于中间操作,一些用于最终输出

2. 工具注册机制详解

2.1 装饰器注册方式

2.1.1 @agent.tool 装饰器

@agent.tool

这是主要的装饰器,适用于需要访问代理运行上下文(RunContext)的工具。RunContext 提供了对依赖项、模型信息和其他运行时数据的访问。

from pydantic_ai import Agent, RunContext
import asyncio
import random
 
# 创建代理实例,指定依赖项类型为字符串(用户ID)
agent = Agent(
    'qwen-plus',  # 使用阿里云通义千问模型
    deps_type=str,  # 依赖项类型为用户ID字符串
    system_prompt=(
        "你是一个智能客服助手,可以查询用户信息和订单状态。"
        "请根据用户请求调用适当的工具来获取信息。"
    ),
)
 
@agent.tool
def get_user_profile(ctx: RunContext[str]) -> dict:
    """
    获取用户个人信息
    
    通过运行上下文中的用户ID来查询用户资料
    """
    user_id = ctx.deps  # 从上下文中获取依赖项(用户ID)
    
    # 模拟数据库查询 - 实际应用中这里可能是真实的数据库操作
    user_profiles = {
        "user_001": {"name": "张三", "level": "VIP", "join_date": "2023-01-15"},
        "user_002": {"name": "李四", "level": "普通", "join_date": "2023-03-20"},
    }
    
    if user_id in user_profiles:
        return user_profiles[user_id]
    else:
        return {"error": "用户不存在"}
 
@agent.tool  
def get_order_status(ctx: RunContext[str], order_id: str) -> dict:
    """
    查询订单状态
    
    Args:
        order_id: 订单编号,格式为 ORDER_XXXX
    """
    user_id = ctx.deps
    
    # 模拟订单数据查询
    orders = {
        "ORDER_1001": {"status": "已发货", "product": "智能手机", "amount": 2999},
        "ORDER_1002": {"status": "处理中", "product": "笔记本电脑", "amount": 5999},
    }
    
    if order_id in orders:
        order_info = orders[order_id]
        return {
            "order_id": order_id,
            "status": order_info["status"],
            "product": order_info["product"],
            "amount": order_info["amount"],
            "user_id": user_id
        }
    else:
        return {"error": "订单不存在"}
 
async def main():
    """
    演示工具调用的完整流程
    """
    # 使用用户ID "user_001" 作为依赖项
    result = await agent.run(
        "请查询我的个人信息和订单ORDER_1001的状态",
        deps="user_001"  # 传递用户ID作为依赖项
    )
    
    print("=== 模型响应 ===")
    print(result.output)
    
    print("\n=== 工具调用详情 ===")
    for i, message in enumerate(result.all_messages()):
        print(f"消息 {i}: {message}")
 
if __name__ == "__main__":
    asyncio.run(main())

代码详细解释:

  • 代理初始化:创建 Agent 时指定了
    deps_type=str
    ,这意味着每个工具都可以通过
    RunContext
    访问字符串类型的依赖项。
  • 工具函数定义:
    get_user_profile
    不需要额外参数,直接从上下文中获取用户ID;
    get_order_status
    需要
    order_id
    参数,由模型在调用时提供。
  • 上下文访问:通过
    ctx.deps
    访问在
    agent.run()
    中传递的依赖项数据。
  • 执行流程:模型分析用户请求,决定需要调用哪些工具,然后依次执行并整合结果。

2.1.2 @agent.tool_plain 装饰器

@agent.tool_plain

用于不需要访问运行上下文的简单工具。这些工具更加纯粹,只依赖于传入的参数。

from pydantic_ai import Agent
import asyncio
import requests
from datetime import datetime, timedelta
 
agent = Agent('qwen-plus')
 
@agent.tool_plain
def get_weather(city: str, date: str = None) -> dict:
    """
    获取城市天气信息
    
    Args:
        city: 城市名称,如"北京"、"上海"
        date: 日期字符串,格式YYYY-MM-DD,默认为今天
    """
    if date is None:
        date = datetime.now().strftime("%Y-%m-%d")
    
    # 模拟天气API调用 - 实际应用中这里调用真实天气API
    weather_data = {
        "北京": {"temperature": "22°C", "condition": "晴", "humidity": "45%"},
        "上海": {"temperature": "25°C", "condition": "多云", "humidity": "65%"},
        "广州": {"temperature": "28°C", "condition": "雨", "humidity": "80%"},
    }
    
    if city in weather_data:
        return {
            "city": city,
            "date": date,
            "weather": weather_data[city]
        }
    else:
        return {"error": f"未找到城市 {city} 的天气信息"}
 
@agent.tool_plain
def calculate_expression(expression: str) -> dict:
    """
    计算数学表达式
    
    Args:
        expression: 数学表达式,如 "2 + 3 * 4"
    """
    try:
        # 安全地计算数学表达式
        result = eval(expression)  # 注意:实际生产环境需要更安全的方式
        return {
            "expression": expression,
            "result": result,
            "type": type(result).__name__
        }
    except Exception as e:
        return {"error": f"计算失败: {str(e)}"}
 
@agent.tool_plain
def get_current_time(timezone: str = "UTC") -> dict:
    """
    获取当前时间
    
    Args:
        timezone: 时区,如"UTC", "Asia/Shanghai"
    """
    from datetime import datetime
    import pytz  # 需要安装pytz库
    
    try:
        tz = pytz.timezone(timezone)
        current_time = datetime.now(tz)
        return {
            "timezone": timezone,
            "current_time": current_time.strftime("%Y-%m-%d %H:%M:%S"),
            "iso_format": current_time.isoformat()
        }
    except Exception as e:
        return {"error": f"时区错误: {str(e)}"}
 
async def main():
    """
    演示多个工具的组合使用
    """
    result = await agent.run(
        "请告诉我北京今天的天气,并计算(15 + 25) * 2的结果,还有上海当前时间"
    )
    
    print("=== 综合查询结果 ===")
    print(result.output)
    
    # 显示详细的工具调用序列
    print("\n=== 执行轨迹 ===")
    for i, msg in enumerate(result.all_messages()):
        print(f"步骤 {i}: {type(msg).__name__}")
        if hasattr(msg, 'parts'):
            for part in msg.parts:
                part_type = type(part).__name__
                if hasattr(part, 'content'):
                    print(f"  - {part_type}: {part.content}")
 
if __name__ == "__main__":
    asyncio.run(main())

代码详细解释:

  • 工具独立性:
    @agent.tool_plain
    装饰的工具不依赖运行上下文,更加纯粹。
  • 参数处理:每个工具都明确定义了参数和返回类型,模型会根据这些信息正确调用。
  • 错误处理:工具内部包含完整的错误处理逻辑,返回结构化的错误信息。
  • 多工具协作:模型可以在一个对话中调用多个不同的工具来满足复杂需求。

2.2 通过 Agent 构造函数注册

除了装饰器,还可以在创建 Agent 时通过

tools
参数注册工具,这种方式更适合工具复用和模块化设计。

from pydantic_ai import Agent, RunContext, Tool
import asyncio
import json
 
# 独立的工具函数 - 可以在多个代理间共享
def currency_converter(amount: float, from_currency: str, to_currency: str) -> dict:
    """
    货币兑换计算
    
    Args:
        amount: 金额数量
        from_currency: 源货币代码,如"USD", "CNY"
        to_currency: 目标货币代码,如"EUR", "JPY"
    """
    # 模拟汇率数据
    exchange_rates = {
        "USD": {"CNY": 7.2, "EUR": 0.92, "JPY": 150},
        "CNY": {"USD": 0.14, "EUR": 0.13, "JPY": 21},
        "EUR": {"USD": 1.09, "CNY": 7.8, "JPY": 163},
    }
    
    if from_currency in exchange_rates and to_currency in exchange_rates[from_currency]:
        rate = exchange_rates[from_currency][to_currency]
        converted = amount * rate
        return {
            "original": f"{amount} {from_currency}",
            "converted": f"{converted:.2f} {to_currency}",
            "exchange_rate": rate
        }
    else:
        return {"error": "不支持的货币兑换"}
 
def get_stock_price(symbol: str) -> dict:
    """
    获取股票价格
    
    Args:
        symbol: 股票代码,如"AAPL", "TSLA"
    """
    # 模拟股票数据
    stock_prices = {
        "AAPL": {"price": 185.32, "change": "+1.25", "change_percent": "+0.68%"},
        "TSLA": {"price": 245.18, "change": "-3.42", "change_percent": "-1.38%"},
        "GOOGL": {"price": 138.45, "change": "+0.85", "change_percent": "+0.62%"},
    }
    
    if symbol.upper() in stock_prices:
        return stock_prices[symbol.upper()]
    else:
        return {"error": f"未找到股票 {symbol} 的信息"}
 
def create_user_session(ctx: RunContext[str]) -> dict:
    """
    创建用户会话
    
    需要访问运行上下文来获取用户信息
    """
    user_id = ctx.deps
    return {
        "session_id": f"SESS_{user_id}_{hash(user_id)}",
        "user_id": user_id,
        "created_at": "2024-01-01T10:00:00Z",
        "status": "active"
    }
 
# 方式1:直接传递函数列表
# Agent会自动检测哪些函数需要上下文
agent_simple = Agent(
    'qwen-plus',
    deps_type=str,
    tools=[currency_converter, get_stock_price, create_user_session],
    system_prompt="金融信息查询助手,可以处理货币兑换和股票查询。"
)
 
# 方式2:使用Tool类进行精细控制
# 明确指定每个工具是否需要上下文
agent_advanced = Agent(
    'qwen-plus', 
    deps_type=str,
    tools=[
        Tool(currency_converter, takes_ctx=False),
        Tool(get_stock_price, takes_ctx=False),
        Tool(create_user_session, takes_ctx=True),
    ],
    system_prompt="高级金融助手,提供货币兑换和股票查询服务。"
)
 
async def main():
    """
    比较两种注册方式的效果
    """
    print("=== 简单注册方式 ===")
    result1 = await agent_simple.run(
        "将100美元换成人民币,并查看苹果股票价格",
        deps="user_123"
    )
    print(result1.output)
    
    print("\n=== 高级注册方式 ===")
    result2 = await agent_advanced.run(
        "创建用户会话并查询特斯拉股票",
        deps="user_456" 
    )
    print(result2.output)
 
if __name__ == "__main__":
    asyncio.run(main())

代码详细解释:

  • 工具函数独立性:工具函数被定义为独立的函数,不依赖于特定的代理实例。
  • 两种注册方式对比:简单方式——直接传递函数列表,Agent 自动检测上下文需求;高级方式——使用
    Tool
    类明确指定每个工具的配置。
  • 工具复用:相同的工具函数可以在多个代理实例间共享。
  • 明确性:高级方式虽然代码稍多,但意图更加明确,便于维护。

3. 工具模式与参数系统

3.1 自动模式生成机制

Pydantic AI 会自动从函数签名和文档字符串中提取信息,生成完整的工具模式(Schema)。这个模式告诉模型如何调用工具,包括参数类型、描述、约束条件等。

from pydantic_ai import Agent
import asyncio
from typing import List, Optional
from pydantic import Field
 
agent = Agent('qwen-plus')
 
@agent.tool_plain(
    docstring_format='google',  # 指定文档字符串格式
    require_parameter_descriptions=True  # 要求必须提供参数描述
)
def book_hotel(
    city: str,
    check_in_date: str,
    check_out_date: str,
    guests: int = Field(1, description="入住人数", ge=1, le=10),
    room_type: str = Field("standard", description="房间类型: standard, deluxe, suite"),
    amenities: Optional[List[str]] = Field(None, description="需要的设施")
) -> dict:
    """
    预订酒店房间
    
    根据用户需求在指定城市预订酒店,支持多种房型和设施选择。
 
    Args:
        city: 城市名称,必须是支持服务的城市
        check_in_date: 入住日期,格式YYYY-MM-DD
        check_out_date: 退房日期,格式YYYY-MM-DD
        guests: 入住人数,范围1-10人
        room_type: 房间类型,可选标准间、豪华间、套房
        amenities: 额外设施需求,如wifi、parking、breakfast等
 
    Returns:
        包含预订详情的字典,包括预订ID和价格信息
 
    Raises:
        ModelRetry: 当参数无效或预订失败时抛出重试异常
    """
    # 参数验证逻辑
    if check_in_date >= check_out_date:
        raise ModelRetry("退房日期必须晚于入住日期,请检查日期输入")
    
    supported_cities = ["北京", "上海", "广州", "深圳", "杭州"]
    if city not in supported_cities:
        raise ModelRetry(
            f"暂不支持城市 {city},请选择以下城市: {', '.join(supported_cities)}"
        )
    
    # 模拟预订逻辑
    import random
    booking_id = f"HTL{random.randint(10000, 99999)}"
    base_prices = {"standard": 300, "deluxe": 500, "suite": 800}
    base_price = base_prices.get(room_type, 300)
    total_price = base_price * (guests if guests <= 2 else guests - 1)
    
    return {
        "booking_id": booking_id,
        "city": city,
        "check_in": check_in_date,
        "check_out": check_out_date,
        "guests": guests,
        "room_type": room_type,
        "amenities": amenities or [],
        "total_price": f"?{total_price}",
        "status": "confirmed"
    }
 
@agent.tool_plain
def search_flights(
    departure: str,
    arrival: str, 
    date: str,
    passengers: int = 1,
    class_type: str = "economy"
) -> List[dict]:
    """
    搜索航班信息
    
    Args:
        departure: 出发城市机场代码
        arrival: 到达城市机场代码  
        date: 出发日期,格式YYYY-MM-DD
        passengers: 乘客人数
        class_type: 舱位类型,economy/business/first
 
    Returns:
        可用航班列表,包含时间和价格信息
    """
    # 模拟航班搜索
    flights = [
        {
            "flight_no": "CA1234",
            "departure_time": "08:00",
            "arrival_time": "11:00", 
            "price": 1200,
            "airline": "中国国际航空"
        },
        {
            "flight_no": "MU5678", 
            "departure_time": "14:30",
            "arrival_time": "17:30",
            "price": 980,
            "airline": "中国东方航空"
        }
    ]
    
    return [
        {
            **flight,
            "class_type": class_type,
            "passengers": passengers,
            "total_price": flight["price"] * passengers
        }
        for flight in flights
    ]
 
async def main():
    """
    演示自动生成的工具模式如何工作
    """
    # 测试复杂参数的工具调用
    result = await agent.run(
        "我想在北京预订酒店,3月20日入住,3月25日退房,2个人,要豪华间,需要wifi和早餐"
    )
    
    print("=== 酒店预订结果 ===")
    print(result.output)
    
    # 航班搜索测试
    result2 = await agent.run(
        "搜索3月15日从北京到上海的航班,2个乘客,经济舱"
    )
    
    print("\n=== 航班搜索结果 ===")
    print(result2.output)
 
if __name__ == "__main__":
    asyncio.run(main())

代码详细解释:

  • 文档字符串解析:系统会解析 Google 风格的文档字符串,提取参数描述。
  • 类型注解:Python 类型注解被转换为 JSON Schema 类型约束。
  • 参数验证:Pydantic 的
    Field
    类提供额外的验证规则(如范围约束)。
  • 错误处理:使用
    ModelRetry
    异常让模型知道需要重新尝试。
  • 复杂返回类型:支持返回字典、列表等复杂数据结构。

3.2 使用 Pydantic 模型作为参数

对于复杂的工具参数,可以使用 Pydantic 模型来定义数据结构,这样既能获得更好的类型安全,也能生成更清晰的工具模式。

from pydantic import BaseModel, Field, validator
from typing import List, Optional
from datetime import datetime
from pydantic_ai import Agent, ModelRetry
import asyncio
 
class Address(BaseModel):
    """地址信息"""
    street: str = Field(description="街道地址")
    city: str = Field(description="城市")
    state: str = Field(description="省/州")
    postal_code: str = Field(description="邮政编码")
    country: str = Field(default="中国", description="国家")
 
class CustomerInfo(BaseModel):
    """客户信息"""
    name: str = Field(description="客户姓名")
    email: str = Field(description="邮箱地址")
    phone: str = Field(description="手机号码")
    address: Address = Field(description="联系地址")
    
    @validator('email')
    def validate_email(cls, v):
        if '@' not in v:
            raise ValueError('邮箱格式无效')
        return v
    
    @validator('phone') 
    def validate_phone(cls, v):
        if not v.replace('+', '').replace(' ', '').isdigit():
            raise ValueError('手机号码格式无效')
        return v
 
class OrderItem(BaseModel):
    """订单项"""
    product_id: str = Field(description="产品ID")
    product_name: str = Field(description="产品名称") 
    quantity: int = Field(description="数量", ge=1)
    unit_price: float = Field(description="单价", gt=0)
 
class CreateOrderRequest(BaseModel):
    """创建订单请求"""
    customer: CustomerInfo = Field(description="客户信息")
    items: List[OrderItem] = Field(description="订单项目列表")
    shipping_method: str = Field(
        default="standard", 
        description="配送方式: standard/express/overnight"
    )
    notes: Optional[str] = Field(None, description="订单备注")
 
agent = Agent('qwen-plus')
 
@agent.tool_plain
def create_order(request: CreateOrderRequest) -> dict:
    """
    创建新订单
    
    接收完整的订单信息,验证后创建订单并返回订单详情。
    
    Args:
        request: 包含客户信息、订单项目和配送方式的完整请求
        
    Returns:
        创建的订单详情,包括订单ID和总金额
    """
    try:
        # 计算总金额
        total_amount = sum(item.quantity * item.unit_price for item in request.items)
        
        # 生成订单ID
        import random
        order_id = f"ORD{datetime.now().strftime('%Y%m%d')}{random.randint(1000, 9999)}"
        
        # 模拟订单创建
        return {
            "order_id": order_id,
            "customer_name": request.customer.name,
            "customer_email": request.customer.email,
            "items": [
                {
                    "product_name": item.product_name,
                    "quantity": item.quantity,
                    "unit_price": item.unit_price,
                    "subtotal": item.quantity * item.unit_price
                }
                for item in request.items
            ],
            "total_amount": total_amount,
            "shipping_method": request.shipping_method,
            "status": "created",
            "created_at": datetime.now().isoformat()
        }
        
    except Exception as e:
        raise ModelRetry(f"订单创建失败: {str(e)}")
 
@agent.tool_plain
def update_order_status(order_id: str, status: str, notes: Optional[str] = None) -> dict:
    """
    更新订单状态
    
    Args:
        order_id: 订单ID
        status: 新状态: processing/shipped/delivered/cancelled
        notes: 状态更新备注
    """
    valid_statuses = ["processing", "shipped", "delivered", "cancelled"]
    if status not in valid_statuses:
        raise ModelRetry(f"无效状态: {status},可选: {', '.join(valid_statuses)}")
    
    return {
        "order_id": order_id,
        "old_status": "created",  # 模拟原有状态
        "new_status": status,
        "updated_at": datetime.now().isoformat(),
        "notes": notes
    }
 
async def main():
    """
    演示复杂参数工具的使用
    """
    result = await agent.run(
        "创建新订单:"
        "客户张三,邮箱zhangsan@example.com,电话13800138000,"
        "地址:北京市海淀区中关村大街100号,北京,北京,100080;"
        "订购产品:笔记本电脑2台单价5000元,鼠标1个单价100元;"
        "使用快递配送,备注尽快发货"
    )
    
    print("=== 订单创建结果 ===")
    print(result.output)
    
    # 显示工具调用的详细信息
    print("\n=== 执行详情 ===")
    for i, message in enumerate(result.all_messages()):
        if hasattr(message, 'parts'):
            print(f"消息 {i}:")
            for part in message.parts:
                if hasattr(part, 'tool_name'):
                    print(f"  工具调用: {part.tool_name}")
                elif hasattr(part, 'content'):
                    content = part.content
                    if len(str(content)) > 100:
                        content = str(content)[:100] + "..."
                    print(f"  内容: {content}")
 
if __name__ == "__main__":
    asyncio.run(main())

代码详细解释:

  • Pydantic 模型定义:使用
    BaseModel
    定义复杂的数据结构。
  • 字段验证:通过
    Field
    设置字段描述和约束条件。
  • 自定义验证器:使用
    @validator
    装饰器实现自定义验证逻辑。
  • 嵌套模型:模型可以嵌套其他模型,构建复杂的数据结构。
  • 自动模式生成:Pydantic AI 会自动将这些模型转换为 JSON Schema。

4. 高级工具功能

4.1 多模态工具输出

Pydantic AI 支持工具返回多种类型的数据,包括文本、图像、文档等,这对于多模态模型特别有用。

from pydantic_ai import Agent, ImageUrl, DocumentUrl, ToolReturn, BinaryContent
from pydantic_ai.messages import TextContent
import asyncio
import base64
 
# 使用多模态模型
agent = Agent('qwen-vl-plus')
 
@agent.tool_plain
def generate_chart(data_type: str) -> ToolReturn:
    """
    生成数据图表
    
    Args:
        data_type: 数据类型: sales/users/performance
    """
    # 模拟图表生成 - 实际应用中这里会调用图表生成库
    chart_data = {
        "sales": {
            "title": "月度销售趋势",
            "data": [120, 150, 180, 200, 240, 300, 280],
            "labels": ["1月", "2月", "3月", "4月", "5月", "6月", "7月"]
        },
        "users": {
            "title": "用户增长统计", 
            "data": [1000, 1500, 2200, 3000, 4000, 5200, 6500],
            "labels": ["Q1", "Q2", "Q3", "Q4", "Q5", "Q6", "Q7"]
        }
    }
    
    if data_type not in chart_data:
        return ToolReturn(
            return_value={"error": f"未知数据类型: {data_type}"},
            content=[f"无法生成 {data_type} 类型的图表"]
        )
    
    data = chart_data[data_type]
    
    # 模拟生成图表图像数据(base64编码)
    # 实际应用中这里会使用matplotlib、plotly等库生成真实图表
    fake_image_data = base64.b64encode(f"chart_{data_type}".encode()).decode()
    
    return ToolReturn(
        return_value={
            "chart_type": data_type,
            "title": data["title"],
            "data_points": len(data["data"]),
            "max_value": max(data["data"])
        },
        content=[
            f"已生成 {data['title']} 图表",
            "图表数据摘要:",
            TextContent(text=str(data)),
            "生成的图表:",
            BinaryContent(data=fake_image_data, media_type="image/png"),
            "请分析图表趋势并提供见解。"
        ],
        metadata={
            "chart_type": data_type,
            "generated_at": "2024-01-01T10:00:00Z",
            "data_source": "internal_system"
        }
    )
 
@agent.tool_plain
def get_product_catalog() -> ToolReturn:
    """
    获取产品目录,包含产品图片和描述
    """
    products = [
        {
            "id": "P001",
            "name": "智能手机",
            "price": 2999,
            "image": "https://example.com/images/phone.jpg",
            "description": "最新款智能手机,高性能处理器"
        },
        {
            "id": "P002", 
            "name": "笔记本电脑",
            "price": 5999,
            "image": "https://example.com/images/laptop.jpg",
            "description": "轻薄便携笔记本电脑,超长续航"
        }
    ]
    
    return ToolReturn(
        return_value={"products": products, "total_count": len(products)},
        content=[
            "产品目录如下:",
            *[
                f"产品: {p['name']} - ?{p['price']}\n"
                f"描述: {p['description']}\n"
                f"图片: {p['image']}"
                for p in products
            ],
            "请基于产品信息回答用户问题。"
        ]
    )
 
@agent.tool_plain
def analyze_document(document_url: str) -> ToolReturn:
    """
    分析文档内容
    
    Args:
        document_url: 文档的URL地址
    """
    # 模拟文档分析
    document_content = {
        "title": "季度报告",
        "author": "财务部", 
        "pages": 25,
        "key_points": [
            "收入同比增长15%",
            "利润率提升至20%",
            "新客户增长30%"
        ],
        "summary": "本季度业绩表现良好,各项指标均超预期"
    }
    
    return ToolReturn(
        return_value=document_content,
        content=[
            f"已分析文档: {document_url}",
            "文档摘要:",
            TextContent(text=document_content["summary"]),
            "关键要点:",
            *[f"- {point}" for point in document_content["key_points"]],
            "请基于文档内容回答用户问题。"
        ],
        metadata={
            "document_url": document_url,
            "analysis_method": "ai_analysis",
            "confidence_score": 0.95
        }
    )
 
async def main():
    """
    演示多模态工具的高级功能
    """
    print("=== 测试图表生成 ===")
    result1 = await agent.run("生成销售数据图表并分析趋势")
    print("响应:", result1.output)
    
    print("\n=== 测试产品目录 ===")
    result2 = await agent.run("显示产品目录并推荐适合的产品")
    print("响应:", result2.output)
    
    print("\n=== 测试文档分析 ===")
    result3 = await agent.run("分析https://example.com/report.pdf这个文档")
    print("响应:", result3.output)
    
    # 显示工具返回的元数据
    if result1.tool_results:
        print("\n=== 工具元数据 ===")
        for tool_result in result1.tool_results:
            print(f"工具: {tool_result.tool_name}")
            print(f"元数据: {tool_result.metadata}")
 
if __name__ == "__main__":
    asyncio.run(main())

代码详细解释:

  • ToolReturn 类:提供对工具返回值的精细控制。
  • 多部分内容:可以同时返回文本、图像、文档等多种类型的内容。
  • 元数据分离:
    metadata
    包含不发送给模型但应用程序需要的数据。
  • 结构化返回值:支持返回复杂的数据结构。
return_value
包含工具的主要执行结果 丰富的上下文 :
content
为模型提供分析和回答所需的额外信息 4.2 动态工具控制 动态工具功能允许根据运行时上下文启用、禁用或调整工具行为。
from pydantic_ai import Agent, RunContext, ToolDefinition, Tool
from typing import Optional, List
import asyncio
 
class UserRole:
    """用户角色定义"""
    ADMIN = "admin"
    USER = "user" 
    GUEST = "guest"
    SUPPORT = "support"
 
agent = Agent('qwen-plus', deps_type=dict)  # 依赖项为用户上下文字典
 
# 工具准备函数 - 基于用户角色过滤工具
async def prepare_tools_by_role(
    ctx: RunContext[dict], 
    tool_defs: List[ToolDefinition]
) -> Optional[List[ToolDefinition]]:
    """
    根据用户角色过滤可用工具
    
    Args:
        ctx: 运行上下文,包含用户信息
        tool_defs: 所有可用的工具定义列表
    """
    user_context = ctx.deps
    user_role = user_context.get('role', UserRole.GUEST)
    
    # 定义角色权限
    role_permissions = {
        UserRole.ADMIN: ['*'],  # 管理员可以使用所有工具
        UserRole.SUPPORT: ['get_user_info', 'view_tickets', 'add_comment'],
        UserRole.USER: ['get_own_info', 'create_ticket', 'view_own_tickets'],
        UserRole.GUEST: ['get_public_info']  # 访客只能使用公开工具
    }
    
    allowed_tools = role_permissions.get(user_role, [])
    
    if '*' in allowed_tools:
        return tool_defs  # 管理员可以使用所有工具
    
    # 过滤工具
    filtered_tools = []
    for tool_def in tool_defs:
        if tool_def.name in allowed_tools:
            filtered_tools.append(tool_def)
    
    return filtered_tools
 
# 设置代理级别的工具准备函数
agent.prepare_tools = prepare_tools_by_role
 
# 定义各种工具
@agent.tool_plain
def get_public_info() -> dict:
    """获取公开信息"""
    return {
        "service_status": "正常运行",
        "announcements": ["系统维护通知", "新功能上线"],
        "contact_info": "support@example.com"
    }
 
@agent.tool
def get_own_info(ctx: RunContext[dict]) -> dict:
    """获取自己的信息"""
    user_context = ctx.deps
    return {
        "user_id": user_context.get('user_id'),
        "name": user_context.get('name', '未知用户'),
        "join_date": "2023-01-01",
        "points": 1500
    }
 
@agent.tool
def create_ticket(ctx: RunContext[dict], title: str, description: str) -> dict:
    """创建支持工单"""
    user_context = ctx.deps
    import random
    ticket_id = f"TICKET{random.randint(10000, 99999)}"
    
    return {
        "ticket_id": ticket_id,
        "title": title,
        "description": description,
        "created_by": user_context.get('user_id'),
        "status": "open",
        "created_at": "2024-01-01T10:00:00Z"
    }
 
@agent.tool
def view_own_tickets(ctx: RunContext[dict]) -> List[dict]:
    """查看自己的工单"""
    user_context = ctx.deps
    user_id = user_context.get('user_id')
    
    # 模拟工单数据
    return [
        {
            "ticket_id": "TICKET1001",
            "title": "登录问题",
            "status": "resolved",
            "created_at": "2024-01-01T09:00:00Z"
        },
        {
            "ticket_id": "TICKET1002", 
            "title": "支付问题",
            "status": "in_progress",
            "created_at": "2024-01-01T10:00:00Z"
        }
    ]
 
@agent.tool_plain
def get_user_info(user_id: str) -> dict:
    """获取用户信息(支持人员专用)"""
    # 模拟用户数据库
    users = {
        "user_001": {"name": "张三", "role": "user", "status": "active"},
        "user_002": {"name": "李四", "role": "user", "status": "active"},
    }
    
    if user_id in users:
        return users[user_id]
    else:
        return {"error": "用户不存在"}
 
@agent.tool_plain  
def view_tickets(status: str = "all") -> List[dict]:
    """查看所有工单(支持人员专用)"""
    # 模拟工单数据库
    all_tickets = [
        {"ticket_id": "TICKET1001", "user_id": "user_001", "status": "open"},
        {"ticket_id": "TICKET1002", "user_id": "user_002", "status": "in_progress"},
        {"ticket_id": "TICKET1003", "user_id": "user_001", "status": "resolved"},
    ]
    
    if status != "all":
        return [t for t in all_tickets if t['status'] == status]
    return all_tickets
 
@agent.tool_plain
def add_comment(ticket_id: str, comment: str) -> dict:
    """添加工单评论(支持人员专用)"""
    return {
        "ticket_id": ticket_id,
        "comment": comment,
        "added_by": "support_agent",
        "added_at": "2024-01-01T10:00:00Z"
    }
 
@agent.tool_plain
def system_maintenance(action: str) -> dict:
    """系统维护操作(管理员专用)"""
    if action == "backup":
        return {"status": "success", "message": "系统备份完成"}
    elif action == "restart":
        return {"status": "success", "message": "系统重启完成"}
    else:
        return {"error": "未知维护操作"}
 
async def main():
    """
    演示基于角色的动态工具访问控制
    """
    # 测试不同角色的工具访问权限
    test_cases = [
        {
            "role": UserRole.GUEST,
            "user_id": None,
            "name": "访客用户",
            "queries": [
                "查看系统状态",
                "查看我的信息",  # 应该被拒绝
                "查看所有工单"   # 应该被拒绝
            ]
        },
        {
            "role": UserRole.USER, 
            "user_id": "user_001",
            "name": "普通用户",
            "queries": [
                "查看我的信息",
                "创建工单:无法登录",
                "查看所有工单"  # 应该被拒绝
            ]
        },
        {
            "role": UserRole.SUPPORT,
            "user_id": "support_001", 
            "name": "客服人员",
            "queries": [
                "查看用户user_001的信息",
                "查看所有工单",
                "系统备份"  # 应该被拒绝
            ]
        },
        {
            "role": UserRole.ADMIN,
            "user_id": "admin_001",
            "name": "系统管理员", 
            "queries": [
                "查看用户user_001的信息",
                "系统备份",
                "查看所有工单"
            ]
        }
    ]
    
    for test_case in test_cases:
        print(f"\n=== 测试角色: {test_case['role']} ===")
        user_context = {
            'role': test_case['role'],
            'user_id': test_case['user_id'], 
            'name': test_case['name']
        }
        
        for query in test_case['queries']:
            print(f"\n查询: {query}")
            try:
                result = await agent.run(query, deps=user_context)
                print(f"响应: {result.output}")
            except Exception as e:
                print(f"错误: {e}")
 
if __name__ == "__main__":
    asyncio.run(main())
代码详细解释: 角色基础权限系统 :定义不同用户角色可以访问的工具 动态工具过滤 :
prepare_tools_by_role
函数根据用户角色过滤可用工具 上下文感知 :工具可以通过
RunContext
访问用户信息 权限分级 : 访客:仅能访问公开信息 用户:可以管理自己的数据和创建工单 客服:可以查看所有用户信息和工单 管理员:无限制访问所有功能 安全控制 :防止未经授权访问敏感操作 5. 工具执行与错误处理 5.1 参数验证与自动重试 Pydantic AI 提供了强大的参数验证和自动重试机制,确保工具调用的可靠性。
from pydantic_ai import Agent, ModelRetry
from pydantic import ValidationError
import asyncio
from typing import List
from datetime import datetime
 
agent = Agent('qwen-plus')
 
@agent.tool_plain
def schedule_meeting(
    title: str,
    participants: List[str],
    start_time: str,
    duration_minutes: int = 60,
    location: str = "线上会议",
    recurring: bool = False
) -> dict:
    """
    安排会议
    
    Args:
        title: 会议标题,不能为空
        participants: 参与者邮箱列表,至少1人,最多50人
        start_time: 开始时间,格式YYYY-MM-DD HH:MM,必须是将来的时间
        duration_minutes: 会议时长(分钟),范围15-480
        location: 会议地点
        recurring: 是否为周期性会议
 
    Returns:
        会议安排结果,包含会议ID和详情
    """
    # 参数验证
    if not title or len(title.strip()) == 0:
        raise ModelRetry("会议标题不能为空,请提供有效的会议标题")
    
    if len(participants) == 0:
        raise ModelRetry("必须指定至少一名参与者")
    
    if len(participants) > 50:
        raise ModelRetry("参与者人数不能超过50人")
    
    # 验证邮箱格式
    for email in participants:
        if '@' not in email:
            raise ModelRetry(f"参与者邮箱格式无效: {email}")
    
    # 验证时间格式
    try:
        meeting_time = datetime.strptime(start_time, "%Y-%m-%d %H:%M")
        if meeting_time <= datetime.now():
            raise ModelRetry("会议时间必须是将来时间,请调整开始时间")
    except ValueError:
        raise ModelRetry("时间格式错误,请使用 YYYY-MM-DD HH:MM 格式")
    
    if duration_minutes < 15 or duration_minutes > 480:
        raise ModelRetry("会议时长必须在15-480分钟之间")
    
    # 模拟会议安排
    import random
    meeting_id = f"MTG{random.randint(10000, 99999)}"
    
    return {
        "meeting_id": meeting_id,
        "title": title,
        "participants": participants,
        "start_time": start_time,
        "duration_minutes": duration_minutes,
        "location": location,
        "recurring": recurring,
        "status": "scheduled",
        "created_at": datetime.now().isoformat()
    }
 
@agent.tool_plain
def make_reservation(
    restaurant: str,
    date: str,
    time: str,
    party_size: int,
    customer_name: str,
    special_requests: str = ""
) -> dict:
    """
    餐厅预订
    
    Args:
        restaurant: 餐厅名称,必须是合作餐厅
        date: 预订日期,格式YYYY-MM-DD
        time: 预订时间,格式HH:MM
        party_size: 用餐人数,1-20人
        customer_name: 顾客姓名
        special_requests: 特殊要求
 
    Returns:
        预订确认信息
    """
    # 验证餐厅
    supported_restaurants = ["和平饭店", "星光餐厅", "海港酒楼", "花园咖啡"]
    if restaurant not in supported_restaurants:
        raise ModelRetry(
            f"不支持餐厅 '{restaurant}',可选餐厅: {', '.join(supported_restaurants)}"
        )
    
    # 验证日期
    try:
        reservation_date = datetime.strptime(date, "%Y-%m-%d")
        if reservation_date < datetime.now().date():
            raise ModelRetry("预订日期不能是过去日期")
    except ValueError:
        raise ModelRetry("日期格式错误,请使用 YYYY-MM-DD 格式")
    
    # 验证时间
    try:
        datetime.strptime(time, "%H:%M")
    except ValueError:
        raise ModelRetry("时间格式错误,请使用 HH:MM 格式")
    
    # 验证人数
    if party_size < 1 or party_size > 20:
        raise ModelRetry("用餐人数必须在1-20人之间")
    
    if not customer_name or len(customer_name.strip()) == 0:
        raise ModelRetry("必须提供顾客姓名")
    
    # 模拟预订
    import random
    reservation_id = f"RES{random.randint(1000, 9999)}"
    
    return {
        "reservation_id": reservation_id,
        "restaurant": restaurant,
        "date": date,
        "time": time,
        "party_size": party_size,
        "customer_name": customer_name,
        "special_requests": special_requests,
        "status": "confirmed",
        "confirmation_code": f"CODE{random.randint(100000, 999999)}"
    }
 
@agent.tool_plain
def calculate_loan_payment(
    loan_amount: float,
    annual_interest_rate: float,
    loan_term_years: int,
    payment_frequency: str = "monthly"
) -> dict:
    """
    计算贷款还款计划
    
    Args:
        loan_amount: 贷款金额,必须大于0
        annual_interest_rate: 年利率,范围0.1-50.0
        loan_term_years: 贷款年限,1-30年
        payment_frequency: 还款频率,monthly/quarterly/yearly
 
    Returns:
        还款计划详情
    """
    # 参数验证
    if loan_amount <= 0:
        raise ModelRetry("贷款金额必须大于0")
    
    if annual_interest_rate < 0.1 or annual_interest_rate > 50.0:
        raise ModelRetry("年利率必须在0.1%到50%之间")
    
    if loan_term_years < 1 or loan_term_years > 30:
        raise ModelRetry("贷款年限必须在1-30年之间")
    
    valid_frequencies = ["monthly", "quarterly", "yearly"]
    if payment_frequency not in valid_frequencies:
        raise ModelRetry(f"还款频率必须是: {', '.join(valid_frequencies)}")
    
    # 计算还款计划
    monthly_rate = annual_interest_rate / 100 / 12
    total_payments = loan_term_years * 12
    
    if monthly_rate == 0:
        monthly_payment = loan_amount / total_payments
    else:
        monthly_payment = loan_amount * monthly_rate * (1 + monthly_rate) ** total_payments / ((1 + monthly_rate) ** total_payments - 1)
    
    total_payment = monthly_payment * total_payments
    total_interest = total_payment - loan_amount
    
    return {
        "loan_amount": loan_amount,
        "annual_interest_rate": annual_interest_rate,
        "loan_term_years": loan_term_years,
        "payment_frequency": payment_frequency,
        "monthly_payment": round(monthly_payment, 2),
        "total_payment": round(total_payment, 2),
        "total_interest": round(total_interest, 2),
        "calculation_date": datetime.now().strftime("%Y-%m-%d")
    }
 
async def main():
    """
    演示工具的错误处理和自动重试机制
    """
    test_cases = [
        # 有效的会议安排
        "安排明天下午2点的团队会议,参与者zhangsan@example.com, lisi@example.com,主题项目评审",
        
        # 无效的会议安排 - 缺少参与者
        "安排会议,主题测试会议,开始时间2024-03-20 10:00",
        
        # 有效的餐厅预订
        "在和平饭店预订3月20日晚上7点的座位,4个人,姓名李四",
        
        # 无效的餐厅预订 - 不支持的餐厅
        "在未知餐厅预订座位,3月20日晚上7点,2个人,姓名张三",
        
        # 有效的贷款计算
        "计算贷款:金额100000元,年利率5%,期限10年,按月还款",
        
        # 无效的贷款计算 - 金额为0
        "计算贷款:金额0元,年利率5%,期限10年"
    ]
    
    for i, query in enumerate(test_cases, 1):
        print(f"\n=== 测试用例 {i} ===")
        print(f"查询: {query}")
        
        try:
            result = await agent.run(query)
            print(f"结果: {result.output}")
            
            # 显示重试次数(如果有)
            if hasattr(result, 'retry_count') and result.retry_count > 0:
                print(f"重试次数: {result.retry_count}")
                
        except Exception as e:
            print(f"执行失败: {e}")
        
        print("-" * 50)
 
if __name__ == "__main__":
    asyncio.run(main())
代码详细解释: 参数验证 :每个工具都包含完整的参数验证逻辑 ModelRetry 异常 :用于指示模型需要重新尝试调用工具 具体错误信息 :提供详细的错误描述,帮助模型理解问题所在 格式验证 :验证日期、时间、邮箱等特定格式 业务规则验证 :验证业务逻辑约束,如时间必须是未来的、人数限制等 自动重试 :系统会自动处理
ModelRetry
异常并重新调用工具 6. 实际应用示例 6.1 完整的电子商务助手 下面是一个完整的电子商务助手示例,展示了如何结合多种工具功能:
from pydantic_ai import Agent, RunContext, ModelRetry
from pydantic import BaseModel, Field
from typing import List, Optional, Dict
import asyncio
from datetime import datetime, timedelta
 
class Product(BaseModel):
    id: str
    name: str
    price: float
    category: str
    stock: int
    description: str
 
class CartItem(BaseModel):
    product_id: str
    quantity: int
    added_at: str
 
class UserContext(BaseModel):
    user_id: str
    name: str
    membership_level: str = "standard"
    preferences: Dict[str, str] = Field(default_factory=dict)
 
# 模拟数据库
products_db = {
    "P001": Product(id="P001", name="智能手机", price=2999.0, category="electronics", stock=50, description="最新款智能手机"),
    "P002": Product(id="P002", name="笔记本电脑", price=5999.0, category="electronics", stock=30, description="高性能笔记本电脑"),
    "P003": Product(id="P003", name="无线耳机", price=599.0, category="electronics", stock=100, description="降噪无线耳机"),
    "P004": Product(id="P004", name="运动鞋", price=399.0, category="clothing", stock=80, description="舒适运动鞋"),
    "P005": Product(id="P005", name="编程书籍", price=89.0, category="books", stock=200, description="Python编程指南"),
}
 
user_carts: Dict[str, List[CartItem]] = {}
order_history: Dict[str, List[Dict]] = {}
 
agent = Agent('qwen-plus', deps_type=UserContext)
 
@agent.tool
def search_products(
    ctx: RunContext[UserContext],
    query: str = "",
    category: str = "",
    min_price: float = 0,
    max_price: float = 100000,
    in_stock: bool = True
) -> List[Product]:
    """
    搜索产品
    
    Args:
        query: 搜索关键词
        category: 产品类别
        min_price: 最低价格
        max_price: 最高价格  
        in_stock: 是否只显示有库存的商品
    """
    results = []
    
    for product in products_db.values():
        # 关键词匹配
        keyword_match = not query or (
            query.lower() in product.name.lower() or 
            query.lower() in product.description.lower()
        )
        
        # 类别匹配
        category_match = not category or product.category == category
        
        # 价格匹配
        price_match = min_price <= product.price <= max_price
        
        # 库存匹配
        stock_match = not in_stock or product.stock > 0
        
        if keyword_match and category_match and price_match and stock_match:
            results.append(product)
    
    return results
 
@agent.tool
def get_product_details(ctx: RunContext[UserContext], product_id: str) -> Product:
    """
    获取产品详情
    """
    if product_id not in products_db:
        raise ModelRetry(f"产品 {product_id} 不存在")
    
    return products_db[product_id]
 
@agent.tool
def add_to_cart(ctx: RunContext[UserContext], product_id: str, quantity: int = 1) -> Dict:
    """
    添加商品到购物车
    """
    user_id = ctx.deps.user_id
    
    if product_id not in products_db:
        raise ModelRetry(f"产品 {product_id} 不存在")
    
    product = products_db[product_id]
    
    if product.stock < quantity:
        raise ModelRetry(f"产品 {product.name} 库存不足,当前库存: {product.stock}")
    
    # 初始化用户购物车
    if user_id not in user_carts:
        user_carts[user_id] = []
    
    # 检查是否已存在相同商品
    for item in user_carts[user_id]:
        if item.product_id == product_id:
            item.quantity += quantity
            break
    else:
        user_carts[user_id].append(CartItem(
            product_id=product_id,
            quantity=quantity,
            added_at=datetime.now().isoformat()
        ))
    
    return {
        "success": True,
        "message": f"已添加 {quantity} 件 {product.name} 到购物车",
        "cart_item_count": len(user_carts[user_id])
    }
 
@agent.tool
def view_cart(ctx: RunContext[UserContext]) -> Dict:
    """
    查看购物车
    """
    user_id = ctx.deps.user_id
    
    if user_id not in user_carts or not user_carts[user_id]:
        return {"empty": True, "message": "购物车为空"}
    
    cart_items = []
    total_amount = 0
    
    for item in user_carts[user_id]:
        product = products_db[item.product_id]
        subtotal = product.price * item.quantity
        cart_items.append({
            "product_id": item.product_id,
            "product_name": product.name,
            "quantity": item.quantity,
            "unit_price": product.price,
            "subtotal": subtotal
        })
        total_amount += subtotal
    
    # 会员折扣
    discount = 0
    if ctx.deps.membership_level == "vip":
        discount = total_amount * 0.1  # VIP 9折
    elif ctx.deps.membership_level == "premium":
        discount = total_amount * 0.15  # 高级会员 85折
    
    final_amount = total_amount - discount
    
    return {
        "cart_items": cart_items,
        "total_amount": total_amount,
        "discount": discount,
        "final_amount": final_amount,
        "item_count": len(cart_items),
        "membership_level": ctx.deps.membership_level
    }
 
@agent.tool
def remove_from_cart(ctx: RunContext[UserContext], product_id: str, quantity: int = None) -> Dict:
    """
    从购物车移除商品
    """
    user_id = ctx.deps.user_id
    
    if user_id not in user_carts:
        return {"success": False, "message": "购物车为空"}
    
    for i, item in enumerate(user_carts[user_id]):
        if item.product_id == product_id:
            if quantity is None or quantity >= item.quantity:
                # 移除整个商品
                removed_item = user_carts[user_id].pop(i)
                product = products_db[product_id]
                return {
                    "success": True,
                    "message": f"已从购物车移除 {product.name}",
                    "removed_quantity": removed_item.quantity
                }
            else:
                # 减少数量
                item.quantity -= quantity
                product = products_db[product_id]
                return {
                    "success": True, 
                    "message": f"已从购物车移除 {quantity} 件 {product.name}",
                    "remaining_quantity": item.quantity
                }
    
    return {"success": False, "message": "商品不在购物车中"}
 
@agent.tool
def checkout(ctx: RunContext[UserContext], shipping_address: str, payment_method: str) -> Dict:
    """
    结算订单
    """
    user_id = ctx.deps.user_id
    
    if user_id not in user_carts or not user_carts[user_id]:
        raise ModelRetry("购物车为空,无法结算")
    
    cart_result = view_cart(ctx)
    
    # 验证库存
    for item in user_carts[user_id]:
        product = products_db[item.product_id]
        if product.stock < item.quantity:
            raise ModelRetry(f"商品 {product.name} 库存不足,当前库存: {product.stock}")
    
    # 创建订单
    import random
    order_id = f"ORD{datetime.now().strftime('%Y%m%d')}{random.randint(1000, 9999)}"
    
    order = {
        "order_id": order_id,
        "user_id": user_id,
        "items": cart_result["cart_items"],
        "total_amount": cart_result["total_amount"],
        "discount": cart_result["discount"],
        "final_amount": cart_result["final_amount"],
        "shipping_address": shipping_address,
        "payment_method": payment_method,
        "status": "pending",
        "created_at": datetime.now().isoformat()
    }
    
    # 保存订单
    if user_id not in order_history:
        order_history[user_id] = []
    order_history[user_id].append(order)
    
    # 更新库存
    for item in user_carts[user_id]:
        products_db[item.product_id].stock -= item.quantity
    
    # 清空购物车
    user_carts[user_id] = []
    
    return {
        "success": True,
        "order_id": order_id,
        "message": "订单创建成功",
        "order_details": order
    }
 
@agent.tool
def get_order_history(ctx: RunContext[UserContext], limit: int = 10) -> List[Dict]:
    """
    获取订单历史
    """
    user_id = ctx.deps.user_id
    
    if user_id not in order_history:
        return []
    
    return order_history[user_id][-limit:]
 
@agent.tool
def track_order(ctx: RunContext[UserContext], order_id: str) -> Dict:
    """
    跟踪订单状态
    """
    user_id = ctx.deps.user_id
    
    if user_id not in order_history:
        return {"error": "没有订单历史"}
    
    for order in order_history[user_id]:
        if order["order_id"] == order_id:
            return order
    
    return {"error": "订单不存在"}
 
async def main():
    """
    演示完整的电商购物流程
    """
    # 创建用户上下文
    user_context = UserContext(
        user_id="user_001",
        name="张三",
        membership_level="vip",
        preferences={"language": "zh-CN", "currency": "CNY"}
    )
    
    print("=== 电商助手演示 ===\n")
    
    # 1. 搜索产品
    print("1. 搜索电子产品...")
    result1 = await agent.run("搜索电子产品", deps=user_context)
    print(f"响应: {result1.output}\n")
    
    # 2. 查看产品详情
    print("2. 查看智能手机详情...")
    result2 = await agent.run("查看产品P001的详情", deps=user_context)
    print(f"响应: {result2.output}\n")
    
    # 3. 添加到购物车
    print("3. 添加商品到购物车...")
    result3 = await agent.run("把智能手机添加到购物车,数量2", deps=user_context)
    print(f"响应: {result3.output}\n")
    
    # 4. 查看购物车
    print("4. 查看购物车...")
    result4 = await agent.run("显示我的购物车", deps=user_context)
    print(f"响应: {result4.output}\n")
    
    # 5. 结算订单
    print("5. 结算订单...")
    result5 = await agent.run(
        "结算订单,收货地址北京市海淀区,支付方式支付宝", 
        deps=user_context
    )
    print(f"响应: {result5.output}\n")
    
    # 6. 查看订单历史
    print("6. 查看订单历史...")
    result6 = await agent.run("显示我的订单历史", deps=user_context)
    print(f"响应: {result6.output}\n")
    
    # 显示完整的工具调用序列
    print("\n=== 完整执行轨迹 ===")
    for i, message in enumerate(result6.all_messages()):
        print(f"\n步骤 {i}:")
        if hasattr(message, 'parts'):
            for part in message.parts:
                part_type = type(part).__name__
                if hasattr(part, 'tool_name'):
                    print(f"  ????? 工具调用: {part.tool_name}")
                elif hasattr(part, 'content'):
                    content = str(part.content)
                    if len(content) > 100:
                        content = content[:100] + "..."
                    print(f"  ???? 内容: {content}")
 
if __name__ == "__main__":
    asyncio.run(main())
代码详细解释: 完整电商流程 :涵盖搜索、详情查看、购物车管理、结算、订单跟踪全流程 用户上下文管理 :使用 Pydantic 模型管理用户信息和偏好 库存管理 :实时检查库存并在下单时更新 会员系统 :根据会员等级提供不同的折扣 错误处理 :全面的错误检查和用户友好的错误信息 数据持久化 :模拟数据库操作,管理产品、购物车、订单数据 业务流程 :完整的业务逻辑,包括库存验证、价格计算、订单创建等 总结 Pydantic AI 的函数工具系统提供了一个强大且灵活的框架,用于扩展大型语言模型的能力。通过本文的详细讲解和完整示例,我们可以看到: 核心优势: 类型安全 :基于 Pydantic 的自动类型验证和模式生成 灵活注册 :多种工具注册方式,支持模块化和代码复用 上下文感知 :工具可以访问运行时的依赖项和上下文信息 动态控制 :基于运行时条件动态启用、禁用或调整工具 多模态支持 :支持返回文本、图像、文档等多种数据类型 错误处理 :完善的验证和自动重试机制 最佳实践: 清晰的文档 :为每个工具提供完整的参数描述和使用说明 充分的验证 :在工具内部实现业务逻辑验证 适当的错误信息 :提供具体、可操作的错误信息帮助模型重试 模块化设计 :将相关工具组织在一起,便于维护和复用 安全控制 :基于用户角色和权限动态控制工具访问 适用场景: 数据查询 :访问数据库、API 或文件系统 计算服务 :执行复杂的数学计算或业务逻辑 系统集成 :与外部系统和服务进行交互 工作流自动化 :执行多步骤的业务流程 实时信息获取 :获取最新的天气、股票、新闻等信息 通过合理使用函数工具,可以构建出功能强大、安全可靠、用户体验良好的 AI 应用,真正实现人工智能与现有系统的无缝集成。
二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:NTIC Ant TIC Participants Descriptions

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
jg-xs1
拉您进交流群
GMT+8, 2025-12-5 17:01