Implement Interview feature for agent interactions in simulations

- Added a new Interview module to facilitate interactions with agents post-simulation, allowing for single and batch interviews.
- Introduced IPC communication mechanism for command and response handling between the Flask backend and simulation scripts.
- Updated README.md to include detailed instructions on the new Interview functionality, including API endpoints and usage examples.
- Enhanced simulation scripts to support waiting for commands after completion, improving user control over the simulation environment.
- Implemented error handling and logging for interview processes, ensuring robust operation and traceability.
This commit is contained in:
666ghj
2025-12-08 15:55:39 +08:00
parent 29bff9ca27
commit 1042d50306
8 changed files with 2963 additions and 70 deletions

View File

@@ -2,8 +2,15 @@
OASIS Reddit模拟预设脚本
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
功能特性:
- 完成模拟后不立即关闭环境,进入等待命令模式
- 支持通过IPC接收Interview命令
- 支持单个Agent采访和批量采访
- 支持远程关闭环境命令
使用方式:
python run_reddit_simulation.py --config /path/to/simulation_config.json
python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭
"""
import argparse
@@ -13,8 +20,9 @@ import logging
import os
import random
import sys
import sqlite3
from datetime import datetime
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
# 添加项目路径
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
@@ -118,10 +126,261 @@ except ImportError as e:
sys.exit(1)
# IPC相关常量
IPC_COMMANDS_DIR = "ipc_commands"
IPC_RESPONSES_DIR = "ipc_responses"
ENV_STATUS_FILE = "env_status.json"
class CommandType:
"""命令类型常量"""
INTERVIEW = "interview"
BATCH_INTERVIEW = "batch_interview"
CLOSE_ENV = "close_env"
class IPCHandler:
"""IPC命令处理器"""
def __init__(self, simulation_dir: str, env, agent_graph):
self.simulation_dir = simulation_dir
self.env = env
self.agent_graph = agent_graph
self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR)
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
self._running = True
# 确保目录存在
os.makedirs(self.commands_dir, exist_ok=True)
os.makedirs(self.responses_dir, exist_ok=True)
def update_status(self, status: str):
"""更新环境状态"""
with open(self.status_file, 'w', encoding='utf-8') as f:
json.dump({
"status": status,
"timestamp": datetime.now().isoformat()
}, f, ensure_ascii=False, indent=2)
def poll_command(self) -> Optional[Dict[str, Any]]:
"""轮询获取待处理命令"""
if not os.path.exists(self.commands_dir):
return None
# 获取命令文件(按时间排序)
command_files = []
for filename in os.listdir(self.commands_dir):
if filename.endswith('.json'):
filepath = os.path.join(self.commands_dir, filename)
command_files.append((filepath, os.path.getmtime(filepath)))
command_files.sort(key=lambda x: x[1])
for filepath, _ in command_files:
try:
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
except (json.JSONDecodeError, OSError):
continue
return None
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
"""发送响应"""
response = {
"command_id": command_id,
"status": status,
"result": result,
"error": error,
"timestamp": datetime.now().isoformat()
}
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
with open(response_file, 'w', encoding='utf-8') as f:
json.dump(response, f, ensure_ascii=False, indent=2)
# 删除命令文件
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
try:
os.remove(command_file)
except OSError:
pass
async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool:
"""
处理单个Agent采访命令
Returns:
True 表示成功False 表示失败
"""
try:
# 获取Agent
agent = self.agent_graph.get_agent(agent_id)
# 创建Interview动作
interview_action = ManualAction(
action_type=ActionType.INTERVIEW,
action_args={"prompt": prompt}
)
# 执行Interview
actions = {agent: interview_action}
await self.env.step(actions)
# 从数据库获取结果
result = self._get_interview_result(agent_id)
self.send_response(command_id, "completed", result=result)
print(f" Interview完成: agent_id={agent_id}")
return True
except Exception as e:
error_msg = str(e)
print(f" Interview失败: agent_id={agent_id}, error={error_msg}")
self.send_response(command_id, "failed", error=error_msg)
return False
async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool:
"""
处理批量采访命令
Args:
interviews: [{"agent_id": int, "prompt": str}, ...]
"""
try:
# 构建动作字典
actions = {}
agent_prompts = {} # 记录每个agent的prompt
for interview in interviews:
agent_id = interview.get("agent_id")
prompt = interview.get("prompt", "")
try:
agent = self.agent_graph.get_agent(agent_id)
actions[agent] = ManualAction(
action_type=ActionType.INTERVIEW,
action_args={"prompt": prompt}
)
agent_prompts[agent_id] = prompt
except Exception as e:
print(f" 警告: 无法获取Agent {agent_id}: {e}")
if not actions:
self.send_response(command_id, "failed", error="没有有效的Agent")
return False
# 执行批量Interview
await self.env.step(actions)
# 获取所有结果
results = {}
for agent_id in agent_prompts.keys():
result = self._get_interview_result(agent_id)
results[agent_id] = result
self.send_response(command_id, "completed", result={
"interviews_count": len(results),
"results": results
})
print(f" 批量Interview完成: {len(results)} 个Agent")
return True
except Exception as e:
error_msg = str(e)
print(f" 批量Interview失败: {error_msg}")
self.send_response(command_id, "failed", error=error_msg)
return False
def _get_interview_result(self, agent_id: int) -> Dict[str, Any]:
"""从数据库获取最新的Interview结果"""
db_path = os.path.join(self.simulation_dir, "reddit_simulation.db")
result = {
"agent_id": agent_id,
"response": None,
"timestamp": None
}
if not os.path.exists(db_path):
return result
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# 查询最新的Interview记录
cursor.execute("""
SELECT user_id, info, created_at
FROM trace
WHERE action = ? AND user_id = ?
ORDER BY created_at DESC
LIMIT 1
""", (ActionType.INTERVIEW.value, agent_id))
row = cursor.fetchone()
if row:
user_id, info_json, created_at = row
try:
info = json.loads(info_json) if info_json else {}
result["response"] = info.get("response", info)
result["timestamp"] = created_at
except json.JSONDecodeError:
result["response"] = info_json
conn.close()
except Exception as e:
print(f" 读取Interview结果失败: {e}")
return result
async def process_commands(self) -> bool:
"""
处理所有待处理命令
Returns:
True 表示继续运行False 表示应该退出
"""
command = self.poll_command()
if not command:
return True
command_id = command.get("command_id")
command_type = command.get("command_type")
args = command.get("args", {})
print(f"\n收到IPC命令: {command_type}, id={command_id}")
if command_type == CommandType.INTERVIEW:
await self.handle_interview(
command_id,
args.get("agent_id", 0),
args.get("prompt", "")
)
return True
elif command_type == CommandType.BATCH_INTERVIEW:
await self.handle_batch_interview(
command_id,
args.get("interviews", [])
)
return True
elif command_type == CommandType.CLOSE_ENV:
print("收到关闭环境命令")
self.send_response(command_id, "completed", result={"message": "环境即将关闭"})
return False
else:
self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}")
return True
class RedditSimulationRunner:
"""Reddit模拟运行器"""
# Reddit可用动作
# Reddit可用动作不包含INTERVIEWINTERVIEW只能通过ManualAction手动触发
AVAILABLE_ACTIONS = [
ActionType.LIKE_POST,
ActionType.DISLIKE_POST,
@@ -138,16 +397,21 @@ class RedditSimulationRunner:
ActionType.MUTE,
]
def __init__(self, config_path: str):
def __init__(self, config_path: str, wait_for_commands: bool = True):
"""
初始化模拟运行器
Args:
config_path: 配置文件路径 (simulation_config.json)
wait_for_commands: 模拟完成后是否等待命令默认True
"""
self.config_path = config_path
self.config = self._load_config()
self.simulation_dir = os.path.dirname(config_path)
self.wait_for_commands = wait_for_commands
self.env = None
self.agent_graph = None
self.ipc_handler = None
def _load_config(self) -> Dict[str, Any]:
"""加载配置文件"""
@@ -261,6 +525,7 @@ class RedditSimulationRunner:
print("OASIS Reddit模拟")
print(f"配置文件: {self.config_path}")
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}")
print("=" * 60)
time_config = self.config.get("time_config", {})
@@ -292,7 +557,7 @@ class RedditSimulationRunner:
print(f"错误: Profile文件不存在: {profile_path}")
return
agent_graph = await generate_reddit_agent_graph(
self.agent_graph = await generate_reddit_agent_graph(
profile_path=profile_path,
model=model,
available_actions=self.AVAILABLE_ACTIONS,
@@ -304,16 +569,20 @@ class RedditSimulationRunner:
print(f"已删除旧数据库: {db_path}")
print("创建OASIS环境...")
env = oasis.make(
agent_graph=agent_graph,
self.env = oasis.make(
agent_graph=self.agent_graph,
platform=oasis.DefaultPlatformType.REDDIT,
database_path=db_path,
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
)
await env.reset()
await self.env.reset()
print("环境初始化完成\n")
# 初始化IPC处理器
self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph)
self.ipc_handler.update_status("running")
# 执行初始事件
event_config = self.config.get("event_config", {})
initial_posts = event_config.get("initial_posts", [])
@@ -325,7 +594,7 @@ class RedditSimulationRunner:
agent_id = post.get("poster_agent_id", 0)
content = post.get("content", "")
try:
agent = env.agent_graph.get_agent(agent_id)
agent = self.env.agent_graph.get_agent(agent_id)
if agent in initial_actions:
if not isinstance(initial_actions[agent], list):
initial_actions[agent] = [initial_actions[agent]]
@@ -342,7 +611,7 @@ class RedditSimulationRunner:
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
if initial_actions:
await env.step(initial_actions)
await self.env.step(initial_actions)
print(f" 已发布 {len(initial_actions)} 条初始帖子")
# 主模拟循环
@@ -355,7 +624,7 @@ class RedditSimulationRunner:
simulated_day = simulated_minutes // (60 * 24) + 1
active_agents = self._get_active_agents_for_round(
env, simulated_hour, round_num
self.env, simulated_hour, round_num
)
if not active_agents:
@@ -366,7 +635,7 @@ class RedditSimulationRunner:
for _, agent in active_agents
}
await env.step(actions)
await self.env.step(actions)
if (round_num + 1) % 10 == 0 or round_num == 0:
elapsed = (datetime.now() - start_time).total_seconds()
@@ -376,12 +645,39 @@ class RedditSimulationRunner:
f"- {len(active_agents)} agents active "
f"- elapsed: {elapsed:.1f}s")
await env.close()
total_elapsed = (datetime.now() - start_time).total_seconds()
print(f"\n模拟完成!")
print(f"\n模拟循环完成!")
print(f" - 总耗时: {total_elapsed:.1f}")
print(f" - 数据库: {db_path}")
# 是否进入等待命令模式
if self.wait_for_commands:
print("\n" + "=" * 60)
print("进入等待命令模式 - 环境保持运行")
print("支持的命令: interview, batch_interview, close_env")
print("=" * 60)
self.ipc_handler.update_status("alive")
# 等待命令循环
try:
while True:
should_continue = await self.ipc_handler.process_commands()
if not should_continue:
break
await asyncio.sleep(0.5) # 轮询间隔
except KeyboardInterrupt:
print("\n收到中断信号")
except Exception as e:
print(f"\n命令处理出错: {e}")
print("\n关闭环境...")
# 关闭环境
self.ipc_handler.update_status("stopped")
await self.env.close()
print("环境已关闭")
print("=" * 60)
@@ -399,6 +695,12 @@ async def main():
default=None,
help='最大模拟轮数(可选,用于截断过长的模拟)'
)
parser.add_argument(
'--no-wait',
action='store_true',
default=False,
help='模拟完成后立即关闭环境,不进入等待命令模式'
)
args = parser.parse_args()
@@ -410,7 +712,10 @@ async def main():
simulation_dir = os.path.dirname(args.config) or "."
setup_oasis_logging(os.path.join(simulation_dir, "log"))
runner = RedditSimulationRunner(args.config)
runner = RedditSimulationRunner(
config_path=args.config,
wait_for_commands=not args.no_wait
)
await runner.run(max_rounds=args.max_rounds)