Enhance simulation management and logging features

- Registered a cleanup function for simulation processes to ensure proper termination on server shutdown.
- Improved logging during application startup to confirm the registration of the cleanup function.
- Updated simulation preparation checks to clarify the conditions for considering a simulation ready, enhancing error handling and user feedback.
- Added detailed logging for simulation status changes, improving traceability during the simulation lifecycle.
- Introduced new files for simulation configuration and profile data, supporting enhanced testing and visualization capabilities.
This commit is contained in:
666ghj
2025-12-02 17:11:47 +08:00
parent 3cc5e3f479
commit d4fac63eb4
15 changed files with 8515 additions and 241 deletions

View File

@@ -10,6 +10,8 @@ import time
import asyncio
import threading
import subprocess
import signal
import atexit
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
@@ -21,6 +23,9 @@ from ..utils.logger import get_logger
logger = get_logger('mirofish.simulation_runner')
# 标记是否已注册清理函数
_cleanup_registered = False
class RunnerStatus(str, Enum):
"""运行器状态"""
@@ -342,34 +347,36 @@ class SimulationRunner:
# 启动模拟进程
try:
# 构建运行命令,使用完整路径
action_log_path = os.path.join(sim_dir, "actions.jsonl")
# 新的日志结构:
# twitter/actions.jsonl - Twitter 动作日志
# reddit/actions.jsonl - Reddit 动作日志
# simulation.log - 主进程日志
cmd = [
sys.executable, # Python解释器
script_path,
"--config", config_path, # 使用完整配置文件路径
"--action-log", action_log_path, # 动作日志文件完整路径
]
# 创建输出日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
stdout_log_path = os.path.join(sim_dir, "simulation_stdout.log")
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
stdout_file = open(stdout_log_path, 'w', encoding='utf-8')
stderr_file = open(stderr_log_path, 'w', encoding='utf-8')
# 创建日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
main_log_path = os.path.join(sim_dir, "simulation.log")
main_log_file = open(main_log_path, 'w', encoding='utf-8')
# 设置工作目录为模拟目录(数据库等文件会生成在此)
# 使用 start_new_session=True 创建新的进程组,确保可以通过 os.killpg 终止所有子进程
process = subprocess.Popen(
cmd,
cwd=sim_dir,
stdout=stdout_file,
stderr=stderr_file,
stdout=main_log_file,
stderr=subprocess.STDOUT, # stderr 也写入同一个文件
text=True,
bufsize=1,
start_new_session=True, # 创建新进程组,确保服务器关闭时能终止所有相关进程
)
# 保存文件句柄以便后续关闭
cls._stdout_files[simulation_id] = stdout_file
cls._stderr_files[simulation_id] = stderr_file
cls._stdout_files[simulation_id] = main_log_file
cls._stderr_files[simulation_id] = None # 不再需要单独的 stderr
state.process_pid = process.pid
state.runner_status = RunnerStatus.RUNNING
@@ -399,7 +406,10 @@ class SimulationRunner:
def _monitor_simulation(cls, simulation_id: str):
"""监控模拟进程,解析动作日志"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
actions_log = os.path.join(sim_dir, "actions.jsonl")
# 新的日志结构:分平台的动作日志
twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl")
reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl")
process = cls._processes.get(simulation_id)
state = cls.get_run_state(simulation_id)
@@ -407,43 +417,32 @@ class SimulationRunner:
if not process or not state:
return
last_position = 0
twitter_position = 0
reddit_position = 0
try:
while process.poll() is None: # 进程仍在运行
# 读取动作日志
if os.path.exists(actions_log):
with open(actions_log, 'r', encoding='utf-8') as f:
f.seek(last_position)
for line in f:
line = line.strip()
if line:
try:
action_data = json.loads(line)
action = AgentAction(
round_num=action_data.get("round", 0),
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
platform=action_data.get("platform", "unknown"),
agent_id=action_data.get("agent_id", 0),
agent_name=action_data.get("agent_name", ""),
action_type=action_data.get("action_type", ""),
action_args=action_data.get("action_args", {}),
result=action_data.get("result"),
success=action_data.get("success", True),
)
state.add_action(action)
# 更新轮次
if action.round_num > state.current_round:
state.current_round = action.round_num
except json.JSONDecodeError:
pass
last_position = f.tell()
# 读取 Twitter 动作日志
if os.path.exists(twitter_actions_log):
twitter_position = cls._read_action_log(
twitter_actions_log, twitter_position, state, "twitter"
)
# 定期保存状态
# 读取 Reddit 动作日志
if os.path.exists(reddit_actions_log):
reddit_position = cls._read_action_log(
reddit_actions_log, reddit_position, state, "reddit"
)
# 更新状态
cls._save_run_state(state)
time.sleep(1) # 每秒检查一次
time.sleep(2)
# 进程结束后,最后读取一次日志
if os.path.exists(twitter_actions_log):
cls._read_action_log(twitter_actions_log, twitter_position, state, "twitter")
if os.path.exists(reddit_actions_log):
cls._read_action_log(reddit_actions_log, reddit_position, state, "reddit")
# 进程结束
exit_code = process.returncode
@@ -454,16 +453,16 @@ class SimulationRunner:
logger.info(f"模拟完成: {simulation_id}")
else:
state.runner_status = RunnerStatus.FAILED
# 从 stderr 日志文件读取错误信息
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
stderr = ""
# 从日志文件读取错误信息
main_log_path = os.path.join(sim_dir, "simulation.log")
error_info = ""
try:
if os.path.exists(stderr_log_path):
with open(stderr_log_path, 'r', encoding='utf-8') as f:
stderr = f.read()
if os.path.exists(main_log_path):
with open(main_log_path, 'r', encoding='utf-8') as f:
error_info = f.read()[-2000:] # 取最后2000字符
except Exception:
pass
state.error = f"进程退出码: {exit_code}, 错误: {stderr[-1000:]}" # 取最后1000字符
state.error = f"进程退出码: {exit_code}, 错误: {error_info}"
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
state.twitter_running = False
@@ -488,13 +487,70 @@ class SimulationRunner:
except Exception:
pass
cls._stdout_files.pop(simulation_id, None)
if simulation_id in cls._stderr_files:
if simulation_id in cls._stderr_files and cls._stderr_files[simulation_id]:
try:
cls._stderr_files[simulation_id].close()
except Exception:
pass
cls._stderr_files.pop(simulation_id, None)
@classmethod
def _read_action_log(
cls,
log_path: str,
position: int,
state: SimulationRunState,
platform: str
) -> int:
"""
读取动作日志文件
Args:
log_path: 日志文件路径
position: 上次读取位置
state: 运行状态对象
platform: 平台名称 (twitter/reddit)
Returns:
新的读取位置
"""
try:
with open(log_path, 'r', encoding='utf-8') as f:
f.seek(position)
for line in f:
line = line.strip()
if line:
try:
action_data = json.loads(line)
# 跳过事件类型的条目(如 simulation_start, round_start 等)
if "event_type" in action_data:
continue
action = AgentAction(
round_num=action_data.get("round", 0),
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
platform=platform,
agent_id=action_data.get("agent_id", 0),
agent_name=action_data.get("agent_name", ""),
action_type=action_data.get("action_type", ""),
action_args=action_data.get("action_args", {}),
result=action_data.get("result"),
success=action_data.get("success", True),
)
state.add_action(action)
# 更新轮次
if action.round_num and action.round_num > state.current_round:
state.current_round = action.round_num
except json.JSONDecodeError:
pass
return f.tell()
except Exception as e:
logger.warning(f"读取动作日志失败: {log_path}, error={e}")
return position
@classmethod
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
"""停止模拟"""
@@ -510,12 +566,35 @@ class SimulationRunner:
# 终止进程
process = cls._processes.get(simulation_id)
if process:
process.terminate()
if process and process.poll() is None:
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
process.kill()
# 使用进程组 ID 终止整个进程组(包括所有子进程)
# 由于使用了 start_new_session=True进程组 ID 等于主进程 PID
pgid = os.getpgid(process.pid)
logger.info(f"终止进程组: simulation={simulation_id}, pgid={pgid}")
# 先发送 SIGTERM 给整个进程组
os.killpg(pgid, signal.SIGTERM)
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
# 如果 10 秒后还没结束,强制发送 SIGKILL
logger.warning(f"进程组未响应 SIGTERM强制终止: {simulation_id}")
os.killpg(pgid, signal.SIGKILL)
process.wait(timeout=5)
except ProcessLookupError:
# 进程已经不存在
pass
except Exception as e:
logger.error(f"终止进程组失败: {simulation_id}, error={e}")
# 回退到直接终止进程
try:
process.terminate()
process.wait(timeout=5)
except Exception:
process.kill()
state.runner_status = RunnerStatus.STOPPED
state.twitter_running = False
@@ -709,4 +788,133 @@ class SimulationRunner:
result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True)
return result
@classmethod
def cleanup_all_simulations(cls):
"""
清理所有运行中的模拟进程
在服务器关闭时调用,确保所有子进程被终止
"""
logger.info("正在清理所有模拟进程...")
# 复制字典以避免在迭代时修改
processes = list(cls._processes.items())
for simulation_id, process in processes:
try:
if process.poll() is None: # 进程仍在运行
logger.info(f"终止模拟进程: {simulation_id}, pid={process.pid}")
try:
# 使用进程组终止(包括所有子进程)
pgid = os.getpgid(process.pid)
os.killpg(pgid, signal.SIGTERM)
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
logger.warning(f"进程组未响应 SIGTERM强制终止: {simulation_id}")
os.killpg(pgid, signal.SIGKILL)
process.wait(timeout=5)
except (ProcessLookupError, OSError):
# 进程可能已经不存在,尝试直接终止
try:
process.terminate()
process.wait(timeout=3)
except Exception:
process.kill()
# 更新状态
state = cls.get_run_state(simulation_id)
if state:
state.runner_status = RunnerStatus.STOPPED
state.twitter_running = False
state.reddit_running = False
state.completed_at = datetime.now().isoformat()
state.error = "服务器关闭,模拟被终止"
cls._save_run_state(state)
except Exception as e:
logger.error(f"清理进程失败: {simulation_id}, error={e}")
# 清理文件句柄
for simulation_id, file_handle in list(cls._stdout_files.items()):
try:
if file_handle:
file_handle.close()
except Exception:
pass
cls._stdout_files.clear()
for simulation_id, file_handle in list(cls._stderr_files.items()):
try:
if file_handle:
file_handle.close()
except Exception:
pass
cls._stderr_files.clear()
# 清理内存中的状态
cls._processes.clear()
cls._action_queues.clear()
logger.info("模拟进程清理完成")
@classmethod
def register_cleanup(cls):
"""
注册清理函数
在 Flask 应用启动时调用,确保服务器关闭时清理所有模拟进程
"""
global _cleanup_registered
if _cleanup_registered:
return
# 保存原有的信号处理器
original_sigint = signal.getsignal(signal.SIGINT)
original_sigterm = signal.getsignal(signal.SIGTERM)
def cleanup_handler(signum=None, frame=None):
"""信号处理器:先清理模拟进程,再调用原处理器"""
logger.info(f"收到信号 {signum},开始清理...")
cls.cleanup_all_simulations()
# 调用原有的信号处理器,让 Flask 正常退出
if signum == signal.SIGINT and callable(original_sigint):
original_sigint(signum, frame)
elif signum == signal.SIGTERM and callable(original_sigterm):
original_sigterm(signum, frame)
else:
# 如果原处理器不可调用(如 SIG_DFL则使用默认行为
raise KeyboardInterrupt
# 注册 atexit 处理器(作为备用)
atexit.register(cls.cleanup_all_simulations)
# 注册信号处理器(仅在主线程中)
try:
# SIGTERM: kill 命令默认信号
signal.signal(signal.SIGTERM, cleanup_handler)
# SIGINT: Ctrl+C
signal.signal(signal.SIGINT, cleanup_handler)
except ValueError:
# 不在主线程中,只能使用 atexit
logger.warning("无法注册信号处理器(不在主线程),仅使用 atexit")
_cleanup_registered = True
@classmethod
def get_running_simulations(cls) -> List[str]:
"""
获取所有正在运行的模拟ID列表
"""
running = []
for sim_id, process in cls._processes.items():
if process.poll() is None:
running.append(sim_id)
return running