Enhance simulation configuration and management features

- Added support for a `max_rounds` parameter in simulation API, allowing users to limit the number of simulation rounds, improving control over simulation duration.
- Updated README.md to reflect the new `max_rounds` parameter and its usage in simulation requests.
- Enhanced error handling for `max_rounds` input validation to ensure it is a positive integer.
- Modified simulation runner and related scripts to incorporate `max_rounds` functionality, ensuring consistent application across Twitter and Reddit simulations.
- Improved logging to indicate when the number of rounds is truncated due to the `max_rounds` setting, enhancing traceability during simulation execution.
This commit is contained in:
666ghj
2025-12-05 15:50:54 +08:00
parent 3c1d554152
commit 5b4f02f421
9 changed files with 243 additions and 53 deletions

View File

@@ -251,8 +251,12 @@ class RedditSimulationRunner:
return active_agents
async def run(self):
"""运行Reddit模拟"""
async def run(self, max_rounds: int = None):
"""运行Reddit模拟
Args:
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
"""
print("=" * 60)
print("OASIS Reddit模拟")
print(f"配置文件: {self.config_path}")
@@ -264,10 +268,19 @@ class RedditSimulationRunner:
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
print(f"\n模拟参数:")
print(f" - 总模拟时长: {total_hours}小时")
print(f" - 每轮时间: {minutes_per_round}分钟")
print(f" - 总轮数: {total_rounds}")
if max_rounds:
print(f" - 最大轮数限制: {max_rounds}")
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
print("\n初始化LLM模型...")
@@ -380,6 +393,12 @@ async def main():
required=True,
help='配置文件路径 (simulation_config.json)'
)
parser.add_argument(
'--max-rounds',
type=int,
default=None,
help='最大模拟轮数(可选,用于截断过长的模拟)'
)
args = parser.parse_args()
@@ -392,7 +411,7 @@ async def main():
setup_oasis_logging(os.path.join(simulation_dir, "log"))
runner = RedditSimulationRunner(args.config)
await runner.run()
await runner.run(max_rounds=args.max_rounds)
if __name__ == "__main__":