Implement Interview feature for agent interactions in simulations

- Added a new Interview module to facilitate interactions with agents post-simulation, allowing for single and batch interviews.
- Introduced IPC communication mechanism for command and response handling between the Flask backend and simulation scripts.
- Updated README.md to include detailed instructions on the new Interview functionality, including API endpoints and usage examples.
- Enhanced simulation scripts to support waiting for commands after completion, improving user control over the simulation environment.
- Implemented error handling and logging for interview processes, ensuring robust operation and traceability.
This commit is contained in:
666ghj
2025-12-08 15:55:39 +08:00
parent 29bff9ca27
commit 1042d50306
8 changed files with 2963 additions and 70 deletions

View File

@@ -28,6 +28,14 @@ from .zep_graph_memory_updater import (
ZepGraphMemoryManager,
AgentActivity
)
from .simulation_ipc import (
SimulationIPCClient,
SimulationIPCServer,
IPCCommand,
IPCResponse,
CommandType,
CommandStatus
)
__all__ = [
'OntologyGenerator',
@@ -55,5 +63,11 @@ __all__ = [
'ZepGraphMemoryUpdater',
'ZepGraphMemoryManager',
'AgentActivity',
'SimulationIPCClient',
'SimulationIPCServer',
'IPCCommand',
'IPCResponse',
'CommandType',
'CommandStatus',
]

View File

@@ -0,0 +1,394 @@
"""
模拟IPC通信模块
用于Flask后端和模拟脚本之间的进程间通信
通过文件系统实现简单的命令/响应模式:
1. Flask写入命令到 commands/ 目录
2. 模拟脚本轮询命令目录,执行命令并写入响应到 responses/ 目录
3. Flask轮询响应目录获取结果
"""
import os
import json
import time
import uuid
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from ..utils.logger import get_logger
logger = get_logger('mirofish.simulation_ipc')
class CommandType(str, Enum):
"""命令类型"""
INTERVIEW = "interview" # 单个Agent采访
BATCH_INTERVIEW = "batch_interview" # 批量采访
CLOSE_ENV = "close_env" # 关闭环境
class CommandStatus(str, Enum):
"""命令状态"""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class IPCCommand:
"""IPC命令"""
command_id: str
command_type: CommandType
args: Dict[str, Any]
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
return {
"command_id": self.command_id,
"command_type": self.command_type.value,
"args": self.args,
"timestamp": self.timestamp
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'IPCCommand':
return cls(
command_id=data["command_id"],
command_type=CommandType(data["command_type"]),
args=data.get("args", {}),
timestamp=data.get("timestamp", datetime.now().isoformat())
)
@dataclass
class IPCResponse:
"""IPC响应"""
command_id: str
status: CommandStatus
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
return {
"command_id": self.command_id,
"status": self.status.value,
"result": self.result,
"error": self.error,
"timestamp": self.timestamp
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'IPCResponse':
return cls(
command_id=data["command_id"],
status=CommandStatus(data["status"]),
result=data.get("result"),
error=data.get("error"),
timestamp=data.get("timestamp", datetime.now().isoformat())
)
class SimulationIPCClient:
"""
模拟IPC客户端Flask端使用
用于向模拟进程发送命令并等待响应
"""
def __init__(self, simulation_dir: str):
"""
初始化IPC客户端
Args:
simulation_dir: 模拟数据目录
"""
self.simulation_dir = simulation_dir
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
# 确保目录存在
os.makedirs(self.commands_dir, exist_ok=True)
os.makedirs(self.responses_dir, exist_ok=True)
def send_command(
self,
command_type: CommandType,
args: Dict[str, Any],
timeout: float = 60.0,
poll_interval: float = 0.5
) -> IPCResponse:
"""
发送命令并等待响应
Args:
command_type: 命令类型
args: 命令参数
timeout: 超时时间(秒)
poll_interval: 轮询间隔(秒)
Returns:
IPCResponse
Raises:
TimeoutError: 等待响应超时
"""
command_id = str(uuid.uuid4())
command = IPCCommand(
command_id=command_id,
command_type=command_type,
args=args
)
# 写入命令文件
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
with open(command_file, 'w', encoding='utf-8') as f:
json.dump(command.to_dict(), f, ensure_ascii=False, indent=2)
logger.info(f"发送IPC命令: {command_type.value}, command_id={command_id}")
# 等待响应
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
start_time = time.time()
while time.time() - start_time < timeout:
if os.path.exists(response_file):
try:
with open(response_file, 'r', encoding='utf-8') as f:
response_data = json.load(f)
response = IPCResponse.from_dict(response_data)
# 清理命令和响应文件
try:
os.remove(command_file)
os.remove(response_file)
except OSError:
pass
logger.info(f"收到IPC响应: command_id={command_id}, status={response.status.value}")
return response
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"解析响应失败: {e}")
time.sleep(poll_interval)
# 超时
logger.error(f"等待IPC响应超时: command_id={command_id}")
# 清理命令文件
try:
os.remove(command_file)
except OSError:
pass
raise TimeoutError(f"等待命令响应超时 ({timeout}秒)")
def send_interview(
self,
agent_id: int,
prompt: str,
platform: str = None,
timeout: float = 60.0
) -> IPCResponse:
"""
发送单个Agent采访命令
Args:
agent_id: Agent ID
prompt: 采访问题
platform: 指定平台(可选)
- "twitter": 只采访Twitter平台
- "reddit": 只采访Reddit平台
- None: 双平台模拟时同时采访两个平台,单平台模拟时采访该平台
timeout: 超时时间
Returns:
IPCResponseresult字段包含采访结果
"""
args = {
"agent_id": agent_id,
"prompt": prompt
}
if platform:
args["platform"] = platform
return self.send_command(
command_type=CommandType.INTERVIEW,
args=args,
timeout=timeout
)
def send_batch_interview(
self,
interviews: List[Dict[str, Any]],
platform: str = None,
timeout: float = 120.0
) -> IPCResponse:
"""
发送批量采访命令
Args:
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
platform: 默认平台可选会被每个采访项的platform覆盖
- "twitter": 默认只采访Twitter平台
- "reddit": 默认只采访Reddit平台
- None: 双平台模拟时每个Agent同时采访两个平台
timeout: 超时时间
Returns:
IPCResponseresult字段包含所有采访结果
"""
args = {"interviews": interviews}
if platform:
args["platform"] = platform
return self.send_command(
command_type=CommandType.BATCH_INTERVIEW,
args=args,
timeout=timeout
)
def send_close_env(self, timeout: float = 30.0) -> IPCResponse:
"""
发送关闭环境命令
Args:
timeout: 超时时间
Returns:
IPCResponse
"""
return self.send_command(
command_type=CommandType.CLOSE_ENV,
args={},
timeout=timeout
)
def check_env_alive(self) -> bool:
"""
检查模拟环境是否存活
通过检查 env_status.json 文件来判断
"""
status_file = os.path.join(self.simulation_dir, "env_status.json")
if not os.path.exists(status_file):
return False
try:
with open(status_file, 'r', encoding='utf-8') as f:
status = json.load(f)
return status.get("status") == "alive"
except (json.JSONDecodeError, OSError):
return False
class SimulationIPCServer:
"""
模拟IPC服务器模拟脚本端使用
轮询命令目录,执行命令并返回响应
"""
def __init__(self, simulation_dir: str):
"""
初始化IPC服务器
Args:
simulation_dir: 模拟数据目录
"""
self.simulation_dir = simulation_dir
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
# 确保目录存在
os.makedirs(self.commands_dir, exist_ok=True)
os.makedirs(self.responses_dir, exist_ok=True)
# 环境状态
self._running = False
def start(self):
"""标记服务器为运行状态"""
self._running = True
self._update_env_status("alive")
def stop(self):
"""标记服务器为停止状态"""
self._running = False
self._update_env_status("stopped")
def _update_env_status(self, status: str):
"""更新环境状态文件"""
status_file = os.path.join(self.simulation_dir, "env_status.json")
with open(status_file, 'w', encoding='utf-8') as f:
json.dump({
"status": status,
"timestamp": datetime.now().isoformat()
}, f, ensure_ascii=False, indent=2)
def poll_commands(self) -> Optional[IPCCommand]:
"""
轮询命令目录,返回第一个待处理的命令
Returns:
IPCCommand 或 None
"""
if not os.path.exists(self.commands_dir):
return None
# 按时间排序获取命令文件
command_files = []
for filename in os.listdir(self.commands_dir):
if filename.endswith('.json'):
filepath = os.path.join(self.commands_dir, filename)
command_files.append((filepath, os.path.getmtime(filepath)))
command_files.sort(key=lambda x: x[1])
for filepath, _ in command_files:
try:
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
return IPCCommand.from_dict(data)
except (json.JSONDecodeError, KeyError, OSError) as e:
logger.warning(f"读取命令文件失败: {filepath}, {e}")
continue
return None
def send_response(self, response: IPCResponse):
"""
发送响应
Args:
response: IPC响应
"""
response_file = os.path.join(self.responses_dir, f"{response.command_id}.json")
with open(response_file, 'w', encoding='utf-8') as f:
json.dump(response.to_dict(), f, ensure_ascii=False, indent=2)
# 删除命令文件
command_file = os.path.join(self.commands_dir, f"{response.command_id}.json")
try:
os.remove(command_file)
except OSError:
pass
def send_success(self, command_id: str, result: Dict[str, Any]):
"""发送成功响应"""
self.send_response(IPCResponse(
command_id=command_id,
status=CommandStatus.COMPLETED,
result=result
))
def send_error(self, command_id: str, error: str):
"""发送错误响应"""
self.send_response(IPCResponse(
command_id=command_id,
status=CommandStatus.FAILED,
error=error
))

View File

@@ -12,7 +12,7 @@ import threading
import subprocess
import signal
import atexit
from typing import Dict, Any, List, Optional
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
@@ -21,6 +21,7 @@ from queue import Queue
from ..config import Config
from ..utils.logger import get_logger
from .zep_graph_memory_updater import ZepGraphMemoryManager
from .simulation_ipc import SimulationIPCClient, CommandType, IPCResponse
logger = get_logger('mirofish.simulation_runner')
@@ -989,4 +990,365 @@ class SimulationRunner:
if process.poll() is None:
running.append(sim_id)
return running
# ============== Interview 功能 ==============
@classmethod
def check_env_alive(cls, simulation_id: str) -> bool:
"""
检查模拟环境是否存活可以接收Interview命令
Args:
simulation_id: 模拟ID
Returns:
True 表示环境存活False 表示环境已关闭
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
if not os.path.exists(sim_dir):
return False
ipc_client = SimulationIPCClient(sim_dir)
return ipc_client.check_env_alive()
@classmethod
def get_env_status_detail(cls, simulation_id: str) -> Dict[str, Any]:
"""
获取模拟环境的详细状态信息
Args:
simulation_id: 模拟ID
Returns:
状态详情字典,包含 status, twitter_available, reddit_available, timestamp
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
status_file = os.path.join(sim_dir, "env_status.json")
default_status = {
"status": "stopped",
"twitter_available": False,
"reddit_available": False,
"timestamp": None
}
if not os.path.exists(status_file):
return default_status
try:
with open(status_file, 'r', encoding='utf-8') as f:
status = json.load(f)
return {
"status": status.get("status", "stopped"),
"twitter_available": status.get("twitter_available", False),
"reddit_available": status.get("reddit_available", False),
"timestamp": status.get("timestamp")
}
except (json.JSONDecodeError, OSError):
return default_status
@classmethod
def interview_agent(
cls,
simulation_id: str,
agent_id: int,
prompt: str,
platform: str = None,
timeout: float = 60.0
) -> Dict[str, Any]:
"""
采访单个Agent
Args:
simulation_id: 模拟ID
agent_id: Agent ID
prompt: 采访问题
platform: 指定平台(可选)
- "twitter": 只采访Twitter平台
- "reddit": 只采访Reddit平台
- None: 双平台模拟时同时采访两个平台,返回整合结果
timeout: 超时时间(秒)
Returns:
采访结果字典
Raises:
ValueError: 模拟不存在或环境未运行
TimeoutError: 等待响应超时
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
if not os.path.exists(sim_dir):
raise ValueError(f"模拟不存在: {simulation_id}")
ipc_client = SimulationIPCClient(sim_dir)
if not ipc_client.check_env_alive():
raise ValueError(f"模拟环境未运行或已关闭无法执行Interview: {simulation_id}")
logger.info(f"发送Interview命令: simulation_id={simulation_id}, agent_id={agent_id}, platform={platform}")
response = ipc_client.send_interview(
agent_id=agent_id,
prompt=prompt,
platform=platform,
timeout=timeout
)
if response.status.value == "completed":
return {
"success": True,
"agent_id": agent_id,
"prompt": prompt,
"result": response.result,
"timestamp": response.timestamp
}
else:
return {
"success": False,
"agent_id": agent_id,
"prompt": prompt,
"error": response.error,
"timestamp": response.timestamp
}
@classmethod
def interview_agents_batch(
cls,
simulation_id: str,
interviews: List[Dict[str, Any]],
platform: str = None,
timeout: float = 120.0
) -> Dict[str, Any]:
"""
批量采访多个Agent
Args:
simulation_id: 模拟ID
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
platform: 默认平台可选会被每个采访项的platform覆盖
- "twitter": 默认只采访Twitter平台
- "reddit": 默认只采访Reddit平台
- None: 双平台模拟时每个Agent同时采访两个平台
timeout: 超时时间(秒)
Returns:
批量采访结果字典
Raises:
ValueError: 模拟不存在或环境未运行
TimeoutError: 等待响应超时
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
if not os.path.exists(sim_dir):
raise ValueError(f"模拟不存在: {simulation_id}")
ipc_client = SimulationIPCClient(sim_dir)
if not ipc_client.check_env_alive():
raise ValueError(f"模拟环境未运行或已关闭无法执行Interview: {simulation_id}")
logger.info(f"发送批量Interview命令: simulation_id={simulation_id}, count={len(interviews)}, platform={platform}")
response = ipc_client.send_batch_interview(
interviews=interviews,
platform=platform,
timeout=timeout
)
if response.status.value == "completed":
return {
"success": True,
"interviews_count": len(interviews),
"result": response.result,
"timestamp": response.timestamp
}
else:
return {
"success": False,
"interviews_count": len(interviews),
"error": response.error,
"timestamp": response.timestamp
}
@classmethod
def interview_all_agents(
cls,
simulation_id: str,
prompt: str,
platform: str = None,
timeout: float = 180.0
) -> Dict[str, Any]:
"""
采访所有Agent全局采访
使用相同的问题采访模拟中的所有Agent
Args:
simulation_id: 模拟ID
prompt: 采访问题所有Agent使用相同问题
platform: 指定平台(可选)
- "twitter": 只采访Twitter平台
- "reddit": 只采访Reddit平台
- None: 双平台模拟时每个Agent同时采访两个平台
timeout: 超时时间(秒)
Returns:
全局采访结果字典
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
if not os.path.exists(sim_dir):
raise ValueError(f"模拟不存在: {simulation_id}")
# 从配置文件获取所有Agent信息
config_path = os.path.join(sim_dir, "simulation_config.json")
if not os.path.exists(config_path):
raise ValueError(f"模拟配置不存在: {simulation_id}")
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
agent_configs = config.get("agent_configs", [])
if not agent_configs:
raise ValueError(f"模拟配置中没有Agent: {simulation_id}")
# 构建批量采访列表
interviews = []
for agent_config in agent_configs:
agent_id = agent_config.get("agent_id")
if agent_id is not None:
interviews.append({
"agent_id": agent_id,
"prompt": prompt
})
logger.info(f"发送全局Interview命令: simulation_id={simulation_id}, agent_count={len(interviews)}, platform={platform}")
return cls.interview_agents_batch(
simulation_id=simulation_id,
interviews=interviews,
platform=platform,
timeout=timeout
)
@classmethod
def close_simulation_env(
cls,
simulation_id: str,
timeout: float = 30.0
) -> Dict[str, Any]:
"""
关闭模拟环境(而不是停止模拟进程)
向模拟发送关闭环境命令,使其优雅退出等待命令模式
Args:
simulation_id: 模拟ID
timeout: 超时时间(秒)
Returns:
操作结果字典
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
if not os.path.exists(sim_dir):
raise ValueError(f"模拟不存在: {simulation_id}")
ipc_client = SimulationIPCClient(sim_dir)
if not ipc_client.check_env_alive():
return {
"success": True,
"message": "环境已经关闭"
}
logger.info(f"发送关闭环境命令: simulation_id={simulation_id}")
try:
response = ipc_client.send_close_env(timeout=timeout)
return {
"success": response.status.value == "completed",
"message": "环境关闭命令已发送",
"result": response.result,
"timestamp": response.timestamp
}
except TimeoutError:
# 超时可能是因为环境正在关闭
return {
"success": True,
"message": "环境关闭命令已发送(等待响应超时,环境可能正在关闭)"
}
@classmethod
def get_interview_history(
cls,
simulation_id: str,
platform: str = "reddit",
agent_id: Optional[int] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
获取Interview历史记录从数据库读取
Args:
simulation_id: 模拟ID
platform: 平台类型reddit/twitter
agent_id: 过滤Agent ID可选
limit: 返回数量限制
Returns:
Interview历史记录列表
"""
import sqlite3
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
db_path = os.path.join(sim_dir, f"{platform}_simulation.db")
if not os.path.exists(db_path):
return []
results = []
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# 构建查询
# 注意ActionType.INTERVIEW.value 应该是字符串形式
if agent_id is not None:
cursor.execute("""
SELECT user_id, info, created_at
FROM trace
WHERE action = 'interview' AND user_id = ?
ORDER BY created_at DESC
LIMIT ?
""", (agent_id, limit))
else:
cursor.execute("""
SELECT user_id, info, created_at
FROM trace
WHERE action = 'interview'
ORDER BY created_at DESC
LIMIT ?
""", (limit,))
for user_id, info_json, created_at in cursor.fetchall():
try:
info = json.loads(info_json) if info_json else {}
except json.JSONDecodeError:
info = {"raw": info_json}
results.append({
"agent_id": user_id,
"response": info.get("response", info),
"prompt": info.get("prompt", ""),
"timestamp": created_at,
"platform": platform
})
conn.close()
except Exception as e:
logger.error(f"读取Interview历史失败: {e}")
return results