Implement Interview feature for agent interactions in simulations
- Added a new Interview module to facilitate interactions with agents post-simulation, allowing for single and batch interviews. - Introduced IPC communication mechanism for command and response handling between the Flask backend and simulation scripts. - Updated README.md to include detailed instructions on the new Interview functionality, including API endpoints and usage examples. - Enhanced simulation scripts to support waiting for commands after completion, improving user control over the simulation environment. - Implemented error handling and logging for interview processes, ensuring robust operation and traceability.
This commit is contained in:
@@ -28,6 +28,14 @@ from .zep_graph_memory_updater import (
|
||||
ZepGraphMemoryManager,
|
||||
AgentActivity
|
||||
)
|
||||
from .simulation_ipc import (
|
||||
SimulationIPCClient,
|
||||
SimulationIPCServer,
|
||||
IPCCommand,
|
||||
IPCResponse,
|
||||
CommandType,
|
||||
CommandStatus
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'OntologyGenerator',
|
||||
@@ -55,5 +63,11 @@ __all__ = [
|
||||
'ZepGraphMemoryUpdater',
|
||||
'ZepGraphMemoryManager',
|
||||
'AgentActivity',
|
||||
'SimulationIPCClient',
|
||||
'SimulationIPCServer',
|
||||
'IPCCommand',
|
||||
'IPCResponse',
|
||||
'CommandType',
|
||||
'CommandStatus',
|
||||
]
|
||||
|
||||
|
||||
394
backend/app/services/simulation_ipc.py
Normal file
394
backend/app/services/simulation_ipc.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
模拟IPC通信模块
|
||||
用于Flask后端和模拟脚本之间的进程间通信
|
||||
|
||||
通过文件系统实现简单的命令/响应模式:
|
||||
1. Flask写入命令到 commands/ 目录
|
||||
2. 模拟脚本轮询命令目录,执行命令并写入响应到 responses/ 目录
|
||||
3. Flask轮询响应目录获取结果
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.simulation_ipc')
|
||||
|
||||
|
||||
class CommandType(str, Enum):
|
||||
"""命令类型"""
|
||||
INTERVIEW = "interview" # 单个Agent采访
|
||||
BATCH_INTERVIEW = "batch_interview" # 批量采访
|
||||
CLOSE_ENV = "close_env" # 关闭环境
|
||||
|
||||
|
||||
class CommandStatus(str, Enum):
|
||||
"""命令状态"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCCommand:
|
||||
"""IPC命令"""
|
||||
command_id: str
|
||||
command_type: CommandType
|
||||
args: Dict[str, Any]
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"command_id": self.command_id,
|
||||
"command_type": self.command_type.value,
|
||||
"args": self.args,
|
||||
"timestamp": self.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'IPCCommand':
|
||||
return cls(
|
||||
command_id=data["command_id"],
|
||||
command_type=CommandType(data["command_type"]),
|
||||
args=data.get("args", {}),
|
||||
timestamp=data.get("timestamp", datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCResponse:
|
||||
"""IPC响应"""
|
||||
command_id: str
|
||||
status: CommandStatus
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"command_id": self.command_id,
|
||||
"status": self.status.value,
|
||||
"result": self.result,
|
||||
"error": self.error,
|
||||
"timestamp": self.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'IPCResponse':
|
||||
return cls(
|
||||
command_id=data["command_id"],
|
||||
status=CommandStatus(data["status"]),
|
||||
result=data.get("result"),
|
||||
error=data.get("error"),
|
||||
timestamp=data.get("timestamp", datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
class SimulationIPCClient:
|
||||
"""
|
||||
模拟IPC客户端(Flask端使用)
|
||||
|
||||
用于向模拟进程发送命令并等待响应
|
||||
"""
|
||||
|
||||
def __init__(self, simulation_dir: str):
|
||||
"""
|
||||
初始化IPC客户端
|
||||
|
||||
Args:
|
||||
simulation_dir: 模拟数据目录
|
||||
"""
|
||||
self.simulation_dir = simulation_dir
|
||||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
def send_command(
|
||||
self,
|
||||
command_type: CommandType,
|
||||
args: Dict[str, Any],
|
||||
timeout: float = 60.0,
|
||||
poll_interval: float = 0.5
|
||||
) -> IPCResponse:
|
||||
"""
|
||||
发送命令并等待响应
|
||||
|
||||
Args:
|
||||
command_type: 命令类型
|
||||
args: 命令参数
|
||||
timeout: 超时时间(秒)
|
||||
poll_interval: 轮询间隔(秒)
|
||||
|
||||
Returns:
|
||||
IPCResponse
|
||||
|
||||
Raises:
|
||||
TimeoutError: 等待响应超时
|
||||
"""
|
||||
command_id = str(uuid.uuid4())
|
||||
command = IPCCommand(
|
||||
command_id=command_id,
|
||||
command_type=command_type,
|
||||
args=args
|
||||
)
|
||||
|
||||
# 写入命令文件
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
with open(command_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(command.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"发送IPC命令: {command_type.value}, command_id={command_id}")
|
||||
|
||||
# 等待响应
|
||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
if os.path.exists(response_file):
|
||||
try:
|
||||
with open(response_file, 'r', encoding='utf-8') as f:
|
||||
response_data = json.load(f)
|
||||
response = IPCResponse.from_dict(response_data)
|
||||
|
||||
# 清理命令和响应文件
|
||||
try:
|
||||
os.remove(command_file)
|
||||
os.remove(response_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
logger.info(f"收到IPC响应: command_id={command_id}, status={response.status.value}")
|
||||
return response
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"解析响应失败: {e}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
# 超时
|
||||
logger.error(f"等待IPC响应超时: command_id={command_id}")
|
||||
|
||||
# 清理命令文件
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
raise TimeoutError(f"等待命令响应超时 ({timeout}秒)")
|
||||
|
||||
def send_interview(
|
||||
self,
|
||||
agent_id: int,
|
||||
prompt: str,
|
||||
platform: str = None,
|
||||
timeout: float = 60.0
|
||||
) -> IPCResponse:
|
||||
"""
|
||||
发送单个Agent采访命令
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
prompt: 采访问题
|
||||
platform: 指定平台(可选)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None: 双平台模拟时同时采访两个平台,单平台模拟时采访该平台
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
IPCResponse,result字段包含采访结果
|
||||
"""
|
||||
args = {
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt
|
||||
}
|
||||
if platform:
|
||||
args["platform"] = platform
|
||||
|
||||
return self.send_command(
|
||||
command_type=CommandType.INTERVIEW,
|
||||
args=args,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
def send_batch_interview(
|
||||
self,
|
||||
interviews: List[Dict[str, Any]],
|
||||
platform: str = None,
|
||||
timeout: float = 120.0
|
||||
) -> IPCResponse:
|
||||
"""
|
||||
发送批量采访命令
|
||||
|
||||
Args:
|
||||
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
|
||||
platform: 默认平台(可选,会被每个采访项的platform覆盖)
|
||||
- "twitter": 默认只采访Twitter平台
|
||||
- "reddit": 默认只采访Reddit平台
|
||||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
IPCResponse,result字段包含所有采访结果
|
||||
"""
|
||||
args = {"interviews": interviews}
|
||||
if platform:
|
||||
args["platform"] = platform
|
||||
|
||||
return self.send_command(
|
||||
command_type=CommandType.BATCH_INTERVIEW,
|
||||
args=args,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
def send_close_env(self, timeout: float = 30.0) -> IPCResponse:
|
||||
"""
|
||||
发送关闭环境命令
|
||||
|
||||
Args:
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
IPCResponse
|
||||
"""
|
||||
return self.send_command(
|
||||
command_type=CommandType.CLOSE_ENV,
|
||||
args={},
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
def check_env_alive(self) -> bool:
|
||||
"""
|
||||
检查模拟环境是否存活
|
||||
|
||||
通过检查 env_status.json 文件来判断
|
||||
"""
|
||||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||||
if not os.path.exists(status_file):
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(status_file, 'r', encoding='utf-8') as f:
|
||||
status = json.load(f)
|
||||
return status.get("status") == "alive"
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
class SimulationIPCServer:
|
||||
"""
|
||||
模拟IPC服务器(模拟脚本端使用)
|
||||
|
||||
轮询命令目录,执行命令并返回响应
|
||||
"""
|
||||
|
||||
def __init__(self, simulation_dir: str):
|
||||
"""
|
||||
初始化IPC服务器
|
||||
|
||||
Args:
|
||||
simulation_dir: 模拟数据目录
|
||||
"""
|
||||
self.simulation_dir = simulation_dir
|
||||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
# 环境状态
|
||||
self._running = False
|
||||
|
||||
def start(self):
|
||||
"""标记服务器为运行状态"""
|
||||
self._running = True
|
||||
self._update_env_status("alive")
|
||||
|
||||
def stop(self):
|
||||
"""标记服务器为停止状态"""
|
||||
self._running = False
|
||||
self._update_env_status("stopped")
|
||||
|
||||
def _update_env_status(self, status: str):
|
||||
"""更新环境状态文件"""
|
||||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||||
with open(status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def poll_commands(self) -> Optional[IPCCommand]:
|
||||
"""
|
||||
轮询命令目录,返回第一个待处理的命令
|
||||
|
||||
Returns:
|
||||
IPCCommand 或 None
|
||||
"""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 按时间排序获取命令文件
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.commands_dir, filename)
|
||||
command_files.append((filepath, os.path.getmtime(filepath)))
|
||||
|
||||
command_files.sort(key=lambda x: x[1])
|
||||
|
||||
for filepath, _ in command_files:
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return IPCCommand.from_dict(data)
|
||||
except (json.JSONDecodeError, KeyError, OSError) as e:
|
||||
logger.warning(f"读取命令文件失败: {filepath}, {e}")
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def send_response(self, response: IPCResponse):
|
||||
"""
|
||||
发送响应
|
||||
|
||||
Args:
|
||||
response: IPC响应
|
||||
"""
|
||||
response_file = os.path.join(self.responses_dir, f"{response.command_id}.json")
|
||||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
command_file = os.path.join(self.commands_dir, f"{response.command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def send_success(self, command_id: str, result: Dict[str, Any]):
|
||||
"""发送成功响应"""
|
||||
self.send_response(IPCResponse(
|
||||
command_id=command_id,
|
||||
status=CommandStatus.COMPLETED,
|
||||
result=result
|
||||
))
|
||||
|
||||
def send_error(self, command_id: str, error: str):
|
||||
"""发送错误响应"""
|
||||
self.send_response(IPCResponse(
|
||||
command_id=command_id,
|
||||
status=CommandStatus.FAILED,
|
||||
error=error
|
||||
))
|
||||
@@ -12,7 +12,7 @@ import threading
|
||||
import subprocess
|
||||
import signal
|
||||
import atexit
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -21,6 +21,7 @@ from queue import Queue
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_graph_memory_updater import ZepGraphMemoryManager
|
||||
from .simulation_ipc import SimulationIPCClient, CommandType, IPCResponse
|
||||
|
||||
logger = get_logger('mirofish.simulation_runner')
|
||||
|
||||
@@ -989,4 +990,365 @@ class SimulationRunner:
|
||||
if process.poll() is None:
|
||||
running.append(sim_id)
|
||||
return running
|
||||
|
||||
# ============== Interview 功能 ==============
|
||||
|
||||
@classmethod
|
||||
def check_env_alive(cls, simulation_id: str) -> bool:
|
||||
"""
|
||||
检查模拟环境是否存活(可以接收Interview命令)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
|
||||
Returns:
|
||||
True 表示环境存活,False 表示环境已关闭
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
return False
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
return ipc_client.check_env_alive()
|
||||
|
||||
@classmethod
|
||||
def get_env_status_detail(cls, simulation_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取模拟环境的详细状态信息
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
|
||||
Returns:
|
||||
状态详情字典,包含 status, twitter_available, reddit_available, timestamp
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
status_file = os.path.join(sim_dir, "env_status.json")
|
||||
|
||||
default_status = {
|
||||
"status": "stopped",
|
||||
"twitter_available": False,
|
||||
"reddit_available": False,
|
||||
"timestamp": None
|
||||
}
|
||||
|
||||
if not os.path.exists(status_file):
|
||||
return default_status
|
||||
|
||||
try:
|
||||
with open(status_file, 'r', encoding='utf-8') as f:
|
||||
status = json.load(f)
|
||||
return {
|
||||
"status": status.get("status", "stopped"),
|
||||
"twitter_available": status.get("twitter_available", False),
|
||||
"reddit_available": status.get("reddit_available", False),
|
||||
"timestamp": status.get("timestamp")
|
||||
}
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return default_status
|
||||
|
||||
@classmethod
|
||||
def interview_agent(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
agent_id: int,
|
||||
prompt: str,
|
||||
platform: str = None,
|
||||
timeout: float = 60.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
采访单个Agent
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
agent_id: Agent ID
|
||||
prompt: 采访问题
|
||||
platform: 指定平台(可选)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None: 双平台模拟时同时采访两个平台,返回整合结果
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
采访结果字典
|
||||
|
||||
Raises:
|
||||
ValueError: 模拟不存在或环境未运行
|
||||
TimeoutError: 等待响应超时
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
|
||||
if not ipc_client.check_env_alive():
|
||||
raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}")
|
||||
|
||||
logger.info(f"发送Interview命令: simulation_id={simulation_id}, agent_id={agent_id}, platform={platform}")
|
||||
|
||||
response = ipc_client.send_interview(
|
||||
agent_id=agent_id,
|
||||
prompt=prompt,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if response.status.value == "completed":
|
||||
return {
|
||||
"success": True,
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt,
|
||||
"result": response.result,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt,
|
||||
"error": response.error,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def interview_agents_batch(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
interviews: List[Dict[str, Any]],
|
||||
platform: str = None,
|
||||
timeout: float = 120.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
批量采访多个Agent
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
|
||||
platform: 默认平台(可选,会被每个采访项的platform覆盖)
|
||||
- "twitter": 默认只采访Twitter平台
|
||||
- "reddit": 默认只采访Reddit平台
|
||||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
批量采访结果字典
|
||||
|
||||
Raises:
|
||||
ValueError: 模拟不存在或环境未运行
|
||||
TimeoutError: 等待响应超时
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
|
||||
if not ipc_client.check_env_alive():
|
||||
raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}")
|
||||
|
||||
logger.info(f"发送批量Interview命令: simulation_id={simulation_id}, count={len(interviews)}, platform={platform}")
|
||||
|
||||
response = ipc_client.send_batch_interview(
|
||||
interviews=interviews,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if response.status.value == "completed":
|
||||
return {
|
||||
"success": True,
|
||||
"interviews_count": len(interviews),
|
||||
"result": response.result,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"interviews_count": len(interviews),
|
||||
"error": response.error,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def interview_all_agents(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
prompt: str,
|
||||
platform: str = None,
|
||||
timeout: float = 180.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
采访所有Agent(全局采访)
|
||||
|
||||
使用相同的问题采访模拟中的所有Agent
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
prompt: 采访问题(所有Agent使用相同问题)
|
||||
platform: 指定平台(可选)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
全局采访结果字典
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
# 从配置文件获取所有Agent信息
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(f"模拟配置不存在: {simulation_id}")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
agent_configs = config.get("agent_configs", [])
|
||||
if not agent_configs:
|
||||
raise ValueError(f"模拟配置中没有Agent: {simulation_id}")
|
||||
|
||||
# 构建批量采访列表
|
||||
interviews = []
|
||||
for agent_config in agent_configs:
|
||||
agent_id = agent_config.get("agent_id")
|
||||
if agent_id is not None:
|
||||
interviews.append({
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt
|
||||
})
|
||||
|
||||
logger.info(f"发送全局Interview命令: simulation_id={simulation_id}, agent_count={len(interviews)}, platform={platform}")
|
||||
|
||||
return cls.interview_agents_batch(
|
||||
simulation_id=simulation_id,
|
||||
interviews=interviews,
|
||||
platform=platform,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def close_simulation_env(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
timeout: float = 30.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
关闭模拟环境(而不是停止模拟进程)
|
||||
|
||||
向模拟发送关闭环境命令,使其优雅退出等待命令模式
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
操作结果字典
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
if not os.path.exists(sim_dir):
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
ipc_client = SimulationIPCClient(sim_dir)
|
||||
|
||||
if not ipc_client.check_env_alive():
|
||||
return {
|
||||
"success": True,
|
||||
"message": "环境已经关闭"
|
||||
}
|
||||
|
||||
logger.info(f"发送关闭环境命令: simulation_id={simulation_id}")
|
||||
|
||||
try:
|
||||
response = ipc_client.send_close_env(timeout=timeout)
|
||||
|
||||
return {
|
||||
"success": response.status.value == "completed",
|
||||
"message": "环境关闭命令已发送",
|
||||
"result": response.result,
|
||||
"timestamp": response.timestamp
|
||||
}
|
||||
except TimeoutError:
|
||||
# 超时可能是因为环境正在关闭
|
||||
return {
|
||||
"success": True,
|
||||
"message": "环境关闭命令已发送(等待响应超时,环境可能正在关闭)"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_interview_history(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
platform: str = "reddit",
|
||||
agent_id: Optional[int] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取Interview历史记录(从数据库读取)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
platform: 平台类型(reddit/twitter)
|
||||
agent_id: 过滤Agent ID(可选)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
Interview历史记录列表
|
||||
"""
|
||||
import sqlite3
|
||||
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
db_path = os.path.join(sim_dir, f"{platform}_simulation.db")
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return []
|
||||
|
||||
results = []
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 构建查询
|
||||
# 注意:ActionType.INTERVIEW.value 应该是字符串形式
|
||||
if agent_id is not None:
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = 'interview' AND user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
""", (agent_id, limit))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = 'interview'
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
""", (limit,))
|
||||
|
||||
for user_id, info_json, created_at in cursor.fetchall():
|
||||
try:
|
||||
info = json.loads(info_json) if info_json else {}
|
||||
except json.JSONDecodeError:
|
||||
info = {"raw": info_json}
|
||||
|
||||
results.append({
|
||||
"agent_id": user_id,
|
||||
"response": info.get("response", info),
|
||||
"prompt": info.get("prompt", ""),
|
||||
"timestamp": created_at,
|
||||
"platform": platform
|
||||
})
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取Interview历史失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user