Enhance simulation configuration and management features
- Added support for a `max_rounds` parameter in simulation API, allowing users to limit the number of simulation rounds, improving control over simulation duration. - Updated README.md to reflect the new `max_rounds` parameter and its usage in simulation requests. - Enhanced error handling for `max_rounds` input validation to ensure it is a positive integer. - Modified simulation runner and related scripts to incorporate `max_rounds` functionality, ensuring consistent application across Twitter and Reddit simulations. - Improved logging to indicate when the number of rounds is truncated due to the `max_rounds` setting, enhancing traceability during simulation execution.
This commit is contained in:
@@ -404,9 +404,18 @@ async def run_twitter_simulation(
|
||||
config: Dict[str, Any],
|
||||
simulation_dir: str,
|
||||
action_logger: Optional[PlatformActionLogger] = None,
|
||||
main_logger: Optional[SimulationLogManager] = None
|
||||
main_logger: Optional[SimulationLogManager] = None,
|
||||
max_rounds: Optional[int] = None
|
||||
):
|
||||
"""运行Twitter模拟"""
|
||||
"""运行Twitter模拟
|
||||
|
||||
Args:
|
||||
config: 模拟配置
|
||||
simulation_dir: 模拟目录
|
||||
action_logger: 动作日志记录器
|
||||
main_logger: 主日志管理器
|
||||
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||
"""
|
||||
def log_info(msg):
|
||||
if main_logger:
|
||||
main_logger.info(f"[Twitter] {msg}")
|
||||
@@ -494,6 +503,13 @@ async def run_twitter_simulation(
|
||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||
total_rounds = (total_hours * 60) // minutes_per_round
|
||||
|
||||
# 如果指定了最大轮数,则截断
|
||||
if max_rounds is not None and max_rounds > 0:
|
||||
original_rounds = total_rounds
|
||||
total_rounds = min(total_rounds, max_rounds)
|
||||
if total_rounds < original_rounds:
|
||||
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
for round_num in range(total_rounds):
|
||||
@@ -552,9 +568,18 @@ async def run_reddit_simulation(
|
||||
config: Dict[str, Any],
|
||||
simulation_dir: str,
|
||||
action_logger: Optional[PlatformActionLogger] = None,
|
||||
main_logger: Optional[SimulationLogManager] = None
|
||||
main_logger: Optional[SimulationLogManager] = None,
|
||||
max_rounds: Optional[int] = None
|
||||
):
|
||||
"""运行Reddit模拟"""
|
||||
"""运行Reddit模拟
|
||||
|
||||
Args:
|
||||
config: 模拟配置
|
||||
simulation_dir: 模拟目录
|
||||
action_logger: 动作日志记录器
|
||||
main_logger: 主日志管理器
|
||||
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||
"""
|
||||
def log_info(msg):
|
||||
if main_logger:
|
||||
main_logger.info(f"[Reddit] {msg}")
|
||||
@@ -649,6 +674,13 @@ async def run_reddit_simulation(
|
||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||
total_rounds = (total_hours * 60) // minutes_per_round
|
||||
|
||||
# 如果指定了最大轮数,则截断
|
||||
if max_rounds is not None and max_rounds > 0:
|
||||
original_rounds = total_rounds
|
||||
total_rounds = min(total_rounds, max_rounds)
|
||||
if total_rounds < original_rounds:
|
||||
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
for round_num in range(total_rounds):
|
||||
@@ -721,6 +753,12 @@ async def main():
|
||||
action='store_true',
|
||||
help='只运行Reddit模拟'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max-rounds',
|
||||
type=int,
|
||||
default=None,
|
||||
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -746,9 +784,18 @@ async def main():
|
||||
log_manager.info("=" * 60)
|
||||
|
||||
time_config = config.get("time_config", {})
|
||||
total_hours = time_config.get('total_simulation_hours', 72)
|
||||
minutes_per_round = time_config.get('minutes_per_round', 30)
|
||||
config_total_rounds = (total_hours * 60) // minutes_per_round
|
||||
|
||||
log_manager.info(f"模拟参数:")
|
||||
log_manager.info(f" - 总模拟时长: {time_config.get('total_simulation_hours', 72)}小时")
|
||||
log_manager.info(f" - 每轮时间: {time_config.get('minutes_per_round', 30)}分钟")
|
||||
log_manager.info(f" - 总模拟时长: {total_hours}小时")
|
||||
log_manager.info(f" - 每轮时间: {minutes_per_round}分钟")
|
||||
log_manager.info(f" - 配置总轮数: {config_total_rounds}")
|
||||
if args.max_rounds:
|
||||
log_manager.info(f" - 最大轮数限制: {args.max_rounds}")
|
||||
if args.max_rounds < config_total_rounds:
|
||||
log_manager.info(f" - 实际执行轮数: {args.max_rounds} (已截断)")
|
||||
log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}")
|
||||
|
||||
log_manager.info("日志结构:")
|
||||
@@ -760,14 +807,14 @@ async def main():
|
||||
start_time = datetime.now()
|
||||
|
||||
if args.twitter_only:
|
||||
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager)
|
||||
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds)
|
||||
elif args.reddit_only:
|
||||
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager)
|
||||
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds)
|
||||
else:
|
||||
# 并行运行(每个平台使用独立的日志记录器)
|
||||
await asyncio.gather(
|
||||
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager),
|
||||
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager),
|
||||
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds),
|
||||
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds),
|
||||
)
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
@@ -251,8 +251,12 @@ class RedditSimulationRunner:
|
||||
|
||||
return active_agents
|
||||
|
||||
async def run(self):
|
||||
"""运行Reddit模拟"""
|
||||
async def run(self, max_rounds: int = None):
|
||||
"""运行Reddit模拟
|
||||
|
||||
Args:
|
||||
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("OASIS Reddit模拟")
|
||||
print(f"配置文件: {self.config_path}")
|
||||
@@ -264,10 +268,19 @@ class RedditSimulationRunner:
|
||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||
total_rounds = (total_hours * 60) // minutes_per_round
|
||||
|
||||
# 如果指定了最大轮数,则截断
|
||||
if max_rounds is not None and max_rounds > 0:
|
||||
original_rounds = total_rounds
|
||||
total_rounds = min(total_rounds, max_rounds)
|
||||
if total_rounds < original_rounds:
|
||||
print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
||||
|
||||
print(f"\n模拟参数:")
|
||||
print(f" - 总模拟时长: {total_hours}小时")
|
||||
print(f" - 每轮时间: {minutes_per_round}分钟")
|
||||
print(f" - 总轮数: {total_rounds}")
|
||||
if max_rounds:
|
||||
print(f" - 最大轮数限制: {max_rounds}")
|
||||
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
|
||||
|
||||
print("\n初始化LLM模型...")
|
||||
@@ -380,6 +393,12 @@ async def main():
|
||||
required=True,
|
||||
help='配置文件路径 (simulation_config.json)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max-rounds',
|
||||
type=int,
|
||||
default=None,
|
||||
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -392,7 +411,7 @@ async def main():
|
||||
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||
|
||||
runner = RedditSimulationRunner(args.config)
|
||||
await runner.run()
|
||||
await runner.run(max_rounds=args.max_rounds)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -259,8 +259,12 @@ class TwitterSimulationRunner:
|
||||
|
||||
return active_agents
|
||||
|
||||
async def run(self):
|
||||
"""运行Twitter模拟"""
|
||||
async def run(self, max_rounds: int = None):
|
||||
"""运行Twitter模拟
|
||||
|
||||
Args:
|
||||
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("OASIS Twitter模拟")
|
||||
print(f"配置文件: {self.config_path}")
|
||||
@@ -275,10 +279,19 @@ class TwitterSimulationRunner:
|
||||
# 计算总轮数
|
||||
total_rounds = (total_hours * 60) // minutes_per_round
|
||||
|
||||
# 如果指定了最大轮数,则截断
|
||||
if max_rounds is not None and max_rounds > 0:
|
||||
original_rounds = total_rounds
|
||||
total_rounds = min(total_rounds, max_rounds)
|
||||
if total_rounds < original_rounds:
|
||||
print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
||||
|
||||
print(f"\n模拟参数:")
|
||||
print(f" - 总模拟时长: {total_hours}小时")
|
||||
print(f" - 每轮时间: {minutes_per_round}分钟")
|
||||
print(f" - 总轮数: {total_rounds}")
|
||||
if max_rounds:
|
||||
print(f" - 最大轮数限制: {max_rounds}")
|
||||
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
|
||||
|
||||
# 创建模型
|
||||
@@ -393,6 +406,12 @@ async def main():
|
||||
required=True,
|
||||
help='配置文件路径 (simulation_config.json)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max-rounds',
|
||||
type=int,
|
||||
default=None,
|
||||
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -405,7 +424,7 @@ async def main():
|
||||
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||
|
||||
runner = TwitterSimulationRunner(args.config)
|
||||
await runner.run()
|
||||
await runner.run(max_rounds=args.max_rounds)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user