fix: wire agent group selection through to simulation start
Step 2 -> Step 3 -> Simulation API now passes selected_agent_ids.
Backend filters reddit_profiles.json and twitter_profiles.csv
to only include selected agents before starting simulation.
Flow: Step2 checkboxes -> emit('next-step', {selectedAgentIds}) ->
router query -> Step3 props -> startSimulation API -> filter profiles
This commit is contained in:
@@ -1448,6 +1448,41 @@ def generate_profiles():
|
||||
|
||||
# ============== 模拟运行控制接口 ==============
|
||||
|
||||
def _filter_simulation_agents(simulation_id: str, selected_agent_ids: list):
|
||||
"""
|
||||
Filter simulation profiles to only include selected agents.
|
||||
Modifies the profile files in-place before simulation starts.
|
||||
"""
|
||||
import csv
|
||||
sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id)
|
||||
|
||||
# Filter Reddit profiles (JSON)
|
||||
reddit_path = os.path.join(sim_dir, 'reddit_profiles.json')
|
||||
if os.path.exists(reddit_path):
|
||||
with open(reddit_path, 'r', encoding='utf-8') as f:
|
||||
profiles = json.load(f)
|
||||
|
||||
if isinstance(profiles, list):
|
||||
filtered = [p for i, p in enumerate(profiles) if i in selected_agent_ids]
|
||||
with open(reddit_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(filtered, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"Reddit profiles: {len(profiles)} -> {len(filtered)}")
|
||||
|
||||
# Filter Twitter profiles (CSV)
|
||||
twitter_path = os.path.join(sim_dir, 'twitter_profiles.csv')
|
||||
if os.path.exists(twitter_path):
|
||||
with open(twitter_path, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
rows = list(reader)
|
||||
fieldnames = reader.fieldnames
|
||||
|
||||
filtered = [r for i, r in enumerate(rows) if i in selected_agent_ids]
|
||||
with open(twitter_path, 'w', encoding='utf-8', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(filtered)
|
||||
logger.info(f"Twitter profiles: {len(rows)} -> {len(filtered)}")
|
||||
|
||||
@simulation_bp.route('/start', methods=['POST'])
|
||||
def start_simulation():
|
||||
"""
|
||||
@@ -1503,6 +1538,7 @@ def start_simulation():
|
||||
max_rounds = data.get('max_rounds') # 可选:最大模拟轮数
|
||||
enable_graph_memory_update = data.get('enable_graph_memory_update', False) # 可选:是否启用图谱记忆更新
|
||||
force = data.get('force', False) # 可选:强制重新开始
|
||||
selected_agent_ids = data.get('selected_agent_ids') # 可选:选中的Agent ID列表
|
||||
|
||||
# 验证 max_rounds 参数
|
||||
if max_rounds is not None:
|
||||
@@ -1600,6 +1636,11 @@ def start_simulation():
|
||||
|
||||
logger.info(f"启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}")
|
||||
|
||||
# Filter agents if selected_agent_ids provided
|
||||
if selected_agent_ids:
|
||||
logger.info(f"Filtering agents: selected {len(selected_agent_ids)} agents")
|
||||
_filter_simulation_agents(simulation_id, selected_agent_ids)
|
||||
|
||||
# 启动模拟
|
||||
run_state = SimulationRunner.start_simulation(
|
||||
simulation_id=simulation_id,
|
||||
|
||||
Reference in New Issue
Block a user