Enhance simulation configuration and management features

- Added support for a `max_rounds` parameter in simulation API, allowing users to limit the number of simulation rounds, improving control over simulation duration.
- Updated README.md to reflect the new `max_rounds` parameter and its usage in simulation requests.
- Enhanced error handling for `max_rounds` input validation to ensure it is a positive integer.
- Modified simulation runner and related scripts to incorporate `max_rounds` functionality, ensuring consistent application across Twitter and Reddit simulations.
- Improved logging to indicate when the number of rounds is truncated due to the `max_rounds` setting, enhancing traceability during simulation execution.
This commit is contained in:
666ghj
2025-12-05 15:50:54 +08:00
parent 3c1d554152
commit 5b4f02f421
9 changed files with 243 additions and 53 deletions

View File

@@ -280,7 +280,8 @@ class SimulationRunner:
def start_simulation(
cls,
simulation_id: str,
platform: str = "parallel" # twitter / reddit / parallel
platform: str = "parallel", # twitter / reddit / parallel
max_rounds: int = None # 最大模拟轮数(可选,用于截断过长的模拟)
) -> SimulationRunState:
"""
启动模拟
@@ -288,6 +289,7 @@ class SimulationRunner:
Args:
simulation_id: 模拟ID
platform: 运行平台 (twitter/reddit/parallel)
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
Returns:
SimulationRunState
@@ -313,6 +315,13 @@ class SimulationRunner:
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = int(total_hours * 60 / minutes_per_round)
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
logger.info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
state = SimulationRunState(
simulation_id=simulation_id,
runner_status=RunnerStatus.STARTING,
@@ -358,6 +367,10 @@ class SimulationRunner:
"--config", config_path, # 使用完整配置文件路径
]
# 如果指定了最大轮数,添加到命令行参数
if max_rounds is not None and max_rounds > 0:
cmd.extend(["--max-rounds", str(max_rounds)])
# 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
main_log_path = os.path.join(sim_dir, "simulation.log")
main_log_file = open(main_log_path, 'w', encoding='utf-8')