Enhance simulation configuration and management features

- Added support for a `max_rounds` parameter in simulation API, allowing users to limit the number of simulation rounds, improving control over simulation duration.
- Updated README.md to reflect the new `max_rounds` parameter and its usage in simulation requests.
- Enhanced error handling for `max_rounds` input validation to ensure it is a positive integer.
- Modified simulation runner and related scripts to incorporate `max_rounds` functionality, ensuring consistent application across Twitter and Reddit simulations.
- Improved logging to indicate when the number of rounds is truncated due to the `max_rounds` setting, enhancing traceability during simulation execution.
This commit is contained in:
666ghj
2025-12-05 15:50:54 +08:00
parent 3c1d554152
commit 5b4f02f421
9 changed files with 243 additions and 53 deletions

View File

@@ -404,9 +404,18 @@ async def run_twitter_simulation(
config: Dict[str, Any],
simulation_dir: str,
action_logger: Optional[PlatformActionLogger] = None,
main_logger: Optional[SimulationLogManager] = None
main_logger: Optional[SimulationLogManager] = None,
max_rounds: Optional[int] = None
):
"""运行Twitter模拟"""
"""运行Twitter模拟
Args:
config: 模拟配置
simulation_dir: 模拟目录
action_logger: 动作日志记录器
main_logger: 主日志管理器
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
"""
def log_info(msg):
if main_logger:
main_logger.info(f"[Twitter] {msg}")
@@ -494,6 +503,13 @@ async def run_twitter_simulation(
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
start_time = datetime.now()
for round_num in range(total_rounds):
@@ -552,9 +568,18 @@ async def run_reddit_simulation(
config: Dict[str, Any],
simulation_dir: str,
action_logger: Optional[PlatformActionLogger] = None,
main_logger: Optional[SimulationLogManager] = None
main_logger: Optional[SimulationLogManager] = None,
max_rounds: Optional[int] = None
):
"""运行Reddit模拟"""
"""运行Reddit模拟
Args:
config: 模拟配置
simulation_dir: 模拟目录
action_logger: 动作日志记录器
main_logger: 主日志管理器
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
"""
def log_info(msg):
if main_logger:
main_logger.info(f"[Reddit] {msg}")
@@ -649,6 +674,13 @@ async def run_reddit_simulation(
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
start_time = datetime.now()
for round_num in range(total_rounds):
@@ -721,6 +753,12 @@ async def main():
action='store_true',
help='只运行Reddit模拟'
)
parser.add_argument(
'--max-rounds',
type=int,
default=None,
help='最大模拟轮数(可选,用于截断过长的模拟)'
)
args = parser.parse_args()
@@ -746,9 +784,18 @@ async def main():
log_manager.info("=" * 60)
time_config = config.get("time_config", {})
total_hours = time_config.get('total_simulation_hours', 72)
minutes_per_round = time_config.get('minutes_per_round', 30)
config_total_rounds = (total_hours * 60) // minutes_per_round
log_manager.info(f"模拟参数:")
log_manager.info(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
log_manager.info(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
log_manager.info(f" - 总模拟时长: {total_hours}小时")
log_manager.info(f" - 每轮时间: {minutes_per_round}分钟")
log_manager.info(f" - 配置总轮数: {config_total_rounds}")
if args.max_rounds:
log_manager.info(f" - 最大轮数限制: {args.max_rounds}")
if args.max_rounds < config_total_rounds:
log_manager.info(f" - 实际执行轮数: {args.max_rounds} (已截断)")
log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}")
log_manager.info("日志结构:")
@@ -760,14 +807,14 @@ async def main():
start_time = datetime.now()
if args.twitter_only:
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager)
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds)
elif args.reddit_only:
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager)
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds)
else:
# 并行运行(每个平台使用独立的日志记录器)
await asyncio.gather(
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager),
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager),
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds),
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds),
)
total_elapsed = (datetime.now() - start_time).total_seconds()

View File

@@ -251,8 +251,12 @@ class RedditSimulationRunner:
return active_agents
async def run(self):
"""运行Reddit模拟"""
async def run(self, max_rounds: int = None):
"""运行Reddit模拟
Args:
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
"""
print("=" * 60)
print("OASIS Reddit模拟")
print(f"配置文件: {self.config_path}")
@@ -264,10 +268,19 @@ class RedditSimulationRunner:
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
print(f"\n模拟参数:")
print(f" - 总模拟时长: {total_hours}小时")
print(f" - 每轮时间: {minutes_per_round}分钟")
print(f" - 总轮数: {total_rounds}")
if max_rounds:
print(f" - 最大轮数限制: {max_rounds}")
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
print("\n初始化LLM模型...")
@@ -380,6 +393,12 @@ async def main():
required=True,
help='配置文件路径 (simulation_config.json)'
)
parser.add_argument(
'--max-rounds',
type=int,
default=None,
help='最大模拟轮数(可选,用于截断过长的模拟)'
)
args = parser.parse_args()
@@ -392,7 +411,7 @@ async def main():
setup_oasis_logging(os.path.join(simulation_dir, "log"))
runner = RedditSimulationRunner(args.config)
await runner.run()
await runner.run(max_rounds=args.max_rounds)
if __name__ == "__main__":

View File

@@ -259,8 +259,12 @@ class TwitterSimulationRunner:
return active_agents
async def run(self):
"""运行Twitter模拟"""
async def run(self, max_rounds: int = None):
"""运行Twitter模拟
Args:
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
"""
print("=" * 60)
print("OASIS Twitter模拟")
print(f"配置文件: {self.config_path}")
@@ -275,10 +279,19 @@ class TwitterSimulationRunner:
# 计算总轮数
total_rounds = (total_hours * 60) // minutes_per_round
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
print(f"\n模拟参数:")
print(f" - 总模拟时长: {total_hours}小时")
print(f" - 每轮时间: {minutes_per_round}分钟")
print(f" - 总轮数: {total_rounds}")
if max_rounds:
print(f" - 最大轮数限制: {max_rounds}")
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
# 创建模型
@@ -393,6 +406,12 @@ async def main():
required=True,
help='配置文件路径 (simulation_config.json)'
)
parser.add_argument(
'--max-rounds',
type=int,
default=None,
help='最大模拟轮数(可选,用于截断过长的模拟)'
)
args = parser.parse_args()
@@ -405,7 +424,7 @@ async def main():
setup_oasis_logging(os.path.join(simulation_dir, "log"))
runner = TwitterSimulationRunner(args.config)
await runner.run()
await runner.run(max_rounds=args.max_rounds)
if __name__ == "__main__":