基于 LlamaIndex Workflow 与 Pydantic 构建大模型结构化输出与自动重试机制
初始化大语言模型
在构建结构化输出工作流之前,首先需要初始化大语言模型(LLM)。对于参数量较小的模型(如 0.5B 级别),生成合法 JSON 的稳定性可能较差,通常需要增加最大重试次数或调整采样参数。以下示例使用 Qwen2.5-7B 模型进行演示,以确保更高的输出成功率。
from llama_index.llms.huggingface import HuggingFaceLLM
# 初始化本地大语言模型客户端
llm_client = HuggingFaceLLM(
model_name="Qwen/Qwen2.5-7B-Instruct",
tokenizer_name="Qwen/Qwen2.5-7B-Instruct",
context_window=32768,
max_new_tokens=2048,
generate_kwargs={"temperature": 0.7, "top_k": 50, "top_p": 0.95},
device_map="auto",
)
构建自动化校验工作流
核心设计思路
为了确保 LLM 能够稳定输出符合预期的结构化数据,工作流被设计为包含两个核心节点的闭环系统:
- 生成节点:调用 LLM 生成 JSON 格式的结构化数据。
- 校验节点:使用 Pydantic 验证生成的 JSON 是否符合预定义的数据模型。
如果校验失败,工作流不会直接终止,而是将错误信息和无效输出作为反馈传递给生成节点,触发重试机制,直到输出合法或达到最大重试次数。
定义工作流事件
在 LlamaIndex 的 Workflow 中,节点之间的数据传递依赖于事件(Event)。我们需要定义两个自定义事件:一个用于传递生成结果,另一个用于传递校验失败的反馈。
from llama_index.core.workflow import Event
class GenerationCompleteEvent(Event):
generated_text: str
source_text: str
class ValidationFailedEvent(Event):
error_message: str
invalid_text: str
source_text: str
定义 Pydantic 数据模型
使用 Pydantic 定义目标数据结构。这里以提取文本中的省市地理信息为例,构建嵌套的数据模型。
from pydantic import BaseModel, Field
from typing import List
class RegionInfo(BaseModel):
province: str = Field(description="省份名称")
city: str = Field(description="城市名称")
class ExtractedLocations(BaseModel):
regions: List[RegionInfo]
实现工作流步骤
接下来定义工作流类及其执行步骤。通过 @step 装饰器将异步函数注册为工作流节点,并利用上下文(Context)管理重试状态。
import json
from llama_index.core.workflow import (
Workflow,
StartEvent,
StopEvent,
Context,
step,
)
EXTRACTION_SYSTEM_PROMPT = """
请根据以下提供的上下文信息提取数据,并严格生成符合指定 JSON Schema 的对象。
上下文:
---------------------
{context}
---------------------
目标 JSON Schema:
{schema}
"""
CORRECTION_PROMPT = """
你上一次生成的 JSON 如下:
---------------------
{invalid_json}
---------------------
该输出导致了以下解析错误:{error_msg}
请修正错误并重新生成。注意:响应中只能包含纯 JSON 字符串,不要包含任何 Markdown 标记或额外说明。
"""
class StructuredOutputWorkflow(Workflow):
max_attempts: int = 5
@step
async def generate_json(
self, ctx: Context, ev: StartEvent | ValidationFailedEvent
) -> GenerationCompleteEvent | StopEvent:
# 检查是否超出最大重试次数
attempts = await ctx.get("attempts", default=0)
if attempts >= self.max_attempts:
return StopEvent(result="达到最大重试次数,提取失败。")
await ctx.set("attempts", attempts + 1)
# 解析输入事件
if isinstance(ev, StartEvent):
context_text = ev.get("context")
if not context_text:
return StopEvent(result="未提供上下文文本。")
correction_instruction = ""
else:
context_text = ev.source_text
correction_instruction = CORRECTION_PROMPT.format(
invalid_json=ev.invalid_text, error_msg=ev.error_message
)
# 构建提示词
prompt = EXTRACTION_SYSTEM_PROMPT.format(
context=context_text, schema=json.dumps(ExtractedLocations.model_json_schema())
)
if correction_instruction:
prompt += "\n\n" + correction_instruction
# 调用大模型
response = await llm_client.acomplete(prompt)
return GenerationCompleteEvent(
generated_text=str(response), source_text=context_text
)
@step
async def validate_json(
self, ev: GenerationCompleteEvent
) -> StopEvent | ValidationFailedEvent:
try:
# 使用 Pydantic 进行严格校验
ExtractedLocations.model_validate_json(ev.generated_text)
return StopEvent(result=ev.generated_text)
except Exception as validation_error:
print(f"校验失败,准备重试: {validation_error}")
return ValidationFailedEvent(
error_message=str(validation_error),
invalid_text=ev.generated_text,
source_text=ev.source_text
)
工作流拓扑可视化
LlamaIndex 提供了内置工具,可以将工作流的执行路径渲染为可视化图表,便于调试和理解节点间的流转逻辑。
from llama_index.utils.workflow import draw_all_possible_flows
draw_all_possible_flows(
StructuredOutputWorkflow, filename="structured_output_workflow.html"
)
执行工作流与结果验证
最后,实例化工作流并传入包含地理信息的自然语言文本。通过异步运行,观察模型在遇到格式错误时如何利用反馈机制自我修正。
import asyncio
async def main():
workflow = StructuredOutputWorkflow(timeout=120, verbose=True)
# 执行工作流
result = await workflow.run(
context="我目前在四川省成都市读大学,我的家乡在湖南省岳阳市。"
)
print("最终提取结果:\n", result)
asyncio.run(main())