Implement real-time profile retrieval and saving in simulation API
- Added a new endpoint to retrieve real-time agent profiles during simulation, allowing users to monitor progress without going through the SimulationManager. - Enhanced the profile generation process to support real-time saving of generated profiles to specified file formats (JSON for Reddit, CSV for Twitter). - Updated the simulation configuration generator to assign appropriate agents to initial posts based on their types, improving the relevance of generated content. - Improved error handling and logging for better traceability during profile generation and retrieval processes.
This commit is contained in:
@@ -853,7 +853,9 @@ class OasisProfileGenerator:
|
||||
use_llm: bool = True,
|
||||
progress_callback: Optional[callable] = None,
|
||||
graph_id: Optional[str] = None,
|
||||
parallel_count: int = 5
|
||||
parallel_count: int = 5,
|
||||
realtime_output_path: Optional[str] = None,
|
||||
output_platform: str = "reddit"
|
||||
) -> List[OasisAgentProfile]:
|
||||
"""
|
||||
批量从实体生成Agent Profile(支持并行生成)
|
||||
@@ -864,6 +866,8 @@ class OasisProfileGenerator:
|
||||
progress_callback: 进度回调函数 (current, total, message)
|
||||
graph_id: 图谱ID,用于Zep检索获取更丰富上下文
|
||||
parallel_count: 并行生成数量,默认5
|
||||
realtime_output_path: 实时写入的文件路径(如果提供,每生成一个就写入一次)
|
||||
output_platform: 输出平台格式 ("reddit" 或 "twitter")
|
||||
|
||||
Returns:
|
||||
Agent Profile列表
|
||||
@@ -880,6 +884,37 @@ class OasisProfileGenerator:
|
||||
completed_count = [0] # 使用列表以便在闭包中修改
|
||||
lock = Lock()
|
||||
|
||||
# 实时写入文件的辅助函数
|
||||
def save_profiles_realtime():
|
||||
"""实时保存已生成的 profiles 到文件"""
|
||||
if not realtime_output_path:
|
||||
return
|
||||
|
||||
with lock:
|
||||
# 过滤出已生成的 profiles
|
||||
existing_profiles = [p for p in profiles if p is not None]
|
||||
if not existing_profiles:
|
||||
return
|
||||
|
||||
try:
|
||||
if output_platform == "reddit":
|
||||
# Reddit JSON 格式
|
||||
profiles_data = [p.to_reddit_format() for p in existing_profiles]
|
||||
with open(realtime_output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(profiles_data, f, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
# Twitter CSV 格式
|
||||
import csv
|
||||
profiles_data = [p.to_twitter_format() for p in existing_profiles]
|
||||
if profiles_data:
|
||||
fieldnames = list(profiles_data[0].keys())
|
||||
with open(realtime_output_path, 'w', encoding='utf-8', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(profiles_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"实时保存 profiles 失败: {e}")
|
||||
|
||||
def generate_single_profile(idx: int, entity: EntityNode) -> tuple:
|
||||
"""生成单个profile的工作函数"""
|
||||
entity_type = entity.get_entity_type() or "Entity"
|
||||
@@ -936,6 +971,9 @@ class OasisProfileGenerator:
|
||||
completed_count[0] += 1
|
||||
current = completed_count[0]
|
||||
|
||||
# 实时写入文件
|
||||
save_profiles_realtime()
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
current,
|
||||
@@ -961,6 +999,8 @@ class OasisProfileGenerator:
|
||||
source_entity_uuid=entity.uuid,
|
||||
source_entity_type=entity_type,
|
||||
)
|
||||
# 实时写入文件(即使是备用人设)
|
||||
save_profiles_realtime()
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"人设生成完成!共生成 {len([p for p in profiles if p])} 个Agent")
|
||||
|
||||
Reference in New Issue
Block a user