Enhance backend functionality with OASIS simulation features
- Updated README.md to include new simulation scripts and configuration details for OASIS, including API retry mechanisms and environment variable settings. - Added simulation management and configuration generation services to streamline the simulation process across Twitter and Reddit platforms. - Introduced new API routes for simulation-related operations, including entity retrieval and simulation status management. - Implemented a robust retry mechanism for external API calls to improve system stability. - Enhanced task management model to include detailed progress tracking. - Added logging capabilities for action tracking during simulations. - Included new scripts for running parallel simulations and testing profile formats.
This commit is contained in:
546
backend/app/services/simulation_manager.py
Normal file
546
backend/app/services/simulation_manager.py
Normal file
@@ -0,0 +1,546 @@
|
||||
"""
|
||||
OASIS模拟管理器
|
||||
管理Twitter和Reddit双平台并行模拟
|
||||
使用预设脚本 + LLM智能生成配置参数
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_entity_reader import ZepEntityReader, FilteredEntities
|
||||
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
|
||||
from .simulation_config_generator import SimulationConfigGenerator, SimulationParameters
|
||||
|
||||
logger = get_logger('mirofish.simulation')
|
||||
|
||||
|
||||
class SimulationStatus(str, Enum):
|
||||
"""模拟状态"""
|
||||
CREATED = "created"
|
||||
PREPARING = "preparing"
|
||||
READY = "ready"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class PlatformType(str, Enum):
|
||||
"""平台类型"""
|
||||
TWITTER = "twitter"
|
||||
REDDIT = "reddit"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimulationState:
|
||||
"""模拟状态"""
|
||||
simulation_id: str
|
||||
project_id: str
|
||||
graph_id: str
|
||||
|
||||
# 平台启用状态
|
||||
enable_twitter: bool = True
|
||||
enable_reddit: bool = True
|
||||
|
||||
# 状态
|
||||
status: SimulationStatus = SimulationStatus.CREATED
|
||||
|
||||
# 准备阶段数据
|
||||
entities_count: int = 0
|
||||
profiles_count: int = 0
|
||||
entity_types: List[str] = field(default_factory=list)
|
||||
|
||||
# 配置生成信息
|
||||
config_generated: bool = False
|
||||
config_reasoning: str = ""
|
||||
|
||||
# 运行时数据
|
||||
current_round: int = 0
|
||||
twitter_status: str = "not_started"
|
||||
reddit_status: str = "not_started"
|
||||
|
||||
# 时间戳
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
# 错误信息
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""完整状态字典(内部使用)"""
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"project_id": self.project_id,
|
||||
"graph_id": self.graph_id,
|
||||
"enable_twitter": self.enable_twitter,
|
||||
"enable_reddit": self.enable_reddit,
|
||||
"status": self.status.value,
|
||||
"entities_count": self.entities_count,
|
||||
"profiles_count": self.profiles_count,
|
||||
"entity_types": self.entity_types,
|
||||
"config_generated": self.config_generated,
|
||||
"config_reasoning": self.config_reasoning,
|
||||
"current_round": self.current_round,
|
||||
"twitter_status": self.twitter_status,
|
||||
"reddit_status": self.reddit_status,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
def to_simple_dict(self) -> Dict[str, Any]:
|
||||
"""简化状态字典(API返回使用)"""
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"project_id": self.project_id,
|
||||
"graph_id": self.graph_id,
|
||||
"status": self.status.value,
|
||||
"entities_count": self.entities_count,
|
||||
"profiles_count": self.profiles_count,
|
||||
"entity_types": self.entity_types,
|
||||
"config_generated": self.config_generated,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
class SimulationManager:
|
||||
"""
|
||||
模拟管理器
|
||||
|
||||
核心功能:
|
||||
1. 从Zep图谱读取实体并过滤
|
||||
2. 生成OASIS Agent Profile
|
||||
3. 使用LLM智能生成模拟配置参数
|
||||
4. 准备预设脚本所需的所有文件
|
||||
"""
|
||||
|
||||
# 模拟数据存储目录
|
||||
SIMULATION_DATA_DIR = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../../uploads/simulations'
|
||||
)
|
||||
|
||||
# 预设脚本目录
|
||||
SCRIPTS_DIR = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'../../scripts'
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
# 确保目录存在
|
||||
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
||||
|
||||
# 内存中的模拟状态缓存
|
||||
self._simulations: Dict[str, SimulationState] = {}
|
||||
|
||||
def _get_simulation_dir(self, simulation_id: str) -> str:
|
||||
"""获取模拟数据目录"""
|
||||
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
|
||||
os.makedirs(sim_dir, exist_ok=True)
|
||||
return sim_dir
|
||||
|
||||
def _save_simulation_state(self, state: SimulationState):
|
||||
"""保存模拟状态到文件"""
|
||||
sim_dir = self._get_simulation_dir(state.simulation_id)
|
||||
state_file = os.path.join(sim_dir, "state.json")
|
||||
|
||||
state.updated_at = datetime.now().isoformat()
|
||||
|
||||
with open(state_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(state.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
|
||||
self._simulations[state.simulation_id] = state
|
||||
|
||||
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
|
||||
"""从文件加载模拟状态"""
|
||||
if simulation_id in self._simulations:
|
||||
return self._simulations[simulation_id]
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
state_file = os.path.join(sim_dir, "state.json")
|
||||
|
||||
if not os.path.exists(state_file):
|
||||
return None
|
||||
|
||||
with open(state_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
state = SimulationState(
|
||||
simulation_id=simulation_id,
|
||||
project_id=data.get("project_id", ""),
|
||||
graph_id=data.get("graph_id", ""),
|
||||
enable_twitter=data.get("enable_twitter", True),
|
||||
enable_reddit=data.get("enable_reddit", True),
|
||||
status=SimulationStatus(data.get("status", "created")),
|
||||
entities_count=data.get("entities_count", 0),
|
||||
profiles_count=data.get("profiles_count", 0),
|
||||
entity_types=data.get("entity_types", []),
|
||||
config_generated=data.get("config_generated", False),
|
||||
config_reasoning=data.get("config_reasoning", ""),
|
||||
current_round=data.get("current_round", 0),
|
||||
twitter_status=data.get("twitter_status", "not_started"),
|
||||
reddit_status=data.get("reddit_status", "not_started"),
|
||||
created_at=data.get("created_at", datetime.now().isoformat()),
|
||||
updated_at=data.get("updated_at", datetime.now().isoformat()),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
self._simulations[simulation_id] = state
|
||||
return state
|
||||
|
||||
def create_simulation(
|
||||
self,
|
||||
project_id: str,
|
||||
graph_id: str,
|
||||
enable_twitter: bool = True,
|
||||
enable_reddit: bool = True,
|
||||
) -> SimulationState:
|
||||
"""
|
||||
创建新的模拟
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
graph_id: Zep图谱ID
|
||||
enable_twitter: 是否启用Twitter模拟
|
||||
enable_reddit: 是否启用Reddit模拟
|
||||
|
||||
Returns:
|
||||
SimulationState
|
||||
"""
|
||||
import uuid
|
||||
simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
state = SimulationState(
|
||||
simulation_id=simulation_id,
|
||||
project_id=project_id,
|
||||
graph_id=graph_id,
|
||||
enable_twitter=enable_twitter,
|
||||
enable_reddit=enable_reddit,
|
||||
status=SimulationStatus.CREATED,
|
||||
)
|
||||
|
||||
self._save_simulation_state(state)
|
||||
logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}")
|
||||
|
||||
return state
|
||||
|
||||
def prepare_simulation(
|
||||
self,
|
||||
simulation_id: str,
|
||||
simulation_requirement: str,
|
||||
document_text: str,
|
||||
defined_entity_types: Optional[List[str]] = None,
|
||||
use_llm_for_profiles: bool = True,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> SimulationState:
|
||||
"""
|
||||
准备模拟环境(全程自动化)
|
||||
|
||||
步骤:
|
||||
1. 从Zep图谱读取并过滤实体
|
||||
2. 为每个实体生成OASIS Agent Profile(可选LLM增强)
|
||||
3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等)
|
||||
4. 保存配置文件和Profile文件
|
||||
5. 复制预设脚本到模拟目录
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
simulation_requirement: 模拟需求描述(用于LLM生成配置)
|
||||
document_text: 原始文档内容(用于LLM理解背景)
|
||||
defined_entity_types: 预定义的实体类型(可选)
|
||||
use_llm_for_profiles: 是否使用LLM生成详细人设
|
||||
progress_callback: 进度回调函数 (stage, progress, message)
|
||||
|
||||
Returns:
|
||||
SimulationState
|
||||
"""
|
||||
state = self._load_simulation_state(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
try:
|
||||
state.status = SimulationStatus.PREPARING
|
||||
self._save_simulation_state(state)
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
|
||||
# ========== 阶段1: 读取并过滤实体 ==========
|
||||
if progress_callback:
|
||||
progress_callback("reading", 0, "正在连接Zep图谱...")
|
||||
|
||||
reader = ZepEntityReader()
|
||||
|
||||
if progress_callback:
|
||||
progress_callback("reading", 30, "正在读取节点数据...")
|
||||
|
||||
filtered = reader.filter_defined_entities(
|
||||
graph_id=state.graph_id,
|
||||
defined_entity_types=defined_entity_types,
|
||||
enrich_with_edges=True
|
||||
)
|
||||
|
||||
state.entities_count = filtered.filtered_count
|
||||
state.entity_types = list(filtered.entity_types)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"reading", 100,
|
||||
f"完成,共 {filtered.filtered_count} 个实体",
|
||||
current=filtered.filtered_count,
|
||||
total=filtered.filtered_count
|
||||
)
|
||||
|
||||
if filtered.filtered_count == 0:
|
||||
state.status = SimulationStatus.FAILED
|
||||
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建"
|
||||
self._save_simulation_state(state)
|
||||
return state
|
||||
|
||||
# ========== 阶段2: 生成Agent Profile ==========
|
||||
total_entities = len(filtered.entities)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 0,
|
||||
"开始生成...",
|
||||
current=0,
|
||||
total=total_entities
|
||||
)
|
||||
|
||||
generator = OasisProfileGenerator()
|
||||
|
||||
def profile_progress(current, total, msg):
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles",
|
||||
int(current / total * 100),
|
||||
msg,
|
||||
current=current,
|
||||
total=total,
|
||||
item_name=msg
|
||||
)
|
||||
|
||||
profiles = generator.generate_profiles_from_entities(
|
||||
entities=filtered.entities,
|
||||
use_llm=use_llm_for_profiles,
|
||||
progress_callback=profile_progress
|
||||
)
|
||||
|
||||
state.profiles_count = len(profiles)
|
||||
|
||||
# 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式)
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 95,
|
||||
"保存Profile文件...",
|
||||
current=total_entities,
|
||||
total=total_entities
|
||||
)
|
||||
|
||||
if state.enable_reddit:
|
||||
generator.save_profiles(
|
||||
profiles=profiles,
|
||||
file_path=os.path.join(sim_dir, "reddit_profiles.json"),
|
||||
platform="reddit"
|
||||
)
|
||||
|
||||
if state.enable_twitter:
|
||||
# Twitter使用CSV格式!这是OASIS的要求
|
||||
generator.save_profiles(
|
||||
profiles=profiles,
|
||||
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
|
||||
platform="twitter"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 100,
|
||||
f"完成,共 {len(profiles)} 个Profile",
|
||||
current=len(profiles),
|
||||
total=len(profiles)
|
||||
)
|
||||
|
||||
# ========== 阶段3: LLM智能生成模拟配置 ==========
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 0,
|
||||
"正在分析模拟需求...",
|
||||
current=0,
|
||||
total=3
|
||||
)
|
||||
|
||||
config_generator = SimulationConfigGenerator()
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 30,
|
||||
"正在调用LLM生成配置...",
|
||||
current=1,
|
||||
total=3
|
||||
)
|
||||
|
||||
sim_params = config_generator.generate_config(
|
||||
simulation_id=simulation_id,
|
||||
project_id=state.project_id,
|
||||
graph_id=state.graph_id,
|
||||
simulation_requirement=simulation_requirement,
|
||||
document_text=document_text,
|
||||
entities=filtered.entities,
|
||||
enable_twitter=state.enable_twitter,
|
||||
enable_reddit=state.enable_reddit
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 70,
|
||||
"正在保存配置文件...",
|
||||
current=2,
|
||||
total=3
|
||||
)
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
f.write(sim_params.to_json())
|
||||
|
||||
state.config_generated = True
|
||||
state.config_reasoning = sim_params.generation_reasoning
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 100,
|
||||
"配置生成完成",
|
||||
current=3,
|
||||
total=3
|
||||
)
|
||||
|
||||
# ========== 阶段4: 复制预设脚本 ==========
|
||||
script_files = ["run_twitter_simulation.py", "run_reddit_simulation.py",
|
||||
"run_parallel_simulation.py", "action_logger.py"]
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"copying_scripts", 0,
|
||||
"开始准备脚本...",
|
||||
current=0,
|
||||
total=len(script_files)
|
||||
)
|
||||
|
||||
self._copy_preset_scripts(sim_dir)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"copying_scripts", 100,
|
||||
f"完成,共 {len(script_files)} 个脚本",
|
||||
current=len(script_files),
|
||||
total=len(script_files)
|
||||
)
|
||||
|
||||
# 更新状态
|
||||
state.status = SimulationStatus.READY
|
||||
self._save_simulation_state(state)
|
||||
|
||||
logger.info(f"模拟准备完成: {simulation_id}, "
|
||||
f"entities={state.entities_count}, profiles={state.profiles_count}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
state.status = SimulationStatus.FAILED
|
||||
state.error = str(e)
|
||||
self._save_simulation_state(state)
|
||||
raise
|
||||
|
||||
def _copy_preset_scripts(self, sim_dir: str):
|
||||
"""复制预设脚本到模拟目录"""
|
||||
scripts = [
|
||||
"run_twitter_simulation.py",
|
||||
"run_reddit_simulation.py",
|
||||
"run_parallel_simulation.py"
|
||||
]
|
||||
|
||||
for script in scripts:
|
||||
src = os.path.join(self.SCRIPTS_DIR, script)
|
||||
dst = os.path.join(sim_dir, script)
|
||||
|
||||
if os.path.exists(src):
|
||||
shutil.copy2(src, dst)
|
||||
logger.debug(f"复制脚本: {script}")
|
||||
else:
|
||||
logger.warning(f"预设脚本不存在: {src}")
|
||||
|
||||
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
||||
"""获取模拟状态"""
|
||||
return self._load_simulation_state(simulation_id)
|
||||
|
||||
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
|
||||
"""列出所有模拟"""
|
||||
simulations = []
|
||||
|
||||
if os.path.exists(self.SIMULATION_DATA_DIR):
|
||||
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
|
||||
state = self._load_simulation_state(sim_id)
|
||||
if state:
|
||||
if project_id is None or state.project_id == project_id:
|
||||
simulations.append(state)
|
||||
|
||||
return simulations
|
||||
|
||||
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
||||
"""获取模拟的Agent Profile"""
|
||||
state = self._load_simulation_state(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
|
||||
|
||||
if not os.path.exists(profile_path):
|
||||
return []
|
||||
|
||||
with open(profile_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取模拟配置"""
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
return None
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
|
||||
"""获取运行说明"""
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
|
||||
return {
|
||||
"simulation_dir": sim_dir,
|
||||
"config_file": config_path,
|
||||
"commands": {
|
||||
"twitter": f"python run_twitter_simulation.py --config simulation_config.json",
|
||||
"reddit": f"python run_reddit_simulation.py --config simulation_config.json",
|
||||
"parallel": f"python run_parallel_simulation.py --config simulation_config.json",
|
||||
},
|
||||
"instructions": (
|
||||
f"1. 进入模拟目录: cd {sim_dir}\n"
|
||||
f"2. 激活conda环境: conda activate MiroFish\n"
|
||||
f"3. 运行模拟:\n"
|
||||
f" - 单独运行Twitter: python run_twitter_simulation.py --config simulation_config.json\n"
|
||||
f" - 单独运行Reddit: python run_reddit_simulation.py --config simulation_config.json\n"
|
||||
f" - 并行运行双平台: python run_parallel_simulation.py --config simulation_config.json"
|
||||
)
|
||||
}
|
||||
Reference in New Issue
Block a user