Implement Interview feature for agent interactions in simulations
- Added a new Interview module to facilitate interactions with agents post-simulation, allowing for single and batch interviews. - Introduced IPC communication mechanism for command and response handling between the Flask backend and simulation scripts. - Updated README.md to include detailed instructions on the new Interview functionality, including API endpoints and usage examples. - Enhanced simulation scripts to support waiting for commands after completion, improving user control over the simulation environment. - Implemented error handling and logging for interview processes, ensuring robust operation and traceability.
This commit is contained in:
@@ -2,8 +2,15 @@
|
||||
OASIS Twitter模拟预设脚本
|
||||
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
|
||||
|
||||
功能特性:
|
||||
- 完成模拟后不立即关闭环境,进入等待命令模式
|
||||
- 支持通过IPC接收Interview命令
|
||||
- 支持单个Agent采访和批量采访
|
||||
- 支持远程关闭环境命令
|
||||
|
||||
使用方式:
|
||||
python run_twitter_simulation.py --config /path/to/simulation_config.json
|
||||
python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -13,8 +20,9 @@ import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
# 添加项目路径
|
||||
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -118,10 +126,261 @@ except ImportError as e:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# IPC相关常量
|
||||
IPC_COMMANDS_DIR = "ipc_commands"
|
||||
IPC_RESPONSES_DIR = "ipc_responses"
|
||||
ENV_STATUS_FILE = "env_status.json"
|
||||
|
||||
class CommandType:
|
||||
"""命令类型常量"""
|
||||
INTERVIEW = "interview"
|
||||
BATCH_INTERVIEW = "batch_interview"
|
||||
CLOSE_ENV = "close_env"
|
||||
|
||||
|
||||
class IPCHandler:
|
||||
"""IPC命令处理器"""
|
||||
|
||||
def __init__(self, simulation_dir: str, env, agent_graph):
|
||||
self.simulation_dir = simulation_dir
|
||||
self.env = env
|
||||
self.agent_graph = agent_graph
|
||||
self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR)
|
||||
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
|
||||
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
|
||||
self._running = True
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
def update_status(self, status: str):
|
||||
"""更新环境状态"""
|
||||
with open(self.status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def poll_command(self) -> Optional[Dict[str, Any]]:
|
||||
"""轮询获取待处理命令"""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 获取命令文件(按时间排序)
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.commands_dir, filename)
|
||||
command_files.append((filepath, os.path.getmtime(filepath)))
|
||||
|
||||
command_files.sort(key=lambda x: x[1])
|
||||
|
||||
for filepath, _ in command_files:
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
|
||||
"""发送响应"""
|
||||
response = {
|
||||
"command_id": command_id,
|
||||
"status": status,
|
||||
"result": result,
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool:
|
||||
"""
|
||||
处理单个Agent采访命令
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败
|
||||
"""
|
||||
try:
|
||||
# 获取Agent
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
|
||||
# 创建Interview动作
|
||||
interview_action = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
|
||||
# 执行Interview
|
||||
actions = {agent: interview_action}
|
||||
await self.env.step(actions)
|
||||
|
||||
# 从数据库获取结果
|
||||
result = self._get_interview_result(agent_id)
|
||||
|
||||
self.send_response(command_id, "completed", result=result)
|
||||
print(f" Interview完成: agent_id={agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f" Interview失败: agent_id={agent_id}, error={error_msg}")
|
||||
self.send_response(command_id, "failed", error=error_msg)
|
||||
return False
|
||||
|
||||
async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool:
|
||||
"""
|
||||
处理批量采访命令
|
||||
|
||||
Args:
|
||||
interviews: [{"agent_id": int, "prompt": str}, ...]
|
||||
"""
|
||||
try:
|
||||
# 构建动作字典
|
||||
actions = {}
|
||||
agent_prompts = {} # 记录每个agent的prompt
|
||||
|
||||
for interview in interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
prompt = interview.get("prompt", "")
|
||||
|
||||
try:
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
actions[agent] = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
agent_prompts[agent_id] = prompt
|
||||
except Exception as e:
|
||||
print(f" 警告: 无法获取Agent {agent_id}: {e}")
|
||||
|
||||
if not actions:
|
||||
self.send_response(command_id, "failed", error="没有有效的Agent")
|
||||
return False
|
||||
|
||||
# 执行批量Interview
|
||||
await self.env.step(actions)
|
||||
|
||||
# 获取所有结果
|
||||
results = {}
|
||||
for agent_id in agent_prompts.keys():
|
||||
result = self._get_interview_result(agent_id)
|
||||
results[agent_id] = result
|
||||
|
||||
self.send_response(command_id, "completed", result={
|
||||
"interviews_count": len(results),
|
||||
"results": results
|
||||
})
|
||||
print(f" 批量Interview完成: {len(results)} 个Agent")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f" 批量Interview失败: {error_msg}")
|
||||
self.send_response(command_id, "failed", error=error_msg)
|
||||
return False
|
||||
|
||||
def _get_interview_result(self, agent_id: int) -> Dict[str, Any]:
|
||||
"""从数据库获取最新的Interview结果"""
|
||||
db_path = os.path.join(self.simulation_dir, "twitter_simulation.db")
|
||||
|
||||
result = {
|
||||
"agent_id": agent_id,
|
||||
"response": None,
|
||||
"timestamp": None
|
||||
}
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return result
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询最新的Interview记录
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = ? AND user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""", (ActionType.INTERVIEW.value, agent_id))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
user_id, info_json, created_at = row
|
||||
try:
|
||||
info = json.loads(info_json) if info_json else {}
|
||||
result["response"] = info.get("response", info)
|
||||
result["timestamp"] = created_at
|
||||
except json.JSONDecodeError:
|
||||
result["response"] = info_json
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f" 读取Interview结果失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def process_commands(self) -> bool:
|
||||
"""
|
||||
处理所有待处理命令
|
||||
|
||||
Returns:
|
||||
True 表示继续运行,False 表示应该退出
|
||||
"""
|
||||
command = self.poll_command()
|
||||
if not command:
|
||||
return True
|
||||
|
||||
command_id = command.get("command_id")
|
||||
command_type = command.get("command_type")
|
||||
args = command.get("args", {})
|
||||
|
||||
print(f"\n收到IPC命令: {command_type}, id={command_id}")
|
||||
|
||||
if command_type == CommandType.INTERVIEW:
|
||||
await self.handle_interview(
|
||||
command_id,
|
||||
args.get("agent_id", 0),
|
||||
args.get("prompt", "")
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.BATCH_INTERVIEW:
|
||||
await self.handle_batch_interview(
|
||||
command_id,
|
||||
args.get("interviews", [])
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.CLOSE_ENV:
|
||||
print("收到关闭环境命令")
|
||||
self.send_response(command_id, "completed", result={"message": "环境即将关闭"})
|
||||
return False
|
||||
|
||||
else:
|
||||
self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}")
|
||||
return True
|
||||
|
||||
|
||||
class TwitterSimulationRunner:
|
||||
"""Twitter模拟运行器"""
|
||||
|
||||
# Twitter可用动作
|
||||
# Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发)
|
||||
AVAILABLE_ACTIONS = [
|
||||
ActionType.CREATE_POST,
|
||||
ActionType.LIKE_POST,
|
||||
@@ -131,16 +390,21 @@ class TwitterSimulationRunner:
|
||||
ActionType.QUOTE_POST,
|
||||
]
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
def __init__(self, config_path: str, wait_for_commands: bool = True):
|
||||
"""
|
||||
初始化模拟运行器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径 (simulation_config.json)
|
||||
wait_for_commands: 模拟完成后是否等待命令(默认True)
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
self.simulation_dir = os.path.dirname(config_path)
|
||||
self.wait_for_commands = wait_for_commands
|
||||
self.env = None
|
||||
self.agent_graph = None
|
||||
self.ipc_handler = None
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
@@ -269,6 +533,7 @@ class TwitterSimulationRunner:
|
||||
print("OASIS Twitter模拟")
|
||||
print(f"配置文件: {self.config_path}")
|
||||
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
|
||||
print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}")
|
||||
print("=" * 60)
|
||||
|
||||
# 加载时间配置
|
||||
@@ -305,7 +570,7 @@ class TwitterSimulationRunner:
|
||||
print(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
|
||||
agent_graph = await generate_twitter_agent_graph(
|
||||
self.agent_graph = await generate_twitter_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=self.AVAILABLE_ACTIONS,
|
||||
@@ -319,16 +584,20 @@ class TwitterSimulationRunner:
|
||||
|
||||
# 创建环境
|
||||
print("创建OASIS环境...")
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
self.env = oasis.make(
|
||||
agent_graph=self.agent_graph,
|
||||
platform=oasis.DefaultPlatformType.TWITTER,
|
||||
database_path=db_path,
|
||||
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
await self.env.reset()
|
||||
print("环境初始化完成\n")
|
||||
|
||||
# 初始化IPC处理器
|
||||
self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph)
|
||||
self.ipc_handler.update_status("running")
|
||||
|
||||
# 执行初始事件
|
||||
event_config = self.config.get("event_config", {})
|
||||
initial_posts = event_config.get("initial_posts", [])
|
||||
@@ -340,7 +609,7 @@ class TwitterSimulationRunner:
|
||||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
agent = self.env.agent_graph.get_agent(agent_id)
|
||||
initial_actions[agent] = ManualAction(
|
||||
action_type=ActionType.CREATE_POST,
|
||||
action_args={"content": content}
|
||||
@@ -349,7 +618,7 @@ class TwitterSimulationRunner:
|
||||
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
await self.env.step(initial_actions)
|
||||
print(f" 已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
@@ -364,7 +633,7 @@ class TwitterSimulationRunner:
|
||||
|
||||
# 获取本轮激活的Agent
|
||||
active_agents = self._get_active_agents_for_round(
|
||||
env, simulated_hour, round_num
|
||||
self.env, simulated_hour, round_num
|
||||
)
|
||||
|
||||
if not active_agents:
|
||||
@@ -377,7 +646,7 @@ class TwitterSimulationRunner:
|
||||
}
|
||||
|
||||
# 执行动作
|
||||
await env.step(actions)
|
||||
await self.env.step(actions)
|
||||
|
||||
# 打印进度
|
||||
if (round_num + 1) % 10 == 0 or round_num == 0:
|
||||
@@ -388,13 +657,39 @@ class TwitterSimulationRunner:
|
||||
f"- {len(active_agents)} agents active "
|
||||
f"- elapsed: {elapsed:.1f}s")
|
||||
|
||||
# 关闭环境
|
||||
await env.close()
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print(f"\n模拟完成!")
|
||||
print(f"\n模拟循环完成!")
|
||||
print(f" - 总耗时: {total_elapsed:.1f}秒")
|
||||
print(f" - 数据库: {db_path}")
|
||||
|
||||
# 是否进入等待命令模式
|
||||
if self.wait_for_commands:
|
||||
print("\n" + "=" * 60)
|
||||
print("进入等待命令模式 - 环境保持运行")
|
||||
print("支持的命令: interview, batch_interview, close_env")
|
||||
print("=" * 60)
|
||||
|
||||
self.ipc_handler.update_status("alive")
|
||||
|
||||
# 等待命令循环
|
||||
try:
|
||||
while True:
|
||||
should_continue = await self.ipc_handler.process_commands()
|
||||
if not should_continue:
|
||||
break
|
||||
await asyncio.sleep(0.5) # 轮询间隔
|
||||
except KeyboardInterrupt:
|
||||
print("\n收到中断信号")
|
||||
except Exception as e:
|
||||
print(f"\n命令处理出错: {e}")
|
||||
|
||||
print("\n关闭环境...")
|
||||
|
||||
# 关闭环境
|
||||
self.ipc_handler.update_status("stopped")
|
||||
await self.env.close()
|
||||
|
||||
print("环境已关闭")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
@@ -412,6 +707,12 @@ async def main():
|
||||
default=None,
|
||||
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-wait',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='模拟完成后立即关闭环境,不进入等待命令模式'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -423,10 +724,12 @@ async def main():
|
||||
simulation_dir = os.path.dirname(args.config) or "."
|
||||
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||
|
||||
runner = TwitterSimulationRunner(args.config)
|
||||
runner = TwitterSimulationRunner(
|
||||
config_path=args.config,
|
||||
wait_for_commands=not args.no_wait
|
||||
)
|
||||
await runner.run(max_rounds=args.max_rounds)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user