Implement real-time profile retrieval and saving in simulation API

- Added a new endpoint to retrieve real-time agent profiles during simulation, allowing users to monitor progress without going through the SimulationManager.
- Enhanced the profile generation process to support real-time saving of generated profiles to specified file formats (JSON for Reddit, CSV for Twitter).
- Updated the simulation configuration generator to assign appropriate agents to initial posts based on their types, improving the relevance of generated content.
- Improved error handling and logging for better traceability during profile generation and retrieval processes.
This commit is contained in:
666ghj
2025-12-04 19:02:10 +08:00
parent 39253b3213
commit 88676e8207
4 changed files with 292 additions and 7 deletions

View File

@@ -853,7 +853,9 @@ class OasisProfileGenerator:
use_llm: bool = True,
progress_callback: Optional[callable] = None,
graph_id: Optional[str] = None,
parallel_count: int = 5
parallel_count: int = 5,
realtime_output_path: Optional[str] = None,
output_platform: str = "reddit"
) -> List[OasisAgentProfile]:
"""
批量从实体生成Agent Profile支持并行生成
@@ -864,6 +866,8 @@ class OasisProfileGenerator:
progress_callback: 进度回调函数 (current, total, message)
graph_id: 图谱ID用于Zep检索获取更丰富上下文
parallel_count: 并行生成数量默认5
realtime_output_path: 实时写入的文件路径(如果提供,每生成一个就写入一次)
output_platform: 输出平台格式 ("reddit""twitter")
Returns:
Agent Profile列表
@@ -880,6 +884,37 @@ class OasisProfileGenerator:
completed_count = [0] # 使用列表以便在闭包中修改
lock = Lock()
# 实时写入文件的辅助函数
def save_profiles_realtime():
"""实时保存已生成的 profiles 到文件"""
if not realtime_output_path:
return
with lock:
# 过滤出已生成的 profiles
existing_profiles = [p for p in profiles if p is not None]
if not existing_profiles:
return
try:
if output_platform == "reddit":
# Reddit JSON 格式
profiles_data = [p.to_reddit_format() for p in existing_profiles]
with open(realtime_output_path, 'w', encoding='utf-8') as f:
json.dump(profiles_data, f, ensure_ascii=False, indent=2)
else:
# Twitter CSV 格式
import csv
profiles_data = [p.to_twitter_format() for p in existing_profiles]
if profiles_data:
fieldnames = list(profiles_data[0].keys())
with open(realtime_output_path, 'w', encoding='utf-8', newline='') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(profiles_data)
except Exception as e:
logger.warning(f"实时保存 profiles 失败: {e}")
def generate_single_profile(idx: int, entity: EntityNode) -> tuple:
"""生成单个profile的工作函数"""
entity_type = entity.get_entity_type() or "Entity"
@@ -936,6 +971,9 @@ class OasisProfileGenerator:
completed_count[0] += 1
current = completed_count[0]
# 实时写入文件
save_profiles_realtime()
if progress_callback:
progress_callback(
current,
@@ -961,6 +999,8 @@ class OasisProfileGenerator:
source_entity_uuid=entity.uuid,
source_entity_type=entity_type,
)
# 实时写入文件(即使是备用人设)
save_profiles_realtime()
print(f"\n{'='*60}")
print(f"人设生成完成!共生成 {len([p for p in profiles if p])} 个Agent")

View File

@@ -292,7 +292,7 @@ class SimulationConfigGenerator:
# ========== 步骤2: 生成事件配置 ==========
report_progress(2, "生成事件配置和热点话题...")
event_config_result = self._generate_event_config(context, simulation_requirement)
event_config_result = self._generate_event_config(context, simulation_requirement, entities)
event_config = self._parse_event_config(event_config_result)
reasoning_parts.append(f"事件配置: {event_config_result.get('reasoning', '成功')}")
@@ -318,6 +318,12 @@ class SimulationConfigGenerator:
reasoning_parts.append(f"Agent配置: 成功生成 {len(all_agent_configs)}")
# ========== 为初始帖子分配发布者 Agent ==========
logger.info("为初始帖子分配合适的发布者 Agent...")
event_config = self._assign_initial_post_agents(event_config, all_agent_configs)
assigned_count = len([p for p in event_config.initial_posts if p.get("poster_agent_id") is not None])
reasoning_parts.append(f"初始帖子分配: {assigned_count} 个帖子已分配发布者")
# ========== 最后一步: 生成平台配置 ==========
report_progress(total_steps, "生成平台配置...")
twitter_config = None
@@ -583,32 +589,63 @@ class SimulationConfigGenerator:
peak_activity_multiplier=1.5
)
def _generate_event_config(self, context: str, simulation_requirement: str) -> Dict[str, Any]:
def _generate_event_config(
self,
context: str,
simulation_requirement: str,
entities: List[EntityNode]
) -> Dict[str, Any]:
"""生成事件配置"""
# 获取可用的实体类型列表,供 LLM 参考
entity_types_available = list(set(
e.get_entity_type() or "Unknown" for e in entities
))
# 为每种类型列出代表性实体名称
type_examples = {}
for e in entities:
etype = e.get_entity_type() or "Unknown"
if etype not in type_examples:
type_examples[etype] = []
if len(type_examples[etype]) < 3:
type_examples[etype].append(e.name)
type_info = "\n".join([
f"- {t}: {', '.join(examples)}"
for t, examples in type_examples.items()
])
prompt = f"""基于以下模拟需求,生成事件配置。
模拟需求: {simulation_requirement}
{context[:3000]}
## 可用实体类型及示例
{type_info}
## 任务
请生成事件配置JSON
- 提取热点话题关键词
- 描述舆论发展方向
- 设计初始帖子内容
- 设计初始帖子内容**每个帖子必须指定 poster_type发布者类型**
**重要**: poster_type 必须从上面的"可用实体类型"中选择,这样初始帖子才能分配给合适的 Agent 发布。
例如:官方声明应由 Official/University 类型发布,新闻由 MediaOutlet 发布,学生观点由 Student 发布。
返回JSON格式不要markdown
{{
"hot_topics": ["关键词1", "关键词2", ...],
"narrative_direction": "<舆论发展方向描述>",
"initial_posts": [
{{"content": "帖子内容", "poster_type": "MediaOutlet"}},
{{"content": "帖子内容", "poster_type": "实体类型(必须从可用类型中选择)"}},
...
],
"reasoning": "<简要说明>"
}}"""
system_prompt = "你是舆论分析专家。返回纯JSON格式。"
system_prompt = "你是舆论分析专家。返回纯JSON格式。注意 poster_type 必须精确匹配可用实体类型。"
try:
return self._call_llm_with_retry(prompt, system_prompt)
@@ -630,6 +667,91 @@ class SimulationConfigGenerator:
narrative_direction=result.get("narrative_direction", "")
)
def _assign_initial_post_agents(
self,
event_config: EventConfig,
agent_configs: List[AgentActivityConfig]
) -> EventConfig:
"""
为初始帖子分配合适的发布者 Agent
根据每个帖子的 poster_type 匹配最合适的 agent_id
"""
if not event_config.initial_posts:
return event_config
# 按实体类型建立 agent 索引
agents_by_type: Dict[str, List[AgentActivityConfig]] = {}
for agent in agent_configs:
etype = agent.entity_type.lower()
if etype not in agents_by_type:
agents_by_type[etype] = []
agents_by_type[etype].append(agent)
# 类型映射表(处理 LLM 可能输出的不同格式)
type_aliases = {
"official": ["official", "university", "governmentagency", "government"],
"university": ["university", "official"],
"mediaoutlet": ["mediaoutlet", "media"],
"student": ["student", "person"],
"professor": ["professor", "expert", "teacher"],
"alumni": ["alumni", "person"],
"organization": ["organization", "ngo", "company", "group"],
"person": ["person", "student", "alumni"],
}
# 记录每种类型已使用的 agent 索引,避免重复使用同一个 agent
used_indices: Dict[str, int] = {}
updated_posts = []
for post in event_config.initial_posts:
poster_type = post.get("poster_type", "").lower()
content = post.get("content", "")
# 尝试找到匹配的 agent
matched_agent_id = None
# 1. 直接匹配
if poster_type in agents_by_type:
agents = agents_by_type[poster_type]
idx = used_indices.get(poster_type, 0) % len(agents)
matched_agent_id = agents[idx].agent_id
used_indices[poster_type] = idx + 1
else:
# 2. 使用别名匹配
for alias_key, aliases in type_aliases.items():
if poster_type in aliases or alias_key == poster_type:
for alias in aliases:
if alias in agents_by_type:
agents = agents_by_type[alias]
idx = used_indices.get(alias, 0) % len(agents)
matched_agent_id = agents[idx].agent_id
used_indices[alias] = idx + 1
break
if matched_agent_id is not None:
break
# 3. 如果仍未找到,使用影响力最高的 agent
if matched_agent_id is None:
logger.warning(f"未找到类型 '{poster_type}' 的匹配 Agent使用影响力最高的 Agent")
if agent_configs:
# 按影响力排序,选择影响力最高的
sorted_agents = sorted(agent_configs, key=lambda a: a.influence_weight, reverse=True)
matched_agent_id = sorted_agents[0].agent_id
else:
matched_agent_id = 0
updated_posts.append({
"content": content,
"poster_type": post.get("poster_type", "Unknown"),
"poster_agent_id": matched_agent_id
})
logger.info(f"初始帖子分配: poster_type='{poster_type}' -> agent_id={matched_agent_id}")
event_config.initial_posts = updated_posts
return event_config
def _generate_agent_configs_batch(
self,
context: str,

View File

@@ -324,17 +324,30 @@ class SimulationManager:
item_name=msg
)
# 设置实时保存的文件路径(优先使用 Reddit JSON 格式)
realtime_output_path = None
realtime_platform = "reddit"
if state.enable_reddit:
realtime_output_path = os.path.join(sim_dir, "reddit_profiles.json")
realtime_platform = "reddit"
elif state.enable_twitter:
realtime_output_path = os.path.join(sim_dir, "twitter_profiles.csv")
realtime_platform = "twitter"
profiles = generator.generate_profiles_from_entities(
entities=filtered.entities,
use_llm=use_llm_for_profiles,
progress_callback=profile_progress,
graph_id=state.graph_id, # 传入graph_id用于Zep检索
parallel_count=parallel_profile_count # 并行生成数量
parallel_count=parallel_profile_count, # 并行生成数量
realtime_output_path=realtime_output_path, # 实时保存路径
output_platform=realtime_platform # 输出格式
)
state.profiles_count = len(profiles)
# 保存Profile文件注意Twitter使用CSV格式Reddit使用JSON格式
# Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性
if progress_callback:
progress_callback(
"generating_profiles", 95,