Implement Interview feature for agent interactions in simulations
- Added a new Interview module to facilitate interactions with agents post-simulation, allowing for single and batch interviews. - Introduced IPC communication mechanism for command and response handling between the Flask backend and simulation scripts. - Updated README.md to include detailed instructions on the new Interview functionality, including API endpoints and usage examples. - Enhanced simulation scripts to support waiting for commands after completion, improving user control over the simulation environment. - Implemented error handling and logging for interview processes, ensuring robust operation and traceability.
This commit is contained in:
@@ -12,7 +12,7 @@ import threading
|
||||
import subprocess
|
||||
import signal
|
||||
import atexit
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -21,6 +21,7 @@ from queue import Queue
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_graph_memory_updater import ZepGraphMemoryManager
|
||||
from .simulation_ipc import SimulationIPCClient, CommandType, IPCResponse
|
||||
|
||||
logger = get_logger('mirofish.simulation_runner')
|
||||
|
||||
@@ -989,4 +990,365 @@ class SimulationRunner:
|
||||
if process.poll() is None:
|
||||
running.append(sim_id)
|
||||
return running
|
||||
|
||||
# ============== Interview 功能 ==============
|
||||
|
||||
@classmethod
|
||||
def check_env_alive(cls, simulation_id: str) -> bool:
|
||||
"""
|
||||
检查模拟环境是否存活(可以接收Interview命令)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
|
||||
Returns:
|
||||
True 表示环境存活,False 表示环境已关闭
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
return False
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
return ipc_client.check_env_alive()
|
||||
|
||||
@classmethod
|
||||
def get_env_status_detail(cls, simulation_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取模拟环境的详细状态信息
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
|
||||
Returns:
|
||||
状态详情字典,包含 status, twitter_available, reddit_available, timestamp
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
status_file = os.path.join(sim_dir, "env_status.json")
|
||||
|
||||
default_status = {
|
||||
"status": "stopped",
|
||||
"twitter_available": False,
|
||||
"reddit_available": False,
|
||||
"timestamp": None
|
||||
}
|
||||
|
||||
if not os.path.exists(status_file):
|
||||
return default_status
|
||||
|
||||
try:
|
||||
with open(status_file, 'r', encoding='utf-8') as f:
|
||||
status = json.load(f)
|
||||
return {
|
||||
"status": status.get("status", "stopped"),
|
||||
"twitter_available": status.get("twitter_available", False),
|
||||
"reddit_available": status.get("reddit_available", False),
|
||||
"timestamp": status.get("timestamp")
|
||||
}
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return default_status
|
||||
|
||||
@classmethod
|
||||
def interview_agent(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
agent_id: int,
|
||||
prompt: str,
|
||||
platform: str = None,
|
||||
timeout: float = 60.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
采访单个Agent
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
agent_id: Agent ID
|
||||
prompt: 采访问题
|
||||
platform: 指定平台(可选)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None: 双平台模拟时同时采访两个平台,返回整合结果
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
采访结果字典
|
||||
|
||||
Raises:
|
||||
ValueError: 模拟不存在或环境未运行
|
||||
TimeoutError: 等待响应超时
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
|
||||
if not ipc_client.check_env_alive():
|
||||
raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}")
|
||||
|
||||
logger.info(f"发送Interview命令: simulation_id={simulation_id}, agent_id={agent_id}, platform={platform}")
|
||||
|
||||
response = ipc_client.send_interview(
|
||||
agent_id=agent_id,
|
||||
prompt=prompt,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if response.status.value == "completed":
|
||||
return {
|
||||
"success": True,
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt,
|
||||
"result": response.result,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt,
|
||||
"error": response.error,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def interview_agents_batch(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
interviews: List[Dict[str, Any]],
|
||||
platform: str = None,
|
||||
timeout: float = 120.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
批量采访多个Agent
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
|
||||
platform: 默认平台(可选,会被每个采访项的platform覆盖)
|
||||
- "twitter": 默认只采访Twitter平台
|
||||
- "reddit": 默认只采访Reddit平台
|
||||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
批量采访结果字典
|
||||
|
||||
Raises:
|
||||
ValueError: 模拟不存在或环境未运行
|
||||
TimeoutError: 等待响应超时
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
|
||||
if not ipc_client.check_env_alive():
|
||||
raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}")
|
||||
|
||||
logger.info(f"发送批量Interview命令: simulation_id={simulation_id}, count={len(interviews)}, platform={platform}")
|
||||
|
||||
response = ipc_client.send_batch_interview(
|
||||
interviews=interviews,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if response.status.value == "completed":
|
||||
return {
|
||||
"success": True,
|
||||
"interviews_count": len(interviews),
|
||||
"result": response.result,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"interviews_count": len(interviews),
|
||||
"error": response.error,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def interview_all_agents(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
prompt: str,
|
||||
platform: str = None,
|
||||
timeout: float = 180.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
采访所有Agent(全局采访)
|
||||
|
||||
使用相同的问题采访模拟中的所有Agent
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
prompt: 采访问题(所有Agent使用相同问题)
|
||||
platform: 指定平台(可选)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
全局采访结果字典
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
# 从配置文件获取所有Agent信息
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(f"模拟配置不存在: {simulation_id}")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
agent_configs = config.get("agent_configs", [])
|
||||
if not agent_configs:
|
||||
raise ValueError(f"模拟配置中没有Agent: {simulation_id}")
|
||||
|
||||
# 构建批量采访列表
|
||||
interviews = []
|
||||
for agent_config in agent_configs:
|
||||
agent_id = agent_config.get("agent_id")
|
||||
if agent_id is not None:
|
||||
interviews.append({
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt
|
||||
})
|
||||
|
||||
logger.info(f"发送全局Interview命令: simulation_id={simulation_id}, agent_count={len(interviews)}, platform={platform}")
|
||||
|
||||
return cls.interview_agents_batch(
|
||||
simulation_id=simulation_id,
|
||||
interviews=interviews,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def close_simulation_env(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
timeout: float = 30.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
关闭模拟环境(而不是停止模拟进程)
|
||||
|
||||
向模拟发送关闭环境命令,使其优雅退出等待命令模式
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
操作结果字典
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
|
||||
if not ipc_client.check_env_alive():
|
||||
return {
|
||||
"success": True,
|
||||
"message": "环境已经关闭"
|
||||
}
|
||||
|
||||
logger.info(f"发送关闭环境命令: simulation_id={simulation_id}")
|
||||
|
||||
try:
|
||||
response = ipc_client.send_close_env(timeout=timeout)
|
||||
|
||||
return {
|
||||
"success": response.status.value == "completed",
|
||||
"message": "环境关闭命令已发送",
|
||||
"result": response.result,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
except TimeoutError:
|
||||
# 超时可能是因为环境正在关闭
|
||||
return {
|
||||
"success": True,
|
||||
"message": "环境关闭命令已发送(等待响应超时,环境可能正在关闭)"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_interview_history(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
platform: str = "reddit",
|
||||
agent_id: Optional[int] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取Interview历史记录(从数据库读取)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
platform: 平台类型(reddit/twitter)
|
||||
agent_id: 过滤Agent ID(可选)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
Interview历史记录列表
|
||||
"""
|
||||
import sqlite3
|
||||
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
db_path = os.path.join(sim_dir, f"{platform}_simulation.db")
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return []
|
||||
|
||||
results = []
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 构建查询
|
||||
# 注意:ActionType.INTERVIEW.value 应该是字符串形式
|
||||
if agent_id is not None:
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = 'interview' AND user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
""", (agent_id, limit))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = 'interview'
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
""", (limit,))
|
||||
|
||||
for user_id, info_json, created_at in cursor.fetchall():
|
||||
try:
|
||||
info = json.loads(info_json) if info_json else {}
|
||||
except json.JSONDecodeError:
|
||||
info = {"raw": info_json}
|
||||
|
||||
results.append({
|
||||
"agent_id": user_id,
|
||||
"response": info.get("response", info),
|
||||
"prompt": info.get("prompt", ""),
|
||||
"timestamp": created_at,
|
||||
"platform": platform
|
||||
})
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取Interview历史失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user