Enhance simulation management and logging features
- Registered a cleanup function for simulation processes to ensure proper termination on server shutdown. - Improved logging during application startup to confirm the registration of the cleanup function. - Updated simulation preparation checks to clarify the conditions for considering a simulation ready, enhancing error handling and user feedback. - Added detailed logging for simulation status changes, improving traceability during the simulation lifecycle. - Introduced new files for simulation configuration and profile data, supporting enhanced testing and visualization capabilities.
This commit is contained in:
@@ -31,6 +31,12 @@ def create_app(config_class=Config):
|
||||
# 启用CORS
|
||||
CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||
|
||||
# 注册模拟进程清理函数(确保服务器关闭时终止所有模拟进程)
|
||||
from .services.simulation_runner import SimulationRunner
|
||||
SimulationRunner.register_cleanup()
|
||||
if should_log_startup:
|
||||
logger.info("已注册模拟进程清理函数")
|
||||
|
||||
# 请求日志中间件
|
||||
@app.before_request
|
||||
def log_request():
|
||||
|
||||
@@ -276,8 +276,16 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
|
||||
# 详细日志
|
||||
logger.debug(f"检测模拟准备状态: {simulation_id}, status={status}, config_generated={config_generated}")
|
||||
|
||||
# 如果状态是ready或preparing(已有文件),认为准备完成
|
||||
if status in ["ready", "preparing"] and config_generated:
|
||||
# 如果 config_generated=True 且文件存在,认为准备完成
|
||||
# 以下状态都说明准备工作已完成:
|
||||
# - ready: 准备完成,可以运行
|
||||
# - preparing: 如果 config_generated=True 说明已完成
|
||||
# - running: 正在运行,说明准备早就完成了
|
||||
# - completed: 运行完成,说明准备早就完成了
|
||||
# - stopped: 已停止,说明准备早就完成了
|
||||
# - failed: 运行失败(但准备是完成的)
|
||||
prepared_statuses = ["ready", "preparing", "running", "completed", "stopped", "failed"]
|
||||
if status in prepared_statuses and config_generated:
|
||||
# 获取文件统计信息
|
||||
profiles_file = os.path.join(simulation_dir, "reddit_profiles.json")
|
||||
config_file = os.path.join(simulation_dir, "simulation_config.json")
|
||||
@@ -315,7 +323,7 @@ def _check_simulation_prepared(simulation_id: str) -> tuple:
|
||||
else:
|
||||
logger.warning(f"模拟 {simulation_id} 检测结果: 未准备完成 (status={status}, config_generated={config_generated})")
|
||||
return False, {
|
||||
"reason": f"状态不是ready或config_generated为false: status={status}, config_generated={config_generated}",
|
||||
"reason": f"状态不在已准备列表中或config_generated为false: status={status}, config_generated={config_generated}",
|
||||
"status": status,
|
||||
"config_generated": config_generated
|
||||
}
|
||||
@@ -1040,11 +1048,33 @@ def start_simulation():
|
||||
"error": f"模拟不存在: {simulation_id}"
|
||||
}), 404
|
||||
|
||||
# 智能处理状态:如果准备工作已完成,允许重新启动
|
||||
if state.status != SimulationStatus.READY:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": f"模拟未准备好,当前状态: {state.status.value},请先调用 /prepare 接口"
|
||||
}), 400
|
||||
# 检查准备工作是否已完成
|
||||
is_prepared, prepare_info = _check_simulation_prepared(simulation_id)
|
||||
|
||||
if is_prepared:
|
||||
# 准备工作已完成,检查是否有正在运行的进程
|
||||
if state.status == SimulationStatus.RUNNING:
|
||||
# 检查模拟进程是否真的在运行
|
||||
run_state = SimulationRunner.get_run_state(simulation_id)
|
||||
if run_state and run_state.runner_status.value == "running":
|
||||
# 进程确实在运行
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": f"模拟正在运行中,请先调用 /stop 接口停止"
|
||||
}), 400
|
||||
|
||||
# 进程不存在或已结束,重置状态为 ready
|
||||
logger.info(f"模拟 {simulation_id} 准备工作已完成,重置状态为 ready(原状态: {state.status.value})")
|
||||
state.status = SimulationStatus.READY
|
||||
manager._save_simulation_state(state)
|
||||
else:
|
||||
# 准备工作未完成
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": f"模拟未准备好,当前状态: {state.status.value},请先调用 /prepare 接口"
|
||||
}), 400
|
||||
|
||||
# 启动模拟
|
||||
run_state = SimulationRunner.start_simulation(simulation_id, platform)
|
||||
|
||||
@@ -10,6 +10,8 @@ import time
|
||||
import asyncio
|
||||
import threading
|
||||
import subprocess
|
||||
import signal
|
||||
import atexit
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
@@ -21,6 +23,9 @@ from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.simulation_runner')
|
||||
|
||||
# 标记是否已注册清理函数
|
||||
_cleanup_registered = False
|
||||
|
||||
|
||||
class RunnerStatus(str, Enum):
|
||||
"""运行器状态"""
|
||||
@@ -342,34 +347,36 @@ class SimulationRunner:
|
||||
# 启动模拟进程
|
||||
try:
|
||||
# 构建运行命令,使用完整路径
|
||||
action_log_path = os.path.join(sim_dir, "actions.jsonl")
|
||||
# 新的日志结构:
|
||||
# twitter/actions.jsonl - Twitter 动作日志
|
||||
# reddit/actions.jsonl - Reddit 动作日志
|
||||
# simulation.log - 主进程日志
|
||||
|
||||
cmd = [
|
||||
sys.executable, # Python解释器
|
||||
script_path,
|
||||
"--config", config_path, # 使用完整配置文件路径
|
||||
"--action-log", action_log_path, # 动作日志文件完整路径
|
||||
]
|
||||
|
||||
# 创建输出日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
|
||||
stdout_log_path = os.path.join(sim_dir, "simulation_stdout.log")
|
||||
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
|
||||
stdout_file = open(stdout_log_path, 'w', encoding='utf-8')
|
||||
stderr_file = open(stderr_log_path, 'w', encoding='utf-8')
|
||||
# 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
|
||||
main_log_path = os.path.join(sim_dir, "simulation.log")
|
||||
main_log_file = open(main_log_path, 'w', encoding='utf-8')
|
||||
|
||||
# 设置工作目录为模拟目录(数据库等文件会生成在此)
|
||||
# 使用 start_new_session=True 创建新的进程组,确保可以通过 os.killpg 终止所有子进程
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=sim_dir,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file,
|
||||
stdout=main_log_file,
|
||||
stderr=subprocess.STDOUT, # stderr 也写入同一个文件
|
||||
text=True,
|
||||
bufsize=1,
|
||||
start_new_session=True, # 创建新进程组,确保服务器关闭时能终止所有相关进程
|
||||
)
|
||||
|
||||
# 保存文件句柄以便后续关闭
|
||||
cls._stdout_files[simulation_id] = stdout_file
|
||||
cls._stderr_files[simulation_id] = stderr_file
|
||||
cls._stdout_files[simulation_id] = main_log_file
|
||||
cls._stderr_files[simulation_id] = None # 不再需要单独的 stderr
|
||||
|
||||
state.process_pid = process.pid
|
||||
state.runner_status = RunnerStatus.RUNNING
|
||||
@@ -399,7 +406,10 @@ class SimulationRunner:
|
||||
def _monitor_simulation(cls, simulation_id: str):
|
||||
"""监控模拟进程,解析动作日志"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
actions_log = os.path.join(sim_dir, "actions.jsonl")
|
||||
|
||||
# 新的日志结构:分平台的动作日志
|
||||
twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl")
|
||||
reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl")
|
||||
|
||||
process = cls._processes.get(simulation_id)
|
||||
state = cls.get_run_state(simulation_id)
|
||||
@@ -407,43 +417,32 @@ class SimulationRunner:
|
||||
if not process or not state:
|
||||
return
|
||||
|
||||
last_position = 0
|
||||
twitter_position = 0
|
||||
reddit_position = 0
|
||||
|
||||
try:
|
||||
while process.poll() is None: # 进程仍在运行
|
||||
# 读取动作日志
|
||||
if os.path.exists(actions_log):
|
||||
with open(actions_log, 'r', encoding='utf-8') as f:
|
||||
f.seek(last_position)
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
action_data = json.loads(line)
|
||||
action = AgentAction(
|
||||
round_num=action_data.get("round", 0),
|
||||
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
|
||||
platform=action_data.get("platform", "unknown"),
|
||||
agent_id=action_data.get("agent_id", 0),
|
||||
agent_name=action_data.get("agent_name", ""),
|
||||
action_type=action_data.get("action_type", ""),
|
||||
action_args=action_data.get("action_args", {}),
|
||||
result=action_data.get("result"),
|
||||
success=action_data.get("success", True),
|
||||
)
|
||||
state.add_action(action)
|
||||
|
||||
# 更新轮次
|
||||
if action.round_num > state.current_round:
|
||||
state.current_round = action.round_num
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
last_position = f.tell()
|
||||
# 读取 Twitter 动作日志
|
||||
if os.path.exists(twitter_actions_log):
|
||||
twitter_position = cls._read_action_log(
|
||||
twitter_actions_log, twitter_position, state, "twitter"
|
||||
)
|
||||
|
||||
# 定期保存状态
|
||||
# 读取 Reddit 动作日志
|
||||
if os.path.exists(reddit_actions_log):
|
||||
reddit_position = cls._read_action_log(
|
||||
reddit_actions_log, reddit_position, state, "reddit"
|
||||
)
|
||||
|
||||
# 更新状态
|
||||
cls._save_run_state(state)
|
||||
time.sleep(1) # 每秒检查一次
|
||||
time.sleep(2)
|
||||
|
||||
# 进程结束后,最后读取一次日志
|
||||
if os.path.exists(twitter_actions_log):
|
||||
cls._read_action_log(twitter_actions_log, twitter_position, state, "twitter")
|
||||
if os.path.exists(reddit_actions_log):
|
||||
cls._read_action_log(reddit_actions_log, reddit_position, state, "reddit")
|
||||
|
||||
# 进程结束
|
||||
exit_code = process.returncode
|
||||
@@ -454,16 +453,16 @@ class SimulationRunner:
|
||||
logger.info(f"模拟完成: {simulation_id}")
|
||||
else:
|
||||
state.runner_status = RunnerStatus.FAILED
|
||||
# 从 stderr 日志文件读取错误信息
|
||||
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
|
||||
stderr = ""
|
||||
# 从主日志文件读取错误信息
|
||||
main_log_path = os.path.join(sim_dir, "simulation.log")
|
||||
error_info = ""
|
||||
try:
|
||||
if os.path.exists(stderr_log_path):
|
||||
with open(stderr_log_path, 'r', encoding='utf-8') as f:
|
||||
stderr = f.read()
|
||||
if os.path.exists(main_log_path):
|
||||
with open(main_log_path, 'r', encoding='utf-8') as f:
|
||||
error_info = f.read()[-2000:] # 取最后2000字符
|
||||
except Exception:
|
||||
pass
|
||||
state.error = f"进程退出码: {exit_code}, 错误: {stderr[-1000:]}" # 取最后1000字符
|
||||
state.error = f"进程退出码: {exit_code}, 错误: {error_info}"
|
||||
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
|
||||
|
||||
state.twitter_running = False
|
||||
@@ -488,13 +487,70 @@ class SimulationRunner:
|
||||
except Exception:
|
||||
pass
|
||||
cls._stdout_files.pop(simulation_id, None)
|
||||
if simulation_id in cls._stderr_files:
|
||||
if simulation_id in cls._stderr_files and cls._stderr_files[simulation_id]:
|
||||
try:
|
||||
cls._stderr_files[simulation_id].close()
|
||||
except Exception:
|
||||
pass
|
||||
cls._stderr_files.pop(simulation_id, None)
|
||||
|
||||
@classmethod
|
||||
def _read_action_log(
|
||||
cls,
|
||||
log_path: str,
|
||||
position: int,
|
||||
state: SimulationRunState,
|
||||
platform: str
|
||||
) -> int:
|
||||
"""
|
||||
读取动作日志文件
|
||||
|
||||
Args:
|
||||
log_path: 日志文件路径
|
||||
position: 上次读取位置
|
||||
state: 运行状态对象
|
||||
platform: 平台名称 (twitter/reddit)
|
||||
|
||||
Returns:
|
||||
新的读取位置
|
||||
"""
|
||||
try:
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
f.seek(position)
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
action_data = json.loads(line)
|
||||
|
||||
# 跳过事件类型的条目(如 simulation_start, round_start 等)
|
||||
if "event_type" in action_data:
|
||||
continue
|
||||
|
||||
action = AgentAction(
|
||||
round_num=action_data.get("round", 0),
|
||||
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
|
||||
platform=platform,
|
||||
agent_id=action_data.get("agent_id", 0),
|
||||
agent_name=action_data.get("agent_name", ""),
|
||||
action_type=action_data.get("action_type", ""),
|
||||
action_args=action_data.get("action_args", {}),
|
||||
result=action_data.get("result"),
|
||||
success=action_data.get("success", True),
|
||||
)
|
||||
state.add_action(action)
|
||||
|
||||
# 更新轮次
|
||||
if action.round_num and action.round_num > state.current_round:
|
||||
state.current_round = action.round_num
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return f.tell()
|
||||
except Exception as e:
|
||||
logger.warning(f"读取动作日志失败: {log_path}, error={e}")
|
||||
return position
|
||||
|
||||
@classmethod
|
||||
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
|
||||
"""停止模拟"""
|
||||
@@ -510,12 +566,35 @@ class SimulationRunner:
|
||||
|
||||
# 终止进程
|
||||
process = cls._processes.get(simulation_id)
|
||||
if process:
|
||||
process.terminate()
|
||||
if process and process.poll() is None:
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
# 使用进程组 ID 终止整个进程组(包括所有子进程)
|
||||
# 由于使用了 start_new_session=True,进程组 ID 等于主进程 PID
|
||||
pgid = os.getpgid(process.pid)
|
||||
logger.info(f"终止进程组: simulation={simulation_id}, pgid={pgid}")
|
||||
|
||||
# 先发送 SIGTERM 给整个进程组
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
# 如果 10 秒后还没结束,强制发送 SIGKILL
|
||||
logger.warning(f"进程组未响应 SIGTERM,强制终止: {simulation_id}")
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
process.wait(timeout=5)
|
||||
|
||||
except ProcessLookupError:
|
||||
# 进程已经不存在
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"终止进程组失败: {simulation_id}, error={e}")
|
||||
# 回退到直接终止进程
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=5)
|
||||
except Exception:
|
||||
process.kill()
|
||||
|
||||
state.runner_status = RunnerStatus.STOPPED
|
||||
state.twitter_running = False
|
||||
@@ -709,4 +788,133 @@ class SimulationRunner:
|
||||
result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def cleanup_all_simulations(cls):
|
||||
"""
|
||||
清理所有运行中的模拟进程
|
||||
|
||||
在服务器关闭时调用,确保所有子进程被终止
|
||||
"""
|
||||
logger.info("正在清理所有模拟进程...")
|
||||
|
||||
# 复制字典以避免在迭代时修改
|
||||
processes = list(cls._processes.items())
|
||||
|
||||
for simulation_id, process in processes:
|
||||
try:
|
||||
if process.poll() is None: # 进程仍在运行
|
||||
logger.info(f"终止模拟进程: {simulation_id}, pid={process.pid}")
|
||||
|
||||
try:
|
||||
# 使用进程组终止(包括所有子进程)
|
||||
pgid = os.getpgid(process.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"进程组未响应 SIGTERM,强制终止: {simulation_id}")
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
process.wait(timeout=5)
|
||||
|
||||
except (ProcessLookupError, OSError):
|
||||
# 进程可能已经不存在,尝试直接终止
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=3)
|
||||
except Exception:
|
||||
process.kill()
|
||||
|
||||
# 更新状态
|
||||
state = cls.get_run_state(simulation_id)
|
||||
if state:
|
||||
state.runner_status = RunnerStatus.STOPPED
|
||||
state.twitter_running = False
|
||||
state.reddit_running = False
|
||||
state.completed_at = datetime.now().isoformat()
|
||||
state.error = "服务器关闭,模拟被终止"
|
||||
cls._save_run_state(state)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理进程失败: {simulation_id}, error={e}")
|
||||
|
||||
# 清理文件句柄
|
||||
for simulation_id, file_handle in list(cls._stdout_files.items()):
|
||||
try:
|
||||
if file_handle:
|
||||
file_handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
cls._stdout_files.clear()
|
||||
|
||||
for simulation_id, file_handle in list(cls._stderr_files.items()):
|
||||
try:
|
||||
if file_handle:
|
||||
file_handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
cls._stderr_files.clear()
|
||||
|
||||
# 清理内存中的状态
|
||||
cls._processes.clear()
|
||||
cls._action_queues.clear()
|
||||
|
||||
logger.info("模拟进程清理完成")
|
||||
|
||||
@classmethod
|
||||
def register_cleanup(cls):
|
||||
"""
|
||||
注册清理函数
|
||||
|
||||
在 Flask 应用启动时调用,确保服务器关闭时清理所有模拟进程
|
||||
"""
|
||||
global _cleanup_registered
|
||||
|
||||
if _cleanup_registered:
|
||||
return
|
||||
|
||||
# 保存原有的信号处理器
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
original_sigterm = signal.getsignal(signal.SIGTERM)
|
||||
|
||||
def cleanup_handler(signum=None, frame=None):
|
||||
"""信号处理器:先清理模拟进程,再调用原处理器"""
|
||||
logger.info(f"收到信号 {signum},开始清理...")
|
||||
cls.cleanup_all_simulations()
|
||||
|
||||
# 调用原有的信号处理器,让 Flask 正常退出
|
||||
if signum == signal.SIGINT and callable(original_sigint):
|
||||
original_sigint(signum, frame)
|
||||
elif signum == signal.SIGTERM and callable(original_sigterm):
|
||||
original_sigterm(signum, frame)
|
||||
else:
|
||||
# 如果原处理器不可调用(如 SIG_DFL),则使用默认行为
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# 注册 atexit 处理器(作为备用)
|
||||
atexit.register(cls.cleanup_all_simulations)
|
||||
|
||||
# 注册信号处理器(仅在主线程中)
|
||||
try:
|
||||
# SIGTERM: kill 命令默认信号
|
||||
signal.signal(signal.SIGTERM, cleanup_handler)
|
||||
# SIGINT: Ctrl+C
|
||||
signal.signal(signal.SIGINT, cleanup_handler)
|
||||
except ValueError:
|
||||
# 不在主线程中,只能使用 atexit
|
||||
logger.warning("无法注册信号处理器(不在主线程),仅使用 atexit")
|
||||
|
||||
_cleanup_registered = True
|
||||
|
||||
@classmethod
|
||||
def get_running_simulations(cls) -> List[str]:
|
||||
"""
|
||||
获取所有正在运行的模拟ID列表
|
||||
"""
|
||||
running = []
|
||||
for sim_id, process in cls._processes.items():
|
||||
if process.poll() is None:
|
||||
running.append(sim_id)
|
||||
return running
|
||||
|
||||
|
||||
@@ -1,29 +1,214 @@
|
||||
"""
|
||||
动作日志记录器
|
||||
用于记录OASIS模拟中每个Agent的动作,供后端监控使用
|
||||
|
||||
日志结构:
|
||||
sim_xxx/
|
||||
├── twitter/
|
||||
│ └── actions.jsonl # Twitter 平台动作日志
|
||||
├── reddit/
|
||||
│ └── actions.jsonl # Reddit 平台动作日志
|
||||
├── simulation.log # 主模拟进程日志
|
||||
└── run_state.json # 运行状态(API 查询用)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class ActionLogger:
|
||||
"""动作日志记录器"""
|
||||
class PlatformActionLogger:
|
||||
"""单平台动作日志记录器"""
|
||||
|
||||
def __init__(self, log_path: str):
|
||||
def __init__(self, platform: str, base_dir: str):
|
||||
"""
|
||||
初始化日志记录器
|
||||
|
||||
Args:
|
||||
log_path: 日志文件路径(.jsonl格式)
|
||||
platform: 平台名称 (twitter/reddit)
|
||||
base_dir: 模拟目录的基础路径
|
||||
"""
|
||||
self.log_path = log_path
|
||||
self.platform = platform
|
||||
self.base_dir = base_dir
|
||||
self.log_dir = os.path.join(base_dir, platform)
|
||||
self.log_path = os.path.join(self.log_dir, "actions.jsonl")
|
||||
self._ensure_dir()
|
||||
|
||||
def _ensure_dir(self):
|
||||
"""确保目录存在"""
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
|
||||
def log_action(
|
||||
self,
|
||||
round_num: int,
|
||||
agent_id: int,
|
||||
agent_name: str,
|
||||
action_type: str,
|
||||
action_args: Optional[Dict[str, Any]] = None,
|
||||
result: Optional[str] = None,
|
||||
success: bool = True
|
||||
):
|
||||
"""记录一个动作"""
|
||||
entry = {
|
||||
"round": round_num,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"agent_id": agent_id,
|
||||
"agent_name": agent_name,
|
||||
"action_type": action_type,
|
||||
"action_args": action_args or {},
|
||||
"result": result,
|
||||
"success": success,
|
||||
}
|
||||
|
||||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_round_start(self, round_num: int, simulated_hour: int):
|
||||
"""记录轮次开始"""
|
||||
entry = {
|
||||
"round": round_num,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"event_type": "round_start",
|
||||
"simulated_hour": simulated_hour,
|
||||
}
|
||||
|
||||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_round_end(self, round_num: int, actions_count: int):
|
||||
"""记录轮次结束"""
|
||||
entry = {
|
||||
"round": round_num,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"event_type": "round_end",
|
||||
"actions_count": actions_count,
|
||||
}
|
||||
|
||||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_simulation_start(self, config: Dict[str, Any]):
|
||||
"""记录模拟开始"""
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"event_type": "simulation_start",
|
||||
"platform": self.platform,
|
||||
"total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2,
|
||||
"agents_count": len(config.get("agent_configs", [])),
|
||||
}
|
||||
|
||||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_simulation_end(self, total_rounds: int, total_actions: int):
|
||||
"""记录模拟结束"""
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"event_type": "simulation_end",
|
||||
"platform": self.platform,
|
||||
"total_rounds": total_rounds,
|
||||
"total_actions": total_actions,
|
||||
}
|
||||
|
||||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
class SimulationLogManager:
|
||||
"""
|
||||
模拟日志管理器
|
||||
统一管理所有日志文件,按平台分离
|
||||
"""
|
||||
|
||||
def __init__(self, simulation_dir: str):
|
||||
"""
|
||||
初始化日志管理器
|
||||
|
||||
Args:
|
||||
simulation_dir: 模拟目录路径
|
||||
"""
|
||||
self.simulation_dir = simulation_dir
|
||||
self.twitter_logger: Optional[PlatformActionLogger] = None
|
||||
self.reddit_logger: Optional[PlatformActionLogger] = None
|
||||
self._main_logger: Optional[logging.Logger] = None
|
||||
|
||||
# 设置主日志
|
||||
self._setup_main_logger()
|
||||
|
||||
def _setup_main_logger(self):
|
||||
"""设置主模拟日志"""
|
||||
log_path = os.path.join(self.simulation_dir, "simulation.log")
|
||||
|
||||
# 创建 logger
|
||||
self._main_logger = logging.getLogger(f"simulation.{os.path.basename(self.simulation_dir)}")
|
||||
self._main_logger.setLevel(logging.INFO)
|
||||
self._main_logger.handlers.clear()
|
||||
|
||||
# 文件处理器
|
||||
file_handler = logging.FileHandler(log_path, encoding='utf-8', mode='w')
|
||||
file_handler.setLevel(logging.INFO)
|
||||
file_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
))
|
||||
self._main_logger.addHandler(file_handler)
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(logging.Formatter(
|
||||
'[%(asctime)s] %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
))
|
||||
self._main_logger.addHandler(console_handler)
|
||||
|
||||
self._main_logger.propagate = False
|
||||
|
||||
def get_twitter_logger(self) -> PlatformActionLogger:
|
||||
"""获取 Twitter 平台日志记录器"""
|
||||
if self.twitter_logger is None:
|
||||
self.twitter_logger = PlatformActionLogger("twitter", self.simulation_dir)
|
||||
return self.twitter_logger
|
||||
|
||||
def get_reddit_logger(self) -> PlatformActionLogger:
|
||||
"""获取 Reddit 平台日志记录器"""
|
||||
if self.reddit_logger is None:
|
||||
self.reddit_logger = PlatformActionLogger("reddit", self.simulation_dir)
|
||||
return self.reddit_logger
|
||||
|
||||
def log(self, message: str, level: str = "info"):
|
||||
"""记录主日志"""
|
||||
if self._main_logger:
|
||||
getattr(self._main_logger, level.lower(), self._main_logger.info)(message)
|
||||
|
||||
def info(self, message: str):
|
||||
self.log(message, "info")
|
||||
|
||||
def warning(self, message: str):
|
||||
self.log(message, "warning")
|
||||
|
||||
def error(self, message: str):
|
||||
self.log(message, "error")
|
||||
|
||||
def debug(self, message: str):
|
||||
self.log(message, "debug")
|
||||
|
||||
|
||||
# ============ 兼容旧接口 ============
|
||||
|
||||
class ActionLogger:
|
||||
"""
|
||||
动作日志记录器(兼容旧接口)
|
||||
建议使用 SimulationLogManager 代替
|
||||
"""
|
||||
|
||||
def __init__(self, log_path: str):
|
||||
self.log_path = log_path
|
||||
self._ensure_dir()
|
||||
|
||||
def _ensure_dir(self):
|
||||
log_dir = os.path.dirname(self.log_path)
|
||||
if log_dir:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
@@ -39,19 +224,6 @@ class ActionLogger:
|
||||
result: Optional[str] = None,
|
||||
success: bool = True
|
||||
):
|
||||
"""
|
||||
记录一个动作
|
||||
|
||||
Args:
|
||||
round_num: 轮次
|
||||
platform: 平台 (twitter/reddit)
|
||||
agent_id: Agent ID
|
||||
agent_name: Agent名称
|
||||
action_type: 动作类型
|
||||
action_args: 动作参数
|
||||
result: 执行结果
|
||||
success: 是否成功
|
||||
"""
|
||||
entry = {
|
||||
"round": round_num,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
@@ -68,7 +240,6 @@ class ActionLogger:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_round_start(self, round_num: int, simulated_hour: int, platform: str):
|
||||
"""记录轮次开始"""
|
||||
entry = {
|
||||
"round": round_num,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
@@ -81,7 +252,6 @@ class ActionLogger:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_round_end(self, round_num: int, actions_count: int, platform: str):
|
||||
"""记录轮次结束"""
|
||||
entry = {
|
||||
"round": round_num,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
@@ -94,7 +264,6 @@ class ActionLogger:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_simulation_start(self, platform: str, config: Dict[str, Any]):
|
||||
"""记录模拟开始"""
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"platform": platform,
|
||||
@@ -107,7 +276,6 @@ class ActionLogger:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
def log_simulation_end(self, platform: str, total_rounds: int, total_actions: int):
|
||||
"""记录模拟结束"""
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"platform": platform,
|
||||
@@ -120,12 +288,12 @@ class ActionLogger:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
# 全局日志实例(可选)
|
||||
# 全局日志实例(兼容旧接口)
|
||||
_global_logger: Optional[ActionLogger] = None
|
||||
|
||||
|
||||
def get_logger(log_path: Optional[str] = None) -> ActionLogger:
|
||||
"""获取全局日志实例"""
|
||||
"""获取全局日志实例(兼容旧接口)"""
|
||||
global _global_logger
|
||||
|
||||
if log_path:
|
||||
@@ -135,4 +303,3 @@ def get_logger(log_path: Optional[str] = None) -> ActionLogger:
|
||||
_global_logger = ActionLogger("actions.jsonl")
|
||||
|
||||
return _global_logger
|
||||
|
||||
|
||||
@@ -3,7 +3,16 @@ OASIS 双平台并行模拟预设脚本
|
||||
同时运行Twitter和Reddit模拟,读取相同的配置文件
|
||||
|
||||
使用方式:
|
||||
python run_parallel_simulation.py --config simulation_config.json [--action-log actions.jsonl]
|
||||
python run_parallel_simulation.py --config simulation_config.json
|
||||
|
||||
日志结构:
|
||||
sim_xxx/
|
||||
├── twitter/
|
||||
│ └── actions.jsonl # Twitter 平台动作日志
|
||||
├── reddit/
|
||||
│ └── actions.jsonl # Reddit 平台动作日志
|
||||
├── simulation.log # 主模拟进程日志
|
||||
└── run_state.json # 运行状态(API 查询用)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -12,9 +21,10 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sqlite3
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
# 添加 backend 目录到路径
|
||||
# 脚本固定位于 backend/scripts/ 目录
|
||||
@@ -38,91 +48,45 @@ else:
|
||||
print(f"已加载环境配置: {_backend_env}")
|
||||
|
||||
|
||||
class UnicodeFormatter(logging.Formatter):
|
||||
def disable_oasis_logging():
|
||||
"""
|
||||
自定义格式化器,将 Unicode 转义序列(如 \\uXXXX)转换为可读字符
|
||||
禁用 OASIS 库的详细日志输出
|
||||
OASIS 的日志太冗余(记录每个 agent 的观察和动作),我们使用自己的 action_logger
|
||||
"""
|
||||
# 禁用 OASIS 的所有日志器
|
||||
oasis_loggers = [
|
||||
"social.agent",
|
||||
"social.twitter",
|
||||
"social.rec",
|
||||
"oasis.env",
|
||||
"table",
|
||||
]
|
||||
|
||||
# 匹配 \uXXXX 形式的 Unicode 转义序列
|
||||
UNICODE_ESCAPE_PATTERN = None
|
||||
|
||||
@classmethod
|
||||
def _get_pattern(cls):
|
||||
if cls.UNICODE_ESCAPE_PATTERN is None:
|
||||
import re
|
||||
cls.UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
|
||||
return cls.UNICODE_ESCAPE_PATTERN
|
||||
|
||||
def format(self, record):
|
||||
# 先获取原始格式化结果
|
||||
result = super().format(record)
|
||||
# 使用正则表达式替换 Unicode 转义序列
|
||||
pattern = self._get_pattern()
|
||||
|
||||
def replace_unicode(match):
|
||||
try:
|
||||
return chr(int(match.group(1), 16))
|
||||
except (ValueError, OverflowError):
|
||||
return match.group(0)
|
||||
|
||||
return pattern.sub(replace_unicode, result)
|
||||
|
||||
|
||||
def setup_oasis_logging(log_dir: str):
|
||||
"""
|
||||
配置 OASIS 的日志,覆盖默认的带时间戳日志文件
|
||||
|
||||
Args:
|
||||
log_dir: 日志目录路径
|
||||
"""
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 清理旧的日志文件
|
||||
for f in os.listdir(log_dir):
|
||||
old_log = os.path.join(log_dir, f)
|
||||
if os.path.isfile(old_log) and f.endswith('.log'):
|
||||
try:
|
||||
os.remove(old_log)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# 创建自定义格式化器(支持 Unicode 解码)
|
||||
formatter = UnicodeFormatter(
|
||||
"%(levelname)s - %(asctime)s - %(name)s - %(message)s"
|
||||
)
|
||||
|
||||
# 重新配置 OASIS 使用的日志器,使用固定名称(不带时间戳)
|
||||
loggers_config = {
|
||||
"social.agent": os.path.join(log_dir, "social.agent.log"),
|
||||
"social.twitter": os.path.join(log_dir, "social.twitter.log"),
|
||||
"social.rec": os.path.join(log_dir, "social.rec.log"),
|
||||
"oasis.env": os.path.join(log_dir, "oasis.env.log"),
|
||||
"table": os.path.join(log_dir, "table.log"),
|
||||
}
|
||||
|
||||
for logger_name, log_file in loggers_config.items():
|
||||
for logger_name in oasis_loggers:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# 清除 OASIS 添加的现有处理器(带时间戳的日志文件)
|
||||
logger.setLevel(logging.CRITICAL) # 只记录严重错误
|
||||
logger.handlers.clear()
|
||||
# 添加新的文件处理器(使用 UTF-8 编码,固定文件名)
|
||||
file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w')
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
# 防止日志向上传播(避免重复)
|
||||
logger.propagate = False
|
||||
|
||||
print(f"日志配置完成,日志目录: {log_dir}")
|
||||
|
||||
|
||||
def init_logging_for_simulation(simulation_dir: str):
|
||||
"""初始化模拟的日志配置"""
|
||||
log_dir = os.path.join(simulation_dir, "log")
|
||||
setup_oasis_logging(log_dir)
|
||||
"""
|
||||
初始化模拟的日志配置
|
||||
|
||||
Args:
|
||||
simulation_dir: 模拟目录路径
|
||||
"""
|
||||
# 禁用 OASIS 的详细日志
|
||||
disable_oasis_logging()
|
||||
|
||||
# 清理旧的 log 目录(如果存在)
|
||||
old_log_dir = os.path.join(simulation_dir, "log")
|
||||
if os.path.exists(old_log_dir):
|
||||
import shutil
|
||||
shutil.rmtree(old_log_dir, ignore_errors=True)
|
||||
|
||||
|
||||
from action_logger import ActionLogger
|
||||
from action_logger import SimulationLogManager, PlatformActionLogger
|
||||
|
||||
try:
|
||||
from camel.models import ModelFactory
|
||||
@@ -175,6 +139,120 @@ def load_config(config_path: str) -> Dict[str, Any]:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# 需要过滤掉的非核心动作类型(这些动作对分析价值较低)
|
||||
FILTERED_ACTIONS = {'refresh', 'sign_up'}
|
||||
|
||||
# 动作类型映射表(数据库中的名称 -> 标准名称)
|
||||
ACTION_TYPE_MAP = {
|
||||
'create_post': 'CREATE_POST',
|
||||
'like_post': 'LIKE_POST',
|
||||
'dislike_post': 'DISLIKE_POST',
|
||||
'repost': 'REPOST',
|
||||
'quote_post': 'QUOTE_POST',
|
||||
'follow': 'FOLLOW',
|
||||
'mute': 'MUTE',
|
||||
'create_comment': 'CREATE_COMMENT',
|
||||
'like_comment': 'LIKE_COMMENT',
|
||||
'dislike_comment': 'DISLIKE_COMMENT',
|
||||
'search_posts': 'SEARCH_POSTS',
|
||||
'search_user': 'SEARCH_USER',
|
||||
'trend': 'TREND',
|
||||
'do_nothing': 'DO_NOTHING',
|
||||
'interview': 'INTERVIEW',
|
||||
}
|
||||
|
||||
|
||||
def fetch_new_actions_from_db(
|
||||
db_path: str,
|
||||
last_rowid: int,
|
||||
agent_names: Dict[int, str]
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""
|
||||
从数据库中获取新的动作记录
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径
|
||||
last_rowid: 上次读取的最大 rowid 值(使用 rowid 而不是 created_at,因为不同平台的 created_at 格式不同)
|
||||
agent_names: agent_id -> agent_name 映射
|
||||
|
||||
Returns:
|
||||
(actions_list, new_last_rowid)
|
||||
- actions_list: 动作列表,每个元素包含 agent_id, agent_name, action_type, action_args
|
||||
- new_last_rowid: 新的最大 rowid 值
|
||||
"""
|
||||
actions = []
|
||||
new_last_rowid = last_rowid
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return actions, new_last_rowid
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 使用 rowid 来追踪已处理的记录(rowid 是 SQLite 的内置自增字段)
|
||||
# 这样可以避免 created_at 格式差异问题(Twitter 用整数,Reddit 用日期时间字符串)
|
||||
cursor.execute("""
|
||||
SELECT rowid, user_id, action, info
|
||||
FROM trace
|
||||
WHERE rowid > ?
|
||||
ORDER BY rowid ASC
|
||||
""", (last_rowid,))
|
||||
|
||||
for rowid, user_id, action, info_json in cursor.fetchall():
|
||||
# 更新最大 rowid
|
||||
new_last_rowid = rowid
|
||||
|
||||
# 过滤非核心动作
|
||||
if action in FILTERED_ACTIONS:
|
||||
continue
|
||||
|
||||
# 解析动作参数
|
||||
try:
|
||||
action_args = json.loads(info_json) if info_json else {}
|
||||
except json.JSONDecodeError:
|
||||
action_args = {}
|
||||
|
||||
# 精简 action_args,只保留关键字段
|
||||
simplified_args = {}
|
||||
if 'content' in action_args:
|
||||
content = action_args['content']
|
||||
# 截断过长的内容
|
||||
simplified_args['content'] = content[:200] + '...' if len(content) > 200 else content
|
||||
if 'post_id' in action_args:
|
||||
simplified_args['post_id'] = action_args['post_id']
|
||||
if 'comment_id' in action_args:
|
||||
simplified_args['comment_id'] = action_args['comment_id']
|
||||
if 'quoted_id' in action_args:
|
||||
simplified_args['quoted_id'] = action_args['quoted_id']
|
||||
if 'new_post_id' in action_args:
|
||||
simplified_args['new_post_id'] = action_args['new_post_id']
|
||||
if 'follow_id' in action_args:
|
||||
simplified_args['follow_id'] = action_args['follow_id']
|
||||
if 'query' in action_args:
|
||||
simplified_args['query'] = action_args['query']
|
||||
if 'like_id' in action_args:
|
||||
simplified_args['like_id'] = action_args['like_id']
|
||||
if 'dislike_id' in action_args:
|
||||
simplified_args['dislike_id'] = action_args['dislike_id']
|
||||
|
||||
# 转换动作类型名称
|
||||
action_type = ACTION_TYPE_MAP.get(action, action.upper())
|
||||
|
||||
actions.append({
|
||||
'agent_id': user_id,
|
||||
'agent_name': agent_names.get(user_id, f'Agent_{user_id}'),
|
||||
'action_type': action_type,
|
||||
'action_args': simplified_args,
|
||||
})
|
||||
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"读取数据库动作失败: {e}")
|
||||
|
||||
return actions, new_last_rowid
|
||||
|
||||
|
||||
def create_model(config: Dict[str, Any]):
|
||||
"""
|
||||
创建LLM模型
|
||||
@@ -269,17 +347,23 @@ def get_active_agents_for_round(
|
||||
async def run_twitter_simulation(
|
||||
config: Dict[str, Any],
|
||||
simulation_dir: str,
|
||||
action_logger: Optional[ActionLogger] = None
|
||||
action_logger: Optional[PlatformActionLogger] = None,
|
||||
main_logger: Optional[SimulationLogManager] = None
|
||||
):
|
||||
"""运行Twitter模拟"""
|
||||
print("[Twitter] 初始化...")
|
||||
def log_info(msg):
|
||||
if main_logger:
|
||||
main_logger.info(f"[Twitter] {msg}")
|
||||
print(f"[Twitter] {msg}")
|
||||
|
||||
log_info("初始化...")
|
||||
|
||||
model = create_model(config)
|
||||
|
||||
# OASIS Twitter使用CSV格式
|
||||
profile_path = os.path.join(simulation_dir, "twitter_profiles.csv")
|
||||
if not os.path.exists(profile_path):
|
||||
print(f"[Twitter] 错误: Profile文件不存在: {profile_path}")
|
||||
log_info(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
|
||||
agent_graph = await generate_twitter_agent_graph(
|
||||
@@ -304,12 +388,13 @@ async def run_twitter_simulation(
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
print("[Twitter] 环境已启动")
|
||||
log_info("环境已启动")
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_start("twitter", config)
|
||||
action_logger.log_simulation_start(config)
|
||||
|
||||
total_actions = 0
|
||||
last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异)
|
||||
|
||||
# 执行初始事件
|
||||
event_config = config.get("event_config", {})
|
||||
@@ -330,7 +415,6 @@ async def run_twitter_simulation(
|
||||
if action_logger:
|
||||
action_logger.log_action(
|
||||
round_num=0,
|
||||
platform="twitter",
|
||||
agent_id=agent_id,
|
||||
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||
action_type="CREATE_POST",
|
||||
@@ -342,7 +426,7 @@ async def run_twitter_simulation(
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
print(f"[Twitter] 已发布 {len(initial_actions)} 条初始帖子")
|
||||
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
time_config = config.get("time_config", {})
|
||||
@@ -365,54 +449,64 @@ async def run_twitter_simulation(
|
||||
continue
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_round_start(round_num + 1, simulated_hour, "twitter")
|
||||
action_logger.log_round_start(round_num + 1, simulated_hour)
|
||||
|
||||
actions = {agent: LLMAction() for _, agent in active_agents}
|
||||
await env.step(actions)
|
||||
|
||||
# 记录动作
|
||||
for agent_id, agent in active_agents:
|
||||
# 从数据库获取实际执行的动作并记录
|
||||
actual_actions, last_rowid = fetch_new_actions_from_db(
|
||||
db_path, last_rowid, agent_names
|
||||
)
|
||||
|
||||
round_action_count = 0
|
||||
for action_data in actual_actions:
|
||||
if action_logger:
|
||||
action_logger.log_action(
|
||||
round_num=round_num + 1,
|
||||
platform="twitter",
|
||||
agent_id=agent_id,
|
||||
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||
action_type="LLM_ACTION",
|
||||
action_args={}
|
||||
agent_id=action_data['agent_id'],
|
||||
agent_name=action_data['agent_name'],
|
||||
action_type=action_data['action_type'],
|
||||
action_args=action_data['action_args']
|
||||
)
|
||||
total_actions += 1
|
||||
round_action_count += 1
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_round_end(round_num + 1, len(active_agents), "twitter")
|
||||
action_logger.log_round_end(round_num + 1, round_action_count)
|
||||
|
||||
if (round_num + 1) % 20 == 0:
|
||||
progress = (round_num + 1) / total_rounds * 100
|
||||
print(f"[Twitter] Day {simulated_day}, {simulated_hour:02d}:00 "
|
||||
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||
|
||||
await env.close()
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_end("twitter", total_rounds, total_actions)
|
||||
action_logger.log_simulation_end(total_rounds, total_actions)
|
||||
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print(f"[Twitter] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
|
||||
|
||||
async def run_reddit_simulation(
|
||||
config: Dict[str, Any],
|
||||
simulation_dir: str,
|
||||
action_logger: Optional[ActionLogger] = None
|
||||
action_logger: Optional[PlatformActionLogger] = None,
|
||||
main_logger: Optional[SimulationLogManager] = None
|
||||
):
|
||||
"""运行Reddit模拟"""
|
||||
print("[Reddit] 初始化...")
|
||||
def log_info(msg):
|
||||
if main_logger:
|
||||
main_logger.info(f"[Reddit] {msg}")
|
||||
print(f"[Reddit] {msg}")
|
||||
|
||||
log_info("初始化...")
|
||||
|
||||
model = create_model(config)
|
||||
|
||||
profile_path = os.path.join(simulation_dir, "reddit_profiles.json")
|
||||
if not os.path.exists(profile_path):
|
||||
print(f"[Reddit] 错误: Profile文件不存在: {profile_path}")
|
||||
log_info(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
|
||||
agent_graph = await generate_reddit_agent_graph(
|
||||
@@ -437,12 +531,13 @@ async def run_reddit_simulation(
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
print("[Reddit] 环境已启动")
|
||||
log_info("环境已启动")
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_start("reddit", config)
|
||||
action_logger.log_simulation_start(config)
|
||||
|
||||
total_actions = 0
|
||||
last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异)
|
||||
|
||||
# 执行初始事件
|
||||
event_config = config.get("event_config", {})
|
||||
@@ -471,7 +566,6 @@ async def run_reddit_simulation(
|
||||
if action_logger:
|
||||
action_logger.log_action(
|
||||
round_num=0,
|
||||
platform="reddit",
|
||||
agent_id=agent_id,
|
||||
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||
action_type="CREATE_POST",
|
||||
@@ -483,7 +577,7 @@ async def run_reddit_simulation(
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
print(f"[Reddit] 已发布 {len(initial_actions)} 条初始帖子")
|
||||
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
time_config = config.get("time_config", {})
|
||||
@@ -506,39 +600,43 @@ async def run_reddit_simulation(
|
||||
continue
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_round_start(round_num + 1, simulated_hour, "reddit")
|
||||
action_logger.log_round_start(round_num + 1, simulated_hour)
|
||||
|
||||
actions = {agent: LLMAction() for _, agent in active_agents}
|
||||
await env.step(actions)
|
||||
|
||||
# 记录动作
|
||||
for agent_id, agent in active_agents:
|
||||
# 从数据库获取实际执行的动作并记录
|
||||
actual_actions, last_rowid = fetch_new_actions_from_db(
|
||||
db_path, last_rowid, agent_names
|
||||
)
|
||||
|
||||
round_action_count = 0
|
||||
for action_data in actual_actions:
|
||||
if action_logger:
|
||||
action_logger.log_action(
|
||||
round_num=round_num + 1,
|
||||
platform="reddit",
|
||||
agent_id=agent_id,
|
||||
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
||||
action_type="LLM_ACTION",
|
||||
action_args={}
|
||||
agent_id=action_data['agent_id'],
|
||||
agent_name=action_data['agent_name'],
|
||||
action_type=action_data['action_type'],
|
||||
action_args=action_data['action_args']
|
||||
)
|
||||
total_actions += 1
|
||||
round_action_count += 1
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_round_end(round_num + 1, len(active_agents), "reddit")
|
||||
action_logger.log_round_end(round_num + 1, round_action_count)
|
||||
|
||||
if (round_num + 1) % 20 == 0:
|
||||
progress = (round_num + 1) / total_rounds * 100
|
||||
print(f"[Reddit] Day {simulated_day}, {simulated_hour:02d}:00 "
|
||||
f"- Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||
|
||||
await env.close()
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_end("reddit", total_rounds, total_actions)
|
||||
action_logger.log_simulation_end(total_rounds, total_actions)
|
||||
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print(f"[Reddit] 模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -559,12 +657,6 @@ async def main():
|
||||
action='store_true',
|
||||
help='只运行Reddit模拟'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--action-log',
|
||||
type=str,
|
||||
default='actions.jsonl',
|
||||
help='动作日志文件路径 (默认: actions.jsonl)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -575,52 +667,53 @@ async def main():
|
||||
config = load_config(args.config)
|
||||
simulation_dir = os.path.dirname(args.config) or "."
|
||||
|
||||
# 初始化日志配置(清理旧日志文件,使用固定名称)
|
||||
# 初始化日志配置(禁用 OASIS 日志,清理旧文件)
|
||||
init_logging_for_simulation(simulation_dir)
|
||||
|
||||
# 创建动作日志记录器
|
||||
action_log_path = os.path.join(simulation_dir, args.action_log)
|
||||
action_logger = ActionLogger(action_log_path)
|
||||
# 创建日志管理器
|
||||
log_manager = SimulationLogManager(simulation_dir)
|
||||
twitter_logger = log_manager.get_twitter_logger()
|
||||
reddit_logger = log_manager.get_reddit_logger()
|
||||
|
||||
print("=" * 60)
|
||||
print("OASIS 双平台并行模拟")
|
||||
print(f"配置文件: {args.config}")
|
||||
print(f"模拟ID: {config.get('simulation_id', 'unknown')}")
|
||||
print(f"动作日志: {action_log_path}")
|
||||
print("=" * 60)
|
||||
log_manager.info("=" * 60)
|
||||
log_manager.info("OASIS 双平台并行模拟")
|
||||
log_manager.info(f"配置文件: {args.config}")
|
||||
log_manager.info(f"模拟ID: {config.get('simulation_id', 'unknown')}")
|
||||
log_manager.info("=" * 60)
|
||||
|
||||
time_config = config.get("time_config", {})
|
||||
print(f"\n模拟参数:")
|
||||
print(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
|
||||
print(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
|
||||
print(f" - Agent数量: {len(config.get('agent_configs', []))}")
|
||||
log_manager.info(f"模拟参数:")
|
||||
log_manager.info(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
|
||||
log_manager.info(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
|
||||
log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}")
|
||||
|
||||
# LLM推理说明
|
||||
reasoning = config.get("generation_reasoning", "")
|
||||
if reasoning:
|
||||
print(f"\nLLM配置推理:")
|
||||
print(f" {reasoning[:500]}..." if len(reasoning) > 500 else f" {reasoning}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
log_manager.info("日志结构:")
|
||||
log_manager.info(f" - 主日志: simulation.log")
|
||||
log_manager.info(f" - Twitter动作: twitter/actions.jsonl")
|
||||
log_manager.info(f" - Reddit动作: reddit/actions.jsonl")
|
||||
log_manager.info("=" * 60)
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
if args.twitter_only:
|
||||
await run_twitter_simulation(config, simulation_dir, action_logger)
|
||||
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager)
|
||||
elif args.reddit_only:
|
||||
await run_reddit_simulation(config, simulation_dir, action_logger)
|
||||
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager)
|
||||
else:
|
||||
# 并行运行(共享同一个action_logger)
|
||||
# 并行运行(每个平台使用独立的日志记录器)
|
||||
await asyncio.gather(
|
||||
run_twitter_simulation(config, simulation_dir, action_logger),
|
||||
run_reddit_simulation(config, simulation_dir, action_logger),
|
||||
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager),
|
||||
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager),
|
||||
)
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print("\n" + "=" * 60)
|
||||
print(f"全部模拟完成! 总耗时: {total_elapsed:.1f}秒")
|
||||
print(f"动作日志已保存到: {action_log_path}")
|
||||
print("=" * 60)
|
||||
log_manager.info("=" * 60)
|
||||
log_manager.info(f"全部模拟完成! 总耗时: {total_elapsed:.1f}秒")
|
||||
log_manager.info(f"日志文件:")
|
||||
log_manager.info(f" - {os.path.join(simulation_dir, 'simulation.log')}")
|
||||
log_manager.info(f" - {os.path.join(simulation_dir, 'twitter', 'actions.jsonl')}")
|
||||
log_manager.info(f" - {os.path.join(simulation_dir, 'reddit', 'actions.jsonl')}")
|
||||
log_manager.info("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user