Implement Interview feature for agent interactions in simulations
- Added a new Interview module to facilitate interactions with agents post-simulation, allowing for single and batch interviews. - Introduced IPC communication mechanism for command and response handling between the Flask backend and simulation scripts. - Updated README.md to include detailed instructions on the new Interview functionality, including API endpoints and usage examples. - Enhanced simulation scripts to support waiting for commands after completion, improving user control over the simulation environment. - Implemented error handling and logging for interview processes, ensuring robust operation and traceability.
This commit is contained in:
@@ -2,8 +2,18 @@
|
||||
OASIS 双平台并行模拟预设脚本
|
||||
同时运行Twitter和Reddit模拟,读取相同的配置文件
|
||||
|
||||
功能特性:
|
||||
- 双平台(Twitter + Reddit)并行模拟
|
||||
- 完成模拟后不立即关闭环境,进入等待命令模式
|
||||
- 支持通过IPC接收Interview命令
|
||||
- 支持单个Agent采访和批量采访
|
||||
- 支持远程关闭环境命令
|
||||
|
||||
使用方式:
|
||||
python run_parallel_simulation.py --config simulation_config.json
|
||||
python run_parallel_simulation.py --config simulation_config.json --no-wait # 完成后立即关闭
|
||||
python run_parallel_simulation.py --config simulation_config.json --twitter-only
|
||||
python run_parallel_simulation.py --config simulation_config.json --reddit-only
|
||||
|
||||
日志结构:
|
||||
sim_xxx/
|
||||
@@ -119,7 +129,7 @@ except ImportError as e:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Twitter可用动作
|
||||
# Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发)
|
||||
TWITTER_ACTIONS = [
|
||||
ActionType.CREATE_POST,
|
||||
ActionType.LIKE_POST,
|
||||
@@ -129,7 +139,7 @@ TWITTER_ACTIONS = [
|
||||
ActionType.QUOTE_POST,
|
||||
]
|
||||
|
||||
# Reddit可用动作
|
||||
# Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发)
|
||||
REDDIT_ACTIONS = [
|
||||
ActionType.LIKE_POST,
|
||||
ActionType.DISLIKE_POST,
|
||||
@@ -147,6 +157,405 @@ REDDIT_ACTIONS = [
|
||||
]
|
||||
|
||||
|
||||
# IPC相关常量
|
||||
IPC_COMMANDS_DIR = "ipc_commands"
|
||||
IPC_RESPONSES_DIR = "ipc_responses"
|
||||
ENV_STATUS_FILE = "env_status.json"
|
||||
|
||||
class CommandType:
|
||||
"""命令类型常量"""
|
||||
INTERVIEW = "interview"
|
||||
BATCH_INTERVIEW = "batch_interview"
|
||||
CLOSE_ENV = "close_env"
|
||||
|
||||
|
||||
class ParallelIPCHandler:
|
||||
"""
|
||||
双平台IPC命令处理器
|
||||
|
||||
管理两个平台的环境,处理Interview命令
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
simulation_dir: str,
|
||||
twitter_env=None,
|
||||
twitter_agent_graph=None,
|
||||
reddit_env=None,
|
||||
reddit_agent_graph=None
|
||||
):
|
||||
self.simulation_dir = simulation_dir
|
||||
self.twitter_env = twitter_env
|
||||
self.twitter_agent_graph = twitter_agent_graph
|
||||
self.reddit_env = reddit_env
|
||||
self.reddit_agent_graph = reddit_agent_graph
|
||||
|
||||
self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR)
|
||||
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
|
||||
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
def update_status(self, status: str):
|
||||
"""更新环境状态"""
|
||||
with open(self.status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
"twitter_available": self.twitter_env is not None,
|
||||
"reddit_available": self.reddit_env is not None,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def poll_command(self) -> Optional[Dict[str, Any]]:
|
||||
"""轮询获取待处理命令"""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 获取命令文件(按时间排序)
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.commands_dir, filename)
|
||||
command_files.append((filepath, os.path.getmtime(filepath)))
|
||||
|
||||
command_files.sort(key=lambda x: x[1])
|
||||
|
||||
for filepath, _ in command_files:
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
|
||||
"""发送响应"""
|
||||
response = {
|
||||
"command_id": command_id,
|
||||
"status": status,
|
||||
"result": result,
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _get_env_and_graph(self, platform: str):
|
||||
"""
|
||||
获取指定平台的环境和agent_graph
|
||||
|
||||
Args:
|
||||
platform: 平台名称 ("twitter" 或 "reddit")
|
||||
|
||||
Returns:
|
||||
(env, agent_graph, platform_name) 或 (None, None, None)
|
||||
"""
|
||||
if platform == "twitter" and self.twitter_env:
|
||||
return self.twitter_env, self.twitter_agent_graph, "twitter"
|
||||
elif platform == "reddit" and self.reddit_env:
|
||||
return self.reddit_env, self.reddit_agent_graph, "reddit"
|
||||
else:
|
||||
return None, None, None
|
||||
|
||||
async def _interview_single_platform(self, agent_id: int, prompt: str, platform: str) -> Dict[str, Any]:
|
||||
"""
|
||||
在单个平台上执行Interview
|
||||
|
||||
Returns:
|
||||
包含结果的字典,或包含error的字典
|
||||
"""
|
||||
env, agent_graph, actual_platform = self._get_env_and_graph(platform)
|
||||
|
||||
if not env or not agent_graph:
|
||||
return {"platform": platform, "error": f"{platform}平台不可用"}
|
||||
|
||||
try:
|
||||
agent = agent_graph.get_agent(agent_id)
|
||||
interview_action = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
actions = {agent: interview_action}
|
||||
await env.step(actions)
|
||||
|
||||
result = self._get_interview_result(agent_id, actual_platform)
|
||||
result["platform"] = actual_platform
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {"platform": platform, "error": str(e)}
|
||||
|
||||
async def handle_interview(self, command_id: str, agent_id: int, prompt: str, platform: str = None) -> bool:
|
||||
"""
|
||||
处理单个Agent采访命令
|
||||
|
||||
Args:
|
||||
command_id: 命令ID
|
||||
agent_id: Agent ID
|
||||
prompt: 采访问题
|
||||
platform: 指定平台(可选)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None/不指定: 同时采访两个平台,返回整合结果
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败
|
||||
"""
|
||||
# 如果指定了平台,只采访该平台
|
||||
if platform in ("twitter", "reddit"):
|
||||
result = await self._interview_single_platform(agent_id, prompt, platform)
|
||||
|
||||
if "error" in result:
|
||||
self.send_response(command_id, "failed", error=result["error"])
|
||||
print(f" Interview失败: agent_id={agent_id}, platform={platform}, error={result['error']}")
|
||||
return False
|
||||
else:
|
||||
self.send_response(command_id, "completed", result=result)
|
||||
print(f" Interview完成: agent_id={agent_id}, platform={platform}")
|
||||
return True
|
||||
|
||||
# 未指定平台:同时采访两个平台
|
||||
if not self.twitter_env and not self.reddit_env:
|
||||
self.send_response(command_id, "failed", error="没有可用的模拟环境")
|
||||
return False
|
||||
|
||||
results = {
|
||||
"agent_id": agent_id,
|
||||
"prompt": prompt,
|
||||
"platforms": {}
|
||||
}
|
||||
success_count = 0
|
||||
|
||||
# 并行采访两个平台
|
||||
tasks = []
|
||||
platforms_to_interview = []
|
||||
|
||||
if self.twitter_env:
|
||||
tasks.append(self._interview_single_platform(agent_id, prompt, "twitter"))
|
||||
platforms_to_interview.append("twitter")
|
||||
|
||||
if self.reddit_env:
|
||||
tasks.append(self._interview_single_platform(agent_id, prompt, "reddit"))
|
||||
platforms_to_interview.append("reddit")
|
||||
|
||||
# 并行执行
|
||||
platform_results = await asyncio.gather(*tasks)
|
||||
|
||||
for platform_name, platform_result in zip(platforms_to_interview, platform_results):
|
||||
results["platforms"][platform_name] = platform_result
|
||||
if "error" not in platform_result:
|
||||
success_count += 1
|
||||
|
||||
if success_count > 0:
|
||||
self.send_response(command_id, "completed", result=results)
|
||||
print(f" Interview完成: agent_id={agent_id}, 成功平台数={success_count}/{len(platforms_to_interview)}")
|
||||
return True
|
||||
else:
|
||||
errors = [f"{p}: {r.get('error', '未知错误')}" for p, r in results["platforms"].items()]
|
||||
self.send_response(command_id, "failed", error="; ".join(errors))
|
||||
print(f" Interview失败: agent_id={agent_id}, 所有平台都失败")
|
||||
return False
|
||||
|
||||
async def handle_batch_interview(self, command_id: str, interviews: List[Dict], platform: str = None) -> bool:
|
||||
"""
|
||||
处理批量采访命令
|
||||
|
||||
Args:
|
||||
command_id: 命令ID
|
||||
interviews: [{"agent_id": int, "prompt": str, "platform": str(optional)}, ...]
|
||||
platform: 默认平台(可被每个interview项覆盖)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None/不指定: 每个Agent同时采访两个平台
|
||||
"""
|
||||
# 按平台分组
|
||||
twitter_interviews = []
|
||||
reddit_interviews = []
|
||||
both_platforms_interviews = [] # 需要同时采访两个平台的
|
||||
|
||||
for interview in interviews:
|
||||
item_platform = interview.get("platform", platform)
|
||||
if item_platform == "twitter":
|
||||
twitter_interviews.append(interview)
|
||||
elif item_platform == "reddit":
|
||||
reddit_interviews.append(interview)
|
||||
else:
|
||||
# 未指定平台:两个平台都采访
|
||||
both_platforms_interviews.append(interview)
|
||||
|
||||
# 把 both_platforms_interviews 拆分到两个平台
|
||||
if both_platforms_interviews:
|
||||
if self.twitter_env:
|
||||
twitter_interviews.extend(both_platforms_interviews)
|
||||
if self.reddit_env:
|
||||
reddit_interviews.extend(both_platforms_interviews)
|
||||
|
||||
results = {}
|
||||
|
||||
# 处理Twitter平台的采访
|
||||
if twitter_interviews and self.twitter_env:
|
||||
try:
|
||||
twitter_actions = {}
|
||||
for interview in twitter_interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
prompt = interview.get("prompt", "")
|
||||
try:
|
||||
agent = self.twitter_agent_graph.get_agent(agent_id)
|
||||
twitter_actions[agent] = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" 警告: 无法获取Twitter Agent {agent_id}: {e}")
|
||||
|
||||
if twitter_actions:
|
||||
await self.twitter_env.step(twitter_actions)
|
||||
|
||||
for interview in twitter_interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
result = self._get_interview_result(agent_id, "twitter")
|
||||
result["platform"] = "twitter"
|
||||
results[f"twitter_{agent_id}"] = result
|
||||
except Exception as e:
|
||||
print(f" Twitter批量Interview失败: {e}")
|
||||
|
||||
# 处理Reddit平台的采访
|
||||
if reddit_interviews and self.reddit_env:
|
||||
try:
|
||||
reddit_actions = {}
|
||||
for interview in reddit_interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
prompt = interview.get("prompt", "")
|
||||
try:
|
||||
agent = self.reddit_agent_graph.get_agent(agent_id)
|
||||
reddit_actions[agent] = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" 警告: 无法获取Reddit Agent {agent_id}: {e}")
|
||||
|
||||
if reddit_actions:
|
||||
await self.reddit_env.step(reddit_actions)
|
||||
|
||||
for interview in reddit_interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
result = self._get_interview_result(agent_id, "reddit")
|
||||
result["platform"] = "reddit"
|
||||
results[f"reddit_{agent_id}"] = result
|
||||
except Exception as e:
|
||||
print(f" Reddit批量Interview失败: {e}")
|
||||
|
||||
if results:
|
||||
self.send_response(command_id, "completed", result={
|
||||
"interviews_count": len(results),
|
||||
"results": results
|
||||
})
|
||||
print(f" 批量Interview完成: {len(results)} 个Agent")
|
||||
return True
|
||||
else:
|
||||
self.send_response(command_id, "failed", error="没有成功的采访")
|
||||
return False
|
||||
|
||||
def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]:
|
||||
"""从数据库获取最新的Interview结果"""
|
||||
db_path = os.path.join(self.simulation_dir, f"{platform}_simulation.db")
|
||||
|
||||
result = {
|
||||
"agent_id": agent_id,
|
||||
"response": None,
|
||||
"timestamp": None
|
||||
}
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return result
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询最新的Interview记录
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = ? AND user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""", (ActionType.INTERVIEW.value, agent_id))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
user_id, info_json, created_at = row
|
||||
try:
|
||||
info = json.loads(info_json) if info_json else {}
|
||||
result["response"] = info.get("response", info)
|
||||
result["timestamp"] = created_at
|
||||
except json.JSONDecodeError:
|
||||
result["response"] = info_json
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f" 读取Interview结果失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def process_commands(self) -> bool:
|
||||
"""
|
||||
处理所有待处理命令
|
||||
|
||||
Returns:
|
||||
True 表示继续运行,False 表示应该退出
|
||||
"""
|
||||
command = self.poll_command()
|
||||
if not command:
|
||||
return True
|
||||
|
||||
command_id = command.get("command_id")
|
||||
command_type = command.get("command_type")
|
||||
args = command.get("args", {})
|
||||
|
||||
print(f"\n收到IPC命令: {command_type}, id={command_id}")
|
||||
|
||||
if command_type == CommandType.INTERVIEW:
|
||||
await self.handle_interview(
|
||||
command_id,
|
||||
args.get("agent_id", 0),
|
||||
args.get("prompt", ""),
|
||||
args.get("platform")
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.BATCH_INTERVIEW:
|
||||
await self.handle_batch_interview(
|
||||
command_id,
|
||||
args.get("interviews", []),
|
||||
args.get("platform")
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.CLOSE_ENV:
|
||||
print("收到关闭环境命令")
|
||||
self.send_response(command_id, "completed", result={"message": "环境即将关闭"})
|
||||
return False
|
||||
|
||||
else:
|
||||
self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}")
|
||||
return True
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
@@ -398,13 +807,21 @@ def get_active_agents_for_round(
|
||||
return active_agents
|
||||
|
||||
|
||||
class PlatformSimulation:
|
||||
"""平台模拟结果容器"""
|
||||
def __init__(self):
|
||||
self.env = None
|
||||
self.agent_graph = None
|
||||
self.total_actions = 0
|
||||
|
||||
|
||||
async def run_twitter_simulation(
|
||||
config: Dict[str, Any],
|
||||
simulation_dir: str,
|
||||
action_logger: Optional[PlatformActionLogger] = None,
|
||||
main_logger: Optional[SimulationLogManager] = None,
|
||||
max_rounds: Optional[int] = None
|
||||
):
|
||||
) -> PlatformSimulation:
|
||||
"""运行Twitter模拟
|
||||
|
||||
Args:
|
||||
@@ -413,7 +830,12 @@ async def run_twitter_simulation(
|
||||
action_logger: 动作日志记录器
|
||||
main_logger: 主日志管理器
|
||||
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||
|
||||
Returns:
|
||||
PlatformSimulation: 包含env和agent_graph的结果对象
|
||||
"""
|
||||
result = PlatformSimulation()
|
||||
|
||||
def log_info(msg):
|
||||
if main_logger:
|
||||
main_logger.info(f"[Twitter] {msg}")
|
||||
@@ -428,9 +850,9 @@ async def run_twitter_simulation(
|
||||
profile_path = os.path.join(simulation_dir, "twitter_profiles.csv")
|
||||
if not os.path.exists(profile_path):
|
||||
log_info(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
return result
|
||||
|
||||
agent_graph = await generate_twitter_agent_graph(
|
||||
result.agent_graph = await generate_twitter_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=TWITTER_ACTIONS,
|
||||
@@ -439,7 +861,7 @@ async def run_twitter_simulation(
|
||||
# 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X)
|
||||
agent_names = get_agent_names_from_config(config)
|
||||
# 如果配置中没有某个 agent,则使用 OASIS 的默认名称
|
||||
for agent_id, agent in agent_graph.get_agents():
|
||||
for agent_id, agent in result.agent_graph.get_agents():
|
||||
if agent_id not in agent_names:
|
||||
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
||||
|
||||
@@ -447,14 +869,14 @@ async def run_twitter_simulation(
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
result.env = oasis.make(
|
||||
agent_graph=result.agent_graph,
|
||||
platform=oasis.DefaultPlatformType.TWITTER,
|
||||
database_path=db_path,
|
||||
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
await result.env.reset()
|
||||
log_info("环境已启动")
|
||||
|
||||
if action_logger:
|
||||
@@ -478,7 +900,7 @@ async def run_twitter_simulation(
|
||||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
agent = result.env.agent_graph.get_agent(agent_id)
|
||||
initial_actions[agent] = ManualAction(
|
||||
action_type=ActionType.CREATE_POST,
|
||||
action_args={"content": content}
|
||||
@@ -498,7 +920,7 @@ async def run_twitter_simulation(
|
||||
pass
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
await result.env.step(initial_actions)
|
||||
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 记录 round 0 结束
|
||||
@@ -526,7 +948,7 @@ async def run_twitter_simulation(
|
||||
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||
|
||||
active_agents = get_active_agents_for_round(
|
||||
env, config, simulated_hour, round_num
|
||||
result.env, config, simulated_hour, round_num
|
||||
)
|
||||
|
||||
# 无论是否有活跃agent,都记录round开始
|
||||
@@ -540,7 +962,7 @@ async def run_twitter_simulation(
|
||||
continue
|
||||
|
||||
actions = {agent: LLMAction() for _, agent in active_agents}
|
||||
await env.step(actions)
|
||||
await result.env.step(actions)
|
||||
|
||||
# 从数据库获取实际执行的动作并记录
|
||||
actual_actions, last_rowid = fetch_new_actions_from_db(
|
||||
@@ -567,13 +989,16 @@ async def run_twitter_simulation(
|
||||
progress = (round_num + 1) / total_rounds * 100
|
||||
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||
|
||||
await env.close()
|
||||
# 注意:不关闭环境,保留给Interview使用
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_end(total_rounds, total_actions)
|
||||
|
||||
result.total_actions = total_actions
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_reddit_simulation(
|
||||
@@ -582,7 +1007,7 @@ async def run_reddit_simulation(
|
||||
action_logger: Optional[PlatformActionLogger] = None,
|
||||
main_logger: Optional[SimulationLogManager] = None,
|
||||
max_rounds: Optional[int] = None
|
||||
):
|
||||
) -> PlatformSimulation:
|
||||
"""运行Reddit模拟
|
||||
|
||||
Args:
|
||||
@@ -591,7 +1016,12 @@ async def run_reddit_simulation(
|
||||
action_logger: 动作日志记录器
|
||||
main_logger: 主日志管理器
|
||||
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||
|
||||
Returns:
|
||||
PlatformSimulation: 包含env和agent_graph的结果对象
|
||||
"""
|
||||
result = PlatformSimulation()
|
||||
|
||||
def log_info(msg):
|
||||
if main_logger:
|
||||
main_logger.info(f"[Reddit] {msg}")
|
||||
@@ -605,9 +1035,9 @@ async def run_reddit_simulation(
|
||||
profile_path = os.path.join(simulation_dir, "reddit_profiles.json")
|
||||
if not os.path.exists(profile_path):
|
||||
log_info(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
return result
|
||||
|
||||
agent_graph = await generate_reddit_agent_graph(
|
||||
result.agent_graph = await generate_reddit_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=REDDIT_ACTIONS,
|
||||
@@ -616,7 +1046,7 @@ async def run_reddit_simulation(
|
||||
# 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X)
|
||||
agent_names = get_agent_names_from_config(config)
|
||||
# 如果配置中没有某个 agent,则使用 OASIS 的默认名称
|
||||
for agent_id, agent in agent_graph.get_agents():
|
||||
for agent_id, agent in result.agent_graph.get_agents():
|
||||
if agent_id not in agent_names:
|
||||
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
||||
|
||||
@@ -624,14 +1054,14 @@ async def run_reddit_simulation(
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
result.env = oasis.make(
|
||||
agent_graph=result.agent_graph,
|
||||
platform=oasis.DefaultPlatformType.REDDIT,
|
||||
database_path=db_path,
|
||||
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
await result.env.reset()
|
||||
log_info("环境已启动")
|
||||
|
||||
if action_logger:
|
||||
@@ -655,7 +1085,7 @@ async def run_reddit_simulation(
|
||||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
agent = result.env.agent_graph.get_agent(agent_id)
|
||||
if agent in initial_actions:
|
||||
if not isinstance(initial_actions[agent], list):
|
||||
initial_actions[agent] = [initial_actions[agent]]
|
||||
@@ -683,7 +1113,7 @@ async def run_reddit_simulation(
|
||||
pass
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
await result.env.step(initial_actions)
|
||||
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 记录 round 0 结束
|
||||
@@ -711,7 +1141,7 @@ async def run_reddit_simulation(
|
||||
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||
|
||||
active_agents = get_active_agents_for_round(
|
||||
env, config, simulated_hour, round_num
|
||||
result.env, config, simulated_hour, round_num
|
||||
)
|
||||
|
||||
# 无论是否有活跃agent,都记录round开始
|
||||
@@ -725,7 +1155,7 @@ async def run_reddit_simulation(
|
||||
continue
|
||||
|
||||
actions = {agent: LLMAction() for _, agent in active_agents}
|
||||
await env.step(actions)
|
||||
await result.env.step(actions)
|
||||
|
||||
# 从数据库获取实际执行的动作并记录
|
||||
actual_actions, last_rowid = fetch_new_actions_from_db(
|
||||
@@ -752,13 +1182,16 @@ async def run_reddit_simulation(
|
||||
progress = (round_num + 1) / total_rounds * 100
|
||||
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
||||
|
||||
await env.close()
|
||||
# 注意:不关闭环境,保留给Interview使用
|
||||
|
||||
if action_logger:
|
||||
action_logger.log_simulation_end(total_rounds, total_actions)
|
||||
|
||||
result.total_actions = total_actions
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
log_info(f"模拟完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -785,6 +1218,12 @@ async def main():
|
||||
default=None,
|
||||
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-wait',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='模拟完成后立即关闭环境,不进入等待命令模式'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -794,6 +1233,7 @@ async def main():
|
||||
|
||||
config = load_config(args.config)
|
||||
simulation_dir = os.path.dirname(args.config) or "."
|
||||
wait_for_commands = not args.no_wait
|
||||
|
||||
# 初始化日志配置(禁用 OASIS 日志,清理旧文件)
|
||||
init_logging_for_simulation(simulation_dir)
|
||||
@@ -807,6 +1247,7 @@ async def main():
|
||||
log_manager.info("OASIS 双平台并行模拟")
|
||||
log_manager.info(f"配置文件: {args.config}")
|
||||
log_manager.info(f"模拟ID: {config.get('simulation_id', 'unknown')}")
|
||||
log_manager.info(f"等待命令模式: {'启用' if wait_for_commands else '禁用'}")
|
||||
log_manager.info("=" * 60)
|
||||
|
||||
time_config = config.get("time_config", {})
|
||||
@@ -832,20 +1273,70 @@ async def main():
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
# 存储两个平台的模拟结果
|
||||
twitter_result: Optional[PlatformSimulation] = None
|
||||
reddit_result: Optional[PlatformSimulation] = None
|
||||
|
||||
if args.twitter_only:
|
||||
await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds)
|
||||
twitter_result = await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds)
|
||||
elif args.reddit_only:
|
||||
await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds)
|
||||
reddit_result = await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds)
|
||||
else:
|
||||
# 并行运行(每个平台使用独立的日志记录器)
|
||||
await asyncio.gather(
|
||||
results = await asyncio.gather(
|
||||
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds),
|
||||
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds),
|
||||
)
|
||||
twitter_result, reddit_result = results
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
log_manager.info("=" * 60)
|
||||
log_manager.info(f"全部模拟完成! 总耗时: {total_elapsed:.1f}秒")
|
||||
log_manager.info(f"模拟循环完成! 总耗时: {total_elapsed:.1f}秒")
|
||||
|
||||
# 是否进入等待命令模式
|
||||
if wait_for_commands:
|
||||
log_manager.info("")
|
||||
log_manager.info("=" * 60)
|
||||
log_manager.info("进入等待命令模式 - 环境保持运行")
|
||||
log_manager.info("支持的命令: interview, batch_interview, close_env")
|
||||
log_manager.info("=" * 60)
|
||||
|
||||
# 创建IPC处理器
|
||||
ipc_handler = ParallelIPCHandler(
|
||||
simulation_dir=simulation_dir,
|
||||
twitter_env=twitter_result.env if twitter_result else None,
|
||||
twitter_agent_graph=twitter_result.agent_graph if twitter_result else None,
|
||||
reddit_env=reddit_result.env if reddit_result else None,
|
||||
reddit_agent_graph=reddit_result.agent_graph if reddit_result else None
|
||||
)
|
||||
ipc_handler.update_status("alive")
|
||||
|
||||
# 等待命令循环
|
||||
try:
|
||||
while True:
|
||||
should_continue = await ipc_handler.process_commands()
|
||||
if not should_continue:
|
||||
break
|
||||
await asyncio.sleep(0.5) # 轮询间隔
|
||||
except KeyboardInterrupt:
|
||||
print("\n收到中断信号")
|
||||
except Exception as e:
|
||||
print(f"\n命令处理出错: {e}")
|
||||
|
||||
log_manager.info("\n关闭环境...")
|
||||
ipc_handler.update_status("stopped")
|
||||
|
||||
# 关闭环境
|
||||
if twitter_result and twitter_result.env:
|
||||
await twitter_result.env.close()
|
||||
log_manager.info("[Twitter] 环境已关闭")
|
||||
|
||||
if reddit_result and reddit_result.env:
|
||||
await reddit_result.env.close()
|
||||
log_manager.info("[Reddit] 环境已关闭")
|
||||
|
||||
log_manager.info("=" * 60)
|
||||
log_manager.info(f"全部完成!")
|
||||
log_manager.info(f"日志文件:")
|
||||
log_manager.info(f" - {os.path.join(simulation_dir, 'simulation.log')}")
|
||||
log_manager.info(f" - {os.path.join(simulation_dir, 'twitter', 'actions.jsonl')}")
|
||||
@@ -855,4 +1346,3 @@ async def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
@@ -2,8 +2,15 @@
|
||||
OASIS Reddit模拟预设脚本
|
||||
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
|
||||
|
||||
功能特性:
|
||||
- 完成模拟后不立即关闭环境,进入等待命令模式
|
||||
- 支持通过IPC接收Interview命令
|
||||
- 支持单个Agent采访和批量采访
|
||||
- 支持远程关闭环境命令
|
||||
|
||||
使用方式:
|
||||
python run_reddit_simulation.py --config /path/to/simulation_config.json
|
||||
python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -13,8 +20,9 @@ import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
# 添加项目路径
|
||||
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -118,10 +126,261 @@ except ImportError as e:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# IPC相关常量
|
||||
IPC_COMMANDS_DIR = "ipc_commands"
|
||||
IPC_RESPONSES_DIR = "ipc_responses"
|
||||
ENV_STATUS_FILE = "env_status.json"
|
||||
|
||||
class CommandType:
|
||||
"""命令类型常量"""
|
||||
INTERVIEW = "interview"
|
||||
BATCH_INTERVIEW = "batch_interview"
|
||||
CLOSE_ENV = "close_env"
|
||||
|
||||
|
||||
class IPCHandler:
|
||||
"""IPC命令处理器"""
|
||||
|
||||
def __init__(self, simulation_dir: str, env, agent_graph):
|
||||
self.simulation_dir = simulation_dir
|
||||
self.env = env
|
||||
self.agent_graph = agent_graph
|
||||
self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR)
|
||||
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
|
||||
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
|
||||
self._running = True
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
def update_status(self, status: str):
|
||||
"""更新环境状态"""
|
||||
with open(self.status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def poll_command(self) -> Optional[Dict[str, Any]]:
|
||||
"""轮询获取待处理命令"""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 获取命令文件(按时间排序)
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.commands_dir, filename)
|
||||
command_files.append((filepath, os.path.getmtime(filepath)))
|
||||
|
||||
command_files.sort(key=lambda x: x[1])
|
||||
|
||||
for filepath, _ in command_files:
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
|
||||
"""发送响应"""
|
||||
response = {
|
||||
"command_id": command_id,
|
||||
"status": status,
|
||||
"result": result,
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool:
|
||||
"""
|
||||
处理单个Agent采访命令
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败
|
||||
"""
|
||||
try:
|
||||
# 获取Agent
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
|
||||
# 创建Interview动作
|
||||
interview_action = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
|
||||
# 执行Interview
|
||||
actions = {agent: interview_action}
|
||||
await self.env.step(actions)
|
||||
|
||||
# 从数据库获取结果
|
||||
result = self._get_interview_result(agent_id)
|
||||
|
||||
self.send_response(command_id, "completed", result=result)
|
||||
print(f" Interview完成: agent_id={agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f" Interview失败: agent_id={agent_id}, error={error_msg}")
|
||||
self.send_response(command_id, "failed", error=error_msg)
|
||||
return False
|
||||
|
||||
async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool:
|
||||
"""
|
||||
处理批量采访命令
|
||||
|
||||
Args:
|
||||
interviews: [{"agent_id": int, "prompt": str}, ...]
|
||||
"""
|
||||
try:
|
||||
# 构建动作字典
|
||||
actions = {}
|
||||
agent_prompts = {} # 记录每个agent的prompt
|
||||
|
||||
for interview in interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
prompt = interview.get("prompt", "")
|
||||
|
||||
try:
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
actions[agent] = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
agent_prompts[agent_id] = prompt
|
||||
except Exception as e:
|
||||
print(f" 警告: 无法获取Agent {agent_id}: {e}")
|
||||
|
||||
if not actions:
|
||||
self.send_response(command_id, "failed", error="没有有效的Agent")
|
||||
return False
|
||||
|
||||
# 执行批量Interview
|
||||
await self.env.step(actions)
|
||||
|
||||
# 获取所有结果
|
||||
results = {}
|
||||
for agent_id in agent_prompts.keys():
|
||||
result = self._get_interview_result(agent_id)
|
||||
results[agent_id] = result
|
||||
|
||||
self.send_response(command_id, "completed", result={
|
||||
"interviews_count": len(results),
|
||||
"results": results
|
||||
})
|
||||
print(f" 批量Interview完成: {len(results)} 个Agent")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f" 批量Interview失败: {error_msg}")
|
||||
self.send_response(command_id, "failed", error=error_msg)
|
||||
return False
|
||||
|
||||
def _get_interview_result(self, agent_id: int) -> Dict[str, Any]:
|
||||
"""从数据库获取最新的Interview结果"""
|
||||
db_path = os.path.join(self.simulation_dir, "reddit_simulation.db")
|
||||
|
||||
result = {
|
||||
"agent_id": agent_id,
|
||||
"response": None,
|
||||
"timestamp": None
|
||||
}
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return result
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询最新的Interview记录
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = ? AND user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""", (ActionType.INTERVIEW.value, agent_id))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
user_id, info_json, created_at = row
|
||||
try:
|
||||
info = json.loads(info_json) if info_json else {}
|
||||
result["response"] = info.get("response", info)
|
||||
result["timestamp"] = created_at
|
||||
except json.JSONDecodeError:
|
||||
result["response"] = info_json
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f" 读取Interview结果失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def process_commands(self) -> bool:
|
||||
"""
|
||||
处理所有待处理命令
|
||||
|
||||
Returns:
|
||||
True 表示继续运行,False 表示应该退出
|
||||
"""
|
||||
command = self.poll_command()
|
||||
if not command:
|
||||
return True
|
||||
|
||||
command_id = command.get("command_id")
|
||||
command_type = command.get("command_type")
|
||||
args = command.get("args", {})
|
||||
|
||||
print(f"\n收到IPC命令: {command_type}, id={command_id}")
|
||||
|
||||
if command_type == CommandType.INTERVIEW:
|
||||
await self.handle_interview(
|
||||
command_id,
|
||||
args.get("agent_id", 0),
|
||||
args.get("prompt", "")
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.BATCH_INTERVIEW:
|
||||
await self.handle_batch_interview(
|
||||
command_id,
|
||||
args.get("interviews", [])
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.CLOSE_ENV:
|
||||
print("收到关闭环境命令")
|
||||
self.send_response(command_id, "completed", result={"message": "环境即将关闭"})
|
||||
return False
|
||||
|
||||
else:
|
||||
self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}")
|
||||
return True
|
||||
|
||||
|
||||
class RedditSimulationRunner:
|
||||
"""Reddit模拟运行器"""
|
||||
|
||||
# Reddit可用动作
|
||||
# Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发)
|
||||
AVAILABLE_ACTIONS = [
|
||||
ActionType.LIKE_POST,
|
||||
ActionType.DISLIKE_POST,
|
||||
@@ -138,16 +397,21 @@ class RedditSimulationRunner:
|
||||
ActionType.MUTE,
|
||||
]
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
def __init__(self, config_path: str, wait_for_commands: bool = True):
|
||||
"""
|
||||
初始化模拟运行器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径 (simulation_config.json)
|
||||
wait_for_commands: 模拟完成后是否等待命令(默认True)
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
self.simulation_dir = os.path.dirname(config_path)
|
||||
self.wait_for_commands = wait_for_commands
|
||||
self.env = None
|
||||
self.agent_graph = None
|
||||
self.ipc_handler = None
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
@@ -261,6 +525,7 @@ class RedditSimulationRunner:
|
||||
print("OASIS Reddit模拟")
|
||||
print(f"配置文件: {self.config_path}")
|
||||
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
|
||||
print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}")
|
||||
print("=" * 60)
|
||||
|
||||
time_config = self.config.get("time_config", {})
|
||||
@@ -292,7 +557,7 @@ class RedditSimulationRunner:
|
||||
print(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
|
||||
agent_graph = await generate_reddit_agent_graph(
|
||||
self.agent_graph = await generate_reddit_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=self.AVAILABLE_ACTIONS,
|
||||
@@ -304,16 +569,20 @@ class RedditSimulationRunner:
|
||||
print(f"已删除旧数据库: {db_path}")
|
||||
|
||||
print("创建OASIS环境...")
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
self.env = oasis.make(
|
||||
agent_graph=self.agent_graph,
|
||||
platform=oasis.DefaultPlatformType.REDDIT,
|
||||
database_path=db_path,
|
||||
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
await self.env.reset()
|
||||
print("环境初始化完成\n")
|
||||
|
||||
# 初始化IPC处理器
|
||||
self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph)
|
||||
self.ipc_handler.update_status("running")
|
||||
|
||||
# 执行初始事件
|
||||
event_config = self.config.get("event_config", {})
|
||||
initial_posts = event_config.get("initial_posts", [])
|
||||
@@ -325,7 +594,7 @@ class RedditSimulationRunner:
|
||||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
agent = self.env.agent_graph.get_agent(agent_id)
|
||||
if agent in initial_actions:
|
||||
if not isinstance(initial_actions[agent], list):
|
||||
initial_actions[agent] = [initial_actions[agent]]
|
||||
@@ -342,7 +611,7 @@ class RedditSimulationRunner:
|
||||
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
await self.env.step(initial_actions)
|
||||
print(f" 已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
@@ -355,7 +624,7 @@ class RedditSimulationRunner:
|
||||
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||
|
||||
active_agents = self._get_active_agents_for_round(
|
||||
env, simulated_hour, round_num
|
||||
self.env, simulated_hour, round_num
|
||||
)
|
||||
|
||||
if not active_agents:
|
||||
@@ -366,7 +635,7 @@ class RedditSimulationRunner:
|
||||
for _, agent in active_agents
|
||||
}
|
||||
|
||||
await env.step(actions)
|
||||
await self.env.step(actions)
|
||||
|
||||
if (round_num + 1) % 10 == 0 or round_num == 0:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
@@ -376,12 +645,39 @@ class RedditSimulationRunner:
|
||||
f"- {len(active_agents)} agents active "
|
||||
f"- elapsed: {elapsed:.1f}s")
|
||||
|
||||
await env.close()
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print(f"\n模拟完成!")
|
||||
print(f"\n模拟循环完成!")
|
||||
print(f" - 总耗时: {total_elapsed:.1f}秒")
|
||||
print(f" - 数据库: {db_path}")
|
||||
|
||||
# 是否进入等待命令模式
|
||||
if self.wait_for_commands:
|
||||
print("\n" + "=" * 60)
|
||||
print("进入等待命令模式 - 环境保持运行")
|
||||
print("支持的命令: interview, batch_interview, close_env")
|
||||
print("=" * 60)
|
||||
|
||||
self.ipc_handler.update_status("alive")
|
||||
|
||||
# 等待命令循环
|
||||
try:
|
||||
while True:
|
||||
should_continue = await self.ipc_handler.process_commands()
|
||||
if not should_continue:
|
||||
break
|
||||
await asyncio.sleep(0.5) # 轮询间隔
|
||||
except KeyboardInterrupt:
|
||||
print("\n收到中断信号")
|
||||
except Exception as e:
|
||||
print(f"\n命令处理出错: {e}")
|
||||
|
||||
print("\n关闭环境...")
|
||||
|
||||
# 关闭环境
|
||||
self.ipc_handler.update_status("stopped")
|
||||
await self.env.close()
|
||||
|
||||
print("环境已关闭")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
@@ -399,6 +695,12 @@ async def main():
|
||||
default=None,
|
||||
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-wait',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='模拟完成后立即关闭环境,不进入等待命令模式'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -410,7 +712,10 @@ async def main():
|
||||
simulation_dir = os.path.dirname(args.config) or "."
|
||||
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||
|
||||
runner = RedditSimulationRunner(args.config)
|
||||
runner = RedditSimulationRunner(
|
||||
config_path=args.config,
|
||||
wait_for_commands=not args.no_wait
|
||||
)
|
||||
await runner.run(max_rounds=args.max_rounds)
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,15 @@
|
||||
OASIS Twitter模拟预设脚本
|
||||
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
|
||||
|
||||
功能特性:
|
||||
- 完成模拟后不立即关闭环境,进入等待命令模式
|
||||
- 支持通过IPC接收Interview命令
|
||||
- 支持单个Agent采访和批量采访
|
||||
- 支持远程关闭环境命令
|
||||
|
||||
使用方式:
|
||||
python run_twitter_simulation.py --config /path/to/simulation_config.json
|
||||
python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -13,8 +20,9 @@ import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
# 添加项目路径
|
||||
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -118,10 +126,261 @@ except ImportError as e:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# IPC相关常量
|
||||
IPC_COMMANDS_DIR = "ipc_commands"
|
||||
IPC_RESPONSES_DIR = "ipc_responses"
|
||||
ENV_STATUS_FILE = "env_status.json"
|
||||
|
||||
class CommandType:
|
||||
"""命令类型常量"""
|
||||
INTERVIEW = "interview"
|
||||
BATCH_INTERVIEW = "batch_interview"
|
||||
CLOSE_ENV = "close_env"
|
||||
|
||||
|
||||
class IPCHandler:
|
||||
"""IPC命令处理器"""
|
||||
|
||||
def __init__(self, simulation_dir: str, env, agent_graph):
|
||||
self.simulation_dir = simulation_dir
|
||||
self.env = env
|
||||
self.agent_graph = agent_graph
|
||||
self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR)
|
||||
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
|
||||
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
|
||||
self._running = True
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
def update_status(self, status: str):
|
||||
"""更新环境状态"""
|
||||
with open(self.status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def poll_command(self) -> Optional[Dict[str, Any]]:
|
||||
"""轮询获取待处理命令"""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 获取命令文件(按时间排序)
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.commands_dir, filename)
|
||||
command_files.append((filepath, os.path.getmtime(filepath)))
|
||||
|
||||
command_files.sort(key=lambda x: x[1])
|
||||
|
||||
for filepath, _ in command_files:
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
|
||||
"""发送响应"""
|
||||
response = {
|
||||
"command_id": command_id,
|
||||
"status": status,
|
||||
"result": result,
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool:
|
||||
"""
|
||||
处理单个Agent采访命令
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败
|
||||
"""
|
||||
try:
|
||||
# 获取Agent
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
|
||||
# 创建Interview动作
|
||||
interview_action = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
|
||||
# 执行Interview
|
||||
actions = {agent: interview_action}
|
||||
await self.env.step(actions)
|
||||
|
||||
# 从数据库获取结果
|
||||
result = self._get_interview_result(agent_id)
|
||||
|
||||
self.send_response(command_id, "completed", result=result)
|
||||
print(f" Interview完成: agent_id={agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f" Interview失败: agent_id={agent_id}, error={error_msg}")
|
||||
self.send_response(command_id, "failed", error=error_msg)
|
||||
return False
|
||||
|
||||
async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool:
|
||||
"""
|
||||
处理批量采访命令
|
||||
|
||||
Args:
|
||||
interviews: [{"agent_id": int, "prompt": str}, ...]
|
||||
"""
|
||||
try:
|
||||
# 构建动作字典
|
||||
actions = {}
|
||||
agent_prompts = {} # 记录每个agent的prompt
|
||||
|
||||
for interview in interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
prompt = interview.get("prompt", "")
|
||||
|
||||
try:
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
actions[agent] = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
agent_prompts[agent_id] = prompt
|
||||
except Exception as e:
|
||||
print(f" 警告: 无法获取Agent {agent_id}: {e}")
|
||||
|
||||
if not actions:
|
||||
self.send_response(command_id, "failed", error="没有有效的Agent")
|
||||
return False
|
||||
|
||||
# 执行批量Interview
|
||||
await self.env.step(actions)
|
||||
|
||||
# 获取所有结果
|
||||
results = {}
|
||||
for agent_id in agent_prompts.keys():
|
||||
result = self._get_interview_result(agent_id)
|
||||
results[agent_id] = result
|
||||
|
||||
self.send_response(command_id, "completed", result={
|
||||
"interviews_count": len(results),
|
||||
"results": results
|
||||
})
|
||||
print(f" 批量Interview完成: {len(results)} 个Agent")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f" 批量Interview失败: {error_msg}")
|
||||
self.send_response(command_id, "failed", error=error_msg)
|
||||
return False
|
||||
|
||||
def _get_interview_result(self, agent_id: int) -> Dict[str, Any]:
|
||||
"""从数据库获取最新的Interview结果"""
|
||||
db_path = os.path.join(self.simulation_dir, "twitter_simulation.db")
|
||||
|
||||
result = {
|
||||
"agent_id": agent_id,
|
||||
"response": None,
|
||||
"timestamp": None
|
||||
}
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return result
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询最新的Interview记录
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
WHERE action = ? AND user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""", (ActionType.INTERVIEW.value, agent_id))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
user_id, info_json, created_at = row
|
||||
try:
|
||||
info = json.loads(info_json) if info_json else {}
|
||||
result["response"] = info.get("response", info)
|
||||
result["timestamp"] = created_at
|
||||
except json.JSONDecodeError:
|
||||
result["response"] = info_json
|
||||
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f" 读取Interview结果失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def process_commands(self) -> bool:
|
||||
"""
|
||||
处理所有待处理命令
|
||||
|
||||
Returns:
|
||||
True 表示继续运行,False 表示应该退出
|
||||
"""
|
||||
command = self.poll_command()
|
||||
if not command:
|
||||
return True
|
||||
|
||||
command_id = command.get("command_id")
|
||||
command_type = command.get("command_type")
|
||||
args = command.get("args", {})
|
||||
|
||||
print(f"\n收到IPC命令: {command_type}, id={command_id}")
|
||||
|
||||
if command_type == CommandType.INTERVIEW:
|
||||
await self.handle_interview(
|
||||
command_id,
|
||||
args.get("agent_id", 0),
|
||||
args.get("prompt", "")
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.BATCH_INTERVIEW:
|
||||
await self.handle_batch_interview(
|
||||
command_id,
|
||||
args.get("interviews", [])
|
||||
)
|
||||
return True
|
||||
|
||||
elif command_type == CommandType.CLOSE_ENV:
|
||||
print("收到关闭环境命令")
|
||||
self.send_response(command_id, "completed", result={"message": "环境即将关闭"})
|
||||
return False
|
||||
|
||||
else:
|
||||
self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}")
|
||||
return True
|
||||
|
||||
|
||||
class TwitterSimulationRunner:
|
||||
"""Twitter模拟运行器"""
|
||||
|
||||
# Twitter可用动作
|
||||
# Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发)
|
||||
AVAILABLE_ACTIONS = [
|
||||
ActionType.CREATE_POST,
|
||||
ActionType.LIKE_POST,
|
||||
@@ -131,16 +390,21 @@ class TwitterSimulationRunner:
|
||||
ActionType.QUOTE_POST,
|
||||
]
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
def __init__(self, config_path: str, wait_for_commands: bool = True):
|
||||
"""
|
||||
初始化模拟运行器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径 (simulation_config.json)
|
||||
wait_for_commands: 模拟完成后是否等待命令(默认True)
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
self.simulation_dir = os.path.dirname(config_path)
|
||||
self.wait_for_commands = wait_for_commands
|
||||
self.env = None
|
||||
self.agent_graph = None
|
||||
self.ipc_handler = None
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
@@ -269,6 +533,7 @@ class TwitterSimulationRunner:
|
||||
print("OASIS Twitter模拟")
|
||||
print(f"配置文件: {self.config_path}")
|
||||
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
|
||||
print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}")
|
||||
print("=" * 60)
|
||||
|
||||
# 加载时间配置
|
||||
@@ -305,7 +570,7 @@ class TwitterSimulationRunner:
|
||||
print(f"错误: Profile文件不存在: {profile_path}")
|
||||
return
|
||||
|
||||
agent_graph = await generate_twitter_agent_graph(
|
||||
self.agent_graph = await generate_twitter_agent_graph(
|
||||
profile_path=profile_path,
|
||||
model=model,
|
||||
available_actions=self.AVAILABLE_ACTIONS,
|
||||
@@ -319,16 +584,20 @@ class TwitterSimulationRunner:
|
||||
|
||||
# 创建环境
|
||||
print("创建OASIS环境...")
|
||||
env = oasis.make(
|
||||
agent_graph=agent_graph,
|
||||
self.env = oasis.make(
|
||||
agent_graph=self.agent_graph,
|
||||
platform=oasis.DefaultPlatformType.TWITTER,
|
||||
database_path=db_path,
|
||||
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
|
||||
)
|
||||
|
||||
await env.reset()
|
||||
await self.env.reset()
|
||||
print("环境初始化完成\n")
|
||||
|
||||
# 初始化IPC处理器
|
||||
self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph)
|
||||
self.ipc_handler.update_status("running")
|
||||
|
||||
# 执行初始事件
|
||||
event_config = self.config.get("event_config", {})
|
||||
initial_posts = event_config.get("initial_posts", [])
|
||||
@@ -340,7 +609,7 @@ class TwitterSimulationRunner:
|
||||
agent_id = post.get("poster_agent_id", 0)
|
||||
content = post.get("content", "")
|
||||
try:
|
||||
agent = env.agent_graph.get_agent(agent_id)
|
||||
agent = self.env.agent_graph.get_agent(agent_id)
|
||||
initial_actions[agent] = ManualAction(
|
||||
action_type=ActionType.CREATE_POST,
|
||||
action_args={"content": content}
|
||||
@@ -349,7 +618,7 @@ class TwitterSimulationRunner:
|
||||
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
|
||||
|
||||
if initial_actions:
|
||||
await env.step(initial_actions)
|
||||
await self.env.step(initial_actions)
|
||||
print(f" 已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
@@ -364,7 +633,7 @@ class TwitterSimulationRunner:
|
||||
|
||||
# 获取本轮激活的Agent
|
||||
active_agents = self._get_active_agents_for_round(
|
||||
env, simulated_hour, round_num
|
||||
self.env, simulated_hour, round_num
|
||||
)
|
||||
|
||||
if not active_agents:
|
||||
@@ -377,7 +646,7 @@ class TwitterSimulationRunner:
|
||||
}
|
||||
|
||||
# 执行动作
|
||||
await env.step(actions)
|
||||
await self.env.step(actions)
|
||||
|
||||
# 打印进度
|
||||
if (round_num + 1) % 10 == 0 or round_num == 0:
|
||||
@@ -388,13 +657,39 @@ class TwitterSimulationRunner:
|
||||
f"- {len(active_agents)} agents active "
|
||||
f"- elapsed: {elapsed:.1f}s")
|
||||
|
||||
# 关闭环境
|
||||
await env.close()
|
||||
|
||||
total_elapsed = (datetime.now() - start_time).total_seconds()
|
||||
print(f"\n模拟完成!")
|
||||
print(f"\n模拟循环完成!")
|
||||
print(f" - 总耗时: {total_elapsed:.1f}秒")
|
||||
print(f" - 数据库: {db_path}")
|
||||
|
||||
# 是否进入等待命令模式
|
||||
if self.wait_for_commands:
|
||||
print("\n" + "=" * 60)
|
||||
print("进入等待命令模式 - 环境保持运行")
|
||||
print("支持的命令: interview, batch_interview, close_env")
|
||||
print("=" * 60)
|
||||
|
||||
self.ipc_handler.update_status("alive")
|
||||
|
||||
# 等待命令循环
|
||||
try:
|
||||
while True:
|
||||
should_continue = await self.ipc_handler.process_commands()
|
||||
if not should_continue:
|
||||
break
|
||||
await asyncio.sleep(0.5) # 轮询间隔
|
||||
except KeyboardInterrupt:
|
||||
print("\n收到中断信号")
|
||||
except Exception as e:
|
||||
print(f"\n命令处理出错: {e}")
|
||||
|
||||
print("\n关闭环境...")
|
||||
|
||||
# 关闭环境
|
||||
self.ipc_handler.update_status("stopped")
|
||||
await self.env.close()
|
||||
|
||||
print("环境已关闭")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
@@ -412,6 +707,12 @@ async def main():
|
||||
default=None,
|
||||
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-wait',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='模拟完成后立即关闭环境,不进入等待命令模式'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -423,10 +724,12 @@ async def main():
|
||||
simulation_dir = os.path.dirname(args.config) or "."
|
||||
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||
|
||||
runner = TwitterSimulationRunner(args.config)
|
||||
runner = TwitterSimulationRunner(
|
||||
config_path=args.config,
|
||||
wait_for_commands=not args.no_wait
|
||||
)
|
||||
await runner.run(max_rounds=args.max_rounds)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user