Refactor simulation management and enhance logging capabilities

- Updated simulation preparation checks to exclude script files from the required files list, improving clarity on file management.
- Implemented a robust retry mechanism for Zep API calls in the ZepEntityReader service, enhancing reliability.
- Enhanced logging in simulation scripts to provide clearer insights into the simulation process and errors.
- Updated simulation runner to manage stdout and stderr logs more effectively, ensuring better error tracking.
- Improved profile generation to standardize gender fields and ensure all required fields are populated correctly.
This commit is contained in:
666ghj
2025-12-02 14:25:53 +08:00
parent af5c235695
commit 3cc5e3f479
8 changed files with 595 additions and 165 deletions

View File

@@ -10,6 +10,7 @@ OASIS Agent Profile生成器
import json
import random
import time
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
@@ -315,32 +316,54 @@ class OasisProfileGenerator:
comprehensive_query = f"关于{entity_name}的所有信息、活动、事件、关系和背景"
def search_edges():
"""搜索边(事实/关系)"""
try:
return self.zep_client.graph.search(
query=comprehensive_query,
graph_id=self.graph_id,
limit=30,
scope="edges",
reranker="rrf"
)
except Exception as e:
logger.debug(f"Zep边搜索失败: {e}")
return None
"""搜索边(事实/关系)- 带重试机制"""
max_retries = 3
last_exception = None
delay = 2.0
for attempt in range(max_retries):
try:
return self.zep_client.graph.search(
query=comprehensive_query,
graph_id=self.graph_id,
limit=30,
scope="edges",
reranker="rrf"
)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
time.sleep(delay)
delay *= 2
else:
logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}")
return None
def search_nodes():
"""搜索节点(实体摘要)"""
try:
return self.zep_client.graph.search(
query=comprehensive_query,
graph_id=self.graph_id,
limit=20,
scope="nodes",
reranker="rrf"
)
except Exception as e:
logger.debug(f"Zep节点搜索失败: {e}")
return None
"""搜索节点(实体摘要)- 带重试机制"""
max_retries = 3
last_exception = None
delay = 2.0
for attempt in range(max_retries):
try:
return self.zep_client.graph.search(
query=comprehensive_query,
graph_id=self.graph_id,
limit=20,
scope="nodes",
reranker="rrf"
)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
time.sleep(delay)
delay *= 2
else:
logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}")
return None
try:
# 并行执行edges和nodes搜索
@@ -684,18 +707,20 @@ class OasisProfileGenerator:
- 立场观点(对话题的态度、可能被激怒/感动的内容)
- 独特特征(口头禅、特殊经历、个人爱好)
- 个人记忆(人设的重要部分,要介绍这个个体与事件的关联,以及这个个体在事件中的已有动作与反应)
3. age: 年龄数字
4. gender: 性别(男/女)
5. mbti: MBTI类型
6. country: 国家
3. age: 年龄数字(必须是整数)
4. gender: 性别,必须是英文: "male""female"
5. mbti: MBTI类型如INTJ、ENFP等
6. country: 国家(使用中文,如"中国"
7. profession: 职业
8. interested_topics: 感兴趣话题数组
重要:
- 所有字段值必须是字符串或数字,不要使用换行符
- persona必须是一段连贯的文字描述
- 使用中文
- 内容要与实体信息保持一致"""
- 使用中文除了gender字段必须用英文male/female
- 内容要与实体信息保持一致
- age必须是有效的整数gender必须是"male""female"
"""
def _build_group_persona_prompt(
self,
@@ -731,17 +756,18 @@ class OasisProfileGenerator:
- 立场态度(对核心话题的官方立场、面对争议的处理方式)
- 特殊说明(代表的群体画像、运营习惯)
- 机构记忆(机构人设的重要部分,要介绍这个机构与事件的关联,以及这个机构在事件中的已有动作与反应)
3. age: null机构不适用
4. gender: null机构不适用
5. mbti: 可选用于描述账号风格如ISTJ代表严谨保守
6. country: 国家
3. age: 固定填30机构账号的虚拟年龄
4. gender: 固定填"other"机构账号使用other表示非个人
5. mbti: MBTI类型用于描述账号风格如ISTJ代表严谨保守
6. country: 国家(使用中文,如"中国"
7. profession: 机构职能描述
8. interested_topics: 关注领域数组
重要:
- 所有字段值必须是字符串数字null
- 所有字段值必须是字符串数字,不允许null
- persona必须是一段连贯的文字描述不要使用换行符
- 使用中文
- 使用中文除了gender字段必须用英文"other"
- age必须是整数30gender必须是字符串"other"
- 机构账号发言要符合其身份定位"""
def _generate_profile_rule_based(
@@ -784,6 +810,10 @@ class OasisProfileGenerator:
return {
"bio": f"Official account for {entity_name}. News and updates.",
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
"age": 30, # 机构虚拟年龄
"gender": "other", # 机构使用other
"mbti": "ISTJ", # 机构风格:严谨保守
"country": "中国",
"profession": "Media",
"interested_topics": ["General News", "Current Events", "Public Affairs"],
}
@@ -792,6 +822,10 @@ class OasisProfileGenerator:
return {
"bio": f"Official account of {entity_name}.",
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
"age": 30, # 机构虚拟年龄
"gender": "other", # 机构使用other
"mbti": "ISTJ", # 机构风格:严谨保守
"country": "中国",
"profession": entity_type,
"interested_topics": ["Public Policy", "Community", "Official Announcements"],
}
@@ -1039,6 +1073,31 @@ class OasisProfileGenerator:
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (OASIS CSV格式)")
def _normalize_gender(self, gender: Optional[str]) -> str:
"""
标准化gender字段为OASIS要求的英文格式
OASIS要求: male, female, other
"""
if not gender:
return "other"
gender_lower = gender.lower().strip()
# 中文映射
gender_map = {
"": "male",
"": "female",
"机构": "other",
"其他": "other",
# 英文已有
"male": "male",
"female": "female",
"other": "other",
}
return gender_map.get(gender_lower, "other")
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
"""
保存Reddit Profile为JSON格式
@@ -1048,26 +1107,30 @@ class OasisProfileGenerator:
2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics
我们使用详细格式,与用户示例数据(36个简单人设.json)保持一致
OASIS要求所有字段都必须存在
- age: 整数
- gender: "male", "female", 或 "other"
- mbti: MBTI类型字符串
- country: 国家字符串
"""
data = []
for profile in profiles:
# 使用详细格式(与用户示例兼容)
# 确保所有必需字段都有有效值
item = {
"realname": profile.name,
"username": profile.user_name,
"bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符
"bio": profile.bio[:150] if profile.bio else f"{profile.name}",
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
# OASIS必需字段 - 确保都有默认值
"age": profile.age if profile.age else 30,
"gender": self._normalize_gender(profile.gender),
"mbti": profile.mbti if profile.mbti else "ISTJ",
"country": profile.country if profile.country else "中国",
}
# 添加人设详情字段
if profile.age:
item["age"] = profile.age
if profile.gender:
item["gender"] = profile.gender
if profile.mbti:
item["mbti"] = profile.mbti
if profile.country:
item["country"] = profile.country
# 可选字段
if profile.profession:
item["profession"] = profile.profession
if profile.interested_topics:
@@ -1078,7 +1141,7 @@ class OasisProfileGenerator:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)")
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式已标准化gender字段)")
# 保留旧方法名作为别名,保持向后兼容
def save_profiles_to_json(

View File

@@ -127,12 +127,6 @@ class SimulationManager:
'../../uploads/simulations'
)
# 预设脚本目录
SCRIPTS_DIR = os.path.join(
os.path.dirname(__file__),
'../../scripts'
)
def __init__(self):
# 确保目录存在
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
@@ -426,27 +420,8 @@ class SimulationManager:
total=3
)
# ========== 阶段4: 复制预设脚本 ==========
script_files = ["run_twitter_simulation.py", "run_reddit_simulation.py",
"run_parallel_simulation.py", "action_logger.py"]
if progress_callback:
progress_callback(
"copying_scripts", 0,
"开始准备脚本...",
current=0,
total=len(script_files)
)
self._copy_preset_scripts(sim_dir)
if progress_callback:
progress_callback(
"copying_scripts", 100,
f"完成,共 {len(script_files)} 个脚本",
current=len(script_files),
total=len(script_files)
)
# 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录
# 启动模拟时simulation_runner 会从 scripts/ 目录运行脚本
# 更新状态
state.status = SimulationStatus.READY
@@ -466,24 +441,6 @@ class SimulationManager:
self._save_simulation_state(state)
raise
def _copy_preset_scripts(self, sim_dir: str):
"""复制预设脚本到模拟目录"""
scripts = [
"run_twitter_simulation.py",
"run_reddit_simulation.py",
"run_parallel_simulation.py"
]
for script in scripts:
src = os.path.join(self.SCRIPTS_DIR, script)
dst = os.path.join(sim_dir, script)
if os.path.exists(src):
shutil.copy2(src, dst)
logger.debug(f"复制脚本: {script}")
else:
logger.warning(f"预设脚本不存在: {src}")
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
"""获取模拟状态"""
return self._load_simulation_state(simulation_id)
@@ -531,21 +488,22 @@ class SimulationManager:
"""获取运行说明"""
sim_dir = self._get_simulation_dir(simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json")
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
return {
"simulation_dir": sim_dir,
"scripts_dir": scripts_dir,
"config_file": config_path,
"commands": {
"twitter": f"python run_twitter_simulation.py --config simulation_config.json",
"reddit": f"python run_reddit_simulation.py --config simulation_config.json",
"parallel": f"python run_parallel_simulation.py --config simulation_config.json",
"twitter": f"python {scripts_dir}/run_twitter_simulation.py --config {config_path}",
"reddit": f"python {scripts_dir}/run_reddit_simulation.py --config {config_path}",
"parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}",
},
"instructions": (
f"1. 进入模拟目录: cd {sim_dir}\n"
f"2. 激活conda环境: conda activate MiroFish\n"
f"3. 运行模拟:\n"
f" - 单独运行Twitter: python run_twitter_simulation.py --config simulation_config.json\n"
f" - 单独运行Reddit: python run_reddit_simulation.py --config simulation_config.json\n"
f" - 并行运行双平台: python run_parallel_simulation.py --config simulation_config.json"
f"1. 激活conda环境: conda activate MiroFish\n"
f"2. 运行模拟 (脚本位于 {scripts_dir}):\n"
f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
)
}

View File

@@ -182,11 +182,19 @@ class SimulationRunner:
'../../uploads/simulations'
)
# 脚本目录
SCRIPTS_DIR = os.path.join(
os.path.dirname(__file__),
'../../scripts'
)
# 内存中的运行状态
_run_states: Dict[str, SimulationRunState] = {}
_processes: Dict[str, subprocess.Popen] = {}
_action_queues: Dict[str, Queue] = {}
_monitor_threads: Dict[str, threading.Thread] = {}
_stdout_files: Dict[str, Any] = {} # 存储 stdout 文件句柄
_stderr_files: Dict[str, Any] = {} # 存储 stderr 文件句柄
@classmethod
def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
@@ -310,7 +318,7 @@ class SimulationRunner:
cls._save_run_state(state)
# 确定运行哪个脚本
# 确定运行哪个脚本(脚本位于 backend/scripts/ 目录)
if platform == "twitter":
script_name = "run_twitter_simulation.py"
state.twitter_running = True
@@ -322,7 +330,7 @@ class SimulationRunner:
state.twitter_running = True
state.reddit_running = True
script_path = os.path.join(sim_dir, script_name)
script_path = os.path.join(cls.SCRIPTS_DIR, script_name)
if not os.path.exists(script_path):
raise ValueError(f"脚本不存在: {script_path}")
@@ -333,24 +341,36 @@ class SimulationRunner:
# 启动模拟进程
try:
# 构建运行命令
# 构建运行命令,使用完整路径
action_log_path = os.path.join(sim_dir, "actions.jsonl")
cmd = [
sys.executable, # Python解释器
script_path,
"--config", "simulation_config.json",
"--action-log", "actions.jsonl", # 动作日志文件
"--config", config_path, # 使用完整配置文件路径
"--action-log", action_log_path, # 动作日志文件完整路径
]
# 设置工作目录为模拟目录
# 创建输出日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
stdout_log_path = os.path.join(sim_dir, "simulation_stdout.log")
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
stdout_file = open(stdout_log_path, 'w', encoding='utf-8')
stderr_file = open(stderr_log_path, 'w', encoding='utf-8')
# 设置工作目录为模拟目录(数据库等文件会生成在此)
process = subprocess.Popen(
cmd,
cwd=sim_dir,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=stdout_file,
stderr=stderr_file,
text=True,
bufsize=1,
)
# 保存文件句柄以便后续关闭
cls._stdout_files[simulation_id] = stdout_file
cls._stderr_files[simulation_id] = stderr_file
state.process_pid = process.pid
state.runner_status = RunnerStatus.RUNNING
cls._processes[simulation_id] = process
@@ -434,8 +454,16 @@ class SimulationRunner:
logger.info(f"模拟完成: {simulation_id}")
else:
state.runner_status = RunnerStatus.FAILED
stderr = process.stderr.read() if process.stderr else ""
state.error = f"进程退出码: {exit_code}, 错误: {stderr[:500]}"
# 从 stderr 日志文件读取错误信息
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
stderr = ""
try:
if os.path.exists(stderr_log_path):
with open(stderr_log_path, 'r', encoding='utf-8') as f:
stderr = f.read()
except Exception:
pass
state.error = f"进程退出码: {exit_code}, 错误: {stderr[-1000:]}" # 取最后1000字符
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
state.twitter_running = False
@@ -449,9 +477,23 @@ class SimulationRunner:
cls._save_run_state(state)
finally:
# 清理
# 清理进程资源
cls._processes.pop(simulation_id, None)
cls._action_queues.pop(simulation_id, None)
# 关闭日志文件句柄
if simulation_id in cls._stdout_files:
try:
cls._stdout_files[simulation_id].close()
except Exception:
pass
cls._stdout_files.pop(simulation_id, None)
if simulation_id in cls._stderr_files:
try:
cls._stderr_files[simulation_id].close()
except Exception:
pass
cls._stderr_files.pop(simulation_id, None)
@classmethod
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:

View File

@@ -3,7 +3,8 @@ Zep实体读取与过滤服务
从Zep图谱中读取节点筛选出符合预定义实体类型的节点
"""
from typing import Dict, Any, List, Optional, Set
import time
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
from dataclasses import dataclass, field
from zep_cloud.client import Zep
@@ -13,6 +14,9 @@ from ..utils.logger import get_logger
logger = get_logger('mirofish.zep_entity_reader')
# 用于泛型返回类型
T = TypeVar('T')
@dataclass
class EntityNode:
@@ -80,9 +84,48 @@ class ZepEntityReader:
self.client = Zep(api_key=self.api_key)
def _call_with_retry(
self,
func: Callable[[], T],
operation_name: str,
max_retries: int = 3,
initial_delay: float = 2.0
) -> T:
"""
带重试机制的Zep API调用
Args:
func: 要执行的函数无参数的lambda或callable
operation_name: 操作名称,用于日志
max_retries: 最大重试次数默认3次即最多尝试3次
initial_delay: 初始延迟秒数
Returns:
API调用结果
"""
last_exception = None
delay = initial_delay
for attempt in range(max_retries):
try:
return func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.warning(
f"Zep {operation_name}{attempt + 1} 次尝试失败: {str(e)[:100]}, "
f"{delay:.1f}秒后重试..."
)
time.sleep(delay)
delay *= 2 # 指数退避
else:
logger.error(f"Zep {operation_name}{max_retries} 次尝试后仍失败: {str(e)}")
raise last_exception
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
"""
获取图谱的所有节点
获取图谱的所有节点(带重试机制)
Args:
graph_id: 图谱ID
@@ -92,7 +135,11 @@ class ZepEntityReader:
"""
logger.info(f"获取图谱 {graph_id} 的所有节点...")
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
# 使用重试机制调用Zep API
nodes = self._call_with_retry(
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取节点(graph={graph_id})"
)
nodes_data = []
for node in nodes:
@@ -109,7 +156,7 @@ class ZepEntityReader:
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
"""
获取图谱的所有边
获取图谱的所有边(带重试机制)
Args:
graph_id: 图谱ID
@@ -119,7 +166,11 @@ class ZepEntityReader:
"""
logger.info(f"获取图谱 {graph_id} 的所有边...")
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
# 使用重试机制调用Zep API
edges = self._call_with_retry(
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
operation_name=f"获取边(graph={graph_id})"
)
edges_data = []
for edge in edges:
@@ -137,7 +188,7 @@ class ZepEntityReader:
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
"""
获取指定节点的所有相关边
获取指定节点的所有相关边(带重试机制)
Args:
node_uuid: 节点UUID
@@ -146,7 +197,11 @@ class ZepEntityReader:
边列表
"""
try:
edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
# 使用重试机制调用Zep API
edges = self._call_with_retry(
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
)
edges_data = []
for edge in edges:
@@ -288,7 +343,7 @@ class ZepEntityReader:
entity_uuid: str
) -> Optional[EntityNode]:
"""
获取单个实体及其完整上下文(边和关联节点)
获取单个实体及其完整上下文(边和关联节点,带重试机制
Args:
graph_id: 图谱ID
@@ -298,8 +353,11 @@ class ZepEntityReader:
EntityNode或None
"""
try:
# 获取节点
node = self.client.graph.node.get(uuid_=entity_uuid)
# 使用重试机制获取节点
node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
)
if not node:
return None