Enhance simulation management and logging features
- Registered a cleanup function for simulation processes to ensure proper termination on server shutdown. - Improved logging during application startup to confirm the registration of the cleanup function. - Updated simulation preparation checks to clarify the conditions for considering a simulation ready, enhancing error handling and user feedback. - Added detailed logging for simulation status changes, improving traceability during the simulation lifecycle. - Introduced new files for simulation configuration and profile data, supporting enhanced testing and visualization capabilities.
This commit is contained in:
@@ -10,6 +10,8 @@ import time
|
||||
import asyncio
|
||||
import threading
|
||||
import subprocess
|
||||
import signal
|
||||
import atexit
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
@@ -21,6 +23,9 @@ from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.simulation_runner')
|
||||
|
||||
# 标记是否已注册清理函数
|
||||
_cleanup_registered = False
|
||||
|
||||
|
||||
class RunnerStatus(str, Enum):
|
||||
"""运行器状态"""
|
||||
@@ -342,34 +347,36 @@ class SimulationRunner:
|
||||
# 启动模拟进程
|
||||
try:
|
||||
# 构建运行命令,使用完整路径
|
||||
action_log_path = os.path.join(sim_dir, "actions.jsonl")
|
||||
# 新的日志结构:
|
||||
# twitter/actions.jsonl - Twitter 动作日志
|
||||
# reddit/actions.jsonl - Reddit 动作日志
|
||||
# simulation.log - 主进程日志
|
||||
|
||||
cmd = [
|
||||
sys.executable, # Python解释器
|
||||
script_path,
|
||||
"--config", config_path, # 使用完整配置文件路径
|
||||
"--action-log", action_log_path, # 动作日志文件完整路径
|
||||
]
|
||||
|
||||
# 创建输出日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
|
||||
stdout_log_path = os.path.join(sim_dir, "simulation_stdout.log")
|
||||
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
|
||||
stdout_file = open(stdout_log_path, 'w', encoding='utf-8')
|
||||
stderr_file = open(stderr_log_path, 'w', encoding='utf-8')
|
||||
# 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
|
||||
main_log_path = os.path.join(sim_dir, "simulation.log")
|
||||
main_log_file = open(main_log_path, 'w', encoding='utf-8')
|
||||
|
||||
# 设置工作目录为模拟目录(数据库等文件会生成在此)
|
||||
# 使用 start_new_session=True 创建新的进程组,确保可以通过 os.killpg 终止所有子进程
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=sim_dir,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file,
|
||||
stdout=main_log_file,
|
||||
stderr=subprocess.STDOUT, # stderr 也写入同一个文件
|
||||
text=True,
|
||||
bufsize=1,
|
||||
start_new_session=True, # 创建新进程组,确保服务器关闭时能终止所有相关进程
|
||||
)
|
||||
|
||||
# 保存文件句柄以便后续关闭
|
||||
cls._stdout_files[simulation_id] = stdout_file
|
||||
cls._stderr_files[simulation_id] = stderr_file
|
||||
cls._stdout_files[simulation_id] = main_log_file
|
||||
cls._stderr_files[simulation_id] = None # 不再需要单独的 stderr
|
||||
|
||||
state.process_pid = process.pid
|
||||
state.runner_status = RunnerStatus.RUNNING
|
||||
@@ -399,7 +406,10 @@ class SimulationRunner:
|
||||
def _monitor_simulation(cls, simulation_id: str):
|
||||
"""监控模拟进程,解析动作日志"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
actions_log = os.path.join(sim_dir, "actions.jsonl")
|
||||
|
||||
# 新的日志结构:分平台的动作日志
|
||||
twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl")
|
||||
reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl")
|
||||
|
||||
process = cls._processes.get(simulation_id)
|
||||
state = cls.get_run_state(simulation_id)
|
||||
@@ -407,43 +417,32 @@ class SimulationRunner:
|
||||
if not process or not state:
|
||||
return
|
||||
|
||||
last_position = 0
|
||||
twitter_position = 0
|
||||
reddit_position = 0
|
||||
|
||||
try:
|
||||
while process.poll() is None: # 进程仍在运行
|
||||
# 读取动作日志
|
||||
if os.path.exists(actions_log):
|
||||
with open(actions_log, 'r', encoding='utf-8') as f:
|
||||
f.seek(last_position)
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
action_data = json.loads(line)
|
||||
action = AgentAction(
|
||||
round_num=action_data.get("round", 0),
|
||||
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
|
||||
platform=action_data.get("platform", "unknown"),
|
||||
agent_id=action_data.get("agent_id", 0),
|
||||
agent_name=action_data.get("agent_name", ""),
|
||||
action_type=action_data.get("action_type", ""),
|
||||
action_args=action_data.get("action_args", {}),
|
||||
result=action_data.get("result"),
|
||||
success=action_data.get("success", True),
|
||||
)
|
||||
state.add_action(action)
|
||||
|
||||
# 更新轮次
|
||||
if action.round_num > state.current_round:
|
||||
state.current_round = action.round_num
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
last_position = f.tell()
|
||||
# 读取 Twitter 动作日志
|
||||
if os.path.exists(twitter_actions_log):
|
||||
twitter_position = cls._read_action_log(
|
||||
twitter_actions_log, twitter_position, state, "twitter"
|
||||
)
|
||||
|
||||
# 定期保存状态
|
||||
# 读取 Reddit 动作日志
|
||||
if os.path.exists(reddit_actions_log):
|
||||
reddit_position = cls._read_action_log(
|
||||
reddit_actions_log, reddit_position, state, "reddit"
|
||||
)
|
||||
|
||||
# 更新状态
|
||||
cls._save_run_state(state)
|
||||
time.sleep(1) # 每秒检查一次
|
||||
time.sleep(2)
|
||||
|
||||
# 进程结束后,最后读取一次日志
|
||||
if os.path.exists(twitter_actions_log):
|
||||
cls._read_action_log(twitter_actions_log, twitter_position, state, "twitter")
|
||||
if os.path.exists(reddit_actions_log):
|
||||
cls._read_action_log(reddit_actions_log, reddit_position, state, "reddit")
|
||||
|
||||
# 进程结束
|
||||
exit_code = process.returncode
|
||||
@@ -454,16 +453,16 @@ class SimulationRunner:
|
||||
logger.info(f"模拟完成: {simulation_id}")
|
||||
else:
|
||||
state.runner_status = RunnerStatus.FAILED
|
||||
# 从 stderr 日志文件读取错误信息
|
||||
stderr_log_path = os.path.join(sim_dir, "simulation_stderr.log")
|
||||
stderr = ""
|
||||
# 从主日志文件读取错误信息
|
||||
main_log_path = os.path.join(sim_dir, "simulation.log")
|
||||
error_info = ""
|
||||
try:
|
||||
if os.path.exists(stderr_log_path):
|
||||
with open(stderr_log_path, 'r', encoding='utf-8') as f:
|
||||
stderr = f.read()
|
||||
if os.path.exists(main_log_path):
|
||||
with open(main_log_path, 'r', encoding='utf-8') as f:
|
||||
error_info = f.read()[-2000:] # 取最后2000字符
|
||||
except Exception:
|
||||
pass
|
||||
state.error = f"进程退出码: {exit_code}, 错误: {stderr[-1000:]}" # 取最后1000字符
|
||||
state.error = f"进程退出码: {exit_code}, 错误: {error_info}"
|
||||
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
|
||||
|
||||
state.twitter_running = False
|
||||
@@ -488,13 +487,70 @@ class SimulationRunner:
|
||||
except Exception:
|
||||
pass
|
||||
cls._stdout_files.pop(simulation_id, None)
|
||||
if simulation_id in cls._stderr_files:
|
||||
if simulation_id in cls._stderr_files and cls._stderr_files[simulation_id]:
|
||||
try:
|
||||
cls._stderr_files[simulation_id].close()
|
||||
except Exception:
|
||||
pass
|
||||
cls._stderr_files.pop(simulation_id, None)
|
||||
|
||||
@classmethod
|
||||
def _read_action_log(
|
||||
cls,
|
||||
log_path: str,
|
||||
position: int,
|
||||
state: SimulationRunState,
|
||||
platform: str
|
||||
) -> int:
|
||||
"""
|
||||
读取动作日志文件
|
||||
|
||||
Args:
|
||||
log_path: 日志文件路径
|
||||
position: 上次读取位置
|
||||
state: 运行状态对象
|
||||
platform: 平台名称 (twitter/reddit)
|
||||
|
||||
Returns:
|
||||
新的读取位置
|
||||
"""
|
||||
try:
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
f.seek(position)
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
action_data = json.loads(line)
|
||||
|
||||
# 跳过事件类型的条目(如 simulation_start, round_start 等)
|
||||
if "event_type" in action_data:
|
||||
continue
|
||||
|
||||
action = AgentAction(
|
||||
round_num=action_data.get("round", 0),
|
||||
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
|
||||
platform=platform,
|
||||
agent_id=action_data.get("agent_id", 0),
|
||||
agent_name=action_data.get("agent_name", ""),
|
||||
action_type=action_data.get("action_type", ""),
|
||||
action_args=action_data.get("action_args", {}),
|
||||
result=action_data.get("result"),
|
||||
success=action_data.get("success", True),
|
||||
)
|
||||
state.add_action(action)
|
||||
|
||||
# 更新轮次
|
||||
if action.round_num and action.round_num > state.current_round:
|
||||
state.current_round = action.round_num
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return f.tell()
|
||||
except Exception as e:
|
||||
logger.warning(f"读取动作日志失败: {log_path}, error={e}")
|
||||
return position
|
||||
|
||||
@classmethod
|
||||
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
|
||||
"""停止模拟"""
|
||||
@@ -510,12 +566,35 @@ class SimulationRunner:
|
||||
|
||||
# 终止进程
|
||||
process = cls._processes.get(simulation_id)
|
||||
if process:
|
||||
process.terminate()
|
||||
if process and process.poll() is None:
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
# 使用进程组 ID 终止整个进程组(包括所有子进程)
|
||||
# 由于使用了 start_new_session=True,进程组 ID 等于主进程 PID
|
||||
pgid = os.getpgid(process.pid)
|
||||
logger.info(f"终止进程组: simulation={simulation_id}, pgid={pgid}")
|
||||
|
||||
# 先发送 SIGTERM 给整个进程组
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
# 如果 10 秒后还没结束,强制发送 SIGKILL
|
||||
logger.warning(f"进程组未响应 SIGTERM,强制终止: {simulation_id}")
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
process.wait(timeout=5)
|
||||
|
||||
except ProcessLookupError:
|
||||
# 进程已经不存在
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"终止进程组失败: {simulation_id}, error={e}")
|
||||
# 回退到直接终止进程
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=5)
|
||||
except Exception:
|
||||
process.kill()
|
||||
|
||||
state.runner_status = RunnerStatus.STOPPED
|
||||
state.twitter_running = False
|
||||
@@ -709,4 +788,133 @@ class SimulationRunner:
|
||||
result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def cleanup_all_simulations(cls):
|
||||
"""
|
||||
清理所有运行中的模拟进程
|
||||
|
||||
在服务器关闭时调用,确保所有子进程被终止
|
||||
"""
|
||||
logger.info("正在清理所有模拟进程...")
|
||||
|
||||
# 复制字典以避免在迭代时修改
|
||||
processes = list(cls._processes.items())
|
||||
|
||||
for simulation_id, process in processes:
|
||||
try:
|
||||
if process.poll() is None: # 进程仍在运行
|
||||
logger.info(f"终止模拟进程: {simulation_id}, pid={process.pid}")
|
||||
|
||||
try:
|
||||
# 使用进程组终止(包括所有子进程)
|
||||
pgid = os.getpgid(process.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"进程组未响应 SIGTERM,强制终止: {simulation_id}")
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
process.wait(timeout=5)
|
||||
|
||||
except (ProcessLookupError, OSError):
|
||||
# 进程可能已经不存在,尝试直接终止
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=3)
|
||||
except Exception:
|
||||
process.kill()
|
||||
|
||||
# 更新状态
|
||||
state = cls.get_run_state(simulation_id)
|
||||
if state:
|
||||
state.runner_status = RunnerStatus.STOPPED
|
||||
state.twitter_running = False
|
||||
state.reddit_running = False
|
||||
state.completed_at = datetime.now().isoformat()
|
||||
state.error = "服务器关闭,模拟被终止"
|
||||
cls._save_run_state(state)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理进程失败: {simulation_id}, error={e}")
|
||||
|
||||
# 清理文件句柄
|
||||
for simulation_id, file_handle in list(cls._stdout_files.items()):
|
||||
try:
|
||||
if file_handle:
|
||||
file_handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
cls._stdout_files.clear()
|
||||
|
||||
for simulation_id, file_handle in list(cls._stderr_files.items()):
|
||||
try:
|
||||
if file_handle:
|
||||
file_handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
cls._stderr_files.clear()
|
||||
|
||||
# 清理内存中的状态
|
||||
cls._processes.clear()
|
||||
cls._action_queues.clear()
|
||||
|
||||
logger.info("模拟进程清理完成")
|
||||
|
||||
@classmethod
|
||||
def register_cleanup(cls):
|
||||
"""
|
||||
注册清理函数
|
||||
|
||||
在 Flask 应用启动时调用,确保服务器关闭时清理所有模拟进程
|
||||
"""
|
||||
global _cleanup_registered
|
||||
|
||||
if _cleanup_registered:
|
||||
return
|
||||
|
||||
# 保存原有的信号处理器
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
original_sigterm = signal.getsignal(signal.SIGTERM)
|
||||
|
||||
def cleanup_handler(signum=None, frame=None):
|
||||
"""信号处理器:先清理模拟进程,再调用原处理器"""
|
||||
logger.info(f"收到信号 {signum},开始清理...")
|
||||
cls.cleanup_all_simulations()
|
||||
|
||||
# 调用原有的信号处理器,让 Flask 正常退出
|
||||
if signum == signal.SIGINT and callable(original_sigint):
|
||||
original_sigint(signum, frame)
|
||||
elif signum == signal.SIGTERM and callable(original_sigterm):
|
||||
original_sigterm(signum, frame)
|
||||
else:
|
||||
# 如果原处理器不可调用(如 SIG_DFL),则使用默认行为
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# 注册 atexit 处理器(作为备用)
|
||||
atexit.register(cls.cleanup_all_simulations)
|
||||
|
||||
# 注册信号处理器(仅在主线程中)
|
||||
try:
|
||||
# SIGTERM: kill 命令默认信号
|
||||
signal.signal(signal.SIGTERM, cleanup_handler)
|
||||
# SIGINT: Ctrl+C
|
||||
signal.signal(signal.SIGINT, cleanup_handler)
|
||||
except ValueError:
|
||||
# 不在主线程中,只能使用 atexit
|
||||
logger.warning("无法注册信号处理器(不在主线程),仅使用 atexit")
|
||||
|
||||
_cleanup_registered = True
|
||||
|
||||
@classmethod
|
||||
def get_running_simulations(cls) -> List[str]:
|
||||
"""
|
||||
获取所有正在运行的模拟ID列表
|
||||
"""
|
||||
running = []
|
||||
for sim_id, process in cls._processes.items():
|
||||
if process.poll() is None:
|
||||
running.append(sim_id)
|
||||
return running
|
||||
|
||||
|
||||
Reference in New Issue
Block a user