fix: wire agent group selection through to simulation start
Step 2 -> Step 3 -> Simulation API now passes selected_agent_ids.
Backend filters reddit_profiles.json and twitter_profiles.csv
to only include selected agents before starting simulation.
Flow: Step2 checkboxes -> emit('next-step', {selectedAgentIds}) ->
router query -> Step3 props -> startSimulation API -> filter profiles
This commit is contained in:
@@ -1448,6 +1448,41 @@ def generate_profiles():
|
|||||||
|
|
||||||
# ============== 模拟运行控制接口 ==============
|
# ============== 模拟运行控制接口 ==============
|
||||||
|
|
||||||
|
def _filter_simulation_agents(simulation_id: str, selected_agent_ids: list):
|
||||||
|
"""
|
||||||
|
Filter simulation profiles to only include selected agents.
|
||||||
|
Modifies the profile files in-place before simulation starts.
|
||||||
|
"""
|
||||||
|
import csv
|
||||||
|
sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id)
|
||||||
|
|
||||||
|
# Filter Reddit profiles (JSON)
|
||||||
|
reddit_path = os.path.join(sim_dir, 'reddit_profiles.json')
|
||||||
|
if os.path.exists(reddit_path):
|
||||||
|
with open(reddit_path, 'r', encoding='utf-8') as f:
|
||||||
|
profiles = json.load(f)
|
||||||
|
|
||||||
|
if isinstance(profiles, list):
|
||||||
|
filtered = [p for i, p in enumerate(profiles) if i in selected_agent_ids]
|
||||||
|
with open(reddit_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(filtered, f, ensure_ascii=False, indent=2)
|
||||||
|
logger.info(f"Reddit profiles: {len(profiles)} -> {len(filtered)}")
|
||||||
|
|
||||||
|
# Filter Twitter profiles (CSV)
|
||||||
|
twitter_path = os.path.join(sim_dir, 'twitter_profiles.csv')
|
||||||
|
if os.path.exists(twitter_path):
|
||||||
|
with open(twitter_path, 'r', encoding='utf-8') as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
rows = list(reader)
|
||||||
|
fieldnames = reader.fieldnames
|
||||||
|
|
||||||
|
filtered = [r for i, r in enumerate(rows) if i in selected_agent_ids]
|
||||||
|
with open(twitter_path, 'w', encoding='utf-8', newline='') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(filtered)
|
||||||
|
logger.info(f"Twitter profiles: {len(rows)} -> {len(filtered)}")
|
||||||
|
|
||||||
@simulation_bp.route('/start', methods=['POST'])
|
@simulation_bp.route('/start', methods=['POST'])
|
||||||
def start_simulation():
|
def start_simulation():
|
||||||
"""
|
"""
|
||||||
@@ -1503,6 +1538,7 @@ def start_simulation():
|
|||||||
max_rounds = data.get('max_rounds') # 可选:最大模拟轮数
|
max_rounds = data.get('max_rounds') # 可选:最大模拟轮数
|
||||||
enable_graph_memory_update = data.get('enable_graph_memory_update', False) # 可选:是否启用图谱记忆更新
|
enable_graph_memory_update = data.get('enable_graph_memory_update', False) # 可选:是否启用图谱记忆更新
|
||||||
force = data.get('force', False) # 可选:强制重新开始
|
force = data.get('force', False) # 可选:强制重新开始
|
||||||
|
selected_agent_ids = data.get('selected_agent_ids') # 可选:选中的Agent ID列表
|
||||||
|
|
||||||
# 验证 max_rounds 参数
|
# 验证 max_rounds 参数
|
||||||
if max_rounds is not None:
|
if max_rounds is not None:
|
||||||
@@ -1600,6 +1636,11 @@ def start_simulation():
|
|||||||
|
|
||||||
logger.info(f"启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}")
|
logger.info(f"启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}")
|
||||||
|
|
||||||
|
# Filter agents if selected_agent_ids provided
|
||||||
|
if selected_agent_ids:
|
||||||
|
logger.info(f"Filtering agents: selected {len(selected_agent_ids)} agents")
|
||||||
|
_filter_simulation_agents(simulation_id, selected_agent_ids)
|
||||||
|
|
||||||
# 启动模拟
|
# 启动模拟
|
||||||
run_state = SimulationRunner.start_simulation(
|
run_state = SimulationRunner.start_simulation(
|
||||||
simulation_id=simulation_id,
|
simulation_id=simulation_id,
|
||||||
|
|||||||
@@ -834,6 +834,12 @@ const handleStartSimulation = () => {
|
|||||||
addLog(t('log.startSimAutoRounds', { rounds: autoGeneratedRounds.value }))
|
addLog(t('log.startSimAutoRounds', { rounds: autoGeneratedRounds.value }))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pass selected agent IDs (filtered by group checkboxes)
|
||||||
|
params.selectedAgentIds = getSelectedAgentIds()
|
||||||
|
if (agentGroups.value.length > 0) {
|
||||||
|
addLog(`✅ ส่ง ${params.selectedAgentIds.length} agents ไปขั้นตอนถัดไป`)
|
||||||
|
}
|
||||||
|
|
||||||
emit('next-step', params)
|
emit('next-step', params)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -301,7 +301,8 @@ const { t } = useI18n()
|
|||||||
|
|
||||||
const props = defineProps({
|
const props = defineProps({
|
||||||
simulationId: String,
|
simulationId: String,
|
||||||
maxRounds: Number, // 从Step2传入的最大轮数
|
maxRounds: Number,
|
||||||
|
selectedAgentIds: Array, // Agent IDs selected from Step 2 grouping // 从Step2传入的最大轮数
|
||||||
minutesPerRound: {
|
minutesPerRound: {
|
||||||
type: Number,
|
type: Number,
|
||||||
default: 30 // 默认每轮30分钟
|
default: 30 // 默认每轮30分钟
|
||||||
@@ -407,6 +408,12 @@ const doStartSimulation = async () => {
|
|||||||
addLog(t('log.setMaxRounds', { rounds: props.maxRounds }))
|
addLog(t('log.setMaxRounds', { rounds: props.maxRounds }))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pass selected agent IDs from Step 2 grouping
|
||||||
|
if (props.selectedAgentIds && props.selectedAgentIds.length > 0) {
|
||||||
|
params.selected_agent_ids = props.selectedAgentIds
|
||||||
|
addLog(`✅ ใช้ ${props.selectedAgentIds.length} agents ที่เลือก`)
|
||||||
|
}
|
||||||
|
|
||||||
addLog(t('log.graphMemoryUpdateEnabled'))
|
addLog(t('log.graphMemoryUpdateEnabled'))
|
||||||
|
|
||||||
const res = await startSimulation(params)
|
const res = await startSimulation(params)
|
||||||
|
|||||||
@@ -54,6 +54,7 @@
|
|||||||
<Step3Simulation
|
<Step3Simulation
|
||||||
:simulationId="currentSimulationId"
|
:simulationId="currentSimulationId"
|
||||||
:maxRounds="maxRounds"
|
:maxRounds="maxRounds"
|
||||||
|
:selectedAgentIds="selectedAgentIds"
|
||||||
:minutesPerRound="minutesPerRound"
|
:minutesPerRound="minutesPerRound"
|
||||||
:projectData="projectData"
|
:projectData="projectData"
|
||||||
:graphData="graphData"
|
:graphData="graphData"
|
||||||
@@ -94,6 +95,7 @@ const viewMode = ref('split')
|
|||||||
const currentSimulationId = ref(route.params.simulationId)
|
const currentSimulationId = ref(route.params.simulationId)
|
||||||
// 直接在初始化时从 query 参数获取 maxRounds,确保子组件能立即获取到值
|
// 直接在初始化时从 query 参数获取 maxRounds,确保子组件能立即获取到值
|
||||||
const maxRounds = ref(route.query.maxRounds ? parseInt(route.query.maxRounds) : null)
|
const maxRounds = ref(route.query.maxRounds ? parseInt(route.query.maxRounds) : null)
|
||||||
|
const selectedAgentIds = ref(route.query.selectedAgentIds ? route.query.selectedAgentIds.split(',').map(Number) : null)
|
||||||
const minutesPerRound = ref(30) // 默认每轮30分钟
|
const minutesPerRound = ref(30) // 默认每轮30分钟
|
||||||
const projectData = ref(null)
|
const projectData = ref(null)
|
||||||
const graphData = ref(null)
|
const graphData = ref(null)
|
||||||
|
|||||||
@@ -167,8 +167,15 @@ const handleNextStep = (params = {}) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 如果有自定义轮数,通过 query 参数传递
|
// 如果有自定义轮数,通过 query 参数传递
|
||||||
|
const query = {}
|
||||||
if (params.maxRounds) {
|
if (params.maxRounds) {
|
||||||
routeParams.query = { maxRounds: params.maxRounds }
|
query.maxRounds = params.maxRounds
|
||||||
|
}
|
||||||
|
if (params.selectedAgentIds && params.selectedAgentIds.length > 0) {
|
||||||
|
query.selectedAgentIds = params.selectedAgentIds.join(',')
|
||||||
|
}
|
||||||
|
if (Object.keys(query).length > 0) {
|
||||||
|
routeParams.query = query
|
||||||
}
|
}
|
||||||
|
|
||||||
// 跳转到 Step 3 页面
|
// 跳转到 Step 3 页面
|
||||||
|
|||||||
Reference in New Issue
Block a user