Enhance simulation configuration and management features

- Added support for a `max_rounds` parameter in simulation API, allowing users to limit the number of simulation rounds, improving control over simulation duration.
- Updated README.md to reflect the new `max_rounds` parameter and its usage in simulation requests.
- Enhanced error handling for `max_rounds` input validation to ensure it is a positive integer.
- Modified simulation runner and related scripts to incorporate `max_rounds` functionality, ensuring consistent application across Twitter and Reddit simulations.
- Improved logging to indicate when the number of rounds is truncated due to the `max_rounds` setting, enhancing traceability during simulation execution.
This commit is contained in:
666ghj
2025-12-05 15:50:54 +08:00
parent 3c1d554152
commit 5b4f02f421
9 changed files with 243 additions and 53 deletions

View File

@@ -85,8 +85,8 @@ class TimeSimulationConfig:
# 模拟总时长(模拟小时数)
total_simulation_hours: int = 72 # 默认模拟72小时3天
# 每轮代表的时间(模拟分钟)
minutes_per_round: int = 30
# 每轮代表的时间(模拟分钟)- 默认60分钟1小时加快时间流速
minutes_per_round: int = 60
# 每小时激活的Agent数量范围
agents_per_hour_min: int = 5
@@ -205,7 +205,7 @@ class SimulationConfigGenerator:
采用分步生成策略:
1. 生成时间配置和事件配置(轻量级)
2. 分批生成Agent配置每批10-15个)
2. 分批生成Agent配置每批10-20个)
3. 生成平台配置
"""
@@ -214,6 +214,13 @@ class SimulationConfigGenerator:
# 每批生成的Agent数量
AGENTS_PER_BATCH = 15
# 各步骤的上下文截断长度(字符数)
TIME_CONFIG_CONTEXT_LENGTH = 10000 # 时间配置
EVENT_CONFIG_CONTEXT_LENGTH = 8000 # 事件配置
ENTITY_SUMMARY_LENGTH = 300 # 实体摘要
AGENT_SUMMARY_LENGTH = 300 # Agent配置中的实体摘要
ENTITIES_PER_TYPE_DISPLAY = 20 # 每类实体显示数量
def __init__(
self,
api_key: Optional[str] = None,
@@ -286,8 +293,9 @@ class SimulationConfigGenerator:
# ========== 步骤1: 生成时间配置 ==========
report_progress(1, "生成时间配置...")
time_config_result = self._generate_time_config(context, len(entities))
time_config = self._parse_time_config(time_config_result)
num_entities = len(entities)
time_config_result = self._generate_time_config(context, num_entities)
time_config = self._parse_time_config(time_config_result, num_entities)
reasoning_parts.append(f"时间配置: {time_config_result.get('reasoning', '成功')}")
# ========== 步骤2: 生成事件配置 ==========
@@ -411,11 +419,14 @@ class SimulationConfigGenerator:
for entity_type, type_entities in by_type.items():
lines.append(f"\n### {entity_type} ({len(type_entities)}个)")
for e in type_entities[:10]: # 每类最多显示10个
summary_preview = (e.summary[:100] + "...") if len(e.summary) > 100 else e.summary
# 使用配置的显示数量和摘要长度
display_count = self.ENTITIES_PER_TYPE_DISPLAY
summary_len = self.ENTITY_SUMMARY_LENGTH
for e in type_entities[:display_count]:
summary_preview = (e.summary[:summary_len] + "...") if len(e.summary) > summary_len else e.summary
lines.append(f"- {e.name}: {summary_preview}")
if len(type_entities) > 10:
lines.append(f" ... 还有 {len(type_entities) - 10}")
if len(type_entities) > display_count:
lines.append(f" ... 还有 {len(type_entities) - display_count}")
return "\n".join(lines)
@@ -522,33 +533,56 @@ class SimulationConfigGenerator:
def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, Any]:
"""生成时间配置"""
# 使用配置的上下文截断长度
context_truncated = context[:self.TIME_CONFIG_CONTEXT_LENGTH]
# 计算最大允许值80%的agent数
max_agents_allowed = max(1, int(num_entities * 0.9))
prompt = f"""基于以下模拟需求,生成时间模拟配置。
{context[:5000]}
{context_truncated}
## 任务
请生成时间配置JSON,注意:
请生成时间配置JSON
### 基本原则(仅供参考,需根据具体事件和参与群体灵活调整):
- 用户群体为中国人,需符合北京时间作息习惯
- 凌晨0-5点几乎无人活动活跃度系数0.05
- 早上6-8点逐渐活跃活跃度系数0.4
- 工作时间9-18点中等活跃活跃度系数0.7
- 晚间19-22点是高峰期活跃度系数1.5
- 23点后活跃度下降活跃度系数0.5
- 一般规律:凌晨低活跃、早间渐增、工作时段中等、晚间高峰
- **重要**:以下示例值仅供参考,你需要根据事件性质、参与群体特点来调整具体时段
- 例如学生群体高峰可能是21-23点媒体全天活跃官方机构只在工作时间
- 例如突发热点可能导致深夜也有讨论off_peak_hours 可适当缩短
当前实体数量: {num_entities}
### 返回JSON格式不要markdown
返回JSON格式不要markdown
示例
{{
"total_simulation_hours": <72-168根据事件性质决定>,
"minutes_per_round": <15-60>,
"agents_per_hour_min": <每小时最少激活Agent数>,
"agents_per_hour_max": <每小时最多激活Agent数>,
"total_simulation_hours": 72,
"minutes_per_round": 60,
"agents_per_hour_min": 5,
"agents_per_hour_max": 50,
"peak_hours": [19, 20, 21, 22],
"off_peak_hours": [0, 1, 2, 3, 4, 5],
"morning_hours": [6, 7, 8],
"work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
"reasoning": "<简要说明>"
}}"""
"reasoning": "针对该事件的时间配置说明"
}}
字段说明:
- total_simulation_hours (int): 模拟总时长24-168小时突发事件短、持续话题长
- minutes_per_round (int): 每轮时长30-120分钟建议60分钟
- agents_per_hour_min (int): 每小时最少激活Agent数取值范围: 1-{max_agents_allowed}
- agents_per_hour_max (int): 每小时最多激活Agent数取值范围: 1-{max_agents_allowed}
- peak_hours (int数组): 高峰时段,根据事件参与群体调整
- off_peak_hours (int数组): 低谷时段,通常深夜凌晨
- morning_hours (int数组): 早间时段
- work_hours (int数组): 工作时段
- reasoning (string): 简要说明为什么这样配置"""
system_prompt = "你是社交媒体模拟专家。返回纯JSON格式时间配置需符合中国人作息习惯。"
@@ -562,23 +596,41 @@ class SimulationConfigGenerator:
"""获取默认时间配置(中国人作息)"""
return {
"total_simulation_hours": 72,
"minutes_per_round": 30,
"minutes_per_round": 60, # 每轮1小时加快时间流速
"agents_per_hour_min": max(1, num_entities // 15),
"agents_per_hour_max": max(5, num_entities // 5),
"peak_hours": [19, 20, 21, 22],
"off_peak_hours": [0, 1, 2, 3, 4, 5],
"morning_hours": [6, 7, 8],
"work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
"reasoning": "使用默认中国人作息配置"
"reasoning": "使用默认中国人作息配置每轮1小时"
}
def _parse_time_config(self, result: Dict[str, Any]) -> TimeSimulationConfig:
"""解析时间配置结果"""
def _parse_time_config(self, result: Dict[str, Any], num_entities: int) -> TimeSimulationConfig:
"""解析时间配置结果并验证agents_per_hour值不超过总agent数"""
# 获取原始值
agents_per_hour_min = result.get("agents_per_hour_min", max(1, num_entities // 15))
agents_per_hour_max = result.get("agents_per_hour_max", max(5, num_entities // 5))
# 验证并修正确保不超过总agent数
if agents_per_hour_min > num_entities:
logger.warning(f"agents_per_hour_min ({agents_per_hour_min}) 超过总Agent数 ({num_entities}),已修正")
agents_per_hour_min = max(1, num_entities // 10)
if agents_per_hour_max > num_entities:
logger.warning(f"agents_per_hour_max ({agents_per_hour_max}) 超过总Agent数 ({num_entities}),已修正")
agents_per_hour_max = max(agents_per_hour_min + 1, num_entities // 2)
# 确保 min < max
if agents_per_hour_min >= agents_per_hour_max:
agents_per_hour_min = max(1, agents_per_hour_max // 2)
logger.warning(f"agents_per_hour_min >= max已修正为 {agents_per_hour_min}")
return TimeSimulationConfig(
total_simulation_hours=result.get("total_simulation_hours", 72),
minutes_per_round=result.get("minutes_per_round", 30),
agents_per_hour_min=result.get("agents_per_hour_min", 5),
agents_per_hour_max=result.get("agents_per_hour_max", 20),
minutes_per_round=result.get("minutes_per_round", 60), # 默认每轮1小时
agents_per_hour_min=agents_per_hour_min,
agents_per_hour_max=agents_per_hour_max,
peak_hours=result.get("peak_hours", [19, 20, 21, 22]),
off_peak_hours=result.get("off_peak_hours", [0, 1, 2, 3, 4, 5]),
off_peak_activity_multiplier=0.05, # 凌晨几乎无人
@@ -616,11 +668,14 @@ class SimulationConfigGenerator:
for t, examples in type_examples.items()
])
# 使用配置的上下文截断长度
context_truncated = context[:self.EVENT_CONFIG_CONTEXT_LENGTH]
prompt = f"""基于以下模拟需求,生成事件配置。
模拟需求: {simulation_requirement}
{context[:3000]}
{context_truncated}
## 可用实体类型及示例
{type_info}
@@ -761,14 +816,15 @@ class SimulationConfigGenerator:
) -> List[AgentActivityConfig]:
"""分批生成Agent配置"""
# 构建实体信息
# 构建实体信息(使用配置的摘要长度)
entity_list = []
summary_len = self.AGENT_SUMMARY_LENGTH
for i, e in enumerate(entities):
entity_list.append({
"agent_id": start_idx + i,
"entity_name": e.name,
"entity_type": e.get_entity_type() or "Unknown",
"summary": e.summary[:150] if e.summary else ""
"summary": e.summary[:summary_len] if e.summary else ""
})
prompt = f"""基于以下信息,为每个实体生成社交媒体活动配置。

View File

@@ -280,7 +280,8 @@ class SimulationRunner:
def start_simulation(
cls,
simulation_id: str,
platform: str = "parallel" # twitter / reddit / parallel
platform: str = "parallel", # twitter / reddit / parallel
max_rounds: int = None # 最大模拟轮数(可选,用于截断过长的模拟)
) -> SimulationRunState:
"""
启动模拟
@@ -288,6 +289,7 @@ class SimulationRunner:
Args:
simulation_id: 模拟ID
platform: 运行平台 (twitter/reddit/parallel)
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
Returns:
SimulationRunState
@@ -313,6 +315,13 @@ class SimulationRunner:
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = int(total_hours * 60 / minutes_per_round)
# 如果指定了最大轮数,则截断
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
logger.info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
state = SimulationRunState(
simulation_id=simulation_id,
runner_status=RunnerStatus.STARTING,
@@ -358,6 +367,10 @@ class SimulationRunner:
"--config", config_path, # 使用完整配置文件路径
]
# 如果指定了最大轮数,添加到命令行参数
if max_rounds is not None and max_rounds > 0:
cmd.extend(["--max-rounds", str(max_rounds)])
# 创建主日志文件,避免 stdout/stderr 管道缓冲区满导致进程阻塞
main_log_path = os.path.join(sim_dir, "simulation.log")
main_log_file = open(main_log_path, 'w', encoding='utf-8')