Enhance backend functionality with OASIS simulation features
- Updated README.md to include new simulation scripts and configuration details for OASIS, including API retry mechanisms and environment variable settings. - Added simulation management and configuration generation services to streamline the simulation process across Twitter and Reddit platforms. - Introduced new API routes for simulation-related operations, including entity retrieval and simulation status management. - Implemented a robust retry mechanism for external API calls to improve system stability. - Enhanced task management model to include detailed progress tracking. - Added logging capabilities for action tracking during simulations. - Included new scripts for running parallel simulations and testing profile formats.
This commit is contained in:
@@ -5,6 +5,47 @@
|
||||
from .ontology_generator import OntologyGenerator
|
||||
from .graph_builder import GraphBuilderService
|
||||
from .text_processor import TextProcessor
|
||||
from .zep_entity_reader import ZepEntityReader, EntityNode, FilteredEntities
|
||||
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
|
||||
from .simulation_manager import SimulationManager, SimulationState, SimulationStatus
|
||||
from .simulation_config_generator import (
|
||||
SimulationConfigGenerator,
|
||||
SimulationParameters,
|
||||
AgentActivityConfig,
|
||||
TimeSimulationConfig,
|
||||
EventConfig,
|
||||
PlatformConfig
|
||||
)
|
||||
from .simulation_runner import (
|
||||
SimulationRunner,
|
||||
SimulationRunState,
|
||||
RunnerStatus,
|
||||
AgentAction,
|
||||
RoundSummary
|
||||
)
|
||||
|
||||
__all__ = ['OntologyGenerator', 'GraphBuilderService', 'TextProcessor']
|
||||
__all__ = [
|
||||
'OntologyGenerator',
|
||||
'GraphBuilderService',
|
||||
'TextProcessor',
|
||||
'ZepEntityReader',
|
||||
'EntityNode',
|
||||
'FilteredEntities',
|
||||
'OasisProfileGenerator',
|
||||
'OasisAgentProfile',
|
||||
'SimulationManager',
|
||||
'SimulationState',
|
||||
'SimulationStatus',
|
||||
'SimulationConfigGenerator',
|
||||
'SimulationParameters',
|
||||
'AgentActivityConfig',
|
||||
'TimeSimulationConfig',
|
||||
'EventConfig',
|
||||
'PlatformConfig',
|
||||
'SimulationRunner',
|
||||
'SimulationRunState',
|
||||
'RunnerStatus',
|
||||
'AgentAction',
|
||||
'RoundSummary',
|
||||
]
|
||||
|
||||
|
||||
561
backend/app/services/oasis_profile_generator.py
Normal file
561
backend/app/services/oasis_profile_generator.py
Normal file
@@ -0,0 +1,561 @@
|
||||
"""
|
||||
OASIS Agent Profile生成器
|
||||
将Zep图谱中的实体转换为OASIS模拟平台所需的Agent Profile格式
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
||||
logger = get_logger('mirofish.oasis_profile')
|
||||
|
||||
|
||||
@dataclass
|
||||
class OasisAgentProfile:
|
||||
"""OASIS Agent Profile数据结构"""
|
||||
# 通用字段
|
||||
user_id: int
|
||||
user_name: str
|
||||
name: str
|
||||
bio: str
|
||||
persona: str
|
||||
|
||||
# 可选字段 - Reddit风格
|
||||
karma: int = 1000
|
||||
|
||||
# 可选字段 - Twitter风格
|
||||
friend_count: int = 100
|
||||
follower_count: int = 150
|
||||
statuses_count: int = 500
|
||||
|
||||
# 额外人设信息
|
||||
age: Optional[int] = None
|
||||
gender: Optional[str] = None
|
||||
mbti: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
profession: Optional[str] = None
|
||||
interested_topics: List[str] = field(default_factory=list)
|
||||
|
||||
# 来源实体信息
|
||||
source_entity_uuid: Optional[str] = None
|
||||
source_entity_type: Optional[str] = None
|
||||
|
||||
created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
|
||||
|
||||
def to_reddit_format(self) -> Dict[str, Any]:
|
||||
"""转换为Reddit平台格式"""
|
||||
profile = {
|
||||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"name": self.name,
|
||||
"bio": self.bio,
|
||||
"persona": self.persona,
|
||||
"karma": self.karma,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
# 添加额外人设信息(如果有)
|
||||
if self.age:
|
||||
profile["age"] = self.age
|
||||
if self.gender:
|
||||
profile["gender"] = self.gender
|
||||
if self.mbti:
|
||||
profile["mbti"] = self.mbti
|
||||
if self.country:
|
||||
profile["country"] = self.country
|
||||
if self.profession:
|
||||
profile["profession"] = self.profession
|
||||
if self.interested_topics:
|
||||
profile["interested_topics"] = self.interested_topics
|
||||
|
||||
return profile
|
||||
|
||||
def to_twitter_format(self) -> Dict[str, Any]:
|
||||
"""转换为Twitter平台格式"""
|
||||
profile = {
|
||||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"name": self.name,
|
||||
"bio": self.bio,
|
||||
"persona": self.persona,
|
||||
"friend_count": self.friend_count,
|
||||
"follower_count": self.follower_count,
|
||||
"statuses_count": self.statuses_count,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
# 添加额外人设信息
|
||||
if self.age:
|
||||
profile["age"] = self.age
|
||||
if self.gender:
|
||||
profile["gender"] = self.gender
|
||||
if self.mbti:
|
||||
profile["mbti"] = self.mbti
|
||||
if self.country:
|
||||
profile["country"] = self.country
|
||||
if self.profession:
|
||||
profile["profession"] = self.profession
|
||||
if self.interested_topics:
|
||||
profile["interested_topics"] = self.interested_topics
|
||||
|
||||
return profile
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为完整字典格式"""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"name": self.name,
|
||||
"bio": self.bio,
|
||||
"persona": self.persona,
|
||||
"karma": self.karma,
|
||||
"friend_count": self.friend_count,
|
||||
"follower_count": self.follower_count,
|
||||
"statuses_count": self.statuses_count,
|
||||
"age": self.age,
|
||||
"gender": self.gender,
|
||||
"mbti": self.mbti,
|
||||
"country": self.country,
|
||||
"profession": self.profession,
|
||||
"interested_topics": self.interested_topics,
|
||||
"source_entity_uuid": self.source_entity_uuid,
|
||||
"source_entity_type": self.source_entity_type,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
|
||||
class OasisProfileGenerator:
|
||||
"""
|
||||
OASIS Profile生成器
|
||||
|
||||
将Zep图谱中的实体转换为OASIS模拟所需的Agent Profile
|
||||
"""
|
||||
|
||||
# MBTI类型列表
|
||||
MBTI_TYPES = [
|
||||
"INTJ", "INTP", "ENTJ", "ENTP",
|
||||
"INFJ", "INFP", "ENFJ", "ENFP",
|
||||
"ISTJ", "ISFJ", "ESTJ", "ESFJ",
|
||||
"ISTP", "ISFP", "ESTP", "ESFP"
|
||||
]
|
||||
|
||||
# 常见国家列表
|
||||
COUNTRIES = [
|
||||
"China", "US", "UK", "Japan", "Germany", "France",
|
||||
"Canada", "Australia", "Brazil", "India", "South Korea"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
model_name: Optional[str] = None
|
||||
):
|
||||
self.api_key = api_key or Config.LLM_API_KEY
|
||||
self.base_url = base_url or Config.LLM_BASE_URL
|
||||
self.model_name = model_name or Config.LLM_MODEL_NAME
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
def generate_profile_from_entity(
|
||||
self,
|
||||
entity: EntityNode,
|
||||
user_id: int,
|
||||
use_llm: bool = True
|
||||
) -> OasisAgentProfile:
|
||||
"""
|
||||
从Zep实体生成OASIS Agent Profile
|
||||
|
||||
Args:
|
||||
entity: Zep实体节点
|
||||
user_id: 用户ID(用于OASIS)
|
||||
use_llm: 是否使用LLM生成详细人设
|
||||
|
||||
Returns:
|
||||
OasisAgentProfile
|
||||
"""
|
||||
entity_type = entity.get_entity_type() or "Entity"
|
||||
|
||||
# 基础信息
|
||||
name = entity.name
|
||||
user_name = self._generate_username(name)
|
||||
|
||||
# 构建上下文信息
|
||||
context = self._build_entity_context(entity)
|
||||
|
||||
if use_llm:
|
||||
# 使用LLM生成详细人设
|
||||
profile_data = self._generate_profile_with_llm(
|
||||
entity_name=name,
|
||||
entity_type=entity_type,
|
||||
entity_summary=entity.summary,
|
||||
entity_attributes=entity.attributes,
|
||||
context=context
|
||||
)
|
||||
else:
|
||||
# 使用规则生成基础人设
|
||||
profile_data = self._generate_profile_rule_based(
|
||||
entity_name=name,
|
||||
entity_type=entity_type,
|
||||
entity_summary=entity.summary,
|
||||
entity_attributes=entity.attributes
|
||||
)
|
||||
|
||||
return OasisAgentProfile(
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
name=name,
|
||||
bio=profile_data.get("bio", f"{entity_type}: {name}"),
|
||||
persona=profile_data.get("persona", entity.summary or f"A {entity_type} named {name}."),
|
||||
karma=profile_data.get("karma", random.randint(500, 5000)),
|
||||
friend_count=profile_data.get("friend_count", random.randint(50, 500)),
|
||||
follower_count=profile_data.get("follower_count", random.randint(100, 1000)),
|
||||
statuses_count=profile_data.get("statuses_count", random.randint(100, 2000)),
|
||||
age=profile_data.get("age"),
|
||||
gender=profile_data.get("gender"),
|
||||
mbti=profile_data.get("mbti"),
|
||||
country=profile_data.get("country"),
|
||||
profession=profile_data.get("profession"),
|
||||
interested_topics=profile_data.get("interested_topics", []),
|
||||
source_entity_uuid=entity.uuid,
|
||||
source_entity_type=entity_type,
|
||||
)
|
||||
|
||||
def _generate_username(self, name: str) -> str:
|
||||
"""生成用户名"""
|
||||
# 移除特殊字符,转换为小写
|
||||
username = name.lower().replace(" ", "_")
|
||||
username = ''.join(c for c in username if c.isalnum() or c == '_')
|
||||
|
||||
# 添加随机后缀避免重复
|
||||
suffix = random.randint(100, 999)
|
||||
return f"{username}_{suffix}"
|
||||
|
||||
def _build_entity_context(self, entity: EntityNode) -> str:
|
||||
"""构建实体的上下文信息"""
|
||||
context_parts = []
|
||||
|
||||
# 添加相关边信息
|
||||
if entity.related_edges:
|
||||
relationships = []
|
||||
for edge in entity.related_edges[:10]: # 最多取10条
|
||||
if edge.get("fact"):
|
||||
relationships.append(edge["fact"])
|
||||
|
||||
if relationships:
|
||||
context_parts.append("Related facts:\n" + "\n".join(f"- {r}" for r in relationships))
|
||||
|
||||
# 添加关联节点信息
|
||||
if entity.related_nodes:
|
||||
related_names = [n["name"] for n in entity.related_nodes[:5]]
|
||||
if related_names:
|
||||
context_parts.append(f"Related to: {', '.join(related_names)}")
|
||||
|
||||
return "\n\n".join(context_parts)
|
||||
|
||||
def _generate_profile_with_llm(
|
||||
self,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
entity_summary: str,
|
||||
entity_attributes: Dict[str, Any],
|
||||
context: str
|
||||
) -> Dict[str, Any]:
|
||||
"""使用LLM生成详细人设"""
|
||||
|
||||
prompt = f"""Based on the following entity information, generate a detailed social media user profile for simulation purposes.
|
||||
|
||||
Entity Information:
|
||||
- Name: {entity_name}
|
||||
- Type: {entity_type}
|
||||
- Summary: {entity_summary}
|
||||
- Attributes: {json.dumps(entity_attributes, ensure_ascii=False)}
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Generate a JSON object with the following fields:
|
||||
{{
|
||||
"bio": "A short bio (max 150 chars) suitable for social media",
|
||||
"persona": "A detailed persona description (2-3 sentences) describing personality, interests, and behavior patterns",
|
||||
"age": <integer between 18-65, or null if not applicable>,
|
||||
"gender": "<male/female/other, or null if not applicable>",
|
||||
"mbti": "<MBTI type like INTJ, ENFP, etc., or null>",
|
||||
"country": "<country name, or null>",
|
||||
"profession": "<profession/occupation, or null>",
|
||||
"interested_topics": ["topic1", "topic2", ...]
|
||||
}}
|
||||
|
||||
Important:
|
||||
- The profile should be consistent with the entity type and context
|
||||
- Make the persona feel realistic and suitable for social media simulation
|
||||
- If the entity is an organization, institution, or non-person, adapt the profile accordingly (e.g., as an official account)
|
||||
- Return ONLY the JSON object, no additional text"""
|
||||
|
||||
try:
|
||||
# 使用重试机制调用LLM API
|
||||
from ..utils.retry import RetryableAPIClient
|
||||
|
||||
retry_client = RetryableAPIClient(max_retries=3, initial_delay=1.0)
|
||||
|
||||
def call_llm():
|
||||
return self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a profile generator for social media simulation. Generate realistic user profiles based on entity information."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
response = retry_client.call_with_retry(call_llm)
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM生成人设失败(已重试): {str(e)}, 使用规则生成")
|
||||
return self._generate_profile_rule_based(
|
||||
entity_name, entity_type, entity_summary, entity_attributes
|
||||
)
|
||||
|
||||
def _generate_profile_rule_based(
|
||||
self,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
entity_summary: str,
|
||||
entity_attributes: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""使用规则生成基础人设"""
|
||||
|
||||
# 根据实体类型生成不同的人设
|
||||
entity_type_lower = entity_type.lower()
|
||||
|
||||
if entity_type_lower in ["student", "alumni"]:
|
||||
return {
|
||||
"bio": f"{entity_type} with interests in academics and social issues.",
|
||||
"persona": f"{entity_name} is a {entity_type.lower()} who is actively engaged in academic and social discussions. They enjoy sharing perspectives and connecting with peers.",
|
||||
"age": random.randint(18, 30),
|
||||
"gender": random.choice(["male", "female"]),
|
||||
"mbti": random.choice(self.MBTI_TYPES),
|
||||
"country": random.choice(self.COUNTRIES),
|
||||
"profession": "Student",
|
||||
"interested_topics": ["Education", "Social Issues", "Technology"],
|
||||
}
|
||||
|
||||
elif entity_type_lower in ["publicfigure", "expert", "faculty"]:
|
||||
return {
|
||||
"bio": f"Expert and thought leader in their field.",
|
||||
"persona": f"{entity_name} is a recognized {entity_type.lower()} who shares insights and opinions on important matters. They are known for their expertise and influence in public discourse.",
|
||||
"age": random.randint(35, 60),
|
||||
"gender": random.choice(["male", "female"]),
|
||||
"mbti": random.choice(["ENTJ", "INTJ", "ENTP", "INTP"]),
|
||||
"country": random.choice(self.COUNTRIES),
|
||||
"profession": entity_attributes.get("occupation", "Expert"),
|
||||
"interested_topics": ["Politics", "Economics", "Culture & Society"],
|
||||
}
|
||||
|
||||
elif entity_type_lower in ["mediaoutlet", "socialmediaplatform"]:
|
||||
return {
|
||||
"bio": f"Official account for {entity_name}. News and updates.",
|
||||
"persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.",
|
||||
"profession": "Media",
|
||||
"interested_topics": ["General News", "Current Events", "Public Affairs"],
|
||||
}
|
||||
|
||||
elif entity_type_lower in ["university", "governmentagency", "ngo", "organization"]:
|
||||
return {
|
||||
"bio": f"Official account of {entity_name}.",
|
||||
"persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.",
|
||||
"profession": entity_type,
|
||||
"interested_topics": ["Public Policy", "Community", "Official Announcements"],
|
||||
}
|
||||
|
||||
else:
|
||||
# 默认人设
|
||||
return {
|
||||
"bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}",
|
||||
"persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.",
|
||||
"age": random.randint(25, 50),
|
||||
"gender": random.choice(["male", "female"]),
|
||||
"mbti": random.choice(self.MBTI_TYPES),
|
||||
"country": random.choice(self.COUNTRIES),
|
||||
"profession": entity_type,
|
||||
"interested_topics": ["General", "Social Issues"],
|
||||
}
|
||||
|
||||
def generate_profiles_from_entities(
|
||||
self,
|
||||
entities: List[EntityNode],
|
||||
use_llm: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> List[OasisAgentProfile]:
|
||||
"""
|
||||
批量从实体生成Agent Profile
|
||||
|
||||
Args:
|
||||
entities: 实体列表
|
||||
use_llm: 是否使用LLM生成详细人设
|
||||
progress_callback: 进度回调函数 (current, total, message)
|
||||
|
||||
Returns:
|
||||
Agent Profile列表
|
||||
"""
|
||||
profiles = []
|
||||
total = len(entities)
|
||||
|
||||
for idx, entity in enumerate(entities):
|
||||
if progress_callback:
|
||||
progress_callback(idx + 1, total, f"生成 {entity.name} 的人设...")
|
||||
|
||||
try:
|
||||
profile = self.generate_profile_from_entity(
|
||||
entity=entity,
|
||||
user_id=idx,
|
||||
use_llm=use_llm
|
||||
)
|
||||
profiles.append(profile)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成实体 {entity.name} 的人设失败: {str(e)}")
|
||||
# 创建一个基础profile
|
||||
profiles.append(OasisAgentProfile(
|
||||
user_id=idx,
|
||||
user_name=self._generate_username(entity.name),
|
||||
name=entity.name,
|
||||
bio=f"{entity.get_entity_type() or 'Entity'}: {entity.name}",
|
||||
persona=entity.summary or f"A participant in social discussions.",
|
||||
source_entity_uuid=entity.uuid,
|
||||
source_entity_type=entity.get_entity_type(),
|
||||
))
|
||||
|
||||
return profiles
|
||||
|
||||
def save_profiles(
|
||||
self,
|
||||
profiles: List[OasisAgentProfile],
|
||||
file_path: str,
|
||||
platform: str = "reddit"
|
||||
):
|
||||
"""
|
||||
保存Profile到文件(根据平台选择正确格式)
|
||||
|
||||
OASIS平台格式要求:
|
||||
- Twitter: CSV格式
|
||||
- Reddit: JSON格式
|
||||
|
||||
Args:
|
||||
profiles: Profile列表
|
||||
file_path: 文件路径
|
||||
platform: 平台类型 ("reddit" 或 "twitter")
|
||||
"""
|
||||
if platform == "twitter":
|
||||
self._save_twitter_csv(profiles, file_path)
|
||||
else:
|
||||
self._save_reddit_json(profiles, file_path)
|
||||
|
||||
def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str):
|
||||
"""
|
||||
保存Twitter Profile为CSV格式
|
||||
|
||||
OASIS Twitter要求的CSV字段:
|
||||
user_id, user_name, name, bio, friend_count, follower_count, statuses_count, created_at
|
||||
"""
|
||||
import csv
|
||||
|
||||
# 确保文件扩展名是.csv
|
||||
if not file_path.endswith('.csv'):
|
||||
file_path = file_path.replace('.json', '.csv')
|
||||
|
||||
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.writer(f)
|
||||
|
||||
# 写入表头
|
||||
headers = ['user_id', 'user_name', 'name', 'bio', 'friend_count',
|
||||
'follower_count', 'statuses_count', 'created_at']
|
||||
writer.writerow(headers)
|
||||
|
||||
# 写入数据行
|
||||
for profile in profiles:
|
||||
# bio需要处理换行符和逗号
|
||||
bio = profile.bio.replace('\n', ' ').replace('\r', ' ')
|
||||
row = [
|
||||
profile.user_id,
|
||||
profile.user_name,
|
||||
profile.name,
|
||||
bio,
|
||||
profile.friend_count,
|
||||
profile.follower_count,
|
||||
profile.statuses_count,
|
||||
profile.created_at
|
||||
]
|
||||
writer.writerow(row)
|
||||
|
||||
logger.info(f"已保存 {len(profiles)} 个Twitter Profile到 {file_path} (CSV格式)")
|
||||
|
||||
def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str):
|
||||
"""
|
||||
保存Reddit Profile为JSON格式
|
||||
|
||||
OASIS Reddit支持两种JSON格式:
|
||||
1. 基础格式: user_id, user_name, name, bio, karma, created_at
|
||||
2. 详细格式: realname, username, bio, persona, age, gender, mbti, country, profession, interested_topics
|
||||
|
||||
我们使用详细格式,与用户示例数据(36个简单人设.json)保持一致
|
||||
"""
|
||||
data = []
|
||||
for profile in profiles:
|
||||
# 使用详细格式(与用户示例兼容)
|
||||
item = {
|
||||
"realname": profile.name,
|
||||
"username": profile.user_name,
|
||||
"bio": profile.bio[:150] if profile.bio else "", # OASIS bio限制150字符
|
||||
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
|
||||
}
|
||||
|
||||
# 添加人设详情字段
|
||||
if profile.age:
|
||||
item["age"] = profile.age
|
||||
if profile.gender:
|
||||
item["gender"] = profile.gender
|
||||
if profile.mbti:
|
||||
item["mbti"] = profile.mbti
|
||||
if profile.country:
|
||||
item["country"] = profile.country
|
||||
if profile.profession:
|
||||
item["profession"] = profile.profession
|
||||
if profile.interested_topics:
|
||||
item["interested_topics"] = profile.interested_topics
|
||||
|
||||
data.append(item)
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"已保存 {len(profiles)} 个Reddit Profile到 {file_path} (JSON详细格式)")
|
||||
|
||||
# 保留旧方法名作为别名,保持向后兼容
|
||||
def save_profiles_to_json(
|
||||
self,
|
||||
profiles: List[OasisAgentProfile],
|
||||
file_path: str,
|
||||
platform: str = "reddit"
|
||||
):
|
||||
"""[已废弃] 请使用 save_profiles() 方法"""
|
||||
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
||||
self.save_profiles(profiles, file_path, platform)
|
||||
|
||||
584
backend/app/services/simulation_config_generator.py
Normal file
584
backend/app/services/simulation_config_generator.py
Normal file
@@ -0,0 +1,584 @@
|
||||
"""
|
||||
模拟配置智能生成器
|
||||
使用LLM根据模拟需求、文档内容、图谱信息自动生成细致的模拟参数
|
||||
实现全程自动化,无需人工设置参数
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
||||
logger = get_logger('mirofish.simulation_config')
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentActivityConfig:
|
||||
"""单个Agent的活动配置"""
|
||||
agent_id: int
|
||||
entity_uuid: str
|
||||
entity_name: str
|
||||
entity_type: str
|
||||
|
||||
# 活跃度配置 (0.0-1.0)
|
||||
activity_level: float = 0.5 # 整体活跃度
|
||||
|
||||
# 发言频率(每小时预期发言次数)
|
||||
posts_per_hour: float = 1.0
|
||||
comments_per_hour: float = 2.0
|
||||
|
||||
# 活跃时间段(24小时制,0-23)
|
||||
active_hours: List[int] = field(default_factory=lambda: list(range(8, 23)))
|
||||
|
||||
# 响应速度(对热点事件的反应延迟,单位:模拟分钟)
|
||||
response_delay_min: int = 5
|
||||
response_delay_max: int = 60
|
||||
|
||||
# 情感倾向 (-1.0到1.0,负面到正面)
|
||||
sentiment_bias: float = 0.0
|
||||
|
||||
# 立场(对特定话题的态度)
|
||||
stance: str = "neutral" # supportive, opposing, neutral, observer
|
||||
|
||||
# 影响力权重(决定其发言被其他Agent看到的概率)
|
||||
influence_weight: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeSimulationConfig:
|
||||
"""时间模拟配置"""
|
||||
# 模拟总时长(模拟小时数)
|
||||
total_simulation_hours: int = 72 # 默认模拟72小时(3天)
|
||||
|
||||
# 每轮代表的时间(模拟分钟)
|
||||
minutes_per_round: int = 30
|
||||
|
||||
# 每小时激活的Agent数量范围
|
||||
agents_per_hour_min: int = 5
|
||||
agents_per_hour_max: int = 20
|
||||
|
||||
# 高峰时段(活跃度提升)
|
||||
peak_hours: List[int] = field(default_factory=lambda: [9, 10, 11, 14, 15, 20, 21, 22])
|
||||
peak_activity_multiplier: float = 1.5
|
||||
|
||||
# 低谷时段(活跃度降低)
|
||||
off_peak_hours: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5, 6])
|
||||
off_peak_activity_multiplier: float = 0.3
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventConfig:
|
||||
"""事件配置"""
|
||||
# 初始事件(模拟开始时的触发事件)
|
||||
initial_posts: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# 定时事件(在特定时间触发的事件)
|
||||
scheduled_events: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# 热点话题关键词
|
||||
hot_topics: List[str] = field(default_factory=list)
|
||||
|
||||
# 舆论引导方向
|
||||
narrative_direction: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformConfig:
|
||||
"""平台特定配置"""
|
||||
platform: str # twitter or reddit
|
||||
|
||||
# 推荐算法权重
|
||||
recency_weight: float = 0.4 # 时间新鲜度
|
||||
popularity_weight: float = 0.3 # 热度
|
||||
relevance_weight: float = 0.3 # 相关性
|
||||
|
||||
# 病毒传播阈值(达到多少互动后触发扩散)
|
||||
viral_threshold: int = 10
|
||||
|
||||
# 回声室效应强度(相似观点聚集程度)
|
||||
echo_chamber_strength: float = 0.5
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimulationParameters:
|
||||
"""完整的模拟参数配置"""
|
||||
# 基础信息
|
||||
simulation_id: str
|
||||
project_id: str
|
||||
graph_id: str
|
||||
simulation_requirement: str
|
||||
|
||||
# 时间配置
|
||||
time_config: TimeSimulationConfig = field(default_factory=TimeSimulationConfig)
|
||||
|
||||
# Agent配置列表
|
||||
agent_configs: List[AgentActivityConfig] = field(default_factory=list)
|
||||
|
||||
# 事件配置
|
||||
event_config: EventConfig = field(default_factory=EventConfig)
|
||||
|
||||
# 平台配置
|
||||
twitter_config: Optional[PlatformConfig] = None
|
||||
reddit_config: Optional[PlatformConfig] = None
|
||||
|
||||
# LLM配置
|
||||
llm_model: str = ""
|
||||
llm_base_url: str = ""
|
||||
|
||||
# 生成元数据
|
||||
generated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
generation_reasoning: str = "" # LLM的推理说明
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"project_id": self.project_id,
|
||||
"graph_id": self.graph_id,
|
||||
"simulation_requirement": self.simulation_requirement,
|
||||
"time_config": asdict(self.time_config),
|
||||
"agent_configs": [asdict(a) for a in self.agent_configs],
|
||||
"event_config": asdict(self.event_config),
|
||||
"twitter_config": asdict(self.twitter_config) if self.twitter_config else None,
|
||||
"reddit_config": asdict(self.reddit_config) if self.reddit_config else None,
|
||||
"llm_model": self.llm_model,
|
||||
"llm_base_url": self.llm_base_url,
|
||||
"generated_at": self.generated_at,
|
||||
"generation_reasoning": self.generation_reasoning,
|
||||
}
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
"""转换为JSON字符串"""
|
||||
return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent)
|
||||
|
||||
|
||||
class SimulationConfigGenerator:
|
||||
"""
|
||||
模拟配置智能生成器
|
||||
|
||||
使用LLM分析模拟需求、文档内容、图谱实体信息,
|
||||
自动生成最佳的模拟参数配置
|
||||
"""
|
||||
|
||||
# 上下文最大字符数
|
||||
MAX_CONTEXT_LENGTH = 50000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
model_name: Optional[str] = None
|
||||
):
|
||||
self.api_key = api_key or Config.LLM_API_KEY
|
||||
self.base_url = base_url or Config.LLM_BASE_URL
|
||||
self.model_name = model_name or Config.LLM_MODEL_NAME
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
def generate_config(
|
||||
self,
|
||||
simulation_id: str,
|
||||
project_id: str,
|
||||
graph_id: str,
|
||||
simulation_requirement: str,
|
||||
document_text: str,
|
||||
entities: List[EntityNode],
|
||||
enable_twitter: bool = True,
|
||||
enable_reddit: bool = True,
|
||||
) -> SimulationParameters:
|
||||
"""
|
||||
智能生成完整的模拟配置
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
project_id: 项目ID
|
||||
graph_id: 图谱ID
|
||||
simulation_requirement: 模拟需求描述
|
||||
document_text: 原始文档内容
|
||||
entities: 过滤后的实体列表
|
||||
enable_twitter: 是否启用Twitter
|
||||
enable_reddit: 是否启用Reddit
|
||||
|
||||
Returns:
|
||||
SimulationParameters: 完整的模拟参数
|
||||
"""
|
||||
logger.info(f"开始智能生成模拟配置: simulation_id={simulation_id}")
|
||||
|
||||
# 1. 构建上下文信息(截断到50000字符)
|
||||
context = self._build_context(
|
||||
simulation_requirement=simulation_requirement,
|
||||
document_text=document_text,
|
||||
entities=entities
|
||||
)
|
||||
|
||||
# 2. 调用LLM生成配置
|
||||
llm_result = self._generate_config_with_llm(
|
||||
context=context,
|
||||
entities=entities,
|
||||
enable_twitter=enable_twitter,
|
||||
enable_reddit=enable_reddit
|
||||
)
|
||||
|
||||
# 3. 构建SimulationParameters对象
|
||||
params = self._build_parameters(
|
||||
simulation_id=simulation_id,
|
||||
project_id=project_id,
|
||||
graph_id=graph_id,
|
||||
simulation_requirement=simulation_requirement,
|
||||
entities=entities,
|
||||
llm_result=llm_result,
|
||||
enable_twitter=enable_twitter,
|
||||
enable_reddit=enable_reddit
|
||||
)
|
||||
|
||||
logger.info(f"模拟配置生成完成: {len(params.agent_configs)} 个Agent配置")
|
||||
|
||||
return params
|
||||
|
||||
def _build_context(
|
||||
self,
|
||||
simulation_requirement: str,
|
||||
document_text: str,
|
||||
entities: List[EntityNode]
|
||||
) -> str:
|
||||
"""构建LLM上下文,截断到最大长度"""
|
||||
|
||||
# 实体摘要
|
||||
entity_summary = self._summarize_entities(entities)
|
||||
|
||||
# 构建上下文
|
||||
context_parts = [
|
||||
f"## 模拟需求\n{simulation_requirement}",
|
||||
f"\n## 实体信息 ({len(entities)}个)\n{entity_summary}",
|
||||
]
|
||||
|
||||
current_length = sum(len(p) for p in context_parts)
|
||||
remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 # 留500字符余量
|
||||
|
||||
if remaining_length > 0 and document_text:
|
||||
doc_text = document_text[:remaining_length]
|
||||
if len(document_text) > remaining_length:
|
||||
doc_text += "\n...(文档已截断)"
|
||||
context_parts.append(f"\n## 原始文档内容\n{doc_text}")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def _summarize_entities(self, entities: List[EntityNode]) -> str:
|
||||
"""生成实体摘要"""
|
||||
lines = []
|
||||
|
||||
# 按类型分组
|
||||
by_type: Dict[str, List[EntityNode]] = {}
|
||||
for e in entities:
|
||||
t = e.get_entity_type() or "Unknown"
|
||||
if t not in by_type:
|
||||
by_type[t] = []
|
||||
by_type[t].append(e)
|
||||
|
||||
for entity_type, type_entities in by_type.items():
|
||||
lines.append(f"\n### {entity_type} ({len(type_entities)}个)")
|
||||
for e in type_entities[:10]: # 每类最多显示10个
|
||||
summary_preview = (e.summary[:100] + "...") if len(e.summary) > 100 else e.summary
|
||||
lines.append(f"- {e.name}: {summary_preview}")
|
||||
if len(type_entities) > 10:
|
||||
lines.append(f" ... 还有 {len(type_entities) - 10} 个")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_config_with_llm(
|
||||
self,
|
||||
context: str,
|
||||
entities: List[EntityNode],
|
||||
enable_twitter: bool,
|
||||
enable_reddit: bool
|
||||
) -> Dict[str, Any]:
|
||||
"""调用LLM生成配置"""
|
||||
|
||||
# 构建实体列表用于Agent配置
|
||||
entity_list = []
|
||||
for i, e in enumerate(entities):
|
||||
entity_list.append({
|
||||
"agent_id": i,
|
||||
"entity_uuid": e.uuid,
|
||||
"entity_name": e.name,
|
||||
"entity_type": e.get_entity_type() or "Unknown",
|
||||
"summary": e.summary[:200] if e.summary else ""
|
||||
})
|
||||
|
||||
prompt = f"""你是一个社交媒体舆论模拟专家。请根据以下信息,生成详细的模拟参数配置。
|
||||
|
||||
{context}
|
||||
|
||||
## 实体列表(需要为每个实体生成活动配置)
|
||||
```json
|
||||
{json.dumps(entity_list, ensure_ascii=False, indent=2)}
|
||||
```
|
||||
|
||||
## 任务
|
||||
请生成一个JSON配置,包含以下部分:
|
||||
|
||||
1. **time_config** - 时间模拟配置
|
||||
- total_simulation_hours: 模拟总时长(小时),根据事件性质决定(短期热点24-72小时,长期舆论168-336小时)
|
||||
- minutes_per_round: 每轮代表的时间(分钟),建议15-60
|
||||
- agents_per_hour_min/max: 每小时激活的Agent数量范围
|
||||
- peak_hours: 高峰时段列表(0-23)
|
||||
- off_peak_hours: 低谷时段列表
|
||||
|
||||
2. **agent_configs** - 每个Agent的活动配置(必须为每个实体生成)
|
||||
对于每个agent_id,设置:
|
||||
- activity_level: 活跃度(0.0-1.0),官方机构通常0.1-0.3,媒体0.3-0.5,个人0.5-0.9
|
||||
- posts_per_hour: 每小时发帖频率,官方机构0.05-0.2,媒体0.5-2,个人0.1-1
|
||||
- comments_per_hour: 每小时评论频率
|
||||
- active_hours: 活跃时间段列表,官方通常工作时间,个人更分散
|
||||
- response_delay_min/max: 响应延迟(模拟分钟),官方较慢(30-180),个人较快(1-30)
|
||||
- sentiment_bias: 情感倾向(-1到1),根据实体立场设置
|
||||
- stance: 立场(supportive/opposing/neutral/observer)
|
||||
- influence_weight: 影响力权重,知名人物和媒体较高
|
||||
|
||||
3. **event_config** - 事件配置
|
||||
- initial_posts: 初始帖子列表,包含content和poster_agent_id
|
||||
- hot_topics: 热点话题关键词列表
|
||||
- narrative_direction: 舆论发展方向描述
|
||||
|
||||
4. **platform_configs** - 平台配置(如果启用)
|
||||
- viral_threshold: 病毒传播阈值
|
||||
- echo_chamber_strength: 回声室效应强度(0-1)
|
||||
|
||||
5. **reasoning** - 你的推理说明,解释为什么这样设置参数
|
||||
|
||||
## 重要原则
|
||||
- 官方机构(University、GovernmentAgency)发言频率低但影响力大
|
||||
- 媒体(MediaOutlet)发言频率中等,传播速度快
|
||||
- 个人(Student、PublicFigure)发言频率高但影响力分散
|
||||
- 根据模拟需求判断各实体的立场和情感倾向
|
||||
- 时间配置要符合真实社交媒体的使用规律
|
||||
|
||||
请返回JSON格式,不要包含markdown代码块标记。"""
|
||||
|
||||
try:
|
||||
# 使用重试机制调用LLM API
|
||||
from ..utils.retry import RetryableAPIClient
|
||||
|
||||
retry_client = RetryableAPIClient(max_retries=3, initial_delay=2.0, max_delay=60.0)
|
||||
|
||||
def call_llm():
|
||||
return self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是社交媒体舆论模拟专家,擅长设计真实的模拟参数。返回纯JSON格式,不要markdown。"
|
||||
},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.7,
|
||||
max_tokens=8000
|
||||
)
|
||||
|
||||
response = retry_client.call_with_retry(call_llm)
|
||||
result = json.loads(response.choices[0].message.content)
|
||||
logger.info(f"LLM配置生成成功")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM配置生成失败(已重试): {str(e)}")
|
||||
# 返回默认配置
|
||||
return self._generate_default_config(entities)
|
||||
|
||||
def _generate_default_config(self, entities: List[EntityNode]) -> Dict[str, Any]:
|
||||
"""生成默认配置(LLM失败时的fallback)"""
|
||||
agent_configs = []
|
||||
|
||||
for i, e in enumerate(entities):
|
||||
entity_type = (e.get_entity_type() or "Unknown").lower()
|
||||
|
||||
# 根据实体类型设置默认参数
|
||||
if entity_type in ["university", "governmentagency", "ngo"]:
|
||||
config = {
|
||||
"agent_id": i,
|
||||
"activity_level": 0.2,
|
||||
"posts_per_hour": 0.1,
|
||||
"comments_per_hour": 0.05,
|
||||
"active_hours": list(range(9, 18)),
|
||||
"response_delay_min": 60,
|
||||
"response_delay_max": 240,
|
||||
"sentiment_bias": 0.0,
|
||||
"stance": "neutral",
|
||||
"influence_weight": 3.0
|
||||
}
|
||||
elif entity_type in ["mediaoutlet"]:
|
||||
config = {
|
||||
"agent_id": i,
|
||||
"activity_level": 0.6,
|
||||
"posts_per_hour": 1.0,
|
||||
"comments_per_hour": 0.5,
|
||||
"active_hours": list(range(6, 24)),
|
||||
"response_delay_min": 5,
|
||||
"response_delay_max": 30,
|
||||
"sentiment_bias": 0.0,
|
||||
"stance": "observer",
|
||||
"influence_weight": 2.5
|
||||
}
|
||||
elif entity_type in ["publicfigure", "expert"]:
|
||||
config = {
|
||||
"agent_id": i,
|
||||
"activity_level": 0.5,
|
||||
"posts_per_hour": 0.3,
|
||||
"comments_per_hour": 0.5,
|
||||
"active_hours": list(range(8, 23)),
|
||||
"response_delay_min": 10,
|
||||
"response_delay_max": 60,
|
||||
"sentiment_bias": 0.0,
|
||||
"stance": "neutral",
|
||||
"influence_weight": 2.0
|
||||
}
|
||||
else: # Student, Person, etc.
|
||||
config = {
|
||||
"agent_id": i,
|
||||
"activity_level": 0.7,
|
||||
"posts_per_hour": 0.5,
|
||||
"comments_per_hour": 1.0,
|
||||
"active_hours": list(range(7, 24)),
|
||||
"response_delay_min": 1,
|
||||
"response_delay_max": 20,
|
||||
"sentiment_bias": 0.0,
|
||||
"stance": "neutral",
|
||||
"influence_weight": 1.0
|
||||
}
|
||||
|
||||
agent_configs.append(config)
|
||||
|
||||
return {
|
||||
"time_config": {
|
||||
"total_simulation_hours": 72,
|
||||
"minutes_per_round": 30,
|
||||
"agents_per_hour_min": max(1, len(entities) // 10),
|
||||
"agents_per_hour_max": max(5, len(entities) // 3),
|
||||
"peak_hours": [9, 10, 11, 14, 15, 20, 21, 22],
|
||||
"off_peak_hours": [0, 1, 2, 3, 4, 5]
|
||||
},
|
||||
"agent_configs": agent_configs,
|
||||
"event_config": {
|
||||
"initial_posts": [],
|
||||
"hot_topics": [],
|
||||
"narrative_direction": ""
|
||||
},
|
||||
"reasoning": "使用默认配置(LLM生成失败)"
|
||||
}
|
||||
|
||||
def _build_parameters(
|
||||
self,
|
||||
simulation_id: str,
|
||||
project_id: str,
|
||||
graph_id: str,
|
||||
simulation_requirement: str,
|
||||
entities: List[EntityNode],
|
||||
llm_result: Dict[str, Any],
|
||||
enable_twitter: bool,
|
||||
enable_reddit: bool
|
||||
) -> SimulationParameters:
|
||||
"""根据LLM结果构建SimulationParameters对象"""
|
||||
|
||||
# 时间配置
|
||||
time_cfg = llm_result.get("time_config", {})
|
||||
time_config = TimeSimulationConfig(
|
||||
total_simulation_hours=time_cfg.get("total_simulation_hours", 72),
|
||||
minutes_per_round=time_cfg.get("minutes_per_round", 30),
|
||||
agents_per_hour_min=time_cfg.get("agents_per_hour_min", 5),
|
||||
agents_per_hour_max=time_cfg.get("agents_per_hour_max", 20),
|
||||
peak_hours=time_cfg.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]),
|
||||
off_peak_hours=time_cfg.get("off_peak_hours", [0, 1, 2, 3, 4, 5]),
|
||||
peak_activity_multiplier=time_cfg.get("peak_activity_multiplier", 1.5),
|
||||
off_peak_activity_multiplier=time_cfg.get("off_peak_activity_multiplier", 0.3)
|
||||
)
|
||||
|
||||
# Agent配置
|
||||
agent_configs = []
|
||||
llm_agent_configs = {cfg["agent_id"]: cfg for cfg in llm_result.get("agent_configs", [])}
|
||||
|
||||
for i, entity in enumerate(entities):
|
||||
cfg = llm_agent_configs.get(i, {})
|
||||
|
||||
agent_config = AgentActivityConfig(
|
||||
agent_id=i,
|
||||
entity_uuid=entity.uuid,
|
||||
entity_name=entity.name,
|
||||
entity_type=entity.get_entity_type() or "Unknown",
|
||||
activity_level=cfg.get("activity_level", 0.5),
|
||||
posts_per_hour=cfg.get("posts_per_hour", 0.5),
|
||||
comments_per_hour=cfg.get("comments_per_hour", 1.0),
|
||||
active_hours=cfg.get("active_hours", list(range(8, 23))),
|
||||
response_delay_min=cfg.get("response_delay_min", 5),
|
||||
response_delay_max=cfg.get("response_delay_max", 60),
|
||||
sentiment_bias=cfg.get("sentiment_bias", 0.0),
|
||||
stance=cfg.get("stance", "neutral"),
|
||||
influence_weight=cfg.get("influence_weight", 1.0)
|
||||
)
|
||||
agent_configs.append(agent_config)
|
||||
|
||||
# 事件配置
|
||||
event_cfg = llm_result.get("event_config", {})
|
||||
event_config = EventConfig(
|
||||
initial_posts=event_cfg.get("initial_posts", []),
|
||||
scheduled_events=event_cfg.get("scheduled_events", []),
|
||||
hot_topics=event_cfg.get("hot_topics", []),
|
||||
narrative_direction=event_cfg.get("narrative_direction", "")
|
||||
)
|
||||
|
||||
# 平台配置
|
||||
twitter_config = None
|
||||
reddit_config = None
|
||||
|
||||
platform_cfgs = llm_result.get("platform_configs", {})
|
||||
|
||||
if enable_twitter:
|
||||
tw_cfg = platform_cfgs.get("twitter", {})
|
||||
twitter_config = PlatformConfig(
|
||||
platform="twitter",
|
||||
recency_weight=tw_cfg.get("recency_weight", 0.4),
|
||||
popularity_weight=tw_cfg.get("popularity_weight", 0.3),
|
||||
relevance_weight=tw_cfg.get("relevance_weight", 0.3),
|
||||
viral_threshold=tw_cfg.get("viral_threshold", 10),
|
||||
echo_chamber_strength=tw_cfg.get("echo_chamber_strength", 0.5)
|
||||
)
|
||||
|
||||
if enable_reddit:
|
||||
rd_cfg = platform_cfgs.get("reddit", {})
|
||||
reddit_config = PlatformConfig(
|
||||
platform="reddit",
|
||||
recency_weight=rd_cfg.get("recency_weight", 0.3),
|
||||
popularity_weight=rd_cfg.get("popularity_weight", 0.4),
|
||||
relevance_weight=rd_cfg.get("relevance_weight", 0.3),
|
||||
viral_threshold=rd_cfg.get("viral_threshold", 15),
|
||||
echo_chamber_strength=rd_cfg.get("echo_chamber_strength", 0.6)
|
||||
)
|
||||
|
||||
return SimulationParameters(
|
||||
simulation_id=simulation_id,
|
||||
project_id=project_id,
|
||||
graph_id=graph_id,
|
||||
simulation_requirement=simulation_requirement,
|
||||
time_config=time_config,
|
||||
agent_configs=agent_configs,
|
||||
event_config=event_config,
|
||||
twitter_config=twitter_config,
|
||||
reddit_config=reddit_config,
|
||||
llm_model=self.model_name,
|
||||
llm_base_url=self.base_url,
|
||||
generation_reasoning=llm_result.get("reasoning", "")
|
||||
)
|
||||
|
||||
|
||||
546
backend/app/services/simulation_manager.py
Normal file
546
backend/app/services/simulation_manager.py
Normal file
@@ -0,0 +1,546 @@
|
||||
"""
|
||||
OASIS模拟管理器
|
||||
管理Twitter和Reddit双平台并行模拟
|
||||
使用预设脚本 + LLM智能生成配置参数
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_entity_reader import ZepEntityReader, FilteredEntities
|
||||
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
|
||||
from .simulation_config_generator import SimulationConfigGenerator, SimulationParameters
|
||||
|
||||
logger = get_logger('mirofish.simulation')
|
||||
|
||||
|
||||
class SimulationStatus(str, Enum):
|
||||
"""模拟状态"""
|
||||
CREATED = "created"
|
||||
PREPARING = "preparing"
|
||||
READY = "ready"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class PlatformType(str, Enum):
|
||||
"""平台类型"""
|
||||
TWITTER = "twitter"
|
||||
REDDIT = "reddit"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimulationState:
|
||||
"""模拟状态"""
|
||||
simulation_id: str
|
||||
project_id: str
|
||||
graph_id: str
|
||||
|
||||
# 平台启用状态
|
||||
enable_twitter: bool = True
|
||||
enable_reddit: bool = True
|
||||
|
||||
# 状态
|
||||
status: SimulationStatus = SimulationStatus.CREATED
|
||||
|
||||
# 准备阶段数据
|
||||
entities_count: int = 0
|
||||
profiles_count: int = 0
|
||||
entity_types: List[str] = field(default_factory=list)
|
||||
|
||||
# 配置生成信息
|
||||
config_generated: bool = False
|
||||
config_reasoning: str = ""
|
||||
|
||||
# 运行时数据
|
||||
current_round: int = 0
|
||||
twitter_status: str = "not_started"
|
||||
reddit_status: str = "not_started"
|
||||
|
||||
# 时间戳
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
# 错误信息
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""完整状态字典(内部使用)"""
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"project_id": self.project_id,
|
||||
"graph_id": self.graph_id,
|
||||
"enable_twitter": self.enable_twitter,
|
||||
"enable_reddit": self.enable_reddit,
|
||||
"status": self.status.value,
|
||||
"entities_count": self.entities_count,
|
||||
"profiles_count": self.profiles_count,
|
||||
"entity_types": self.entity_types,
|
||||
"config_generated": self.config_generated,
|
||||
"config_reasoning": self.config_reasoning,
|
||||
"current_round": self.current_round,
|
||||
"twitter_status": self.twitter_status,
|
||||
"reddit_status": self.reddit_status,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
def to_simple_dict(self) -> Dict[str, Any]:
|
||||
"""简化状态字典(API返回使用)"""
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"project_id": self.project_id,
|
||||
"graph_id": self.graph_id,
|
||||
"status": self.status.value,
|
||||
"entities_count": self.entities_count,
|
||||
"profiles_count": self.profiles_count,
|
||||
"entity_types": self.entity_types,
|
||||
"config_generated": self.config_generated,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
class SimulationManager:
|
||||
"""
|
||||
模拟管理器
|
||||
|
||||
核心功能:
|
||||
1. 从Zep图谱读取实体并过滤
|
||||
2. 生成OASIS Agent Profile
|
||||
3. 使用LLM智能生成模拟配置参数
|
||||
4. 准备预设脚本所需的所有文件
|
||||
"""
|
||||
|
||||
# 模拟数据存储目录
|
||||
SIMULATION_DATA_DIR = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../../uploads/simulations'
|
||||
)
|
||||
|
||||
# 预设脚本目录
|
||||
SCRIPTS_DIR = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../../scripts'
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
# 确保目录存在
|
||||
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
||||
|
||||
# 内存中的模拟状态缓存
|
||||
self._simulations: Dict[str, SimulationState] = {}
|
||||
|
||||
def _get_simulation_dir(self, simulation_id: str) -> str:
|
||||
"""获取模拟数据目录"""
|
||||
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
|
||||
os.makedirs(sim_dir, exist_ok=True)
|
||||
return sim_dir
|
||||
|
||||
def _save_simulation_state(self, state: SimulationState):
|
||||
"""保存模拟状态到文件"""
|
||||
sim_dir = self._get_simulation_dir(state.simulation_id)
|
||||
state_file = os.path.join(sim_dir, "state.json")
|
||||
|
||||
state.updated_at = datetime.now().isoformat()
|
||||
|
||||
with open(state_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(state.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
|
||||
self._simulations[state.simulation_id] = state
|
||||
|
||||
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
|
||||
"""从文件加载模拟状态"""
|
||||
if simulation_id in self._simulations:
|
||||
return self._simulations[simulation_id]
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
state_file = os.path.join(sim_dir, "state.json")
|
||||
|
||||
if not os.path.exists(state_file):
|
||||
return None
|
||||
|
||||
with open(state_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
state = SimulationState(
|
||||
simulation_id=simulation_id,
|
||||
project_id=data.get("project_id", ""),
|
||||
graph_id=data.get("graph_id", ""),
|
||||
enable_twitter=data.get("enable_twitter", True),
|
||||
enable_reddit=data.get("enable_reddit", True),
|
||||
status=SimulationStatus(data.get("status", "created")),
|
||||
entities_count=data.get("entities_count", 0),
|
||||
profiles_count=data.get("profiles_count", 0),
|
||||
entity_types=data.get("entity_types", []),
|
||||
config_generated=data.get("config_generated", False),
|
||||
config_reasoning=data.get("config_reasoning", ""),
|
||||
current_round=data.get("current_round", 0),
|
||||
twitter_status=data.get("twitter_status", "not_started"),
|
||||
reddit_status=data.get("reddit_status", "not_started"),
|
||||
created_at=data.get("created_at", datetime.now().isoformat()),
|
||||
updated_at=data.get("updated_at", datetime.now().isoformat()),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
self._simulations[simulation_id] = state
|
||||
return state
|
||||
|
||||
def create_simulation(
|
||||
self,
|
||||
project_id: str,
|
||||
graph_id: str,
|
||||
enable_twitter: bool = True,
|
||||
enable_reddit: bool = True,
|
||||
) -> SimulationState:
|
||||
"""
|
||||
创建新的模拟
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
graph_id: Zep图谱ID
|
||||
enable_twitter: 是否启用Twitter模拟
|
||||
enable_reddit: 是否启用Reddit模拟
|
||||
|
||||
Returns:
|
||||
SimulationState
|
||||
"""
|
||||
import uuid
|
||||
simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
state = SimulationState(
|
||||
simulation_id=simulation_id,
|
||||
project_id=project_id,
|
||||
graph_id=graph_id,
|
||||
enable_twitter=enable_twitter,
|
||||
enable_reddit=enable_reddit,
|
||||
status=SimulationStatus.CREATED,
|
||||
)
|
||||
|
||||
self._save_simulation_state(state)
|
||||
logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}")
|
||||
|
||||
return state
|
||||
|
||||
def prepare_simulation(
|
||||
self,
|
||||
simulation_id: str,
|
||||
simulation_requirement: str,
|
||||
document_text: str,
|
||||
defined_entity_types: Optional[List[str]] = None,
|
||||
use_llm_for_profiles: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> SimulationState:
|
||||
"""
|
||||
准备模拟环境(全程自动化)
|
||||
|
||||
步骤:
|
||||
1. 从Zep图谱读取并过滤实体
|
||||
2. 为每个实体生成OASIS Agent Profile(可选LLM增强)
|
||||
3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等)
|
||||
4. 保存配置文件和Profile文件
|
||||
5. 复制预设脚本到模拟目录
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
simulation_requirement: 模拟需求描述(用于LLM生成配置)
|
||||
document_text: 原始文档内容(用于LLM理解背景)
|
||||
defined_entity_types: 预定义的实体类型(可选)
|
||||
use_llm_for_profiles: 是否使用LLM生成详细人设
|
||||
progress_callback: 进度回调函数 (stage, progress, message)
|
||||
|
||||
Returns:
|
||||
SimulationState
|
||||
"""
|
||||
state = self._load_simulation_state(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
try:
|
||||
state.status = SimulationStatus.PREPARING
|
||||
self._save_simulation_state(state)
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
|
||||
# ========== 阶段1: 读取并过滤实体 ==========
|
||||
if progress_callback:
|
||||
progress_callback("reading", 0, "正在连接Zep图谱...")
|
||||
|
||||
reader = ZepEntityReader()
|
||||
|
||||
if progress_callback:
|
||||
progress_callback("reading", 30, "正在读取节点数据...")
|
||||
|
||||
filtered = reader.filter_defined_entities(
|
||||
graph_id=state.graph_id,
|
||||
defined_entity_types=defined_entity_types,
|
||||
enrich_with_edges=True
|
||||
)
|
||||
|
||||
state.entities_count = filtered.filtered_count
|
||||
state.entity_types = list(filtered.entity_types)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"reading", 100,
|
||||
f"完成,共 {filtered.filtered_count} 个实体",
|
||||
current=filtered.filtered_count,
|
||||
total=filtered.filtered_count
|
||||
)
|
||||
|
||||
if filtered.filtered_count == 0:
|
||||
state.status = SimulationStatus.FAILED
|
||||
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建"
|
||||
self._save_simulation_state(state)
|
||||
return state
|
||||
|
||||
# ========== 阶段2: 生成Agent Profile ==========
|
||||
total_entities = len(filtered.entities)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 0,
|
||||
"开始生成...",
|
||||
current=0,
|
||||
total=total_entities
|
||||
)
|
||||
|
||||
generator = OasisProfileGenerator()
|
||||
|
||||
def profile_progress(current, total, msg):
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles",
|
||||
int(current / total * 100),
|
||||
msg,
|
||||
current=current,
|
||||
total=total,
|
||||
item_name=msg
|
||||
)
|
||||
|
||||
profiles = generator.generate_profiles_from_entities(
|
||||
entities=filtered.entities,
|
||||
use_llm=use_llm_for_profiles,
|
||||
progress_callback=profile_progress
|
||||
)
|
||||
|
||||
state.profiles_count = len(profiles)
|
||||
|
||||
# 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式)
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 95,
|
||||
"保存Profile文件...",
|
||||
current=total_entities,
|
||||
total=total_entities
|
||||
)
|
||||
|
||||
if state.enable_reddit:
|
||||
generator.save_profiles(
|
||||
profiles=profiles,
|
||||
file_path=os.path.join(sim_dir, "reddit_profiles.json"),
|
||||
platform="reddit"
|
||||
)
|
||||
|
||||
if state.enable_twitter:
|
||||
# Twitter使用CSV格式!这是OASIS的要求
|
||||
generator.save_profiles(
|
||||
profiles=profiles,
|
||||
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
|
||||
platform="twitter"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 100,
|
||||
f"完成,共 {len(profiles)} 个Profile",
|
||||
current=len(profiles),
|
||||
total=len(profiles)
|
||||
)
|
||||
|
||||
# ========== 阶段3: LLM智能生成模拟配置 ==========
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 0,
|
||||
"正在分析模拟需求...",
|
||||
current=0,
|
||||
total=3
|
||||
)
|
||||
|
||||
config_generator = SimulationConfigGenerator()
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 30,
|
||||
"正在调用LLM生成配置...",
|
||||
current=1,
|
||||
total=3
|
||||
)
|
||||
|
||||
sim_params = config_generator.generate_config(
|
||||
simulation_id=simulation_id,
|
||||
project_id=state.project_id,
|
||||
graph_id=state.graph_id,
|
||||
simulation_requirement=simulation_requirement,
|
||||
document_text=document_text,
|
||||
entities=filtered.entities,
|
||||
enable_twitter=state.enable_twitter,
|
||||
enable_reddit=state.enable_reddit
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 70,
|
||||
"正在保存配置文件...",
|
||||
current=2,
|
||||
total=3
|
||||
)
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
f.write(sim_params.to_json())
|
||||
|
||||
state.config_generated = True
|
||||
state.config_reasoning = sim_params.generation_reasoning
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 100,
|
||||
"配置生成完成",
|
||||
current=3,
|
||||
total=3
|
||||
)
|
||||
|
||||
# ========== 阶段4: 复制预设脚本 ==========
|
||||
script_files = ["run_twitter_simulation.py", "run_reddit_simulation.py",
|
||||
"run_parallel_simulation.py", "action_logger.py"]
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"copying_scripts", 0,
|
||||
"开始准备脚本...",
|
||||
current=0,
|
||||
total=len(script_files)
|
||||
)
|
||||
|
||||
self._copy_preset_scripts(sim_dir)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"copying_scripts", 100,
|
||||
f"完成,共 {len(script_files)} 个脚本",
|
||||
current=len(script_files),
|
||||
total=len(script_files)
|
||||
)
|
||||
|
||||
# 更新状态
|
||||
state.status = SimulationStatus.READY
|
||||
self._save_simulation_state(state)
|
||||
|
||||
logger.info(f"模拟准备完成: {simulation_id}, "
|
||||
f"entities={state.entities_count}, profiles={state.profiles_count}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
state.status = SimulationStatus.FAILED
|
||||
state.error = str(e)
|
||||
self._save_simulation_state(state)
|
||||
raise
|
||||
|
||||
def _copy_preset_scripts(self, sim_dir: str):
|
||||
"""复制预设脚本到模拟目录"""
|
||||
scripts = [
|
||||
"run_twitter_simulation.py",
|
||||
"run_reddit_simulation.py",
|
||||
"run_parallel_simulation.py"
|
||||
]
|
||||
|
||||
for script in scripts:
|
||||
src = os.path.join(self.SCRIPTS_DIR, script)
|
||||
dst = os.path.join(sim_dir, script)
|
||||
|
||||
if os.path.exists(src):
|
||||
shutil.copy2(src, dst)
|
||||
logger.debug(f"复制脚本: {script}")
|
||||
else:
|
||||
logger.warning(f"预设脚本不存在: {src}")
|
||||
|
||||
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
||||
"""获取模拟状态"""
|
||||
return self._load_simulation_state(simulation_id)
|
||||
|
||||
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
|
||||
"""列出所有模拟"""
|
||||
simulations = []
|
||||
|
||||
if os.path.exists(self.SIMULATION_DATA_DIR):
|
||||
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
|
||||
state = self._load_simulation_state(sim_id)
|
||||
if state:
|
||||
if project_id is None or state.project_id == project_id:
|
||||
simulations.append(state)
|
||||
|
||||
return simulations
|
||||
|
||||
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
||||
"""获取模拟的Agent Profile"""
|
||||
state = self._load_simulation_state(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
|
||||
|
||||
if not os.path.exists(profile_path):
|
||||
return []
|
||||
|
||||
with open(profile_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取模拟配置"""
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
return None
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
|
||||
"""获取运行说明"""
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
|
||||
return {
|
||||
"simulation_dir": sim_dir,
|
||||
"config_file": config_path,
|
||||
"commands": {
|
||||
"twitter": f"python run_twitter_simulation.py --config simulation_config.json",
|
||||
"reddit": f"python run_reddit_simulation.py --config simulation_config.json",
|
||||
"parallel": f"python run_parallel_simulation.py --config simulation_config.json",
|
||||
},
|
||||
"instructions": (
|
||||
f"1. 进入模拟目录: cd {sim_dir}\n"
|
||||
f"2. 激活conda环境: conda activate MiroFish\n"
|
||||
f"3. 运行模拟:\n"
|
||||
f" - 单独运行Twitter: python run_twitter_simulation.py --config simulation_config.json\n"
|
||||
f" - 单独运行Reddit: python run_reddit_simulation.py --config simulation_config.json\n"
|
||||
f" - 并行运行双平台: python run_parallel_simulation.py --config simulation_config.json"
|
||||
)
|
||||
}
|
||||
670
backend/app/services/simulation_runner.py
Normal file
670
backend/app/services/simulation_runner.py
Normal file
@@ -0,0 +1,670 @@
|
||||
"""
|
||||
OASIS模拟运行器
|
||||
在后台运行模拟并记录每个Agent的动作,支持实时状态监控
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import threading
|
||||
import subprocess
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from queue import Queue
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.simulation_runner')
|
||||
|
||||
|
||||
class RunnerStatus(str, Enum):
|
||||
"""运行器状态"""
|
||||
IDLE = "idle"
|
||||
STARTING = "starting"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
STOPPING = "stopping"
|
||||
STOPPED = "stopped"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""Agent动作记录"""
|
||||
round_num: int
|
||||
timestamp: str
|
||||
platform: str # twitter / reddit
|
||||
agent_id: int
|
||||
agent_name: str
|
||||
action_type: str # CREATE_POST, LIKE_POST, etc.
|
||||
action_args: Dict[str, Any] = field(default_factory=dict)
|
||||
result: Optional[str] = None
|
||||
success: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"round_num": self.round_num,
|
||||
"timestamp": self.timestamp,
|
||||
"platform": self.platform,
|
||||
"agent_id": self.agent_id,
|
||||
"agent_name": self.agent_name,
|
||||
"action_type": self.action_type,
|
||||
"action_args": self.action_args,
|
||||
"result": self.result,
|
||||
"success": self.success,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoundSummary:
|
||||
"""每轮摘要"""
|
||||
round_num: int
|
||||
start_time: str
|
||||
end_time: Optional[str] = None
|
||||
simulated_hour: int = 0
|
||||
twitter_actions: int = 0
|
||||
reddit_actions: int = 0
|
||||
active_agents: List[int] = field(default_factory=list)
|
||||
actions: List[AgentAction] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"round_num": self.round_num,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"simulated_hour": self.simulated_hour,
|
||||
"twitter_actions": self.twitter_actions,
|
||||
"reddit_actions": self.reddit_actions,
|
||||
"active_agents": self.active_agents,
|
||||
"actions_count": len(self.actions),
|
||||
"actions": [a.to_dict() for a in self.actions],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimulationRunState:
|
||||
"""模拟运行状态(实时)"""
|
||||
simulation_id: str
|
||||
runner_status: RunnerStatus = RunnerStatus.IDLE
|
||||
|
||||
# 进度信息
|
||||
current_round: int = 0
|
||||
total_rounds: int = 0
|
||||
simulated_hours: int = 0
|
||||
total_simulation_hours: int = 0
|
||||
|
||||
# 平台状态
|
||||
twitter_running: bool = False
|
||||
reddit_running: bool = False
|
||||
twitter_actions_count: int = 0
|
||||
reddit_actions_count: int = 0
|
||||
|
||||
# 每轮摘要
|
||||
rounds: List[RoundSummary] = field(default_factory=list)
|
||||
|
||||
# 最近动作(用于前端实时展示)
|
||||
recent_actions: List[AgentAction] = field(default_factory=list)
|
||||
max_recent_actions: int = 50
|
||||
|
||||
# 时间戳
|
||||
started_at: Optional[str] = None
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
completed_at: Optional[str] = None
|
||||
|
||||
# 错误信息
|
||||
error: Optional[str] = None
|
||||
|
||||
# 进程ID(用于停止)
|
||||
process_pid: Optional[int] = None
|
||||
|
||||
def add_action(self, action: AgentAction):
|
||||
"""添加动作到最近动作列表"""
|
||||
self.recent_actions.insert(0, action)
|
||||
if len(self.recent_actions) > self.max_recent_actions:
|
||||
self.recent_actions = self.recent_actions[:self.max_recent_actions]
|
||||
|
||||
if action.platform == "twitter":
|
||||
self.twitter_actions_count += 1
|
||||
else:
|
||||
self.reddit_actions_count += 1
|
||||
|
||||
self.updated_at = datetime.now().isoformat()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"runner_status": self.runner_status.value,
|
||||
"current_round": self.current_round,
|
||||
"total_rounds": self.total_rounds,
|
||||
"simulated_hours": self.simulated_hours,
|
||||
"total_simulation_hours": self.total_simulation_hours,
|
||||
"progress_percent": round(self.current_round / max(self.total_rounds, 1) * 100, 1),
|
||||
"twitter_running": self.twitter_running,
|
||||
"reddit_running": self.reddit_running,
|
||||
"twitter_actions_count": self.twitter_actions_count,
|
||||
"reddit_actions_count": self.reddit_actions_count,
|
||||
"total_actions_count": self.twitter_actions_count + self.reddit_actions_count,
|
||||
"started_at": self.started_at,
|
||||
"updated_at": self.updated_at,
|
||||
"completed_at": self.completed_at,
|
||||
"error": self.error,
|
||||
"process_pid": self.process_pid,
|
||||
}
|
||||
|
||||
def to_detail_dict(self) -> Dict[str, Any]:
|
||||
"""包含最近动作的详细信息"""
|
||||
result = self.to_dict()
|
||||
result["recent_actions"] = [a.to_dict() for a in self.recent_actions]
|
||||
result["rounds_count"] = len(self.rounds)
|
||||
return result
|
||||
|
||||
|
||||
class SimulationRunner:
|
||||
"""
|
||||
模拟运行器
|
||||
|
||||
负责:
|
||||
1. 在后台进程中运行OASIS模拟
|
||||
2. 解析运行日志,记录每个Agent的动作
|
||||
3. 提供实时状态查询接口
|
||||
4. 支持暂停/停止/恢复操作
|
||||
"""
|
||||
|
||||
# 运行状态存储目录
|
||||
RUN_STATE_DIR = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../../uploads/simulations'
|
||||
)
|
||||
|
||||
# 内存中的运行状态
|
||||
_run_states: Dict[str, SimulationRunState] = {}
|
||||
_processes: Dict[str, subprocess.Popen] = {}
|
||||
_action_queues: Dict[str, Queue] = {}
|
||||
_monitor_threads: Dict[str, threading.Thread] = {}
|
||||
|
||||
@classmethod
|
||||
def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
|
||||
"""获取运行状态"""
|
||||
if simulation_id in cls._run_states:
|
||||
return cls._run_states[simulation_id]
|
||||
|
||||
# 尝试从文件加载
|
||||
state = cls._load_run_state(simulation_id)
|
||||
if state:
|
||||
cls._run_states[simulation_id] = state
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
|
||||
"""从文件加载运行状态"""
|
||||
state_file = os.path.join(cls.RUN_STATE_DIR, simulation_id, "run_state.json")
|
||||
if not os.path.exists(state_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(state_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
state = SimulationRunState(
|
||||
simulation_id=simulation_id,
|
||||
runner_status=RunnerStatus(data.get("runner_status", "idle")),
|
||||
current_round=data.get("current_round", 0),
|
||||
total_rounds=data.get("total_rounds", 0),
|
||||
simulated_hours=data.get("simulated_hours", 0),
|
||||
total_simulation_hours=data.get("total_simulation_hours", 0),
|
||||
twitter_running=data.get("twitter_running", False),
|
||||
reddit_running=data.get("reddit_running", False),
|
||||
twitter_actions_count=data.get("twitter_actions_count", 0),
|
||||
reddit_actions_count=data.get("reddit_actions_count", 0),
|
||||
started_at=data.get("started_at"),
|
||||
updated_at=data.get("updated_at", datetime.now().isoformat()),
|
||||
completed_at=data.get("completed_at"),
|
||||
error=data.get("error"),
|
||||
process_pid=data.get("process_pid"),
|
||||
)
|
||||
|
||||
# 加载最近动作
|
||||
actions_data = data.get("recent_actions", [])
|
||||
for a in actions_data:
|
||||
state.recent_actions.append(AgentAction(
|
||||
round_num=a.get("round_num", 0),
|
||||
timestamp=a.get("timestamp", ""),
|
||||
platform=a.get("platform", ""),
|
||||
agent_id=a.get("agent_id", 0),
|
||||
agent_name=a.get("agent_name", ""),
|
||||
action_type=a.get("action_type", ""),
|
||||
action_args=a.get("action_args", {}),
|
||||
result=a.get("result"),
|
||||
success=a.get("success", True),
|
||||
))
|
||||
|
||||
return state
|
||||
except Exception as e:
|
||||
logger.error(f"加载运行状态失败: {str(e)}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _save_run_state(cls, state: SimulationRunState):
|
||||
"""保存运行状态到文件"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id)
|
||||
os.makedirs(sim_dir, exist_ok=True)
|
||||
state_file = os.path.join(sim_dir, "run_state.json")
|
||||
|
||||
data = state.to_detail_dict()
|
||||
|
||||
with open(state_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
cls._run_states[state.simulation_id] = state
|
||||
|
||||
@classmethod
|
||||
def start_simulation(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
platform: str = "parallel" # twitter / reddit / parallel
|
||||
) -> SimulationRunState:
|
||||
"""
|
||||
启动模拟
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
platform: 运行平台 (twitter/reddit/parallel)
|
||||
|
||||
Returns:
|
||||
SimulationRunState
|
||||
"""
|
||||
# 检查是否已在运行
|
||||
existing = cls.get_run_state(simulation_id)
|
||||
if existing and existing.runner_status in [RunnerStatus.RUNNING, RunnerStatus.STARTING]:
|
||||
raise ValueError(f"模拟已在运行中: {simulation_id}")
|
||||
|
||||
# 加载模拟配置
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(f"模拟配置不存在,请先调用 /prepare 接口")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# 初始化运行状态
|
||||
time_config = config.get("time_config", {})
|
||||
total_hours = time_config.get("total_simulation_hours", 72)
|
||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||
total_rounds = int(total_hours * 60 / minutes_per_round)
|
||||
|
||||
state = SimulationRunState(
|
||||
simulation_id=simulation_id,
|
||||
runner_status=RunnerStatus.STARTING,
|
||||
total_rounds=total_rounds,
|
||||
total_simulation_hours=total_hours,
|
||||
started_at=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
cls._save_run_state(state)
|
||||
|
||||
# 确定运行哪个脚本
|
||||
if platform == "twitter":
|
||||
script_name = "run_twitter_simulation.py"
|
||||
state.twitter_running = True
|
||||
elif platform == "reddit":
|
||||
script_name = "run_reddit_simulation.py"
|
||||
state.reddit_running = True
|
||||
else:
|
||||
script_name = "run_parallel_simulation.py"
|
||||
state.twitter_running = True
|
||||
state.reddit_running = True
|
||||
|
||||
script_path = os.path.join(sim_dir, script_name)
|
||||
|
||||
if not os.path.exists(script_path):
|
||||
raise ValueError(f"脚本不存在: {script_path}")
|
||||
|
||||
# 创建动作队列
|
||||
action_queue = Queue()
|
||||
cls._action_queues[simulation_id] = action_queue
|
||||
|
||||
# 启动模拟进程
|
||||
try:
|
||||
# 构建运行命令
|
||||
cmd = [
|
||||
sys.executable, # Python解释器
|
||||
script_path,
|
||||
"--config", "simulation_config.json",
|
||||
"--action-log", "actions.jsonl", # 动作日志文件
|
||||
]
|
||||
|
||||
# 设置工作目录为模拟目录
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=sim_dir,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
state.process_pid = process.pid
|
||||
state.runner_status = RunnerStatus.RUNNING
|
||||
cls._processes[simulation_id] = process
|
||||
cls._save_run_state(state)
|
||||
|
||||
# 启动监控线程
|
||||
monitor_thread = threading.Thread(
|
||||
target=cls._monitor_simulation,
|
||||
args=(simulation_id,),
|
||||
daemon=True
|
||||
)
|
||||
monitor_thread.start()
|
||||
cls._monitor_threads[simulation_id] = monitor_thread
|
||||
|
||||
logger.info(f"模拟启动成功: {simulation_id}, pid={process.pid}, platform={platform}")
|
||||
|
||||
except Exception as e:
|
||||
state.runner_status = RunnerStatus.FAILED
|
||||
state.error = str(e)
|
||||
cls._save_run_state(state)
|
||||
raise
|
||||
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def _monitor_simulation(cls, simulation_id: str):
|
||||
"""监控模拟进程,解析动作日志"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
actions_log = os.path.join(sim_dir, "actions.jsonl")
|
||||
|
||||
process = cls._processes.get(simulation_id)
|
||||
state = cls.get_run_state(simulation_id)
|
||||
|
||||
if not process or not state:
|
||||
return
|
||||
|
||||
last_position = 0
|
||||
|
||||
try:
|
||||
while process.poll() is None: # 进程仍在运行
|
||||
# 读取动作日志
|
||||
if os.path.exists(actions_log):
|
||||
with open(actions_log, 'r', encoding='utf-8') as f:
|
||||
f.seek(last_position)
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
action_data = json.loads(line)
|
||||
action = AgentAction(
|
||||
round_num=action_data.get("round", 0),
|
||||
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
|
||||
platform=action_data.get("platform", "unknown"),
|
||||
agent_id=action_data.get("agent_id", 0),
|
||||
agent_name=action_data.get("agent_name", ""),
|
||||
action_type=action_data.get("action_type", ""),
|
||||
action_args=action_data.get("action_args", {}),
|
||||
result=action_data.get("result"),
|
||||
success=action_data.get("success", True),
|
||||
)
|
||||
state.add_action(action)
|
||||
|
||||
# 更新轮次
|
||||
if action.round_num > state.current_round:
|
||||
state.current_round = action.round_num
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
last_position = f.tell()
|
||||
|
||||
# 定期保存状态
|
||||
cls._save_run_state(state)
|
||||
time.sleep(1) # 每秒检查一次
|
||||
|
||||
# 进程结束
|
||||
exit_code = process.returncode
|
||||
|
||||
if exit_code == 0:
|
||||
state.runner_status = RunnerStatus.COMPLETED
|
||||
state.completed_at = datetime.now().isoformat()
|
||||
logger.info(f"模拟完成: {simulation_id}")
|
||||
else:
|
||||
state.runner_status = RunnerStatus.FAILED
|
||||
stderr = process.stderr.read() if process.stderr else ""
|
||||
state.error = f"进程退出码: {exit_code}, 错误: {stderr[:500]}"
|
||||
logger.error(f"模拟失败: {simulation_id}, error={state.error}")
|
||||
|
||||
state.twitter_running = False
|
||||
state.reddit_running = False
|
||||
cls._save_run_state(state)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"监控线程异常: {simulation_id}, error={str(e)}")
|
||||
state.runner_status = RunnerStatus.FAILED
|
||||
state.error = str(e)
|
||||
cls._save_run_state(state)
|
||||
|
||||
finally:
|
||||
# 清理
|
||||
cls._processes.pop(simulation_id, None)
|
||||
cls._action_queues.pop(simulation_id, None)
|
||||
|
||||
@classmethod
|
||||
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
|
||||
"""停止模拟"""
|
||||
state = cls.get_run_state(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
if state.runner_status not in [RunnerStatus.RUNNING, RunnerStatus.PAUSED]:
|
||||
raise ValueError(f"模拟未在运行: {simulation_id}, status={state.runner_status}")
|
||||
|
||||
state.runner_status = RunnerStatus.STOPPING
|
||||
cls._save_run_state(state)
|
||||
|
||||
# 终止进程
|
||||
process = cls._processes.get(simulation_id)
|
||||
if process:
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
|
||||
state.runner_status = RunnerStatus.STOPPED
|
||||
state.twitter_running = False
|
||||
state.reddit_running = False
|
||||
state.completed_at = datetime.now().isoformat()
|
||||
cls._save_run_state(state)
|
||||
|
||||
logger.info(f"模拟已停止: {simulation_id}")
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def get_actions(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
platform: Optional[str] = None,
|
||||
agent_id: Optional[int] = None,
|
||||
round_num: Optional[int] = None
|
||||
) -> List[AgentAction]:
|
||||
"""
|
||||
获取动作历史
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
platform: 过滤平台
|
||||
agent_id: 过滤Agent
|
||||
round_num: 过滤轮次
|
||||
|
||||
Returns:
|
||||
动作列表
|
||||
"""
|
||||
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
|
||||
actions_log = os.path.join(sim_dir, "actions.jsonl")
|
||||
|
||||
if not os.path.exists(actions_log):
|
||||
return []
|
||||
|
||||
actions = []
|
||||
|
||||
with open(actions_log, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
|
||||
# 过滤
|
||||
if platform and data.get("platform") != platform:
|
||||
continue
|
||||
if agent_id is not None and data.get("agent_id") != agent_id:
|
||||
continue
|
||||
if round_num is not None and data.get("round") != round_num:
|
||||
continue
|
||||
|
||||
actions.append(AgentAction(
|
||||
round_num=data.get("round", 0),
|
||||
timestamp=data.get("timestamp", ""),
|
||||
platform=data.get("platform", ""),
|
||||
agent_id=data.get("agent_id", 0),
|
||||
agent_name=data.get("agent_name", ""),
|
||||
action_type=data.get("action_type", ""),
|
||||
action_args=data.get("action_args", {}),
|
||||
result=data.get("result"),
|
||||
success=data.get("success", True),
|
||||
))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# 按时间倒序排列
|
||||
actions.reverse()
|
||||
|
||||
# 分页
|
||||
return actions[offset:offset + limit]
|
||||
|
||||
@classmethod
|
||||
def get_timeline(
|
||||
cls,
|
||||
simulation_id: str,
|
||||
start_round: int = 0,
|
||||
end_round: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取模拟时间线(按轮次汇总)
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
start_round: 起始轮次
|
||||
end_round: 结束轮次
|
||||
|
||||
Returns:
|
||||
每轮的汇总信息
|
||||
"""
|
||||
actions = cls.get_actions(simulation_id, limit=10000)
|
||||
|
||||
# 按轮次分组
|
||||
rounds: Dict[int, Dict[str, Any]] = {}
|
||||
|
||||
for action in actions:
|
||||
round_num = action.round_num
|
||||
|
||||
if round_num < start_round:
|
||||
continue
|
||||
if end_round is not None and round_num > end_round:
|
||||
continue
|
||||
|
||||
if round_num not in rounds:
|
||||
rounds[round_num] = {
|
||||
"round_num": round_num,
|
||||
"twitter_actions": 0,
|
||||
"reddit_actions": 0,
|
||||
"active_agents": set(),
|
||||
"action_types": {},
|
||||
"first_action_time": action.timestamp,
|
||||
"last_action_time": action.timestamp,
|
||||
}
|
||||
|
||||
r = rounds[round_num]
|
||||
|
||||
if action.platform == "twitter":
|
||||
r["twitter_actions"] += 1
|
||||
else:
|
||||
r["reddit_actions"] += 1
|
||||
|
||||
r["active_agents"].add(action.agent_id)
|
||||
r["action_types"][action.action_type] = r["action_types"].get(action.action_type, 0) + 1
|
||||
r["last_action_time"] = action.timestamp
|
||||
|
||||
# 转换为列表
|
||||
result = []
|
||||
for round_num in sorted(rounds.keys()):
|
||||
r = rounds[round_num]
|
||||
result.append({
|
||||
"round_num": round_num,
|
||||
"twitter_actions": r["twitter_actions"],
|
||||
"reddit_actions": r["reddit_actions"],
|
||||
"total_actions": r["twitter_actions"] + r["reddit_actions"],
|
||||
"active_agents_count": len(r["active_agents"]),
|
||||
"active_agents": list(r["active_agents"]),
|
||||
"action_types": r["action_types"],
|
||||
"first_action_time": r["first_action_time"],
|
||||
"last_action_time": r["last_action_time"],
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取每个Agent的统计信息
|
||||
|
||||
Returns:
|
||||
Agent统计列表
|
||||
"""
|
||||
actions = cls.get_actions(simulation_id, limit=10000)
|
||||
|
||||
agent_stats: Dict[int, Dict[str, Any]] = {}
|
||||
|
||||
for action in actions:
|
||||
agent_id = action.agent_id
|
||||
|
||||
if agent_id not in agent_stats:
|
||||
agent_stats[agent_id] = {
|
||||
"agent_id": agent_id,
|
||||
"agent_name": action.agent_name,
|
||||
"total_actions": 0,
|
||||
"twitter_actions": 0,
|
||||
"reddit_actions": 0,
|
||||
"action_types": {},
|
||||
"first_action_time": action.timestamp,
|
||||
"last_action_time": action.timestamp,
|
||||
}
|
||||
|
||||
stats = agent_stats[agent_id]
|
||||
stats["total_actions"] += 1
|
||||
|
||||
if action.platform == "twitter":
|
||||
stats["twitter_actions"] += 1
|
||||
else:
|
||||
stats["reddit_actions"] += 1
|
||||
|
||||
stats["action_types"][action.action_type] = stats["action_types"].get(action.action_type, 0) + 1
|
||||
stats["last_action_time"] = action.timestamp
|
||||
|
||||
# 按总动作数排序
|
||||
result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True)
|
||||
|
||||
return result
|
||||
|
||||
386
backend/app/services/zep_entity_reader.py
Normal file
386
backend/app/services/zep_entity_reader.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Zep实体读取与过滤服务
|
||||
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Set
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.zep_entity_reader')
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityNode:
|
||||
"""实体节点数据结构"""
|
||||
uuid: str
|
||||
name: str
|
||||
labels: List[str]
|
||||
summary: str
|
||||
attributes: Dict[str, Any]
|
||||
# 相关的边信息
|
||||
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
||||
# 相关的其他节点信息
|
||||
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"uuid": self.uuid,
|
||||
"name": self.name,
|
||||
"labels": self.labels,
|
||||
"summary": self.summary,
|
||||
"attributes": self.attributes,
|
||||
"related_edges": self.related_edges,
|
||||
"related_nodes": self.related_nodes,
|
||||
}
|
||||
|
||||
def get_entity_type(self) -> Optional[str]:
|
||||
"""获取实体类型(排除默认的Entity标签)"""
|
||||
for label in self.labels:
|
||||
if label not in ["Entity", "Node"]:
|
||||
return label
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilteredEntities:
|
||||
"""过滤后的实体集合"""
|
||||
entities: List[EntityNode]
|
||||
entity_types: Set[str]
|
||||
total_count: int
|
||||
filtered_count: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"entities": [e.to_dict() for e in self.entities],
|
||||
"entity_types": list(self.entity_types),
|
||||
"total_count": self.total_count,
|
||||
"filtered_count": self.filtered_count,
|
||||
}
|
||||
|
||||
|
||||
class ZepEntityReader:
|
||||
"""
|
||||
Zep实体读取与过滤服务
|
||||
|
||||
主要功能:
|
||||
1. 从Zep图谱读取所有节点
|
||||
2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点)
|
||||
3. 获取每个实体的相关边和关联节点信息
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有节点
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
Returns:
|
||||
节点列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
|
||||
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
nodes_data.append({
|
||||
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
"name": node.name or "",
|
||||
"labels": node.labels or [],
|
||||
"summary": node.summary or "",
|
||||
"attributes": node.attributes or {},
|
||||
})
|
||||
|
||||
logger.info(f"共获取 {len(nodes_data)} 个节点")
|
||||
return nodes_data
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有边
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
|
||||
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
edges_data.append({
|
||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||
"name": edge.name or "",
|
||||
"fact": edge.fact or "",
|
||||
"source_node_uuid": edge.source_node_uuid,
|
||||
"target_node_uuid": edge.target_node_uuid,
|
||||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
logger.info(f"共获取 {len(edges_data)} 条边")
|
||||
return edges_data
|
||||
|
||||
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定节点的所有相关边
|
||||
|
||||
Args:
|
||||
node_uuid: 节点UUID
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
"""
|
||||
try:
|
||||
edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
edges_data.append({
|
||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||
"name": edge.name or "",
|
||||
"fact": edge.fact or "",
|
||||
"source_node_uuid": edge.source_node_uuid,
|
||||
"target_node_uuid": edge.target_node_uuid,
|
||||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
return edges_data
|
||||
except Exception as e:
|
||||
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def filter_defined_entities(
|
||||
self,
|
||||
graph_id: str,
|
||||
defined_entity_types: Optional[List[str]] = None,
|
||||
enrich_with_edges: bool = True
|
||||
) -> FilteredEntities:
|
||||
"""
|
||||
筛选出符合预定义实体类型的节点
|
||||
|
||||
筛选逻辑:
|
||||
- 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过
|
||||
- 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型)
|
||||
enrich_with_edges: 是否获取每个实体的相关边信息
|
||||
|
||||
Returns:
|
||||
FilteredEntities: 过滤后的实体集合
|
||||
"""
|
||||
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
|
||||
|
||||
# 获取所有节点
|
||||
all_nodes = self.get_all_nodes(graph_id)
|
||||
total_count = len(all_nodes)
|
||||
|
||||
# 获取所有边(用于后续关联查找)
|
||||
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
||||
|
||||
# 构建节点UUID到节点数据的映射
|
||||
node_map = {n["uuid"]: n for n in all_nodes}
|
||||
|
||||
# 筛选符合条件的实体
|
||||
filtered_entities = []
|
||||
entity_types_found = set()
|
||||
|
||||
for node in all_nodes:
|
||||
labels = node.get("labels", [])
|
||||
|
||||
# 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签
|
||||
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
||||
|
||||
if not custom_labels:
|
||||
# 只有默认标签,跳过
|
||||
continue
|
||||
|
||||
# 如果指定了预定义类型,检查是否匹配
|
||||
if defined_entity_types:
|
||||
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
||||
if not matching_labels:
|
||||
continue
|
||||
entity_type = matching_labels[0]
|
||||
else:
|
||||
entity_type = custom_labels[0]
|
||||
|
||||
entity_types_found.add(entity_type)
|
||||
|
||||
# 创建实体节点对象
|
||||
entity = EntityNode(
|
||||
uuid=node["uuid"],
|
||||
name=node["name"],
|
||||
labels=labels,
|
||||
summary=node["summary"],
|
||||
attributes=node["attributes"],
|
||||
)
|
||||
|
||||
# 获取相关边和节点
|
||||
if enrich_with_edges:
|
||||
related_edges = []
|
||||
related_node_uuids = set()
|
||||
|
||||
for edge in all_edges:
|
||||
if edge["source_node_uuid"] == node["uuid"]:
|
||||
related_edges.append({
|
||||
"direction": "outgoing",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"target_node_uuid": edge["target_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["target_node_uuid"])
|
||||
elif edge["target_node_uuid"] == node["uuid"]:
|
||||
related_edges.append({
|
||||
"direction": "incoming",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"source_node_uuid": edge["source_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["source_node_uuid"])
|
||||
|
||||
entity.related_edges = related_edges
|
||||
|
||||
# 获取关联节点的基本信息
|
||||
related_nodes = []
|
||||
for related_uuid in related_node_uuids:
|
||||
if related_uuid in node_map:
|
||||
related_node = node_map[related_uuid]
|
||||
related_nodes.append({
|
||||
"uuid": related_node["uuid"],
|
||||
"name": related_node["name"],
|
||||
"labels": related_node["labels"],
|
||||
"summary": related_node.get("summary", ""),
|
||||
})
|
||||
|
||||
entity.related_nodes = related_nodes
|
||||
|
||||
filtered_entities.append(entity)
|
||||
|
||||
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
|
||||
f"实体类型: {entity_types_found}")
|
||||
|
||||
return FilteredEntities(
|
||||
entities=filtered_entities,
|
||||
entity_types=entity_types_found,
|
||||
total_count=total_count,
|
||||
filtered_count=len(filtered_entities),
|
||||
)
|
||||
|
||||
def get_entity_with_context(
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_uuid: str
|
||||
) -> Optional[EntityNode]:
|
||||
"""
|
||||
获取单个实体及其完整上下文(边和关联节点)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
entity_uuid: 实体UUID
|
||||
|
||||
Returns:
|
||||
EntityNode或None
|
||||
"""
|
||||
try:
|
||||
# 获取节点
|
||||
node = self.client.graph.node.get(uuid_=entity_uuid)
|
||||
|
||||
if not node:
|
||||
return None
|
||||
|
||||
# 获取节点的边
|
||||
edges = self.get_node_edges(entity_uuid)
|
||||
|
||||
# 获取所有节点用于关联查找
|
||||
all_nodes = self.get_all_nodes(graph_id)
|
||||
node_map = {n["uuid"]: n for n in all_nodes}
|
||||
|
||||
# 处理相关边和节点
|
||||
related_edges = []
|
||||
related_node_uuids = set()
|
||||
|
||||
for edge in edges:
|
||||
if edge["source_node_uuid"] == entity_uuid:
|
||||
related_edges.append({
|
||||
"direction": "outgoing",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"target_node_uuid": edge["target_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["target_node_uuid"])
|
||||
else:
|
||||
related_edges.append({
|
||||
"direction": "incoming",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"source_node_uuid": edge["source_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["source_node_uuid"])
|
||||
|
||||
# 获取关联节点信息
|
||||
related_nodes = []
|
||||
for related_uuid in related_node_uuids:
|
||||
if related_uuid in node_map:
|
||||
related_node = node_map[related_uuid]
|
||||
related_nodes.append({
|
||||
"uuid": related_node["uuid"],
|
||||
"name": related_node["name"],
|
||||
"labels": related_node["labels"],
|
||||
"summary": related_node.get("summary", ""),
|
||||
})
|
||||
|
||||
return EntityNode(
|
||||
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
name=node.name or "",
|
||||
labels=node.labels or [],
|
||||
summary=node.summary or "",
|
||||
attributes=node.attributes or {},
|
||||
related_edges=related_edges,
|
||||
related_nodes=related_nodes,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_entities_by_type(
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_type: str,
|
||||
enrich_with_edges: bool = True
|
||||
) -> List[EntityNode]:
|
||||
"""
|
||||
获取指定类型的所有实体
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
entity_type: 实体类型(如 "Student", "PublicFigure" 等)
|
||||
enrich_with_edges: 是否获取相关边信息
|
||||
|
||||
Returns:
|
||||
实体列表
|
||||
"""
|
||||
result = self.filter_defined_entities(
|
||||
graph_id=graph_id,
|
||||
defined_entity_types=[entity_type],
|
||||
enrich_with_edges=enrich_with_edges
|
||||
)
|
||||
return result.entities
|
||||
|
||||
|
||||
Reference in New Issue
Block a user