Refactor simulation management and enhance logging capabilities
- Updated simulation preparation checks to exclude script files from the required files list, improving clarity on file management. - Implemented a robust retry mechanism for Zep API calls in the ZepEntityReader service, enhancing reliability. - Enhanced logging in simulation scripts to provide clearer insights into the simulation process and errors. - Updated simulation runner to manage stdout and stderr logs more effectively, ensuring better error tracking. - Improved profile generation to standardize gender fields and ensure all required fields are populated correctly.
This commit is contained in:
@@ -10,6 +10,7 @@ OASIS Agent Profile生成器
|
||||
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
@@ -315,32 +316,54 @@ class OasisProfileGenerator:
|
||||
comprehensive_query = f"关于{entity_name}的所有信息、活动、事件、关系和背景"
|
||||
|
||||
def search_edges():
|
||||
"""搜索边(事实/关系)"""
|
||||
try:
|
||||
return self.zep_client.graph.search(
|
||||
query=comprehensive_query,
|
||||
graph_id=self.graph_id,
|
||||
limit=30,
|
||||
scope="edges",
|
||||
reranker="rrf"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Zep边搜索失败: {e}")
|
||||
return None
|
||||
"""搜索边(事实/关系)- 带重试机制"""
|
||||
max_retries = 3
|
||||
last_exception = None
|
||||
delay = 2.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return self.zep_client.graph.search(
|
||||
query=comprehensive_query,
|
||||
graph_id=self.graph_id,
|
||||
limit=30,
|
||||
scope="edges",
|
||||
reranker="rrf"
|
||||
)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
else:
|
||||
logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}")
|
||||
return None
|
||||
|
||||
def search_nodes():
|
||||
"""搜索节点(实体摘要)"""
|
||||
try:
|
||||
return self.zep_client.graph.search(
|
||||
query=comprehensive_query,
|
||||
graph_id=self.graph_id,
|
||||
limit=20,
|
||||
scope="nodes",
|
||||
reranker="rrf"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Zep节点搜索失败: {e}")
|
||||
return None
|
||||
"""搜索节点(实体摘要)- 带重试机制"""
|
||||
max_retries = 3
|
||||
last_exception = None
|
||||
delay = 2.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return self.zep_client.graph.search(
|
||||
query=comprehensive_query,
|
||||
graph_id=self.graph_id,
|
||||
limit=20,
|
||||
scope="nodes",
|
||||
reranker="rrf"
|
||||
)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
else:
|
||||
logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 并行执行edges和nodes搜索
|
||||
@@ -684,18 +707,20 @@ class OasisProfileGenerator:
|
||||
- 立场观点(对话题的态度、可能被激怒/感动的内容)
|
||||
- 独特特征(口头禅、特殊经历、个人爱好)
|
||||
- 个人记忆(人设的重要部分,要介绍这个个体与事件的关联,以及这个个体在事件中的已有动作与反应)
|
||||
3. age: 年龄数字
|
||||
4. gender: 性别(男/女)
|
||||
5. mbti: MBTI类型
|
||||
6. country: 国家
|
||||
3. age: 年龄数字(必须是整数)
|
||||
4. gender: 性别,必须是英文: "male" 或 "female"
|
||||
5. mbti: MBTI类型(如INTJ、ENFP等)
|
||||
6. country: 国家(使用中文,如"中国")
|
||||
7. profession: 职业
|
||||
8. interested_topics: 感兴趣话题数组
|
||||
|
||||
重要:
|
||||
- 所有字段值必须是字符串或数字,不要使用换行符
|
||||
- persona必须是一段连贯的文字描述
|
||||
- 使用中文
|
||||
- 内容要与实体信息保持一致"""
|
||||
- 使用中文(除了gender字段必须用英文male/female)
|
||||
- 内容要与实体信息保持一致
|
||||
- age必须是有效的整数,gender必须是"male"或"female"
|
||||
"""
|
||||
|
||||
def _build_group_persona_prompt(
|
||||
self,
|
||||
@@ -731,17 +756,18 @@ class OasisProfileGenerator:
|
||||
- 立场态度(对核心话题的官方立场、面对争议的处理方式)
|
||||
- 特殊说明(代表的群体画像、运营习惯)
|
||||
- 机构记忆(机构人设的重要部分,要介绍这个机构与事件的关联,以及这个机构在事件中的已有动作与反应)
|
||||
3. age: null(机构不适用)
|
||||
4. gender: null(机构不适用)
|
||||
5. mbti: 可选,用于描述账号风格,如ISTJ代表严谨保守
|
||||
6. country: 国家
|
||||
3. age: 固定填30(机构账号的虚拟年龄)
|
||||
4. gender: 固定填"other"(机构账号使用other表示非个人)
|
||||
5. mbti: MBTI类型,用于描述账号风格,如ISTJ代表严谨保守
|
||||
6. country: 国家(使用中文,如"中国")
|
||||
7. profession: 机构职能描述
|
||||
8. interested_topics: 关注领域数组
|
||||
|
||||
重要:
|
||||
- 所有字段值必须是字符串、数字或null
|
||||
- 所有字段值必须是字符串或数字,不允许null值
|
||||
- persona必须是一段连贯的文字描述,不要使用换行符
|
||||
- 使用中文
|
||||
- 使用中文(除了gender字段必须用英文"other")
|
||||
- age必须是整数30,gender必须是字符串"other"
|
||||
- 机构账号发言要符合其身份定位"""
|
||||
|
||||
def _generate_profile_rule_based(
|
||||
@@ -784,6 +810,10 @@ class OasisProfileGenerator:
|
||||
return {
|
||||
"bio": f"Official account for {entity_name}. News and updates.",
|
||||
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
|
||||
"age": 30, # 机构虚拟年龄
|
||||
"gender": "other", # 机构使用other
|
||||
"mbti": "ISTJ", # 机构风格:严谨保守
|
||||
"country": "中国",
|
||||
"profession": "Media",
|
||||
"interested_topics": ["General News", "Current Events", "Public Affairs"],
|
||||
}
|
||||
@@ -792,6 +822,10 @@ class OasisProfileGenerator:
|
||||
return {
|
||||
"bio": f"Official account of {entity_name}.",
|
||||
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
|
||||
"age": 30, # 机构虚拟年龄
|
||||
"gender": "other", # 机构使用other
|
||||
"mbti": "ISTJ", # 机构风格:严谨保守
|
||||
"country": "中国",
|
||||
"profession": entity_type,
|
||||
"interested_topics": ["Public Policy", "Community", "Official Announcements"],
|
||||
}
|
||||
@@ -1039,6 +1073,31 @@ class OasisProfileGenerator:
|
||||
|
||||
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)")
|
||||
|
||||
def _normalize_gender(self, gender: Optional[str]) -> str:
|
||||
"""
|
||||
标准化gender字段为OASIS要求的英文格式
|
||||
|
||||
OASIS要求: male, female, other
|
||||
"""
|
||||
if not gender:
|
||||
return "other"
|
||||
|
||||
gender_lower = gender.lower().strip()
|
||||
|
||||
# 中文映射
|
||||
gender_map = {
|
||||
"男": "male",
|
||||
"女": "female",
|
||||
"机构": "other",
|
||||
"其他": "other",
|
||||
# 英文已有
|
||||
"male": "male",
|
||||
"female": "female",
|
||||
"other": "other",
|
||||
}
|
||||
|
||||
return gender_map.get(gender_lower, "other")
|
||||
|
||||
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
|
||||
"""
|
||||
保存Reddit Profile为JSON格式
|
||||
@@ -1048,26 +1107,30 @@ class OasisProfileGenerator:
|
||||
2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics
|
||||
|
||||
我们使用详细格式,与用户示例数据(36个简单人设.json)保持一致
|
||||
|
||||
OASIS要求所有字段都必须存在:
|
||||
- age: 整数
|
||||
- gender: "male", "female", 或 "other"
|
||||
- mbti: MBTI类型字符串
|
||||
- country: 国家字符串
|
||||
"""
|
||||
data = []
|
||||
for profile in profiles:
|
||||
# 使用详细格式(与用户示例兼容)
|
||||
# 确保所有必需字段都有有效值
|
||||
item = {
|
||||
"realname": profile.name,
|
||||
"username": profile.user_name,
|
||||
"bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符
|
||||
"bio": profile.bio[:150] if profile.bio else f"{profile.name}",
|
||||
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
|
||||
# OASIS必需字段 - 确保都有默认值
|
||||
"age": profile.age if profile.age else 30,
|
||||
"gender": self._normalize_gender(profile.gender),
|
||||
"mbti": profile.mbti if profile.mbti else "ISTJ",
|
||||
"country": profile.country if profile.country else "中国",
|
||||
}
|
||||
|
||||
# 添加人设详情字段
|
||||
if profile.age:
|
||||
item["age"] = profile.age
|
||||
if profile.gender:
|
||||
item["gender"] = profile.gender
|
||||
if profile.mbti:
|
||||
item["mbti"] = profile.mbti
|
||||
if profile.country:
|
||||
item["country"] = profile.country
|
||||
# 可选字段
|
||||
if profile.profession:
|
||||
item["profession"] = profile.profession
|
||||
if profile.interested_topics:
|
||||
@@ -1078,7 +1141,7 @@ class OasisProfileGenerator:
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)")
|
||||
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式,已标准化gender字段)")
|
||||
|
||||
# 保留旧方法名作为别名,保持向后兼容
|
||||
def save_profiles_to_json(
|
||||
|
||||
@@ -127,12 +127,6 @@ class SimulationManager:
|
||||
'../../uploads/simulations'
|
||||
)
|
||||
|
||||
# 预设脚本目录
|
||||
SCRIPTS_DIR = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../../scripts'
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
# 确保目录存在
|
||||
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
||||
@@ -426,27 +420,8 @@ class SimulationManager:
|
||||
total=3
|
||||
)
|
||||
|
||||
# ========== 阶段4: 复制预设脚本 ==========
|
||||
script_files = ["run_twitter_simulation.py", "run_reddit_simulation.py",
|
||||
"run_parallel_simulation.py", "action_logger.py"]
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"copying_scripts", 0,
|
||||
"开始准备脚本...",
|
||||
current=0,
|
||||
total=len(script_files)
|
||||
)
|
||||
|
||||
self._copy_preset_scripts(sim_dir)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"copying_scripts", 100,
|
||||
f"完成,共 {len(script_files)} 个脚本",
|
||||
current=len(script_files),
|
||||
total=len(script_files)
|
||||
)
|
||||
# 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录
|
||||
# 启动模拟时,simulation_runner 会从 scripts/ 目录运行脚本
|
||||
|
||||
# 更新状态
|
||||
state.status = SimulationStatus.READY
|
||||
@@ -466,24 +441,6 @@ class SimulationManager:
|
||||
self._save_simulation_state(state)
|
||||
raise
|
||||
|
||||
def _copy_preset_scripts(self, sim_dir: str):
|
||||
"""复制预设脚本到模拟目录"""
|
||||
scripts = [
|
||||
"run_twitter_simulation.py",
|
||||
"run_reddit_simulation.py",
|
||||
"run_parallel_simulation.py"
|
||||
]
|
||||
|
||||
for script in scripts:
|
||||
src = os.path.join(self.SCRIPTS_DIR, script)
|
||||
dst = os.path.join(sim_dir, script)
|
||||
|
||||
if os.path.exists(src):
|
||||
shutil.copy2(src, dst)
|
||||
logger.debug(f"复制脚本: {script}")
|
||||
else:
|
||||
logger.warning(f"预设脚本不存在: {src}")
|
||||
|
||||
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
||||
"""获取模拟状态"""
|
||||
return self._load_simulation_state(simulation_id)
|
||||
@@ -531,21 +488,22 @@ class SimulationManager:
|
||||
"""获取运行说明"""
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
|
||||
|
||||
return {
|
||||
"simulation_dir": sim_dir,
|
||||
"scripts_dir": scripts_dir,
|
||||
"config_file": config_path,
|
||||
"commands": {
|
||||
"twitter": f"python run_twitter_simulation.py --config simulation_config.json",
|
||||
"reddit": f"python run_reddit_simulation.py --config simulation_config.json",
|
||||
"parallel": f"python run_parallel_simulation.py --config simulation_config.json",
|
||||
"twitter": f"python {scripts_dir}/run_twitter_simulation.py --config {config_path}",
|
||||
"reddit": f"python {scripts_dir}/run_reddit_simulation.py --config {config_path}",
|
||||
"parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}",
|
||||
},
|
||||
"instructions": (
|
||||
f"1. 进入模拟目录: cd {sim_dir}\n"
|
||||
f"2. 激活conda环境: conda activate MiroFish\n"
|
||||
f"3. 运行模拟:\n"
|
||||
f" - 单独运行Twitter: python run_twitter_simulation.py --config simulation_config.json\n"
|
||||
f" - 单独运行Reddit: python run_reddit_simulation.py --config simulation_config.json\n"
|
||||
f" - 并行运行双平台: python run_parallel_simulation.py --config simulation_config.json"
|
||||
f"1. 激活conda环境: conda activate MiroFish\n"
|
||||
f"2. 运行模拟 (脚本位于 {scripts_dir}):\n"
|
||||
f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
|
||||
f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
|
||||
f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -182,11 +182,19 @@ class SimulationRunner:
|
||||
'../../uploads/simulations'
|
||||
)
|
||||
|
||||
# 脚本目录
|
||||
SCRIPTS_DIR = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../../scripts'
|
||||
)
|
||||
|
||||
# 内存中的运行状态
|
||||
_run_states: Dict[str, SimulationRunState] = {}
|
||||
_processes: Dict[str, subprocess.Popen] = {}
|
||||
_action_queues: Dict[str, Queue] = {}
|
||||
_monitor_threads: Dict[str, threading.Thread] = {}
|
||||
_stdout_files: Dict[str, Any] = {} # 存储 stdout 文件句柄
|
||||
_stderr_files: Dict[str, Any] = {} # 存储 stderr 文件句柄
|
||||
|
||||
@classmethod
|
||||
def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
|
||||
@@ -310,7 +318,7 @@ class SimulationRunner:
|
||||
|
||||
cls._save_run_state(state)
|
||||
|
||||
# 确定运行哪个脚本
|
||||
# 确定运行哪个脚本(脚本位于 backend/scripts/ 目录)
|
||||
if platform == "twitter":
|
||||
script_name = "run_twitter_simulation.py"
|
||||
state.twitter_running = True
|
||||
@@ -322,7 +330,7 @@ class SimulationRunner:
|
||||
state.twitter_running = True
|
||||
state.reddit_running = True
|
||||
|
||||
script_path = os.path.join(sim_dir, script_name)
|
||||
script_path = os.path.join(cls.SCRIPTS_DIR, script_name)
|
||||
|
||||
if not os.path.exists(script_path):
|
||||
raise ValueError(f"脚本不存在: {script_path}")
|
||||
@@ -333,24 +341,36 @@ class SimulationRunner:
|
||||
|
||||
# 启动模拟进程
|
||||
try:
|
||||
# 构建运行命令
|
||||
# 构建运行命令,使用完整路径
|
||||
action_log_path = os.path.join(sim_dir, "actions.jsonl")
|
||||
|
||||
cmd = [
|
||||
sys.executable, # Python解释器
|
||||
script_path,
|
||||
"--config", "simulation_config.json",
|
||||
"--action-log", "actions.jsonl", # 动作日志文件
|
||||
"--config", config_path, # 使用完整配置文件路径
|
||||
"--action-log", action_log_path, # 动作日志文件完整路径
|
||||
]
|
||||
|
||||
# 设置工作目录为模拟目录
|
||||
# 创建输出日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
|
||||
stdout_log_path = os.path.join(sim_dir, "simulation_stdout.log")
|
||||
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
|
||||
stdout_file = open(stdout_log_path, 'w', encoding='utf-8')
|
||||
stderr_file = open(stderr_log_path, 'w', encoding='utf-8')
|
||||
|
||||
# 设置工作目录为模拟目录(数据库等文件会生成在此)
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=sim_dir,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# 保存文件句柄以便后续关闭
|
||||
cls._stdout_files[simulation_id] = stdout_file
|
||||
cls._stderr_files[simulation_id] = stderr_file
|
||||
|
||||
state.process_pid = process.pid
|
||||
state.runner_status = RunnerStatus.RUNNING
|
||||
cls._processes[simulation_id] = process
|
||||
@@ -434,8 +454,16 @@ class SimulationRunner:
|
||||
logger.info(f"模拟完成: {simulation_id}")
|
||||
else:
|
||||
state.runner_status = RunnerStatus.FAILED
|
||||
stderr = process.stderr.read() if process.stderr else ""
|
||||
state.error = f"进程退出码: {exit_code}, 错误: {stderr[:500]}"
|
||||
# 从 stderr 日志文件读取错误信息
|
||||
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
|
||||
stderr = ""
|
||||
try:
|
||||
if os.path.exists(stderr_log_path):
|
||||
with open(stderr_log_path, 'r', encoding='utf-8') as f:
|
||||
stderr = f.read()
|
||||
except Exception:
|
||||
pass
|
||||
state.error = f"进程退出码: {exit_code}, 错误: {stderr[-1000:]}" # 取最后1000字符
|
||||
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
|
||||
|
||||
state.twitter_running = False
|
||||
@@ -449,9 +477,23 @@ class SimulationRunner:
|
||||
cls._save_run_state(state)
|
||||
|
||||
finally:
|
||||
# 清理
|
||||
# 清理进程资源
|
||||
cls._processes.pop(simulation_id, None)
|
||||
cls._action_queues.pop(simulation_id, None)
|
||||
|
||||
# 关闭日志文件句柄
|
||||
if simulation_id in cls._stdout_files:
|
||||
try:
|
||||
cls._stdout_files[simulation_id].close()
|
||||
except Exception:
|
||||
pass
|
||||
cls._stdout_files.pop(simulation_id, None)
|
||||
if simulation_id in cls._stderr_files:
|
||||
try:
|
||||
cls._stderr_files[simulation_id].close()
|
||||
except Exception:
|
||||
pass
|
||||
cls._stderr_files.pop(simulation_id, None)
|
||||
|
||||
@classmethod
|
||||
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
|
||||
|
||||
@@ -3,7 +3,8 @@ Zep实体读取与过滤服务
|
||||
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Set
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
@@ -13,6 +14,9 @@ from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.zep_entity_reader')
|
||||
|
||||
# 用于泛型返回类型
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityNode:
|
||||
@@ -80,9 +84,48 @@ class ZepEntityReader:
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
|
||||
def _call_with_retry(
|
||||
self,
|
||||
func: Callable[[], T],
|
||||
operation_name: str,
|
||||
max_retries: int = 3,
|
||||
initial_delay: float = 2.0
|
||||
) -> T:
|
||||
"""
|
||||
带重试机制的Zep API调用
|
||||
|
||||
Args:
|
||||
func: 要执行的函数(无参数的lambda或callable)
|
||||
operation_name: 操作名称,用于日志
|
||||
max_retries: 最大重试次数(默认3次,即最多尝试3次)
|
||||
initial_delay: 初始延迟秒数
|
||||
|
||||
Returns:
|
||||
API调用结果
|
||||
"""
|
||||
last_exception = None
|
||||
delay = initial_delay
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func()
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, "
|
||||
f"{delay:.1f}秒后重试..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
delay *= 2 # 指数退避
|
||||
else:
|
||||
logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}")
|
||||
|
||||
raise last_exception
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有节点
|
||||
获取图谱的所有节点(带重试机制)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
@@ -92,7 +135,11 @@ class ZepEntityReader:
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
|
||||
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||||
# 使用重试机制调用Zep API
|
||||
nodes = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
|
||||
operation_name=f"获取节点(graph={graph_id})"
|
||||
)
|
||||
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
@@ -109,7 +156,7 @@ class ZepEntityReader:
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有边
|
||||
获取图谱的所有边(带重试机制)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
@@ -119,7 +166,11 @@ class ZepEntityReader:
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
|
||||
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||||
# 使用重试机制调用Zep API
|
||||
edges = self._call_with_retry(
|
||||
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
|
||||
operation_name=f"获取边(graph={graph_id})"
|
||||
)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
@@ -137,7 +188,7 @@ class ZepEntityReader:
|
||||
|
||||
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定节点的所有相关边
|
||||
获取指定节点的所有相关边(带重试机制)
|
||||
|
||||
Args:
|
||||
node_uuid: 节点UUID
|
||||
@@ -146,7 +197,11 @@ class ZepEntityReader:
|
||||
边列表
|
||||
"""
|
||||
try:
|
||||
edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
|
||||
# 使用重试机制调用Zep API
|
||||
edges = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
||||
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
|
||||
)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
@@ -288,7 +343,7 @@ class ZepEntityReader:
|
||||
entity_uuid: str
|
||||
) -> Optional[EntityNode]:
|
||||
"""
|
||||
获取单个实体及其完整上下文(边和关联节点)
|
||||
获取单个实体及其完整上下文(边和关联节点,带重试机制)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
@@ -298,8 +353,11 @@ class ZepEntityReader:
|
||||
EntityNode或None
|
||||
"""
|
||||
try:
|
||||
# 获取节点
|
||||
node = self.client.graph.node.get(uuid_=entity_uuid)
|
||||
# 使用重试机制获取节点
|
||||
node = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
||||
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
|
||||
)
|
||||
|
||||
if not node:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user