Enhance simulation management and logging features

- Registered a cleanup function for simulation processes to ensure proper termination on server shutdown.
- Improved logging during application startup to confirm the registration of the cleanup function.
- Updated simulation preparation checks to clarify the conditions for considering a simulation ready, enhancing error handling and user feedback.
- Added detailed logging for simulation status changes, improving traceability during the simulation lifecycle.
- Introduced new files for simulation configuration and profile data, supporting enhanced testing and visualization capabilities.
This commit is contained in:
666ghj
2025-12-02 17:11:47 +08:00
parent 3cc5e3f479
commit d4fac63eb4
15 changed files with 8515 additions and 241 deletions

View File

@@ -3,7 +3,16 @@ OASIS 双平台并行模拟预设脚本
同时运行Twitter和Reddit模拟读取相同的配置文件
使用方式:
python run_parallel_simulation.py --config simulation_config.json [--action-log actions.jsonl]
python run_parallel_simulation.py --config simulation_config.json
日志结构:
sim_xxx/
├── twitter/
│ └── actions.jsonl # Twitter 平台动作日志
├── reddit/
│ └── actions.jsonl # Reddit 平台动作日志
├── simulation.log # 主模拟进程日志
└── run_state.json # 运行状态API 查询用)
"""
import argparse
@@ -12,9 +21,10 @@ import json
import logging
import os
import random
import sqlite3
import sys
from datetime import datetime
from typing import Dict, Any, List, Optional
from typing import Dict, Any, List, Optional, Tuple
# 添加 backend 目录到路径
# 脚本固定位于 backend/scripts/ 目录
@@ -38,91 +48,45 @@ else:
print(f"已加载环境配置: {_backend_env}")
class UnicodeFormatter(logging.Formatter):
def disable_oasis_logging():
"""
自定义格式化器,将 Unicode 转义序列(如 \\uXXXX转换为可读字符
禁用 OASIS 库的详细日志输出
OASIS 的日志太冗余(记录每个 agent 的观察和动作),我们使用自己的 action_logger
"""
# 禁用 OASIS 的所有日志器
oasis_loggers = [
"social.agent",
"social.twitter",
"social.rec",
"oasis.env",
"table",
]
# 匹配 \uXXXX 形式的 Unicode 转义序列
UNICODE_ESCAPE_PATTERN = None
@classmethod
def _get_pattern(cls):
if cls.UNICODE_ESCAPE_PATTERN is None:
import re
cls.UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
return cls.UNICODE_ESCAPE_PATTERN
def format(self, record):
# 先获取原始格式化结果
result = super().format(record)
# 使用正则表达式替换 Unicode 转义序列
pattern = self._get_pattern()
def replace_unicode(match):
try:
return chr(int(match.group(1), 16))
except (ValueError, OverflowError):
return match.group(0)
return pattern.sub(replace_unicode, result)
def setup_oasis_logging(log_dir: str):
"""
配置 OASIS 的日志,覆盖默认的带时间戳日志文件
Args:
log_dir: 日志目录路径
"""
os.makedirs(log_dir, exist_ok=True)
# 清理旧的日志文件
for f in os.listdir(log_dir):
old_log = os.path.join(log_dir, f)
if os.path.isfile(old_log) and f.endswith('.log'):
try:
os.remove(old_log)
except OSError:
pass
# 创建自定义格式化器(支持 Unicode 解码)
formatter = UnicodeFormatter(
"%(levelname)s - %(asctime)s - %(name)s - %(message)s"
)
# 重新配置 OASIS 使用的日志器,使用固定名称(不带时间戳)
loggers_config = {
"social.agent": os.path.join(log_dir, "social.agent.log"),
"social.twitter": os.path.join(log_dir, "social.twitter.log"),
"social.rec": os.path.join(log_dir, "social.rec.log"),
"oasis.env": os.path.join(log_dir, "oasis.env.log"),
"table": os.path.join(log_dir, "table.log"),
}
for logger_name, log_file in loggers_config.items():
for logger_name in oasis_loggers:
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG)
# 清除 OASIS 添加的现有处理器(带时间戳的日志文件)
logger.setLevel(logging.CRITICAL) # 只记录严重错误
logger.handlers.clear()
# 添加新的文件处理器(使用 UTF-8 编码,固定文件名)
file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# 防止日志向上传播(避免重复)
logger.propagate = False
print(f"日志配置完成,日志目录: {log_dir}")
def init_logging_for_simulation(simulation_dir: str):
"""初始化模拟的日志配置"""
log_dir = os.path.join(simulation_dir, "log")
setup_oasis_logging(log_dir)
"""
初始化模拟的日志配置
Args:
simulation_dir: 模拟目录路径
"""
# 禁用 OASIS 的详细日志
disable_oasis_logging()
# 清理旧的 log 目录(如果存在)
old_log_dir = os.path.join(simulation_dir, "log")
if os.path.exists(old_log_dir):
import shutil
shutil.rmtree(old_log_dir, ignore_errors=True)
from action_logger import ActionLogger
from action_logger import SimulationLogManager, PlatformActionLogger
try:
from camel.models import ModelFactory
@@ -175,6 +139,120 @@ def load_config(config_path: str) -> Dict[str, Any]:
return json.load(f)
# 需要过滤掉的非核心动作类型(这些动作对分析价值较低)
FILTERED_ACTIONS = {'refresh', 'sign_up'}
# 动作类型映射表(数据库中的名称 -> 标准名称)
ACTION_TYPE_MAP = {
'create_post': 'CREATE_POST',
'like_post': 'LIKE_POST',
'dislike_post': 'DISLIKE_POST',
'repost': 'REPOST',
'quote_post': 'QUOTE_POST',
'follow': 'FOLLOW',
'mute': 'MUTE',
'create_comment': 'CREATE_COMMENT',
'like_comment': 'LIKE_COMMENT',
'dislike_comment': 'DISLIKE_COMMENT',
'search_posts': 'SEARCH_POSTS',
'search_user': 'SEARCH_USER',
'trend': 'TREND',
'do_nothing': 'DO_NOTHING',
'interview': 'INTERVIEW',
}
def fetch_new_actions_from_db(
db_path: str,
last_rowid: int,
agent_names: Dict[int, str]
) -> Tuple[List[Dict[str, Any]], int]:
"""
从数据库中获取新的动作记录
Args:
db_path: 数据库文件路径
last_rowid: 上次读取的最大 rowid 值(使用 rowid 而不是 created_at因为不同平台的 created_at 格式不同)
agent_names: agent_id -> agent_name 映射
Returns:
(actions_list, new_last_rowid)
- actions_list: 动作列表,每个元素包含 agent_id, agent_name, action_type, action_args
- new_last_rowid: 新的最大 rowid 值
"""
actions = []
new_last_rowid = last_rowid
if not os.path.exists(db_path):
return actions, new_last_rowid
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# 使用 rowid 来追踪已处理的记录rowid 是 SQLite 的内置自增字段)
# 这样可以避免 created_at 格式差异问题Twitter 用整数Reddit 用日期时间字符串)
cursor.execute("""
SELECT rowid, user_id, action, info
FROM trace
WHERE rowid > ?
ORDER BY rowid ASC
""", (last_rowid,))
for rowid, user_id, action, info_json in cursor.fetchall():
# 更新最大 rowid
new_last_rowid = rowid
# 过滤非核心动作
if action in FILTERED_ACTIONS:
continue
# 解析动作参数
try:
action_args = json.loads(info_json) if info_json else {}
except json.JSONDecodeError:
action_args = {}
# 精简 action_args只保留关键字段
simplified_args = {}
if 'content' in action_args:
content = action_args['content']
# 截断过长的内容
simplified_args['content'] = content[:200] + '...' if len(content) > 200 else content
if 'post_id' in action_args:
simplified_args['post_id'] = action_args['post_id']
if 'comment_id' in action_args:
simplified_args['comment_id'] = action_args['comment_id']
if 'quoted_id' in action_args:
simplified_args['quoted_id'] = action_args['quoted_id']
if 'new_post_id' in action_args:
simplified_args['new_post_id'] = action_args['new_post_id']
if 'follow_id' in action_args:
simplified_args['follow_id'] = action_args['follow_id']
if 'query' in action_args:
simplified_args['query'] = action_args['query']
if 'like_id' in action_args:
simplified_args['like_id'] = action_args['like_id']
if 'dislike_id' in action_args:
simplified_args['dislike_id'] = action_args['dislike_id']
# 转换动作类型名称
action_type = ACTION_TYPE_MAP.get(action, action.upper())
actions.append({
'agent_id': user_id,
'agent_name': agent_names.get(user_id, f'Agent_{user_id}'),
'action_type': action_type,
'action_args': simplified_args,
})
conn.close()
except Exception as e:
print(f"读取数据库动作失败: {e}")
return actions, new_last_rowid
def create_model(config: Dict[str, Any]):
"""
创建LLM模型
@@ -269,17 +347,23 @@ def get_active_agents_for_round(
async def run_twitter_simulation(
config: Dict[str, Any],
simulation_dir: str,
action_logger: Optional[ActionLogger] = None
action_logger: Optional[PlatformActionLogger] = None,
main_logger: Optional[SimulationLogManager] = None
):
"""运行Twitter模拟"""
print("[Twitter] 初始化...")
def log_info(msg):
if main_logger:
main_logger.info(f"[Twitter] {msg}")
print(f"[Twitter] {msg}")
log_info("初始化...")
model = create_model(config)
# OASIS Twitter使用CSV格式
profile_path = os.path.join(simulation_dir, "twitter_profiles.csv")
if not os.path.exists(profile_path):
print(f"[Twitter] 错误: Profile文件不存在: {profile_path}")
log_info(f"错误: Profile文件不存在: {profile_path}")
return
agent_graph = await generate_twitter_agent_graph(
@@ -304,12 +388,13 @@ async def run_twitter_simulation(
)
await env.reset()
print("[Twitter] 环境已启动")
log_info("环境已启动")
if action_logger:
action_logger.log_simulation_start("twitter", config)
action_logger.log_simulation_start(config)
total_actions = 0
last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异)
# 执行初始事件
event_config = config.get("event_config", {})
@@ -330,7 +415,6 @@ async def run_twitter_simulation(
if action_logger:
action_logger.log_action(
round_num=0,
platform="twitter",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="CREATE_POST",
@@ -342,7 +426,7 @@ async def run_twitter_simulation(
if initial_actions:
await env.step(initial_actions)
print(f"[Twitter] 已发布 {len(initial_actions)} 条初始帖子")
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
# 主模拟循环
time_config = config.get("time_config", {})
@@ -365,54 +449,64 @@ async def run_twitter_simulation(
continue
if action_logger:
action_logger.log_round_start(round_num + 1, simulated_hour, "twitter")
action_logger.log_round_start(round_num + 1, simulated_hour)
actions = {agent: LLMAction() for _, agent in active_agents}
await env.step(actions)
# 记录动作
for agent_id, agent in active_agents:
# 从数据库获取实际执行的动作并记录
actual_actions, last_rowid = fetch_new_actions_from_db(
db_path, last_rowid, agent_names
)
round_action_count = 0
for action_data in actual_actions:
if action_logger:
action_logger.log_action(
round_num=round_num + 1,
platform="twitter",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="LLM_ACTION",
action_args={}
agent_id=action_data['agent_id'],
agent_name=action_data['agent_name'],
action_type=action_data['action_type'],
action_args=action_data['action_args']
)
total_actions += 1
round_action_count += 1
if action_logger:
action_logger.log_round_end(round_num + 1, len(active_agents), "twitter")
action_logger.log_round_end(round_num + 1, round_action_count)
if (round_num + 1) % 20 == 0:
progress = (round_num + 1) / total_rounds * 100
print(f"[Twitter] Day {simulated_day}, {simulated_hour:02d}:00 "
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
await env.close()
if action_logger:
action_logger.log_simulation_end("twitter", total_rounds, total_actions)
action_logger.log_simulation_end(total_rounds, total_actions)
elapsed = (datetime.now() - start_time).total_seconds()
print(f"[Twitter] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
async def run_reddit_simulation(
config: Dict[str, Any],
simulation_dir: str,
action_logger: Optional[ActionLogger] = None
action_logger: Optional[PlatformActionLogger] = None,
main_logger: Optional[SimulationLogManager] = None
):
"""运行Reddit模拟"""
print("[Reddit] 初始化...")
def log_info(msg):
if main_logger:
main_logger.info(f"[Reddit] {msg}")
print(f"[Reddit] {msg}")
log_info("初始化...")
model = create_model(config)
profile_path = os.path.join(simulation_dir, "reddit_profiles.json")
if not os.path.exists(profile_path):
print(f"[Reddit] 错误: Profile文件不存在: {profile_path}")
log_info(f"错误: Profile文件不存在: {profile_path}")
return
agent_graph = await generate_reddit_agent_graph(
@@ -437,12 +531,13 @@ async def run_reddit_simulation(
)
await env.reset()
print("[Reddit] 环境已启动")
log_info("环境已启动")
if action_logger:
action_logger.log_simulation_start("reddit", config)
action_logger.log_simulation_start(config)
total_actions = 0
last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异)
# 执行初始事件
event_config = config.get("event_config", {})
@@ -471,7 +566,6 @@ async def run_reddit_simulation(
if action_logger:
action_logger.log_action(
round_num=0,
platform="reddit",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="CREATE_POST",
@@ -483,7 +577,7 @@ async def run_reddit_simulation(
if initial_actions:
await env.step(initial_actions)
print(f"[Reddit] 已发布 {len(initial_actions)} 条初始帖子")
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
# 主模拟循环
time_config = config.get("time_config", {})
@@ -506,39 +600,43 @@ async def run_reddit_simulation(
continue
if action_logger:
action_logger.log_round_start(round_num + 1, simulated_hour, "reddit")
action_logger.log_round_start(round_num + 1, simulated_hour)
actions = {agent: LLMAction() for _, agent in active_agents}
await env.step(actions)
# 记录动作
for agent_id, agent in active_agents:
# 从数据库获取实际执行的动作并记录
actual_actions, last_rowid = fetch_new_actions_from_db(
db_path, last_rowid, agent_names
)
round_action_count = 0
for action_data in actual_actions:
if action_logger:
action_logger.log_action(
round_num=round_num + 1,
platform="reddit",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="LLM_ACTION",
action_args={}
agent_id=action_data['agent_id'],
agent_name=action_data['agent_name'],
action_type=action_data['action_type'],
action_args=action_data['action_args']
)
total_actions += 1
round_action_count += 1
if action_logger:
action_logger.log_round_end(round_num + 1, len(active_agents), "reddit")
action_logger.log_round_end(round_num + 1, round_action_count)
if (round_num + 1) % 20 == 0:
progress = (round_num + 1) / total_rounds * 100
print(f"[Reddit] Day {simulated_day}, {simulated_hour:02d}:00 "
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
await env.close()
if action_logger:
action_logger.log_simulation_end("reddit", total_rounds, total_actions)
action_logger.log_simulation_end(total_rounds, total_actions)
elapsed = (datetime.now() - start_time).total_seconds()
print(f"[Reddit] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
async def main():
@@ -559,12 +657,6 @@ async def main():
action='store_true',
help='只运行Reddit模拟'
)
parser.add_argument(
'--action-log',
type=str,
default='actions.jsonl',
help='动作日志文件路径 (默认: actions.jsonl)'
)
args = parser.parse_args()
@@ -575,52 +667,53 @@ async def main():
config = load_config(args.config)
simulation_dir = os.path.dirname(args.config) or "."
# 初始化日志配置(清理旧日志文件,使用固定名称
# 初始化日志配置(禁用 OASIS 日志,清理旧文件
init_logging_for_simulation(simulation_dir)
# 创建动作日志记录
action_log_path = os.path.join(simulation_dir, args.action_log)
action_logger = ActionLogger(action_log_path)
# 创建日志管理
log_manager = SimulationLogManager(simulation_dir)
twitter_logger = log_manager.get_twitter_logger()
reddit_logger = log_manager.get_reddit_logger()
print("=" * 60)
print("OASIS 双平台并行模拟")
print(f"配置文件: {args.config}")
print(f"模拟ID: {config.get('simulation_id', 'unknown')}")
print(f"动作日志: {action_log_path}")
print("=" * 60)
log_manager.info("=" * 60)
log_manager.info("OASIS 双平台并行模拟")
log_manager.info(f"配置文件: {args.config}")
log_manager.info(f"模拟ID: {config.get('simulation_id', 'unknown')}")
log_manager.info("=" * 60)
time_config = config.get("time_config", {})
print(f"\n模拟参数:")
print(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
print(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
print(f" - Agent数量: {len(config.get('agent_configs', []))}")
log_manager.info(f"模拟参数:")
log_manager.info(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
log_manager.info(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}")
# LLM推理说明
reasoning = config.get("generation_reasoning", "")
if reasoning:
print(f"\nLLM配置推理:")
print(f" {reasoning[:500]}..." if len(reasoning) > 500 else f" {reasoning}")
print("\n" + "=" * 60)
log_manager.info("日志结构:")
log_manager.info(f" - 主日志: simulation.log")
log_manager.info(f" - Twitter动作: twitter/actions.jsonl")
log_manager.info(f" - Reddit动作: reddit/actions.jsonl")
log_manager.info("=" * 60)
start_time = datetime.now()
if args.twitter_only:
await run_twitter_simulation(config, simulation_dir, action_logger)
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager)
elif args.reddit_only:
await run_reddit_simulation(config, simulation_dir, action_logger)
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager)
else:
# 并行运行(共享同一个action_logger
# 并行运行(每个平台使用独立的日志记录器
await asyncio.gather(
run_twitter_simulation(config, simulation_dir, action_logger),
run_reddit_simulation(config, simulation_dir, action_logger),
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager),
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager),
)
total_elapsed = (datetime.now() - start_time).total_seconds()
print("\n" + "=" * 60)
print(f"全部模拟完成! 总耗时: {total_elapsed:.1f}")
print(f"动作日志已保存到: {action_log_path}")
print("=" * 60)
log_manager.info("=" * 60)
log_manager.info(f"全部模拟完成! 总耗时: {total_elapsed:.1f}")
log_manager.info(f"日志文件:")
log_manager.info(f" - {os.path.join(simulation_dir, 'simulation.log')}")
log_manager.info(f" - {os.path.join(simulation_dir, 'twitter', 'actions.jsonl')}")
log_manager.info(f" - {os.path.join(simulation_dir, 'reddit', 'actions.jsonl')}")
log_manager.info("=" * 60)
if __name__ == "__main__":