From 3fc96f39dd4701dc8c9b77215bac57c1263846d7 Mon Sep 17 00:00:00 2001 From: Kunthawat Greethong Date: Fri, 26 Jun 2026 13:48:23 +0700 Subject: [PATCH] fix: wire agent group selection through to simulation start Step 2 -> Step 3 -> Simulation API now passes selected_agent_ids. Backend filters reddit_profiles.json and twitter_profiles.csv to only include selected agents before starting simulation. Flow: Step2 checkboxes -> emit('next-step', {selectedAgentIds}) -> router query -> Step3 props -> startSimulation API -> filter profiles --- backend/app/api/simulation.py | 41 +++++++++++++++++++++ frontend/src/components/Step2EnvSetup.vue | 6 +++ frontend/src/components/Step3Simulation.vue | 9 ++++- frontend/src/views/SimulationRunView.vue | 2 + frontend/src/views/SimulationView.vue | 9 ++++- 5 files changed, 65 insertions(+), 2 deletions(-) diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index a3aa712..7de5bed 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -1448,6 +1448,41 @@ def generate_profiles(): # ============== 模拟运行控制接口 ============== +def _filter_simulation_agents(simulation_id: str, selected_agent_ids: list): + """ + Filter simulation profiles to only include selected agents. + Modifies the profile files in-place before simulation starts. + """ + import csv + sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) + + # Filter Reddit profiles (JSON) + reddit_path = os.path.join(sim_dir, 'reddit_profiles.json') + if os.path.exists(reddit_path): + with open(reddit_path, 'r', encoding='utf-8') as f: + profiles = json.load(f) + + if isinstance(profiles, list): + filtered = [p for i, p in enumerate(profiles) if i in selected_agent_ids] + with open(reddit_path, 'w', encoding='utf-8') as f: + json.dump(filtered, f, ensure_ascii=False, indent=2) + logger.info(f"Reddit profiles: {len(profiles)} -> {len(filtered)}") + + # Filter Twitter profiles (CSV) + twitter_path = os.path.join(sim_dir, 'twitter_profiles.csv') + if os.path.exists(twitter_path): + with open(twitter_path, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + rows = list(reader) + fieldnames = reader.fieldnames + + filtered = [r for i, r in enumerate(rows) if i in selected_agent_ids] + with open(twitter_path, 'w', encoding='utf-8', newline='') as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(filtered) + logger.info(f"Twitter profiles: {len(rows)} -> {len(filtered)}") + @simulation_bp.route('/start', methods=['POST']) def start_simulation(): """ @@ -1503,6 +1538,7 @@ def start_simulation(): max_rounds = data.get('max_rounds') # 可选:最大模拟轮数 enable_graph_memory_update = data.get('enable_graph_memory_update', False) # 可选:是否启用图谱记忆更新 force = data.get('force', False) # 可选:强制重新开始 + selected_agent_ids = data.get('selected_agent_ids') # 可选:选中的Agent ID列表 # 验证 max_rounds 参数 if max_rounds is not None: @@ -1600,6 +1636,11 @@ def start_simulation(): logger.info(f"启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}") + # Filter agents if selected_agent_ids provided + if selected_agent_ids: + logger.info(f"Filtering agents: selected {len(selected_agent_ids)} agents") + _filter_simulation_agents(simulation_id, selected_agent_ids) + # 启动模拟 run_state = SimulationRunner.start_simulation( simulation_id=simulation_id, diff --git a/frontend/src/components/Step2EnvSetup.vue b/frontend/src/components/Step2EnvSetup.vue index 62ddefd..823dd6e 100644 --- a/frontend/src/components/Step2EnvSetup.vue +++ b/frontend/src/components/Step2EnvSetup.vue @@ -834,6 +834,12 @@ const handleStartSimulation = () => { addLog(t('log.startSimAutoRounds', { rounds: autoGeneratedRounds.value })) } + // Pass selected agent IDs (filtered by group checkboxes) + params.selectedAgentIds = getSelectedAgentIds() + if (agentGroups.value.length > 0) { + addLog(`✅ ส่ง ${params.selectedAgentIds.length} agents ไปขั้นตอนถัดไป`) + } + emit('next-step', params) } diff --git a/frontend/src/components/Step3Simulation.vue b/frontend/src/components/Step3Simulation.vue index 5b0f968..87a8c15 100644 --- a/frontend/src/components/Step3Simulation.vue +++ b/frontend/src/components/Step3Simulation.vue @@ -301,7 +301,8 @@ const { t } = useI18n() const props = defineProps({ simulationId: String, - maxRounds: Number, // 从Step2传入的最大轮数 + maxRounds: Number, + selectedAgentIds: Array, // Agent IDs selected from Step 2 grouping // 从Step2传入的最大轮数 minutesPerRound: { type: Number, default: 30 // 默认每轮30分钟 @@ -407,6 +408,12 @@ const doStartSimulation = async () => { addLog(t('log.setMaxRounds', { rounds: props.maxRounds })) } + // Pass selected agent IDs from Step 2 grouping + if (props.selectedAgentIds && props.selectedAgentIds.length > 0) { + params.selected_agent_ids = props.selectedAgentIds + addLog(`✅ ใช้ ${props.selectedAgentIds.length} agents ที่เลือก`) + } + addLog(t('log.graphMemoryUpdateEnabled')) const res = await startSimulation(params) diff --git a/frontend/src/views/SimulationRunView.vue b/frontend/src/views/SimulationRunView.vue index 6166c3d..2b6d966 100644 --- a/frontend/src/views/SimulationRunView.vue +++ b/frontend/src/views/SimulationRunView.vue @@ -54,6 +54,7 @@ { } // 如果有自定义轮数,通过 query 参数传递 + const query = {} if (params.maxRounds) { - routeParams.query = { maxRounds: params.maxRounds } + query.maxRounds = params.maxRounds + } + if (params.selectedAgentIds && params.selectedAgentIds.length > 0) { + query.selectedAgentIds = params.selectedAgentIds.join(',') + } + if (Object.keys(query).length > 0) { + routeParams.query = query } // 跳转到 Step 3 页面