Enhance backend functionality with OASIS simulation features
- Updated README.md to include new simulation scripts and configuration details for OASIS, including API retry mechanisms and environment variable settings. - Added simulation management and configuration generation services to streamline the simulation process across Twitter and Reddit platforms. - Introduced new API routes for simulation-related operations, including entity retrieval and simulation status management. - Implemented a robust retry mechanism for external API calls to improve system stability. - Enhanced task management model to include detailed progress tracking. - Added logging capabilities for action tracking during simulations. - Included new scripts for running parallel simulations and testing profile formats.
This commit is contained in:
561
backend/app/services/oasis_profile_generator.py
Normal file
561
backend/app/services/oasis_profile_generator.py
Normal file
@@ -0,0 +1,561 @@
|
||||
"""
|
||||
OASIS Agent Profile生成器
|
||||
将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
||||
logger = get_logger('mirofish.oasis_profile')
|
||||
|
||||
|
||||
@dataclass
|
||||
class OasisAgentProfile:
|
||||
"""OASIS Agent Profile数据结构"""
|
||||
# 通用字段
|
||||
user_id: int
|
||||
user_name: str
|
||||
name: str
|
||||
bio: str
|
||||
persona: str
|
||||
|
||||
# 可选字段 - Reddit风格
|
||||
karma: int = 1000
|
||||
|
||||
# 可选字段 - Twitter风格
|
||||
friend_count: int = 100
|
||||
follower_count: int = 150
|
||||
statuses_count: int = 500
|
||||
|
||||
# 额外人设信息
|
||||
age: Optional[int] = None
|
||||
gender: Optional[str] = None
|
||||
mbti: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
profession: Optional[str] = None
|
||||
interested_topics: List[str] = field(default_factory=list)
|
||||
|
||||
# 来源实体信息
|
||||
source_entity_uuid: Optional[str] = None
|
||||
source_entity_type: Optional[str] = None
|
||||
|
||||
created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
|
||||
|
||||
def to_reddit_format(self) -> Dict[str, Any]:
|
||||
"""转换为Reddit平台格式"""
|
||||
profile = {
|
||||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"name": self.name,
|
||||
"bio": self.bio,
|
||||
"persona": self.persona,
|
||||
"karma": self.karma,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
# 添加额外人设信息(如果有)
|
||||
if self.age:
|
||||
profile["age"] = self.age
|
||||
if self.gender:
|
||||
profile["gender"] = self.gender
|
||||
if self.mbti:
|
||||
profile["mbti"] = self.mbti
|
||||
if self.country:
|
||||
profile["country"] = self.country
|
||||
if self.profession:
|
||||
profile["profession"] = self.profession
|
||||
if self.interested_topics:
|
||||
profile["interested_topics"] = self.interested_topics
|
||||
|
||||
return profile
|
||||
|
||||
def to_twitter_format(self) -> Dict[str, Any]:
|
||||
"""转换为Twitter平台格式"""
|
||||
profile = {
|
||||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"name": self.name,
|
||||
"bio": self.bio,
|
||||
"persona": self.persona,
|
||||
"friend_count": self.friend_count,
|
||||
"follower_count": self.follower_count,
|
||||
"statuses_count": self.statuses_count,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
# 添加额外人设信息
|
||||
if self.age:
|
||||
profile["age"] = self.age
|
||||
if self.gender:
|
||||
profile["gender"] = self.gender
|
||||
if self.mbti:
|
||||
profile["mbti"] = self.mbti
|
||||
if self.country:
|
||||
profile["country"] = self.country
|
||||
if self.profession:
|
||||
profile["profession"] = self.profession
|
||||
if self.interested_topics:
|
||||
profile["interested_topics"] = self.interested_topics
|
||||
|
||||
return profile
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为完整字典格式"""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"name": self.name,
|
||||
"bio": self.bio,
|
||||
"persona": self.persona,
|
||||
"karma": self.karma,
|
||||
"friend_count": self.friend_count,
|
||||
"follower_count": self.follower_count,
|
||||
"statuses_count": self.statuses_count,
|
||||
"age": self.age,
|
||||
"gender": self.gender,
|
||||
"mbti": self.mbti,
|
||||
"country": self.country,
|
||||
"profession": self.profession,
|
||||
"interested_topics": self.interested_topics,
|
||||
"source_entity_uuid": self.source_entity_uuid,
|
||||
"source_entity_type": self.source_entity_type,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
|
||||
class OasisProfileGenerator:
|
||||
"""
|
||||
OASIS Profile生成器
|
||||
|
||||
将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile
|
||||
"""
|
||||
|
||||
# MBTI类型列表
|
||||
MBTI_TYPES = [
|
||||
"INTJ", "INTP", "ENTJ", "ENTP",
|
||||
"INFJ", "INFP", "ENFJ", "ENFP",
|
||||
"ISTJ", "ISFJ", "ESTJ", "ESFJ",
|
||||
"ISTP", "ISFP", "ESTP", "ESFP"
|
||||
]
|
||||
|
||||
# 常见国家列表
|
||||
COUNTRIES = [
|
||||
"China", "US", "UK", "Japan", "Germany", "France",
|
||||
"Canada", "Australia", "Brazil", "India", "South Korea"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
model_name: Optional[str] = None
|
||||
):
|
||||
self.api_key = api_key or Config.LLM_API_KEY
|
||||
self.base_url = base_url or Config.LLM_BASE_URL
|
||||
self.model_name = model_name or Config.LLM_MODEL_NAME
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
def generate_profile_from_entity(
|
||||
self,
|
||||
entity: EntityNode,
|
||||
user_id: int,
|
||||
use_llm: bool = True
|
||||
) -> OasisAgentProfile:
|
||||
"""
|
||||
从Zep实体生成OASIS Agent Profile
|
||||
|
||||
Args:
|
||||
entity: Zep实体节点
|
||||
user_id: 用户ID(用于OASIS)
|
||||
use_llm: 是否使用LLM生成详细人设
|
||||
|
||||
Returns:
|
||||
OasisAgentProfile
|
||||
"""
|
||||
entity_type = entity.get_entity_type() or "Entity"
|
||||
|
||||
# 基础信息
|
||||
name = entity.name
|
||||
user_name = self._generate_username(name)
|
||||
|
||||
# 构建上下文信息
|
||||
context = self._build_entity_context(entity)
|
||||
|
||||
if use_llm:
|
||||
# 使用LLM生成详细人设
|
||||
profile_data = self._generate_profile_with_llm(
|
||||
entity_name=name,
|
||||
entity_type=entity_type,
|
||||
entity_summary=entity.summary,
|
||||
entity_attributes=entity.attributes,
|
||||
context=context
|
||||
)
|
||||
else:
|
||||
# 使用规则生成基础人设
|
||||
profile_data = self._generate_profile_rule_based(
|
||||
entity_name=name,
|
||||
entity_type=entity_type,
|
||||
entity_summary=entity.summary,
|
||||
entity_attributes=entity.attributes
|
||||
)
|
||||
|
||||
return OasisAgentProfile(
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
name=name,
|
||||
bio=profile_data.get("bio", f"{entity_type}: {name}"),
|
||||
persona=profile_data.get("persona", entity.summary or f"A {entity_type} named {name}."),
|
||||
karma=profile_data.get("karma", random.randint(500, 5000)),
|
||||
friend_count=profile_data.get("friend_count", random.randint(50, 500)),
|
||||
follower_count=profile_data.get("follower_count", random.randint(100, 1000)),
|
||||
statuses_count=profile_data.get("statuses_count", random.randint(100, 2000)),
|
||||
age=profile_data.get("age"),
|
||||
gender=profile_data.get("gender"),
|
||||
mbti=profile_data.get("mbti"),
|
||||
country=profile_data.get("country"),
|
||||
profession=profile_data.get("profession"),
|
||||
interested_topics=profile_data.get("interested_topics", []),
|
||||
source_entity_uuid=entity.uuid,
|
||||
source_entity_type=entity_type,
|
||||
)
|
||||
|
||||
def _generate_username(self, name: str) -> str:
|
||||
"""生成用户名"""
|
||||
# 移除特殊字符,转换为小写
|
||||
username = name.lower().replace(" ", "_")
|
||||
username = ''.join(c for c in username if c.isalnum() or c == '_')
|
||||
|
||||
# 添加随机后缀避免重复
|
||||
suffix = random.randint(100, 999)
|
||||
return f"{username}_{suffix}"
|
||||
|
||||
def _build_entity_context(self, entity: EntityNode) -> str:
|
||||
"""构建实体的上下文信息"""
|
||||
context_parts = []
|
||||
|
||||
# 添加相关边信息
|
||||
if entity.related_edges:
|
||||
relationships = []
|
||||
for edge in entity.related_edges[:10]: # 最多取10条
|
||||
if edge.get("fact"):
|
||||
relationships.append(edge["fact"])
|
||||
|
||||
if relationships:
|
||||
context_parts.append("Related facts:\n" + "\n".join(f"- {r}" for r in relationships))
|
||||
|
||||
# 添加关联节点信息
|
||||
if entity.related_nodes:
|
||||
related_names = [n["name"] for n in entity.related_nodes[:5]]
|
||||
if related_names:
|
||||
context_parts.append(f"Related to: {', '.join(related_names)}")
|
||||
|
||||
return "\n\n".join(context_parts)
|
||||
|
||||
def _generate_profile_with_llm(
|
||||
self,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
entity_summary: str,
|
||||
entity_attributes: Dict[str, Any],
|
||||
context: str
|
||||
) -> Dict[str, Any]:
|
||||
"""使用LLM生成详细人设"""
|
||||
|
||||
prompt = f"""Based on the following entity information, generate a detailed social media user profile for simulation purposes.
|
||||
|
||||
Entity Information:
|
||||
- Name: {entity_name}
|
||||
- Type: {entity_type}
|
||||
- Summary: {entity_summary}
|
||||
- Attributes: {json.dumps(entity_attributes, ensure_ascii=False)}
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Generate a JSON object with the following fields:
|
||||
{{
|
||||
"bio": "A short bio (max 150 chars) suitable for social media",
|
||||
"persona": "A detailed persona description (2-3 sentences) describing personality, interests, and behavior patterns",
|
||||
"age": <integer between 18-65, or null if not applicable>,
|
||||
"gender": "<male/female/other, or null if not applicable>",
|
||||
"mbti": "<MBTI type like INTJ, ENFP, etc., or null>",
|
||||
"country": "<country name, or null>",
|
||||
"profession": "<profession/occupation, or null>",
|
||||
"interested_topics": ["topic1", "topic2", ...]
|
||||
}}
|
||||
|
||||
Important:
|
||||
- The profile should be consistent with the entity type and context
|
||||
- Make the persona feel realistic and suitable for social media simulation
|
||||
- If the entity is an organization, institution, or non-person, adapt the profile accordingly (e.g., as an official account)
|
||||
- Return ONLY the JSON object, no additional text"""
|
||||
|
||||
try:
|
||||
# 使用重试机制调用LLM API
|
||||
from ..utils.retry import RetryableAPIClient
|
||||
|
||||
retry_client = RetryableAPIClient(max_retries=3, initial_delay=1.0)
|
||||
|
||||
def call_llm():
|
||||
return self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a profile generator for social media simulation. Generate realistic user profiles based on entity information."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
response = retry_client.call_with_retry(call_llm)
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM生成人设失败(已重试): {str(e)}, 使用规则生成")
|
||||
return self._generate_profile_rule_based(
|
||||
entity_name, entity_type, entity_summary, entity_attributes
|
||||
)
|
||||
|
||||
def _generate_profile_rule_based(
|
||||
self,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
entity_summary: str,
|
||||
entity_attributes: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""使用规则生成基础人设"""
|
||||
|
||||
# 根据实体类型生成不同的人设
|
||||
entity_type_lower = entity_type.lower()
|
||||
|
||||
if entity_type_lower in ["student", "alumni"]:
|
||||
return {
|
||||
"bio": f"{entity_type} with interests in academics and social issues.",
|
||||
"persona": f"{entity_name} is a {entity_type.lower()} who is actively engaged in academic and social discussions. They enjoy sharing perspectives and connecting with peers.",
|
||||
"age": random.randint(18, 30),
|
||||
"gender": random.choice(["male", "female"]),
|
||||
"mbti": random.choice(self.MBTI_TYPES),
|
||||
"country": random.choice(self.COUNTRIES),
|
||||
"profession": "Student",
|
||||
"interested_topics": ["Education", "Social Issues", "Technology"],
|
||||
}
|
||||
|
||||
elif entity_type_lower in ["publicfigure", "expert", "faculty"]:
|
||||
return {
|
||||
"bio": f"Expert and thought leader in their field.",
|
||||
"persona": f"{entity_name} is a recognized {entity_type.lower()} who shares insights and opinions on important matters. They are known for their expertise and influence in public discourse.",
|
||||
"age": random.randint(35, 60),
|
||||
"gender": random.choice(["male", "female"]),
|
||||
"mbti": random.choice(["ENTJ", "INTJ", "ENTP", "INTP"]),
|
||||
"country": random.choice(self.COUNTRIES),
|
||||
"profession": entity_attributes.get("occupation", "Expert"),
|
||||
"interested_topics": ["Politics", "Economics", "Culture & Society"],
|
||||
}
|
||||
|
||||
elif entity_type_lower in ["mediaoutlet", "socialmediaplatform"]:
|
||||
return {
|
||||
"bio": f"Official account for {entity_name}. News and updates.",
|
||||
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
|
||||
"profession": "Media",
|
||||
"interested_topics": ["General News", "Current Events", "Public Affairs"],
|
||||
}
|
||||
|
||||
elif entity_type_lower in ["university", "governmentagency", "ngo", "organization"]:
|
||||
return {
|
||||
"bio": f"Official account of {entity_name}.",
|
||||
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
|
||||
"profession": entity_type,
|
||||
"interested_topics": ["Public Policy", "Community", "Official Announcements"],
|
||||
}
|
||||
|
||||
else:
|
||||
# 默认人设
|
||||
return {
|
||||
"bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}",
|
||||
"persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.",
|
||||
"age": random.randint(25, 50),
|
||||
"gender": random.choice(["male", "female"]),
|
||||
"mbti": random.choice(self.MBTI_TYPES),
|
||||
"country": random.choice(self.COUNTRIES),
|
||||
"profession": entity_type,
|
||||
"interested_topics": ["General", "Social Issues"],
|
||||
}
|
||||
|
||||
def generate_profiles_from_entities(
|
||||
self,
|
||||
entities: List[EntityNode],
|
||||
use_llm: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> List[OasisAgentProfile]:
|
||||
"""
|
||||
批量从实体生成Agent Profile
|
||||
|
||||
Args:
|
||||
entities: 实体列表
|
||||
use_llm: 是否使用LLM生成详细人设
|
||||
progress_callback: 进度回调函数 (current, total, message)
|
||||
|
||||
Returns:
|
||||
Agent Profile列表
|
||||
"""
|
||||
profiles = []
|
||||
total = len(entities)
|
||||
|
||||
for idx, entity in enumerate(entities):
|
||||
if progress_callback:
|
||||
progress_callback(idx + 1, total, f"生成 {entity.name} 的人设...")
|
||||
|
||||
try:
|
||||
profile = self.generate_profile_from_entity(
|
||||
entity=entity,
|
||||
user_id=idx,
|
||||
use_llm=use_llm
|
||||
)
|
||||
profiles.append(profile)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}")
|
||||
# 创建一个基础profile
|
||||
profiles.append(OasisAgentProfile(
|
||||
user_id=idx,
|
||||
user_name=self._generate_username(entity.name),
|
||||
name=entity.name,
|
||||
bio=f"{entity.get_entity_type() or 'Entity'}: {entity.name}",
|
||||
persona=entity.summary or f"A participant in social discussions.",
|
||||
source_entity_uuid=entity.uuid,
|
||||
source_entity_type=entity.get_entity_type(),
|
||||
))
|
||||
|
||||
return profiles
|
||||
|
||||
def save_profiles(
|
||||
self,
|
||||
profiles: List[OasisAgentProfile],
|
||||
file_path: str,
|
||||
platform: str = "reddit"
|
||||
):
|
||||
"""
|
||||
保存Profile到文件(根据平台选择正确格式)
|
||||
|
||||
OASIS平台格式要求:
|
||||
- Twitter: CSV格式
|
||||
- Reddit: JSON格式
|
||||
|
||||
Args:
|
||||
profiles: Profile列表
|
||||
file_path: 文件路径
|
||||
platform: 平台类型 ("reddit" 或 "twitter")
|
||||
"""
|
||||
if platform == "twitter":
|
||||
self._save_twitter_csv(profiles, file_path)
|
||||
else:
|
||||
self._save_reddit_json(profiles, file_path)
|
||||
|
||||
def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str):
|
||||
"""
|
||||
保存Twitter Profile为CSV格式
|
||||
|
||||
OASIS Twitter要求的CSV字段:
|
||||
user_id, user_name, name, bio, friend_count, follower_count, statuses_count, created_at
|
||||
"""
|
||||
import csv
|
||||
|
||||
# 确保文件扩展名是.csv
|
||||
if not file_path.endswith('.csv'):
|
||||
file_path = file_path.replace('.json', '.csv')
|
||||
|
||||
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.writer(f)
|
||||
|
||||
# 写入表头
|
||||
headers = ['user_id', 'user_name', 'name', 'bio', 'friend_count',
|
||||
'follower_count', 'statuses_count', 'created_at']
|
||||
writer.writerow(headers)
|
||||
|
||||
# 写入数据行
|
||||
for profile in profiles:
|
||||
# bio需要处理换行符和逗号
|
||||
bio = profile.bio.replace('\n', ' ').replace('\r', ' ')
|
||||
row = [
|
||||
profile.user_id,
|
||||
profile.user_name,
|
||||
profile.name,
|
||||
bio,
|
||||
profile.friend_count,
|
||||
profile.follower_count,
|
||||
profile.statuses_count,
|
||||
profile.created_at
|
||||
]
|
||||
writer.writerow(row)
|
||||
|
||||
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (CSV格式)")
|
||||
|
||||
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
|
||||
"""
|
||||
保存Reddit Profile为JSON格式
|
||||
|
||||
OASIS Reddit支持两种JSON格式:
|
||||
1. 基础格式: user_id, user_name, name, bio, karma, created_at
|
||||
2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics
|
||||
|
||||
我们使用详细格式,与用户示例数据(36个简单人设.json)保持一致
|
||||
"""
|
||||
data = []
|
||||
for profile in profiles:
|
||||
# 使用详细格式(与用户示例兼容)
|
||||
item = {
|
||||
"realname": profile.name,
|
||||
"username": profile.user_name,
|
||||
"bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符
|
||||
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
|
||||
}
|
||||
|
||||
# 添加人设详情字段
|
||||
if profile.age:
|
||||
item["age"] = profile.age
|
||||
if profile.gender:
|
||||
item["gender"] = profile.gender
|
||||
if profile.mbti:
|
||||
item["mbti"] = profile.mbti
|
||||
if profile.country:
|
||||
item["country"] = profile.country
|
||||
if profile.profession:
|
||||
item["profession"] = profile.profession
|
||||
if profile.interested_topics:
|
||||
item["interested_topics"] = profile.interested_topics
|
||||
|
||||
data.append(item)
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)")
|
||||
|
||||
# 保留旧方法名作为别名,保持向后兼容
|
||||
def save_profiles_to_json(
|
||||
self,
|
||||
profiles: List[OasisAgentProfile],
|
||||
file_path: str,
|
||||
platform: str = "reddit"
|
||||
):
|
||||
"""[已废弃] 请使用 save_profiles() 方法"""
|
||||
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
||||
self.save_profiles(profiles, file_path, platform)
|
||||
|
||||
Reference in New Issue
Block a user