Implement Report Agent for automated report generation and interaction

- Introduced the Report Agent module to facilitate the automatic generation of simulation analysis reports using LangChain and Zep, following the ReACT model.
- Added functionality for report outline planning, segmented content generation, and user interaction through a dialogue interface.
- Implemented new API endpoints for report generation, status checking, and retrieval, enhancing the overall reporting capabilities.
- Updated README.md to include detailed instructions on the new report generation features and API usage.
- Enhanced the project structure to accommodate the new report management functionalities, including report storage and retrieval mechanisms.
This commit is contained in:
666ghj
2025-12-09 15:10:55 +08:00
parent 8d63f40b71
commit 5ece3f670b
8 changed files with 3445 additions and 9 deletions

View File

@@ -63,9 +63,10 @@ def create_app(config_class=Config):
return response
# 注册蓝图
from .api import graph_bp, simulation_bp
from .api import graph_bp, simulation_bp, report_bp
app.register_blueprint(graph_bp, url_prefix='/api/graph')
app.register_blueprint(simulation_bp, url_prefix='/api/simulation')
app.register_blueprint(report_bp, url_prefix='/api/report')
# 健康检查
@app.route('/health')

View File

@@ -6,7 +6,9 @@ from flask import Blueprint
graph_bp = Blueprint('graph', __name__)
simulation_bp = Blueprint('simulation', __name__)
report_bp = Blueprint('report', __name__)
from . import graph # noqa: E402, F401
from . import simulation # noqa: E402, F401
from . import report # noqa: E402, F401

829
backend/app/api/report.py Normal file
View File

@@ -0,0 +1,829 @@
"""
Report API路由
提供模拟报告生成、获取、对话等接口
"""
import os
import traceback
import threading
from flask import request, jsonify, send_file
from . import report_bp
from ..config import Config
from ..services.report_agent import ReportAgent, ReportManager, ReportStatus
from ..services.simulation_manager import SimulationManager
from ..models.project import ProjectManager
from ..models.task import TaskManager, TaskStatus
from ..utils.logger import get_logger
logger = get_logger('mirofish.api.report')
# ============== 报告生成接口 ==============
@report_bp.route('/generate', methods=['POST'])
def generate_report():
"""
生成模拟分析报告(异步任务)
这是一个耗时操作接口会立即返回task_id
使用 GET /api/report/generate/status 查询进度
请求JSON
{
"simulation_id": "sim_xxxx", // 必填模拟ID
"force_regenerate": false // 可选,强制重新生成
}
返回:
{
"success": true,
"data": {
"simulation_id": "sim_xxxx",
"task_id": "task_xxxx",
"status": "generating",
"message": "报告生成任务已启动"
}
}
"""
try:
data = request.get_json() or {}
simulation_id = data.get('simulation_id')
if not simulation_id:
return jsonify({
"success": False,
"error": "请提供 simulation_id"
}), 400
force_regenerate = data.get('force_regenerate', False)
# 获取模拟信息
manager = SimulationManager()
state = manager.get_simulation(simulation_id)
if not state:
return jsonify({
"success": False,
"error": f"模拟不存在: {simulation_id}"
}), 404
# 检查是否已有报告
if not force_regenerate:
existing_report = ReportManager.get_report_by_simulation(simulation_id)
if existing_report and existing_report.status == ReportStatus.COMPLETED:
return jsonify({
"success": True,
"data": {
"simulation_id": simulation_id,
"report_id": existing_report.report_id,
"status": "completed",
"message": "报告已存在",
"already_generated": True
}
})
# 获取项目信息
project = ProjectManager.get_project(state.project_id)
if not project:
return jsonify({
"success": False,
"error": f"项目不存在: {state.project_id}"
}), 404
graph_id = state.graph_id or project.graph_id
if not graph_id:
return jsonify({
"success": False,
"error": "缺少图谱ID请确保已构建图谱"
}), 400
simulation_requirement = project.simulation_requirement
if not simulation_requirement:
return jsonify({
"success": False,
"error": "缺少模拟需求描述"
}), 400
# 创建异步任务
task_manager = TaskManager()
task_id = task_manager.create_task(
task_type="report_generate",
metadata={
"simulation_id": simulation_id,
"graph_id": graph_id
}
)
# 定义后台任务
def run_generate():
try:
task_manager.update_task(
task_id,
status=TaskStatus.PROCESSING,
progress=0,
message="初始化Report Agent..."
)
# 创建Report Agent
agent = ReportAgent(
graph_id=graph_id,
simulation_id=simulation_id,
simulation_requirement=simulation_requirement
)
# 进度回调
def progress_callback(stage, progress, message):
task_manager.update_task(
task_id,
progress=progress,
message=f"[{stage}] {message}"
)
# 生成报告
report = agent.generate_report(progress_callback=progress_callback)
# 保存报告
ReportManager.save_report(report)
if report.status == ReportStatus.COMPLETED:
task_manager.complete_task(
task_id,
result={
"report_id": report.report_id,
"simulation_id": simulation_id,
"status": "completed"
}
)
else:
task_manager.fail_task(task_id, report.error or "报告生成失败")
except Exception as e:
logger.error(f"报告生成失败: {str(e)}")
task_manager.fail_task(task_id, str(e))
# 启动后台线程
thread = threading.Thread(target=run_generate, daemon=True)
thread.start()
return jsonify({
"success": True,
"data": {
"simulation_id": simulation_id,
"task_id": task_id,
"status": "generating",
"message": "报告生成任务已启动,请通过 /api/report/generate/status 查询进度",
"already_generated": False
}
})
except Exception as e:
logger.error(f"启动报告生成任务失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
@report_bp.route('/generate/status', methods=['POST'])
def get_generate_status():
"""
查询报告生成任务进度
请求JSON
{
"task_id": "task_xxxx", // 可选generate返回的task_id
"simulation_id": "sim_xxxx" // 可选模拟ID
}
返回:
{
"success": true,
"data": {
"task_id": "task_xxxx",
"status": "processing|completed|failed",
"progress": 45,
"message": "..."
}
}
"""
try:
data = request.get_json() or {}
task_id = data.get('task_id')
simulation_id = data.get('simulation_id')
# 如果提供了simulation_id先检查是否已有完成的报告
if simulation_id:
existing_report = ReportManager.get_report_by_simulation(simulation_id)
if existing_report and existing_report.status == ReportStatus.COMPLETED:
return jsonify({
"success": True,
"data": {
"simulation_id": simulation_id,
"report_id": existing_report.report_id,
"status": "completed",
"progress": 100,
"message": "报告已生成",
"already_completed": True
}
})
if not task_id:
return jsonify({
"success": False,
"error": "请提供 task_id 或 simulation_id"
}), 400
task_manager = TaskManager()
task = task_manager.get_task(task_id)
if not task:
return jsonify({
"success": False,
"error": f"任务不存在: {task_id}"
}), 404
return jsonify({
"success": True,
"data": task.to_dict()
})
except Exception as e:
logger.error(f"查询任务状态失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e)
}), 500
# ============== 报告获取接口 ==============
@report_bp.route('/<report_id>', methods=['GET'])
def get_report(report_id: str):
"""
获取报告详情
返回:
{
"success": true,
"data": {
"report_id": "report_xxxx",
"simulation_id": "sim_xxxx",
"status": "completed",
"outline": {...},
"markdown_content": "...",
"created_at": "...",
"completed_at": "..."
}
}
"""
try:
report = ReportManager.get_report(report_id)
if not report:
return jsonify({
"success": False,
"error": f"报告不存在: {report_id}"
}), 404
return jsonify({
"success": True,
"data": report.to_dict()
})
except Exception as e:
logger.error(f"获取报告失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
@report_bp.route('/by-simulation/<simulation_id>', methods=['GET'])
def get_report_by_simulation(simulation_id: str):
"""
根据模拟ID获取报告
返回:
{
"success": true,
"data": {
"report_id": "report_xxxx",
...
}
}
"""
try:
report = ReportManager.get_report_by_simulation(simulation_id)
if not report:
return jsonify({
"success": False,
"error": f"该模拟暂无报告: {simulation_id}",
"has_report": False
}), 404
return jsonify({
"success": True,
"data": report.to_dict(),
"has_report": True
})
except Exception as e:
logger.error(f"获取报告失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
@report_bp.route('/list', methods=['GET'])
def list_reports():
"""
列出所有报告
Query参数
simulation_id: 按模拟ID过滤可选
limit: 返回数量限制默认50
返回:
{
"success": true,
"data": [...],
"count": 10
}
"""
try:
simulation_id = request.args.get('simulation_id')
limit = request.args.get('limit', 50, type=int)
reports = ReportManager.list_reports(
simulation_id=simulation_id,
limit=limit
)
return jsonify({
"success": True,
"data": [r.to_dict() for r in reports],
"count": len(reports)
})
except Exception as e:
logger.error(f"列出报告失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
@report_bp.route('/<report_id>/download', methods=['GET'])
def download_report(report_id: str):
"""
下载报告Markdown格式
返回Markdown文件
"""
try:
report = ReportManager.get_report(report_id)
if not report:
return jsonify({
"success": False,
"error": f"报告不存在: {report_id}"
}), 404
md_path = ReportManager._get_report_markdown_path(report_id)
if not os.path.exists(md_path):
# 如果MD文件不存在生成一个临时文件
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f:
f.write(report.markdown_content)
temp_path = f.name
return send_file(
temp_path,
as_attachment=True,
download_name=f"{report_id}.md"
)
return send_file(
md_path,
as_attachment=True,
download_name=f"{report_id}.md"
)
except Exception as e:
logger.error(f"下载报告失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
@report_bp.route('/<report_id>', methods=['DELETE'])
def delete_report(report_id: str):
"""删除报告"""
try:
success = ReportManager.delete_report(report_id)
if not success:
return jsonify({
"success": False,
"error": f"报告不存在: {report_id}"
}), 404
return jsonify({
"success": True,
"message": f"报告已删除: {report_id}"
})
except Exception as e:
logger.error(f"删除报告失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
# ============== Report Agent对话接口 ==============
@report_bp.route('/chat', methods=['POST'])
def chat_with_report_agent():
"""
与Report Agent对话
Report Agent可以在对话中自主调用检索工具来回答问题
请求JSON
{
"simulation_id": "sim_xxxx", // 必填模拟ID
"message": "请解释一下舆情走向", // 必填,用户消息
"chat_history": [ // 可选,对话历史
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
返回:
{
"success": true,
"data": {
"response": "Agent回复...",
"tool_calls": [调用的工具列表],
"sources": [信息来源]
}
}
"""
try:
data = request.get_json() or {}
simulation_id = data.get('simulation_id')
message = data.get('message')
chat_history = data.get('chat_history', [])
if not simulation_id:
return jsonify({
"success": False,
"error": "请提供 simulation_id"
}), 400
if not message:
return jsonify({
"success": False,
"error": "请提供 message"
}), 400
# 获取模拟和项目信息
manager = SimulationManager()
state = manager.get_simulation(simulation_id)
if not state:
return jsonify({
"success": False,
"error": f"模拟不存在: {simulation_id}"
}), 404
project = ProjectManager.get_project(state.project_id)
if not project:
return jsonify({
"success": False,
"error": f"项目不存在: {state.project_id}"
}), 404
graph_id = state.graph_id or project.graph_id
if not graph_id:
return jsonify({
"success": False,
"error": "缺少图谱ID"
}), 400
simulation_requirement = project.simulation_requirement or ""
# 创建Agent并进行对话
agent = ReportAgent(
graph_id=graph_id,
simulation_id=simulation_id,
simulation_requirement=simulation_requirement
)
result = agent.chat(message=message, chat_history=chat_history)
return jsonify({
"success": True,
"data": result
})
except Exception as e:
logger.error(f"对话失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
# ============== 报告进度与分章节接口 ==============
@report_bp.route('/<report_id>/progress', methods=['GET'])
def get_report_progress(report_id: str):
"""
获取报告生成进度(实时)
返回:
{
"success": true,
"data": {
"status": "generating",
"progress": 45,
"message": "正在生成章节: 关键发现",
"current_section": "关键发现",
"completed_sections": ["执行摘要", "模拟背景"],
"updated_at": "2025-12-09T..."
}
}
"""
try:
progress = ReportManager.get_progress(report_id)
if not progress:
return jsonify({
"success": False,
"error": f"报告不存在或进度信息不可用: {report_id}"
}), 404
return jsonify({
"success": True,
"data": progress
})
except Exception as e:
logger.error(f"获取报告进度失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
@report_bp.route('/<report_id>/sections', methods=['GET'])
def get_report_sections(report_id: str):
"""
获取已生成的章节列表(分章节输出)
前端可以轮询此接口获取已生成的章节内容,无需等待整个报告完成
返回:
{
"success": true,
"data": {
"report_id": "report_xxxx",
"sections": [
{
"filename": "section_01.md",
"section_index": 1,
"content": "## 执行摘要\\n\\n..."
},
...
],
"total_sections": 3,
"is_complete": false
}
}
"""
try:
sections = ReportManager.get_generated_sections(report_id)
# 获取报告状态
report = ReportManager.get_report(report_id)
is_complete = report is not None and report.status == ReportStatus.COMPLETED
return jsonify({
"success": True,
"data": {
"report_id": report_id,
"sections": sections,
"total_sections": len(sections),
"is_complete": is_complete
}
})
except Exception as e:
logger.error(f"获取章节列表失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
@report_bp.route('/<report_id>/section/<int:section_index>', methods=['GET'])
def get_single_section(report_id: str, section_index: int):
"""
获取单个章节内容
返回:
{
"success": true,
"data": {
"filename": "section_01.md",
"content": "## 执行摘要\\n\\n..."
}
}
"""
try:
section_path = ReportManager._get_section_path(report_id, section_index)
if not os.path.exists(section_path):
return jsonify({
"success": False,
"error": f"章节不存在: section_{section_index:02d}.md"
}), 404
with open(section_path, 'r', encoding='utf-8') as f:
content = f.read()
return jsonify({
"success": True,
"data": {
"filename": f"section_{section_index:02d}.md",
"section_index": section_index,
"content": content
}
})
except Exception as e:
logger.error(f"获取章节内容失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
# ============== 报告状态检查接口 ==============
@report_bp.route('/check/<simulation_id>', methods=['GET'])
def check_report_status(simulation_id: str):
"""
检查模拟是否有报告,以及报告状态
用于前端判断是否解锁Interview功能
返回:
{
"success": true,
"data": {
"simulation_id": "sim_xxxx",
"has_report": true,
"report_status": "completed",
"report_id": "report_xxxx",
"interview_unlocked": true
}
}
"""
try:
report = ReportManager.get_report_by_simulation(simulation_id)
has_report = report is not None
report_status = report.status.value if report else None
report_id = report.report_id if report else None
# 只有报告完成后才解锁interview
interview_unlocked = has_report and report.status == ReportStatus.COMPLETED
return jsonify({
"success": True,
"data": {
"simulation_id": simulation_id,
"has_report": has_report,
"report_status": report_status,
"report_id": report_id,
"interview_unlocked": interview_unlocked
}
})
except Exception as e:
logger.error(f"检查报告状态失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
# ============== 工具调用接口(供调试使用)==============
@report_bp.route('/tools/search', methods=['POST'])
def search_graph_tool():
"""
图谱搜索工具接口(供调试使用)
请求JSON
{
"graph_id": "mirofish_xxxx",
"query": "搜索查询",
"limit": 10
}
"""
try:
data = request.get_json() or {}
graph_id = data.get('graph_id')
query = data.get('query')
limit = data.get('limit', 10)
if not graph_id or not query:
return jsonify({
"success": False,
"error": "请提供 graph_id 和 query"
}), 400
from ..services.zep_tools import ZepToolsService
tools = ZepToolsService()
result = tools.search_graph(
graph_id=graph_id,
query=query,
limit=limit
)
return jsonify({
"success": True,
"data": result.to_dict()
})
except Exception as e:
logger.error(f"图谱搜索失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
@report_bp.route('/tools/statistics', methods=['POST'])
def get_graph_statistics_tool():
"""
图谱统计工具接口(供调试使用)
请求JSON
{
"graph_id": "mirofish_xxxx"
}
"""
try:
data = request.get_json() or {}
graph_id = data.get('graph_id')
if not graph_id:
return jsonify({
"success": False,
"error": "请提供 graph_id"
}), 400
from ..services.zep_tools import ZepToolsService
tools = ZepToolsService()
result = tools.get_graph_statistics(graph_id)
return jsonify({
"success": True,
"data": result
})
except Exception as e:
logger.error(f"获取图谱统计失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500

View File

@@ -58,6 +58,11 @@ class Config:
'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE'
]
# Report Agent配置
REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5'))
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))
@classmethod
def validate(cls):
"""验证必要配置"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,621 @@
"""
Zep检索工具服务
封装图谱搜索、节点读取、边查询等工具供Report Agent使用
"""
import time
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from zep_cloud.client import Zep
from ..config import Config
from ..utils.logger import get_logger
logger = get_logger('mirofish.zep_tools')
@dataclass
class SearchResult:
"""搜索结果"""
facts: List[str]
edges: List[Dict[str, Any]]
nodes: List[Dict[str, Any]]
query: str
total_count: int
def to_dict(self) -> Dict[str, Any]:
return {
"facts": self.facts,
"edges": self.edges,
"nodes": self.nodes,
"query": self.query,
"total_count": self.total_count
}
def to_text(self) -> str:
"""转换为文本格式供LLM理解"""
text_parts = [f"搜索查询: {self.query}", f"找到 {self.total_count} 条相关信息"]
if self.facts:
text_parts.append("\n### 相关事实:")
for i, fact in enumerate(self.facts, 1):
text_parts.append(f"{i}. {fact}")
return "\n".join(text_parts)
@dataclass
class NodeInfo:
"""节点信息"""
uuid: str
name: str
labels: List[str]
summary: str
attributes: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
"name": self.name,
"labels": self.labels,
"summary": self.summary,
"attributes": self.attributes
}
def to_text(self) -> str:
"""转换为文本格式"""
entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "未知类型")
return f"实体: {self.name} (类型: {entity_type})\n摘要: {self.summary}"
@dataclass
class EdgeInfo:
"""边信息"""
uuid: str
name: str
fact: str
source_node_uuid: str
target_node_uuid: str
source_node_name: Optional[str] = None
target_node_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
"name": self.name,
"fact": self.fact,
"source_node_uuid": self.source_node_uuid,
"target_node_uuid": self.target_node_uuid,
"source_node_name": self.source_node_name,
"target_node_name": self.target_node_name
}
def to_text(self) -> str:
"""转换为文本格式"""
source = self.source_node_name or self.source_node_uuid[:8]
target = self.target_node_name or self.target_node_uuid[:8]
return f"关系: {source} --[{self.name}]--> {target}\n事实: {self.fact}"
class ZepToolsService:
"""
Zep检索工具服务
提供多种图谱检索工具可以被Report Agent调用
1. search_graph - 图谱语义搜索
2. get_all_nodes - 获取图谱所有节点
3. get_all_edges - 获取图谱所有边
4. get_node_detail - 获取节点详细信息
5. get_node_edges - 获取节点相关的边
6. get_entities_by_type - 按类型获取实体
7. get_entity_summary - 获取实体的关系摘要
"""
# 重试配置
MAX_RETRIES = 3
RETRY_DELAY = 2.0
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置")
self.client = Zep(api_key=self.api_key)
logger.info("ZepToolsService 初始化完成")
def _call_with_retry(self, func, operation_name: str, max_retries: int = None):
"""带重试机制的API调用"""
max_retries = max_retries or self.MAX_RETRIES
last_exception = None
delay = self.RETRY_DELAY
for attempt in range(max_retries):
try:
return func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.warning(
f"Zep {operation_name}{attempt + 1} 次尝试失败: {str(e)[:100]}, "
f"{delay:.1f}秒后重试..."
)
time.sleep(delay)
delay *= 2
else:
logger.error(f"Zep {operation_name}{max_retries} 次尝试后仍失败: {str(e)}")
raise last_exception
def search_graph(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges"
) -> SearchResult:
"""
图谱语义搜索
使用混合搜索(语义+BM25在图谱中搜索相关信息。
如果Zep Cloud的search API不可用则降级为本地关键词匹配。
Args:
graph_id: 图谱ID (Standalone Graph)
query: 搜索查询
limit: 返回结果数量
scope: 搜索范围,"edges""nodes"
Returns:
SearchResult: 搜索结果
"""
logger.info(f"图谱搜索: graph_id={graph_id}, query={query[:50]}...")
# 尝试使用Zep Cloud Search API
try:
search_results = self._call_with_retry(
func=lambda: self.client.graph.search(
graph_id=graph_id,
query=query,
limit=limit,
scope=scope,
reranker="cross_encoder"
),
operation_name=f"图谱搜索(graph={graph_id})"
)
facts = []
edges = []
nodes = []
# 解析边搜索结果
if hasattr(search_results, 'edges') and search_results.edges:
for edge in search_results.edges:
if hasattr(edge, 'fact') and edge.fact:
facts.append(edge.fact)
edges.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": getattr(edge, 'name', ''),
"fact": getattr(edge, 'fact', ''),
"source_node_uuid": getattr(edge, 'source_node_uuid', ''),
"target_node_uuid": getattr(edge, 'target_node_uuid', ''),
})
# 解析节点搜索结果
if hasattr(search_results, 'nodes') and search_results.nodes:
for node in search_results.nodes:
nodes.append({
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
"name": getattr(node, 'name', ''),
"labels": getattr(node, 'labels', []),
"summary": getattr(node, 'summary', ''),
})
# 节点摘要也算作事实
if hasattr(node, 'summary') and node.summary:
facts.append(f"[{node.name}]: {node.summary}")
logger.info(f"搜索完成: 找到 {len(facts)} 条相关事实")
return SearchResult(
facts=facts,
edges=edges,
nodes=nodes,
query=query,
total_count=len(facts)
)
except Exception as e:
logger.warning(f"Zep Search API失败降级为本地搜索: {str(e)}")
# 降级:使用本地关键词匹配搜索
return self._local_search(graph_id, query, limit, scope)
def _local_search(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges"
) -> SearchResult:
"""
本地关键词匹配搜索作为Zep Search API的降级方案
获取所有边/节点,然后在本地进行关键词匹配
Args:
graph_id: 图谱ID
query: 搜索查询
limit: 返回结果数量
scope: 搜索范围
Returns:
SearchResult: 搜索结果
"""
logger.info(f"使用本地搜索: query={query[:30]}...")
facts = []
edges_result = []
nodes_result = []
# 提取查询关键词(简单分词)
query_lower = query.lower()
keywords = [w.strip() for w in query_lower.replace(',', ' ').replace('', ' ').split() if len(w.strip()) > 1]
def match_score(text: str) -> int:
"""计算文本与查询的匹配分数"""
if not text:
return 0
text_lower = text.lower()
# 完全匹配查询
if query_lower in text_lower:
return 100
# 关键词匹配
score = 0
for keyword in keywords:
if keyword in text_lower:
score += 10
return score
try:
if scope in ["edges", "both"]:
# 获取所有边并匹配
all_edges = self.get_all_edges(graph_id)
scored_edges = []
for edge in all_edges:
score = match_score(edge.fact) + match_score(edge.name)
if score > 0:
scored_edges.append((score, edge))
# 按分数排序
scored_edges.sort(key=lambda x: x[0], reverse=True)
for score, edge in scored_edges[:limit]:
if edge.fact:
facts.append(edge.fact)
edges_result.append({
"uuid": edge.uuid,
"name": edge.name,
"fact": edge.fact,
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
})
if scope in ["nodes", "both"]:
# 获取所有节点并匹配
all_nodes = self.get_all_nodes(graph_id)
scored_nodes = []
for node in all_nodes:
score = match_score(node.name) + match_score(node.summary)
if score > 0:
scored_nodes.append((score, node))
scored_nodes.sort(key=lambda x: x[0], reverse=True)
for score, node in scored_nodes[:limit]:
nodes_result.append({
"uuid": node.uuid,
"name": node.name,
"labels": node.labels,
"summary": node.summary,
})
if node.summary:
facts.append(f"[{node.name}]: {node.summary}")
logger.info(f"本地搜索完成: 找到 {len(facts)} 条相关事实")
except Exception as e:
logger.error(f"本地搜索失败: {str(e)}")
return SearchResult(
facts=facts,
edges=edges_result,
nodes=nodes_result,
query=query,
total_count=len(facts)
)
def get_all_nodes(self, graph_id: str) -> List[NodeInfo]:
"""
获取图谱的所有节点
Args:
graph_id: 图谱ID
Returns:
节点列表
"""
logger.info(f"获取图谱 {graph_id} 的所有节点...")
nodes = self._call_with_retry(
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取节点(graph={graph_id})"
)
result = []
for node in nodes:
result.append(NodeInfo(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {}
))
logger.info(f"获取到 {len(result)} 个节点")
return result
def get_all_edges(self, graph_id: str) -> List[EdgeInfo]:
"""
获取图谱的所有边
Args:
graph_id: 图谱ID
Returns:
边列表
"""
logger.info(f"获取图谱 {graph_id} 的所有边...")
edges = self._call_with_retry(
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取边(graph={graph_id})"
)
result = []
for edge in edges:
result.append(EdgeInfo(
uuid=getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
name=edge.name or "",
fact=edge.fact or "",
source_node_uuid=edge.source_node_uuid or "",
target_node_uuid=edge.target_node_uuid or ""
))
logger.info(f"获取到 {len(result)} 条边")
return result
def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]:
"""
获取单个节点的详细信息
Args:
node_uuid: 节点UUID
Returns:
节点信息或None
"""
logger.info(f"获取节点详情: {node_uuid[:8]}...")
try:
node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=node_uuid),
operation_name=f"获取节点详情(uuid={node_uuid[:8]}...)"
)
if not node:
return None
return NodeInfo(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {}
)
except Exception as e:
logger.error(f"获取节点详情失败: {str(e)}")
return None
def get_node_edges(self, graph_id: str, node_uuid: str) -> List[EdgeInfo]:
"""
获取节点相关的所有边
通过获取图谱所有边,然后过滤出与指定节点相关的边
Args:
graph_id: 图谱ID
node_uuid: 节点UUID
Returns:
边列表
"""
logger.info(f"获取节点 {node_uuid[:8]}... 的相关边")
try:
# 获取图谱所有边,然后过滤
all_edges = self.get_all_edges(graph_id)
result = []
for edge in all_edges:
# 检查边是否与指定节点相关(作为源或目标)
if edge.source_node_uuid == node_uuid or edge.target_node_uuid == node_uuid:
result.append(edge)
logger.info(f"找到 {len(result)} 条与节点相关的边")
return result
except Exception as e:
logger.warning(f"获取节点边失败: {str(e)}")
return []
def get_entities_by_type(
self,
graph_id: str,
entity_type: str
) -> List[NodeInfo]:
"""
按类型获取实体
Args:
graph_id: 图谱ID
entity_type: 实体类型(如 Student, PublicFigure 等)
Returns:
符合类型的实体列表
"""
logger.info(f"获取类型为 {entity_type} 的实体...")
all_nodes = self.get_all_nodes(graph_id)
filtered = []
for node in all_nodes:
# 检查labels是否包含指定类型
if entity_type in node.labels:
filtered.append(node)
logger.info(f"找到 {len(filtered)}{entity_type} 类型的实体")
return filtered
def get_entity_summary(
self,
graph_id: str,
entity_name: str
) -> Dict[str, Any]:
"""
获取指定实体的关系摘要
搜索与该实体相关的所有信息,并生成摘要
Args:
graph_id: 图谱ID
entity_name: 实体名称
Returns:
实体摘要信息
"""
logger.info(f"获取实体 {entity_name} 的关系摘要...")
# 先搜索该实体相关的信息
search_result = self.search_graph(
graph_id=graph_id,
query=entity_name,
limit=20
)
# 尝试在所有节点中找到该实体
all_nodes = self.get_all_nodes(graph_id)
entity_node = None
for node in all_nodes:
if node.name.lower() == entity_name.lower():
entity_node = node
break
related_edges = []
if entity_node:
# 传入graph_id参数
related_edges = self.get_node_edges(graph_id, entity_node.uuid)
return {
"entity_name": entity_name,
"entity_info": entity_node.to_dict() if entity_node else None,
"related_facts": search_result.facts,
"related_edges": [e.to_dict() for e in related_edges],
"total_relations": len(related_edges)
}
def get_graph_statistics(self, graph_id: str) -> Dict[str, Any]:
"""
获取图谱的统计信息
Args:
graph_id: 图谱ID
Returns:
统计信息
"""
logger.info(f"获取图谱 {graph_id} 的统计信息...")
nodes = self.get_all_nodes(graph_id)
edges = self.get_all_edges(graph_id)
# 统计实体类型分布
entity_types = {}
for node in nodes:
for label in node.labels:
if label not in ["Entity", "Node"]:
entity_types[label] = entity_types.get(label, 0) + 1
# 统计关系类型分布
relation_types = {}
for edge in edges:
relation_types[edge.name] = relation_types.get(edge.name, 0) + 1
return {
"graph_id": graph_id,
"total_nodes": len(nodes),
"total_edges": len(edges),
"entity_types": entity_types,
"relation_types": relation_types
}
def get_simulation_context(
self,
graph_id: str,
simulation_requirement: str,
limit: int = 30
) -> Dict[str, Any]:
"""
获取模拟相关的上下文信息
综合搜索与模拟需求相关的所有信息
Args:
graph_id: 图谱ID
simulation_requirement: 模拟需求描述
limit: 每类信息的数量限制
Returns:
模拟上下文信息
"""
logger.info(f"获取模拟上下文: {simulation_requirement[:50]}...")
# 搜索与模拟需求相关的信息
search_result = self.search_graph(
graph_id=graph_id,
query=simulation_requirement,
limit=limit
)
# 获取图谱统计
stats = self.get_graph_statistics(graph_id)
# 获取所有实体节点
all_nodes = self.get_all_nodes(graph_id)
# 筛选有实际类型的实体非纯Entity节点
entities = []
for node in all_nodes:
custom_labels = [l for l in node.labels if l not in ["Entity", "Node"]]
if custom_labels:
entities.append({
"name": node.name,
"type": custom_labels[0],
"summary": node.summary
})
return {
"simulation_requirement": simulation_requirement,
"related_facts": search_result.facts,
"graph_statistics": stats,
"entities": entities[:limit], # 限制数量
"total_entities": len(entities)
}