Enhance simulation management and logging features

- Registered a cleanup function for simulation processes to ensure proper termination on server shutdown.
- Improved logging during application startup to confirm the registration of the cleanup function.
- Updated simulation preparation checks to clarify the conditions for considering a simulation ready, enhancing error handling and user feedback.
- Added detailed logging for simulation status changes, improving traceability during the simulation lifecycle.
- Introduced new files for simulation configuration and profile data, supporting enhanced testing and visualization capabilities.
This commit is contained in:
666ghj
2025-12-02 17:11:47 +08:00
parent 3cc5e3f479
commit d4fac63eb4
15 changed files with 8515 additions and 241 deletions

View File

@@ -31,6 +31,12 @@ def create_app(config_class=Config):
# 启用CORS
CORS(app, resources={r"/api/*": {"origins": "*"}})
# 注册模拟进程清理函数(确保服务器关闭时终止所有模拟进程)
from .services.simulation_runner import SimulationRunner
SimulationRunner.register_cleanup()
if should_log_startup:
logger.info("已注册模拟进程清理函数")
# 请求日志中间件
@app.before_request
def log_request():

View File

@@ -276,8 +276,16 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
# 详细日志
logger.debug(f"检测模拟准备状态: {simulation_id}, status={status}, config_generated={config_generated}")
# 如果状态是ready或preparing已有文件,认为准备完成
if status in ["ready", "preparing"] and config_generated:
# 如果 config_generated=True 且文件存在,认为准备完成
# 以下状态都说明准备工作已完成:
# - ready: 准备完成,可以运行
# - preparing: 如果 config_generated=True 说明已完成
# - running: 正在运行,说明准备早就完成了
# - completed: 运行完成,说明准备早就完成了
# - stopped: 已停止,说明准备早就完成了
# - failed: 运行失败(但准备是完成的)
prepared_statuses = ["ready", "preparing", "running", "completed", "stopped", "failed"]
if status in prepared_statuses and config_generated:
# 获取文件统计信息
profiles_file = os.path.join(simulation_dir, "reddit_profiles.json")
config_file = os.path.join(simulation_dir, "simulation_config.json")
@@ -315,7 +323,7 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
else:
logger.warning(f"模拟 {simulation_id} 检测结果: 未准备完成 (status={status}, config_generated={config_generated})")
return False, {
"reason": f"状态不是ready或config_generated为false: status={status}, config_generated={config_generated}",
"reason": f"状态不在已准备列表中或config_generated为false: status={status}, config_generated={config_generated}",
"status": status,
"config_generated": config_generated
}
@@ -1040,11 +1048,33 @@ def start_simulation():
"error": f"模拟不存在: {simulation_id}"
}), 404
# 智能处理状态:如果准备工作已完成,允许重新启动
if state.status != SimulationStatus.READY:
return jsonify({
"success": False,
"error": f"模拟未准备好,当前状态: {state.status.value},请先调用 /prepare 接口"
}), 400
# 检查准备工作是否已完成
is_prepared, prepare_info = _check_simulation_prepared(simulation_id)
if is_prepared:
# 准备工作已完成,检查是否有正在运行的进程
if state.status == SimulationStatus.RUNNING:
# 检查模拟进程是否真的在运行
run_state = SimulationRunner.get_run_state(simulation_id)
if run_state and run_state.runner_status.value == "running":
# 进程确实在运行
return jsonify({
"success": False,
"error": f"模拟正在运行中,请先调用 /stop 接口停止"
}), 400
# 进程不存在或已结束,重置状态为 ready
logger.info(f"模拟 {simulation_id} 准备工作已完成,重置状态为 ready原状态: {state.status.value}")
state.status = SimulationStatus.READY
manager._save_simulation_state(state)
else:
# 准备工作未完成
return jsonify({
"success": False,
"error": f"模拟未准备好,当前状态: {state.status.value},请先调用 /prepare 接口"
}), 400
# 启动模拟
run_state = SimulationRunner.start_simulation(simulation_id, platform)

View File

@@ -10,6 +10,8 @@ import time
import asyncio
import threading
import subprocess
import signal
import atexit
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
@@ -21,6 +23,9 @@ from ..utils.logger import get_logger
logger = get_logger('mirofish.simulation_runner')
# 标记是否已注册清理函数
_cleanup_registered = False
class RunnerStatus(str, Enum):
"""运行器状态"""
@@ -342,34 +347,36 @@ class SimulationRunner:
# 启动模拟进程
try:
# 构建运行命令,使用完整路径
action_log_path = os.path.join(sim_dir, "actions.jsonl")
# 新的日志结构:
# twitter/actions.jsonl - Twitter 动作日志
# reddit/actions.jsonl - Reddit 动作日志
# simulation.log - 主进程日志
cmd = [
sys.executable, # Python解释器
script_path,
"--config", config_path, # 使用完整配置文件路径
"--action-log", action_log_path, # 动作日志文件完整路径
]
# 创建输出日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
stdout_log_path = os.path.join(sim_dir, "simulation_stdout.log")
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
stdout_file = open(stdout_log_path, 'w', encoding='utf-8')
stderr_file = open(stderr_log_path, 'w', encoding='utf-8')
# 创建日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
main_log_path = os.path.join(sim_dir, "simulation.log")
main_log_file = open(main_log_path, 'w', encoding='utf-8')
# 设置工作目录为模拟目录(数据库等文件会生成在此)
# 使用 start_new_session=True 创建新的进程组,确保可以通过 os.killpg 终止所有子进程
process = subprocess.Popen(
cmd,
cwd=sim_dir,
stdout=stdout_file,
stderr=stderr_file,
stdout=main_log_file,
stderr=subprocess.STDOUT, # stderr 也写入同一个文件
text=True,
bufsize=1,
start_new_session=True, # 创建新进程组,确保服务器关闭时能终止所有相关进程
)
# 保存文件句柄以便后续关闭
cls._stdout_files[simulation_id] = stdout_file
cls._stderr_files[simulation_id] = stderr_file
cls._stdout_files[simulation_id] = main_log_file
cls._stderr_files[simulation_id] = None # 不再需要单独的 stderr
state.process_pid = process.pid
state.runner_status = RunnerStatus.RUNNING
@@ -399,7 +406,10 @@ class SimulationRunner:
def _monitor_simulation(cls, simulation_id: str):
"""监控模拟进程,解析动作日志"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
actions_log = os.path.join(sim_dir, "actions.jsonl")
# 新的日志结构:分平台的动作日志
twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl")
reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl")
process = cls._processes.get(simulation_id)
state = cls.get_run_state(simulation_id)
@@ -407,43 +417,32 @@ class SimulationRunner:
if not process or not state:
return
last_position = 0
twitter_position = 0
reddit_position = 0
try:
while process.poll() is None: # 进程仍在运行
# 读取动作日志
if os.path.exists(actions_log):
with open(actions_log, 'r', encoding='utf-8') as f:
f.seek(last_position)
for line in f:
line = line.strip()
if line:
try:
action_data = json.loads(line)
action = AgentAction(
round_num=action_data.get("round", 0),
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
platform=action_data.get("platform", "unknown"),
agent_id=action_data.get("agent_id", 0),
agent_name=action_data.get("agent_name", ""),
action_type=action_data.get("action_type", ""),
action_args=action_data.get("action_args", {}),
result=action_data.get("result"),
success=action_data.get("success", True),
)
state.add_action(action)
# 更新轮次
if action.round_num > state.current_round:
state.current_round = action.round_num
except json.JSONDecodeError:
pass
last_position = f.tell()
# 读取 Twitter 动作日志
if os.path.exists(twitter_actions_log):
twitter_position = cls._read_action_log(
twitter_actions_log, twitter_position, state, "twitter"
)
# 定期保存状态
# 读取 Reddit 动作日志
if os.path.exists(reddit_actions_log):
reddit_position = cls._read_action_log(
reddit_actions_log, reddit_position, state, "reddit"
)
# 更新状态
cls._save_run_state(state)
time.sleep(1) # 每秒检查一次
time.sleep(2)
# 进程结束后,最后读取一次日志
if os.path.exists(twitter_actions_log):
cls._read_action_log(twitter_actions_log, twitter_position, state, "twitter")
if os.path.exists(reddit_actions_log):
cls._read_action_log(reddit_actions_log, reddit_position, state, "reddit")
# 进程结束
exit_code = process.returncode
@@ -454,16 +453,16 @@ class SimulationRunner:
logger.info(f"模拟完成: {simulation_id}")
else:
state.runner_status = RunnerStatus.FAILED
# 从 stderr 日志文件读取错误信息
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
stderr = ""
# 从日志文件读取错误信息
main_log_path = os.path.join(sim_dir, "simulation.log")
error_info = ""
try:
if os.path.exists(stderr_log_path):
with open(stderr_log_path, 'r', encoding='utf-8') as f:
stderr = f.read()
if os.path.exists(main_log_path):
with open(main_log_path, 'r', encoding='utf-8') as f:
error_info = f.read()[-2000:] # 取最后2000字符
except Exception:
pass
state.error = f"进程退出码: {exit_code}, 错误: {stderr[-1000:]}" # 取最后1000字符
state.error = f"进程退出码: {exit_code}, 错误: {error_info}"
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
state.twitter_running = False
@@ -488,13 +487,70 @@ class SimulationRunner:
except Exception:
pass
cls._stdout_files.pop(simulation_id, None)
if simulation_id in cls._stderr_files:
if simulation_id in cls._stderr_files and cls._stderr_files[simulation_id]:
try:
cls._stderr_files[simulation_id].close()
except Exception:
pass
cls._stderr_files.pop(simulation_id, None)
@classmethod
def _read_action_log(
cls,
log_path: str,
position: int,
state: SimulationRunState,
platform: str
) -> int:
"""
读取动作日志文件
Args:
log_path: 日志文件路径
position: 上次读取位置
state: 运行状态对象
platform: 平台名称 (twitter/reddit)
Returns:
新的读取位置
"""
try:
with open(log_path, 'r', encoding='utf-8') as f:
f.seek(position)
for line in f:
line = line.strip()
if line:
try:
action_data = json.loads(line)
# 跳过事件类型的条目(如 simulation_start, round_start 等)
if "event_type" in action_data:
continue
action = AgentAction(
round_num=action_data.get("round", 0),
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
platform=platform,
agent_id=action_data.get("agent_id", 0),
agent_name=action_data.get("agent_name", ""),
action_type=action_data.get("action_type", ""),
action_args=action_data.get("action_args", {}),
result=action_data.get("result"),
success=action_data.get("success", True),
)
state.add_action(action)
# 更新轮次
if action.round_num and action.round_num > state.current_round:
state.current_round = action.round_num
except json.JSONDecodeError:
pass
return f.tell()
except Exception as e:
logger.warning(f"读取动作日志失败: {log_path}, error={e}")
return position
@classmethod
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
"""停止模拟"""
@@ -510,12 +566,35 @@ class SimulationRunner:
# 终止进程
process = cls._processes.get(simulation_id)
if process:
process.terminate()
if process and process.poll() is None:
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
process.kill()
# 使用进程组 ID 终止整个进程组(包括所有子进程)
# 由于使用了 start_new_session=True进程组 ID 等于主进程 PID
pgid = os.getpgid(process.pid)
logger.info(f"终止进程组: simulation={simulation_id}, pgid={pgid}")
# 先发送 SIGTERM 给整个进程组
os.killpg(pgid, signal.SIGTERM)
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
# 如果 10 秒后还没结束,强制发送 SIGKILL
logger.warning(f"进程组未响应 SIGTERM强制终止: {simulation_id}")
os.killpg(pgid, signal.SIGKILL)
process.wait(timeout=5)
except ProcessLookupError:
# 进程已经不存在
pass
except Exception as e:
logger.error(f"终止进程组失败: {simulation_id}, error={e}")
# 回退到直接终止进程
try:
process.terminate()
process.wait(timeout=5)
except Exception:
process.kill()
state.runner_status = RunnerStatus.STOPPED
state.twitter_running = False
@@ -709,4 +788,133 @@ class SimulationRunner:
result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True)
return result
@classmethod
def cleanup_all_simulations(cls):
"""
清理所有运行中的模拟进程
在服务器关闭时调用,确保所有子进程被终止
"""
logger.info("正在清理所有模拟进程...")
# 复制字典以避免在迭代时修改
processes = list(cls._processes.items())
for simulation_id, process in processes:
try:
if process.poll() is None: # 进程仍在运行
logger.info(f"终止模拟进程: {simulation_id}, pid={process.pid}")
try:
# 使用进程组终止(包括所有子进程)
pgid = os.getpgid(process.pid)
os.killpg(pgid, signal.SIGTERM)
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
logger.warning(f"进程组未响应 SIGTERM强制终止: {simulation_id}")
os.killpg(pgid, signal.SIGKILL)
process.wait(timeout=5)
except (ProcessLookupError, OSError):
# 进程可能已经不存在,尝试直接终止
try:
process.terminate()
process.wait(timeout=3)
except Exception:
process.kill()
# 更新状态
state = cls.get_run_state(simulation_id)
if state:
state.runner_status = RunnerStatus.STOPPED
state.twitter_running = False
state.reddit_running = False
state.completed_at = datetime.now().isoformat()
state.error = "服务器关闭,模拟被终止"
cls._save_run_state(state)
except Exception as e:
logger.error(f"清理进程失败: {simulation_id}, error={e}")
# 清理文件句柄
for simulation_id, file_handle in list(cls._stdout_files.items()):
try:
if file_handle:
file_handle.close()
except Exception:
pass
cls._stdout_files.clear()
for simulation_id, file_handle in list(cls._stderr_files.items()):
try:
if file_handle:
file_handle.close()
except Exception:
pass
cls._stderr_files.clear()
# 清理内存中的状态
cls._processes.clear()
cls._action_queues.clear()
logger.info("模拟进程清理完成")
@classmethod
def register_cleanup(cls):
"""
注册清理函数
在 Flask 应用启动时调用,确保服务器关闭时清理所有模拟进程
"""
global _cleanup_registered
if _cleanup_registered:
return
# 保存原有的信号处理器
original_sigint = signal.getsignal(signal.SIGINT)
original_sigterm = signal.getsignal(signal.SIGTERM)
def cleanup_handler(signum=None, frame=None):
"""信号处理器:先清理模拟进程,再调用原处理器"""
logger.info(f"收到信号 {signum},开始清理...")
cls.cleanup_all_simulations()
# 调用原有的信号处理器,让 Flask 正常退出
if signum == signal.SIGINT and callable(original_sigint):
original_sigint(signum, frame)
elif signum == signal.SIGTERM and callable(original_sigterm):
original_sigterm(signum, frame)
else:
# 如果原处理器不可调用(如 SIG_DFL则使用默认行为
raise KeyboardInterrupt
# 注册 atexit 处理器(作为备用)
atexit.register(cls.cleanup_all_simulations)
# 注册信号处理器(仅在主线程中)
try:
# SIGTERM: kill 命令默认信号
signal.signal(signal.SIGTERM, cleanup_handler)
# SIGINT: Ctrl+C
signal.signal(signal.SIGINT, cleanup_handler)
except ValueError:
# 不在主线程中,只能使用 atexit
logger.warning("无法注册信号处理器(不在主线程),仅使用 atexit")
_cleanup_registered = True
@classmethod
def get_running_simulations(cls) -> List[str]:
"""
获取所有正在运行的模拟ID列表
"""
running = []
for sim_id, process in cls._processes.items():
if process.poll() is None:
running.append(sim_id)
return running

View File

@@ -1,29 +1,214 @@
"""
动作日志记录器
用于记录OASIS模拟中每个Agent的动作供后端监控使用
日志结构:
sim_xxx/
├── twitter/
│ └── actions.jsonl # Twitter 平台动作日志
├── reddit/
│ └── actions.jsonl # Reddit 平台动作日志
├── simulation.log # 主模拟进程日志
└── run_state.json # 运行状态API 查询用)
"""
import json
import os
import logging
from datetime import datetime
from typing import Dict, Any, Optional
class ActionLogger:
"""动作日志记录器"""
class PlatformActionLogger:
"""单平台动作日志记录器"""
def __init__(self, log_path: str):
def __init__(self, platform: str, base_dir: str):
"""
初始化日志记录器
Args:
log_path: 日志文件路径(.jsonl格式
platform: 平台名称 (twitter/reddit)
base_dir: 模拟目录的基础路径
"""
self.log_path = log_path
self.platform = platform
self.base_dir = base_dir
self.log_dir = os.path.join(base_dir, platform)
self.log_path = os.path.join(self.log_dir, "actions.jsonl")
self._ensure_dir()
def _ensure_dir(self):
"""确保目录存在"""
os.makedirs(self.log_dir, exist_ok=True)
def log_action(
self,
round_num: int,
agent_id: int,
agent_name: str,
action_type: str,
action_args: Optional[Dict[str, Any]] = None,
result: Optional[str] = None,
success: bool = True
):
"""记录一个动作"""
entry = {
"round": round_num,
"timestamp": datetime.now().isoformat(),
"agent_id": agent_id,
"agent_name": agent_name,
"action_type": action_type,
"action_args": action_args or {},
"result": result,
"success": success,
}
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_round_start(self, round_num: int, simulated_hour: int):
"""记录轮次开始"""
entry = {
"round": round_num,
"timestamp": datetime.now().isoformat(),
"event_type": "round_start",
"simulated_hour": simulated_hour,
}
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_round_end(self, round_num: int, actions_count: int):
"""记录轮次结束"""
entry = {
"round": round_num,
"timestamp": datetime.now().isoformat(),
"event_type": "round_end",
"actions_count": actions_count,
}
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_simulation_start(self, config: Dict[str, Any]):
"""记录模拟开始"""
entry = {
"timestamp": datetime.now().isoformat(),
"event_type": "simulation_start",
"platform": self.platform,
"total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2,
"agents_count": len(config.get("agent_configs", [])),
}
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_simulation_end(self, total_rounds: int, total_actions: int):
"""记录模拟结束"""
entry = {
"timestamp": datetime.now().isoformat(),
"event_type": "simulation_end",
"platform": self.platform,
"total_rounds": total_rounds,
"total_actions": total_actions,
}
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
class SimulationLogManager:
"""
模拟日志管理器
统一管理所有日志文件,按平台分离
"""
def __init__(self, simulation_dir: str):
"""
初始化日志管理器
Args:
simulation_dir: 模拟目录路径
"""
self.simulation_dir = simulation_dir
self.twitter_logger: Optional[PlatformActionLogger] = None
self.reddit_logger: Optional[PlatformActionLogger] = None
self._main_logger: Optional[logging.Logger] = None
# 设置主日志
self._setup_main_logger()
def _setup_main_logger(self):
"""设置主模拟日志"""
log_path = os.path.join(self.simulation_dir, "simulation.log")
# 创建 logger
self._main_logger = logging.getLogger(f"simulation.{os.path.basename(self.simulation_dir)}")
self._main_logger.setLevel(logging.INFO)
self._main_logger.handlers.clear()
# 文件处理器
file_handler = logging.FileHandler(log_path, encoding='utf-8', mode='w')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
))
self._main_logger.addHandler(file_handler)
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(
'[%(asctime)s] %(message)s',
datefmt='%H:%M:%S'
))
self._main_logger.addHandler(console_handler)
self._main_logger.propagate = False
def get_twitter_logger(self) -> PlatformActionLogger:
"""获取 Twitter 平台日志记录器"""
if self.twitter_logger is None:
self.twitter_logger = PlatformActionLogger("twitter", self.simulation_dir)
return self.twitter_logger
def get_reddit_logger(self) -> PlatformActionLogger:
"""获取 Reddit 平台日志记录器"""
if self.reddit_logger is None:
self.reddit_logger = PlatformActionLogger("reddit", self.simulation_dir)
return self.reddit_logger
def log(self, message: str, level: str = "info"):
"""记录主日志"""
if self._main_logger:
getattr(self._main_logger, level.lower(), self._main_logger.info)(message)
def info(self, message: str):
self.log(message, "info")
def warning(self, message: str):
self.log(message, "warning")
def error(self, message: str):
self.log(message, "error")
def debug(self, message: str):
self.log(message, "debug")
# ============ 兼容旧接口 ============
class ActionLogger:
"""
动作日志记录器(兼容旧接口)
建议使用 SimulationLogManager 代替
"""
def __init__(self, log_path: str):
self.log_path = log_path
self._ensure_dir()
def _ensure_dir(self):
log_dir = os.path.dirname(self.log_path)
if log_dir:
os.makedirs(log_dir, exist_ok=True)
@@ -39,19 +224,6 @@ class ActionLogger:
result: Optional[str] = None,
success: bool = True
):
"""
记录一个动作
Args:
round_num: 轮次
platform: 平台 (twitter/reddit)
agent_id: Agent ID
agent_name: Agent名称
action_type: 动作类型
action_args: 动作参数
result: 执行结果
success: 是否成功
"""
entry = {
"round": round_num,
"timestamp": datetime.now().isoformat(),
@@ -68,7 +240,6 @@ class ActionLogger:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_round_start(self, round_num: int, simulated_hour: int, platform: str):
"""记录轮次开始"""
entry = {
"round": round_num,
"timestamp": datetime.now().isoformat(),
@@ -81,7 +252,6 @@ class ActionLogger:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_round_end(self, round_num: int, actions_count: int, platform: str):
"""记录轮次结束"""
entry = {
"round": round_num,
"timestamp": datetime.now().isoformat(),
@@ -94,7 +264,6 @@ class ActionLogger:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_simulation_start(self, platform: str, config: Dict[str, Any]):
"""记录模拟开始"""
entry = {
"timestamp": datetime.now().isoformat(),
"platform": platform,
@@ -107,7 +276,6 @@ class ActionLogger:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_simulation_end(self, platform: str, total_rounds: int, total_actions: int):
"""记录模拟结束"""
entry = {
"timestamp": datetime.now().isoformat(),
"platform": platform,
@@ -120,12 +288,12 @@ class ActionLogger:
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
# 全局日志实例(可选
# 全局日志实例(兼容旧接口
_global_logger: Optional[ActionLogger] = None
def get_logger(log_path: Optional[str] = None) -> ActionLogger:
"""获取全局日志实例"""
"""获取全局日志实例(兼容旧接口)"""
global _global_logger
if log_path:
@@ -135,4 +303,3 @@ def get_logger(log_path: Optional[str] = None) -> ActionLogger:
_global_logger = ActionLogger("actions.jsonl")
return _global_logger

View File

@@ -3,7 +3,16 @@ OASIS 双平台并行模拟预设脚本
同时运行Twitter和Reddit模拟读取相同的配置文件
使用方式:
python run_parallel_simulation.py --config simulation_config.json [--action-log actions.jsonl]
python run_parallel_simulation.py --config simulation_config.json
日志结构:
sim_xxx/
├── twitter/
│ └── actions.jsonl # Twitter 平台动作日志
├── reddit/
│ └── actions.jsonl # Reddit 平台动作日志
├── simulation.log # 主模拟进程日志
└── run_state.json # 运行状态API 查询用)
"""
import argparse
@@ -12,9 +21,10 @@ import json
import logging
import os
import random
import sqlite3
import sys
from datetime import datetime
from typing import Dict, Any, List, Optional
from typing import Dict, Any, List, Optional, Tuple
# 添加 backend 目录到路径
# 脚本固定位于 backend/scripts/ 目录
@@ -38,91 +48,45 @@ else:
print(f"已加载环境配置: {_backend_env}")
class UnicodeFormatter(logging.Formatter):
def disable_oasis_logging():
"""
自定义格式化器,将 Unicode 转义序列(如 \\uXXXX转换为可读字符
禁用 OASIS 库的详细日志输出
OASIS 的日志太冗余(记录每个 agent 的观察和动作),我们使用自己的 action_logger
"""
# 禁用 OASIS 的所有日志器
oasis_loggers = [
"social.agent",
"social.twitter",
"social.rec",
"oasis.env",
"table",
]
# 匹配 \uXXXX 形式的 Unicode 转义序列
UNICODE_ESCAPE_PATTERN = None
@classmethod
def _get_pattern(cls):
if cls.UNICODE_ESCAPE_PATTERN is None:
import re
cls.UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
return cls.UNICODE_ESCAPE_PATTERN
def format(self, record):
# 先获取原始格式化结果
result = super().format(record)
# 使用正则表达式替换 Unicode 转义序列
pattern = self._get_pattern()
def replace_unicode(match):
try:
return chr(int(match.group(1), 16))
except (ValueError, OverflowError):
return match.group(0)
return pattern.sub(replace_unicode, result)
def setup_oasis_logging(log_dir: str):
"""
配置 OASIS 的日志,覆盖默认的带时间戳日志文件
Args:
log_dir: 日志目录路径
"""
os.makedirs(log_dir, exist_ok=True)
# 清理旧的日志文件
for f in os.listdir(log_dir):
old_log = os.path.join(log_dir, f)
if os.path.isfile(old_log) and f.endswith('.log'):
try:
os.remove(old_log)
except OSError:
pass
# 创建自定义格式化器(支持 Unicode 解码)
formatter = UnicodeFormatter(
"%(levelname)s - %(asctime)s - %(name)s - %(message)s"
)
# 重新配置 OASIS 使用的日志器,使用固定名称(不带时间戳)
loggers_config = {
"social.agent": os.path.join(log_dir, "social.agent.log"),
"social.twitter": os.path.join(log_dir, "social.twitter.log"),
"social.rec": os.path.join(log_dir, "social.rec.log"),
"oasis.env": os.path.join(log_dir, "oasis.env.log"),
"table": os.path.join(log_dir, "table.log"),
}
for logger_name, log_file in loggers_config.items():
for logger_name in oasis_loggers:
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG)
# 清除 OASIS 添加的现有处理器(带时间戳的日志文件)
logger.setLevel(logging.CRITICAL) # 只记录严重错误
logger.handlers.clear()
# 添加新的文件处理器(使用 UTF-8 编码,固定文件名)
file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# 防止日志向上传播(避免重复)
logger.propagate = False
print(f"日志配置完成,日志目录: {log_dir}")
def init_logging_for_simulation(simulation_dir: str):
"""初始化模拟的日志配置"""
log_dir = os.path.join(simulation_dir, "log")
setup_oasis_logging(log_dir)
"""
初始化模拟的日志配置
Args:
simulation_dir: 模拟目录路径
"""
# 禁用 OASIS 的详细日志
disable_oasis_logging()
# 清理旧的 log 目录(如果存在)
old_log_dir = os.path.join(simulation_dir, "log")
if os.path.exists(old_log_dir):
import shutil
shutil.rmtree(old_log_dir, ignore_errors=True)
from action_logger import ActionLogger
from action_logger import SimulationLogManager, PlatformActionLogger
try:
from camel.models import ModelFactory
@@ -175,6 +139,120 @@ def load_config(config_path: str) -> Dict[str, Any]:
return json.load(f)
# 需要过滤掉的非核心动作类型(这些动作对分析价值较低)
FILTERED_ACTIONS = {'refresh', 'sign_up'}
# 动作类型映射表(数据库中的名称 -> 标准名称)
ACTION_TYPE_MAP = {
'create_post': 'CREATE_POST',
'like_post': 'LIKE_POST',
'dislike_post': 'DISLIKE_POST',
'repost': 'REPOST',
'quote_post': 'QUOTE_POST',
'follow': 'FOLLOW',
'mute': 'MUTE',
'create_comment': 'CREATE_COMMENT',
'like_comment': 'LIKE_COMMENT',
'dislike_comment': 'DISLIKE_COMMENT',
'search_posts': 'SEARCH_POSTS',
'search_user': 'SEARCH_USER',
'trend': 'TREND',
'do_nothing': 'DO_NOTHING',
'interview': 'INTERVIEW',
}
def fetch_new_actions_from_db(
db_path: str,
last_rowid: int,
agent_names: Dict[int, str]
) -> Tuple[List[Dict[str, Any]], int]:
"""
从数据库中获取新的动作记录
Args:
db_path: 数据库文件路径
last_rowid: 上次读取的最大 rowid 值(使用 rowid 而不是 created_at因为不同平台的 created_at 格式不同)
agent_names: agent_id -> agent_name 映射
Returns:
(actions_list, new_last_rowid)
- actions_list: 动作列表,每个元素包含 agent_id, agent_name, action_type, action_args
- new_last_rowid: 新的最大 rowid 值
"""
actions = []
new_last_rowid = last_rowid
if not os.path.exists(db_path):
return actions, new_last_rowid
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# 使用 rowid 来追踪已处理的记录rowid 是 SQLite 的内置自增字段)
# 这样可以避免 created_at 格式差异问题Twitter 用整数Reddit 用日期时间字符串)
cursor.execute("""
SELECT rowid, user_id, action, info
FROM trace
WHERE rowid > ?
ORDER BY rowid ASC
""", (last_rowid,))
for rowid, user_id, action, info_json in cursor.fetchall():
# 更新最大 rowid
new_last_rowid = rowid
# 过滤非核心动作
if action in FILTERED_ACTIONS:
continue
# 解析动作参数
try:
action_args = json.loads(info_json) if info_json else {}
except json.JSONDecodeError:
action_args = {}
# 精简 action_args只保留关键字段
simplified_args = {}
if 'content' in action_args:
content = action_args['content']
# 截断过长的内容
simplified_args['content'] = content[:200] + '...' if len(content) > 200 else content
if 'post_id' in action_args:
simplified_args['post_id'] = action_args['post_id']
if 'comment_id' in action_args:
simplified_args['comment_id'] = action_args['comment_id']
if 'quoted_id' in action_args:
simplified_args['quoted_id'] = action_args['quoted_id']
if 'new_post_id' in action_args:
simplified_args['new_post_id'] = action_args['new_post_id']
if 'follow_id' in action_args:
simplified_args['follow_id'] = action_args['follow_id']
if 'query' in action_args:
simplified_args['query'] = action_args['query']
if 'like_id' in action_args:
simplified_args['like_id'] = action_args['like_id']
if 'dislike_id' in action_args:
simplified_args['dislike_id'] = action_args['dislike_id']
# 转换动作类型名称
action_type = ACTION_TYPE_MAP.get(action, action.upper())
actions.append({
'agent_id': user_id,
'agent_name': agent_names.get(user_id, f'Agent_{user_id}'),
'action_type': action_type,
'action_args': simplified_args,
})
conn.close()
except Exception as e:
print(f"读取数据库动作失败: {e}")
return actions, new_last_rowid
def create_model(config: Dict[str, Any]):
"""
创建LLM模型
@@ -269,17 +347,23 @@ def get_active_agents_for_round(
async def run_twitter_simulation(
config: Dict[str, Any],
simulation_dir: str,
action_logger: Optional[ActionLogger] = None
action_logger: Optional[PlatformActionLogger] = None,
main_logger: Optional[SimulationLogManager] = None
):
"""运行Twitter模拟"""
print("[Twitter] 初始化...")
def log_info(msg):
if main_logger:
main_logger.info(f"[Twitter] {msg}")
print(f"[Twitter] {msg}")
log_info("初始化...")
model = create_model(config)
# OASIS Twitter使用CSV格式
profile_path = os.path.join(simulation_dir, "twitter_profiles.csv")
if not os.path.exists(profile_path):
print(f"[Twitter] 错误: Profile文件不存在: {profile_path}")
log_info(f"错误: Profile文件不存在: {profile_path}")
return
agent_graph = await generate_twitter_agent_graph(
@@ -304,12 +388,13 @@ async def run_twitter_simulation(
)
await env.reset()
print("[Twitter] 环境已启动")
log_info("环境已启动")
if action_logger:
action_logger.log_simulation_start("twitter", config)
action_logger.log_simulation_start(config)
total_actions = 0
last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异)
# 执行初始事件
event_config = config.get("event_config", {})
@@ -330,7 +415,6 @@ async def run_twitter_simulation(
if action_logger:
action_logger.log_action(
round_num=0,
platform="twitter",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="CREATE_POST",
@@ -342,7 +426,7 @@ async def run_twitter_simulation(
if initial_actions:
await env.step(initial_actions)
print(f"[Twitter] 已发布 {len(initial_actions)} 条初始帖子")
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
# 主模拟循环
time_config = config.get("time_config", {})
@@ -365,54 +449,64 @@ async def run_twitter_simulation(
continue
if action_logger:
action_logger.log_round_start(round_num + 1, simulated_hour, "twitter")
action_logger.log_round_start(round_num + 1, simulated_hour)
actions = {agent: LLMAction() for _, agent in active_agents}
await env.step(actions)
# 记录动作
for agent_id, agent in active_agents:
# 从数据库获取实际执行的动作并记录
actual_actions, last_rowid = fetch_new_actions_from_db(
db_path, last_rowid, agent_names
)
round_action_count = 0
for action_data in actual_actions:
if action_logger:
action_logger.log_action(
round_num=round_num + 1,
platform="twitter",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="LLM_ACTION",
action_args={}
agent_id=action_data['agent_id'],
agent_name=action_data['agent_name'],
action_type=action_data['action_type'],
action_args=action_data['action_args']
)
total_actions += 1
round_action_count += 1
if action_logger:
action_logger.log_round_end(round_num + 1, len(active_agents), "twitter")
action_logger.log_round_end(round_num + 1, round_action_count)
if (round_num + 1) % 20 == 0:
progress = (round_num + 1) / total_rounds * 100
print(f"[Twitter] Day {simulated_day}, {simulated_hour:02d}:00 "
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
await env.close()
if action_logger:
action_logger.log_simulation_end("twitter", total_rounds, total_actions)
action_logger.log_simulation_end(total_rounds, total_actions)
elapsed = (datetime.now() - start_time).total_seconds()
print(f"[Twitter] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
async def run_reddit_simulation(
config: Dict[str, Any],
simulation_dir: str,
action_logger: Optional[ActionLogger] = None
action_logger: Optional[PlatformActionLogger] = None,
main_logger: Optional[SimulationLogManager] = None
):
"""运行Reddit模拟"""
print("[Reddit] 初始化...")
def log_info(msg):
if main_logger:
main_logger.info(f"[Reddit] {msg}")
print(f"[Reddit] {msg}")
log_info("初始化...")
model = create_model(config)
profile_path = os.path.join(simulation_dir, "reddit_profiles.json")
if not os.path.exists(profile_path):
print(f"[Reddit] 错误: Profile文件不存在: {profile_path}")
log_info(f"错误: Profile文件不存在: {profile_path}")
return
agent_graph = await generate_reddit_agent_graph(
@@ -437,12 +531,13 @@ async def run_reddit_simulation(
)
await env.reset()
print("[Reddit] 环境已启动")
log_info("环境已启动")
if action_logger:
action_logger.log_simulation_start("reddit", config)
action_logger.log_simulation_start(config)
total_actions = 0
last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异)
# 执行初始事件
event_config = config.get("event_config", {})
@@ -471,7 +566,6 @@ async def run_reddit_simulation(
if action_logger:
action_logger.log_action(
round_num=0,
platform="reddit",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="CREATE_POST",
@@ -483,7 +577,7 @@ async def run_reddit_simulation(
if initial_actions:
await env.step(initial_actions)
print(f"[Reddit] 已发布 {len(initial_actions)} 条初始帖子")
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
# 主模拟循环
time_config = config.get("time_config", {})
@@ -506,39 +600,43 @@ async def run_reddit_simulation(
continue
if action_logger:
action_logger.log_round_start(round_num + 1, simulated_hour, "reddit")
action_logger.log_round_start(round_num + 1, simulated_hour)
actions = {agent: LLMAction() for _, agent in active_agents}
await env.step(actions)
# 记录动作
for agent_id, agent in active_agents:
# 从数据库获取实际执行的动作并记录
actual_actions, last_rowid = fetch_new_actions_from_db(
db_path, last_rowid, agent_names
)
round_action_count = 0
for action_data in actual_actions:
if action_logger:
action_logger.log_action(
round_num=round_num + 1,
platform="reddit",
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="LLM_ACTION",
action_args={}
agent_id=action_data['agent_id'],
agent_name=action_data['agent_name'],
action_type=action_data['action_type'],
action_args=action_data['action_args']
)
total_actions += 1
round_action_count += 1
if action_logger:
action_logger.log_round_end(round_num + 1, len(active_agents), "reddit")
action_logger.log_round_end(round_num + 1, round_action_count)
if (round_num + 1) % 20 == 0:
progress = (round_num + 1) / total_rounds * 100
print(f"[Reddit] Day {simulated_day}, {simulated_hour:02d}:00 "
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
await env.close()
if action_logger:
action_logger.log_simulation_end("reddit", total_rounds, total_actions)
action_logger.log_simulation_end(total_rounds, total_actions)
elapsed = (datetime.now() - start_time).total_seconds()
print(f"[Reddit] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
async def main():
@@ -559,12 +657,6 @@ async def main():
action='store_true',
help='只运行Reddit模拟'
)
parser.add_argument(
'--action-log',
type=str,
default='actions.jsonl',
help='动作日志文件路径 (默认: actions.jsonl)'
)
args = parser.parse_args()
@@ -575,52 +667,53 @@ async def main():
config = load_config(args.config)
simulation_dir = os.path.dirname(args.config) or "."
# 初始化日志配置(清理旧日志文件,使用固定名称
# 初始化日志配置(禁用 OASIS 日志,清理旧文件
init_logging_for_simulation(simulation_dir)
# 创建动作日志记录
action_log_path = os.path.join(simulation_dir, args.action_log)
action_logger = ActionLogger(action_log_path)
# 创建日志管理
log_manager = SimulationLogManager(simulation_dir)
twitter_logger = log_manager.get_twitter_logger()
reddit_logger = log_manager.get_reddit_logger()
print("=" * 60)
print("OASIS 双平台并行模拟")
print(f"配置文件: {args.config}")
print(f"模拟ID: {config.get('simulation_id', 'unknown')}")
print(f"动作日志: {action_log_path}")
print("=" * 60)
log_manager.info("=" * 60)
log_manager.info("OASIS 双平台并行模拟")
log_manager.info(f"配置文件: {args.config}")
log_manager.info(f"模拟ID: {config.get('simulation_id', 'unknown')}")
log_manager.info("=" * 60)
time_config = config.get("time_config", {})
print(f"\n模拟参数:")
print(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
print(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
print(f" - Agent数量: {len(config.get('agent_configs', []))}")
log_manager.info(f"模拟参数:")
log_manager.info(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
log_manager.info(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}")
# LLM推理说明
reasoning = config.get("generation_reasoning", "")
if reasoning:
print(f"\nLLM配置推理:")
print(f" {reasoning[:500]}..." if len(reasoning) > 500 else f" {reasoning}")
print("\n" + "=" * 60)
log_manager.info("日志结构:")
log_manager.info(f" - 主日志: simulation.log")
log_manager.info(f" - Twitter动作: twitter/actions.jsonl")
log_manager.info(f" - Reddit动作: reddit/actions.jsonl")
log_manager.info("=" * 60)
start_time = datetime.now()
if args.twitter_only:
await run_twitter_simulation(config, simulation_dir, action_logger)
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager)
elif args.reddit_only:
await run_reddit_simulation(config, simulation_dir, action_logger)
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager)
else:
# 并行运行(共享同一个action_logger
# 并行运行(每个平台使用独立的日志记录器
await asyncio.gather(
run_twitter_simulation(config, simulation_dir, action_logger),
run_reddit_simulation(config, simulation_dir, action_logger),
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager),
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager),
)
total_elapsed = (datetime.now() - start_time).total_seconds()
print("\n" + "=" * 60)
print(f"全部模拟完成! 总耗时: {total_elapsed:.1f}")
print(f"动作日志已保存到: {action_log_path}")
print("=" * 60)
log_manager.info("=" * 60)
log_manager.info(f"全部模拟完成! 总耗时: {total_elapsed:.1f}")
log_manager.info(f"日志文件:")
log_manager.info(f" - {os.path.join(simulation_dir, 'simulation.log')}")
log_manager.info(f" - {os.path.join(simulation_dir, 'twitter', 'actions.jsonl')}")
log_manager.info(f" - {os.path.join(simulation_dir, 'reddit', 'actions.jsonl')}")
log_manager.info("=" * 60)
if __name__ == "__main__":