Enhance backend functionality with OASIS simulation features
- Updated README.md to include new simulation scripts and configuration details for OASIS, including API retry mechanisms and environment variable settings. - Added simulation management and configuration generation services to streamline the simulation process across Twitter and Reddit platforms. - Introduced new API routes for simulation-related operations, including entity retrieval and simulation status management. - Implemented a robust retry mechanism for external API calls to improve system stability. - Enhanced task management model to include detailed progress tracking. - Added logging capabilities for action tracking during simulations. - Included new scripts for running parallel simulations and testing profile formats.
This commit is contained in:
386
backend/app/services/zep_entity_reader.py
Normal file
386
backend/app/services/zep_entity_reader.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Zep实体读取与过滤服务
|
||||
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Set
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.zep_entity_reader')
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityNode:
|
||||
"""实体节点数据结构"""
|
||||
uuid: str
|
||||
name: str
|
||||
labels: List[str]
|
||||
summary: str
|
||||
attributes: Dict[str, Any]
|
||||
# 相关的边信息
|
||||
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
||||
# 相关的其他节点信息
|
||||
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"uuid": self.uuid,
|
||||
"name": self.name,
|
||||
"labels": self.labels,
|
||||
"summary": self.summary,
|
||||
"attributes": self.attributes,
|
||||
"related_edges": self.related_edges,
|
||||
"related_nodes": self.related_nodes,
|
||||
}
|
||||
|
||||
def get_entity_type(self) -> Optional[str]:
|
||||
"""获取实体类型(排除默认的Entity标签)"""
|
||||
for label in self.labels:
|
||||
if label not in ["Entity", "Node"]:
|
||||
return label
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilteredEntities:
|
||||
"""过滤后的实体集合"""
|
||||
entities: List[EntityNode]
|
||||
entity_types: Set[str]
|
||||
total_count: int
|
||||
filtered_count: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"entities": [e.to_dict() for e in self.entities],
|
||||
"entity_types": list(self.entity_types),
|
||||
"total_count": self.total_count,
|
||||
"filtered_count": self.filtered_count,
|
||||
}
|
||||
|
||||
|
||||
class ZepEntityReader:
|
||||
"""
|
||||
Zep实体读取与过滤服务
|
||||
|
||||
主要功能:
|
||||
1. 从Zep图谱读取所有节点
|
||||
2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点)
|
||||
3. 获取每个实体的相关边和关联节点信息
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有节点
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
Returns:
|
||||
节点列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
|
||||
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
nodes_data.append({
|
||||
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
"name": node.name or "",
|
||||
"labels": node.labels or [],
|
||||
"summary": node.summary or "",
|
||||
"attributes": node.attributes or {},
|
||||
})
|
||||
|
||||
logger.info(f"共获取 {len(nodes_data)} 个节点")
|
||||
return nodes_data
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有边
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
|
||||
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
edges_data.append({
|
||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||
"name": edge.name or "",
|
||||
"fact": edge.fact or "",
|
||||
"source_node_uuid": edge.source_node_uuid,
|
||||
"target_node_uuid": edge.target_node_uuid,
|
||||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
logger.info(f"共获取 {len(edges_data)} 条边")
|
||||
return edges_data
|
||||
|
||||
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定节点的所有相关边
|
||||
|
||||
Args:
|
||||
node_uuid: 节点UUID
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
"""
|
||||
try:
|
||||
edges = self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
edges_data.append({
|
||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||
"name": edge.name or "",
|
||||
"fact": edge.fact or "",
|
||||
"source_node_uuid": edge.source_node_uuid,
|
||||
"target_node_uuid": edge.target_node_uuid,
|
||||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
return edges_data
|
||||
except Exception as e:
|
||||
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def filter_defined_entities(
|
||||
self,
|
||||
graph_id: str,
|
||||
defined_entity_types: Optional[List[str]] = None,
|
||||
enrich_with_edges: bool = True
|
||||
) -> FilteredEntities:
|
||||
"""
|
||||
筛选出符合预定义实体类型的节点
|
||||
|
||||
筛选逻辑:
|
||||
- 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过
|
||||
- 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型)
|
||||
enrich_with_edges: 是否获取每个实体的相关边信息
|
||||
|
||||
Returns:
|
||||
FilteredEntities: 过滤后的实体集合
|
||||
"""
|
||||
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
|
||||
|
||||
# 获取所有节点
|
||||
all_nodes = self.get_all_nodes(graph_id)
|
||||
total_count = len(all_nodes)
|
||||
|
||||
# 获取所有边(用于后续关联查找)
|
||||
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
||||
|
||||
# 构建节点UUID到节点数据的映射
|
||||
node_map = {n["uuid"]: n for n in all_nodes}
|
||||
|
||||
# 筛选符合条件的实体
|
||||
filtered_entities = []
|
||||
entity_types_found = set()
|
||||
|
||||
for node in all_nodes:
|
||||
labels = node.get("labels", [])
|
||||
|
||||
# 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签
|
||||
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
||||
|
||||
if not custom_labels:
|
||||
# 只有默认标签,跳过
|
||||
continue
|
||||
|
||||
# 如果指定了预定义类型,检查是否匹配
|
||||
if defined_entity_types:
|
||||
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
||||
if not matching_labels:
|
||||
continue
|
||||
entity_type = matching_labels[0]
|
||||
else:
|
||||
entity_type = custom_labels[0]
|
||||
|
||||
entity_types_found.add(entity_type)
|
||||
|
||||
# 创建实体节点对象
|
||||
entity = EntityNode(
|
||||
uuid=node["uuid"],
|
||||
name=node["name"],
|
||||
labels=labels,
|
||||
summary=node["summary"],
|
||||
attributes=node["attributes"],
|
||||
)
|
||||
|
||||
# 获取相关边和节点
|
||||
if enrich_with_edges:
|
||||
related_edges = []
|
||||
related_node_uuids = set()
|
||||
|
||||
for edge in all_edges:
|
||||
if edge["source_node_uuid"] == node["uuid"]:
|
||||
related_edges.append({
|
||||
"direction": "outgoing",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"target_node_uuid": edge["target_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["target_node_uuid"])
|
||||
elif edge["target_node_uuid"] == node["uuid"]:
|
||||
related_edges.append({
|
||||
"direction": "incoming",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"source_node_uuid": edge["source_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["source_node_uuid"])
|
||||
|
||||
entity.related_edges = related_edges
|
||||
|
||||
# 获取关联节点的基本信息
|
||||
related_nodes = []
|
||||
for related_uuid in related_node_uuids:
|
||||
if related_uuid in node_map:
|
||||
related_node = node_map[related_uuid]
|
||||
related_nodes.append({
|
||||
"uuid": related_node["uuid"],
|
||||
"name": related_node["name"],
|
||||
"labels": related_node["labels"],
|
||||
"summary": related_node.get("summary", ""),
|
||||
})
|
||||
|
||||
entity.related_nodes = related_nodes
|
||||
|
||||
filtered_entities.append(entity)
|
||||
|
||||
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
|
||||
f"实体类型: {entity_types_found}")
|
||||
|
||||
return FilteredEntities(
|
||||
entities=filtered_entities,
|
||||
entity_types=entity_types_found,
|
||||
total_count=total_count,
|
||||
filtered_count=len(filtered_entities),
|
||||
)
|
||||
|
||||
def get_entity_with_context(
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_uuid: str
|
||||
) -> Optional[EntityNode]:
|
||||
"""
|
||||
获取单个实体及其完整上下文(边和关联节点)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
entity_uuid: 实体UUID
|
||||
|
||||
Returns:
|
||||
EntityNode或None
|
||||
"""
|
||||
try:
|
||||
# 获取节点
|
||||
node = self.client.graph.node.get(uuid_=entity_uuid)
|
||||
|
||||
if not node:
|
||||
return None
|
||||
|
||||
# 获取节点的边
|
||||
edges = self.get_node_edges(entity_uuid)
|
||||
|
||||
# 获取所有节点用于关联查找
|
||||
all_nodes = self.get_all_nodes(graph_id)
|
||||
node_map = {n["uuid"]: n for n in all_nodes}
|
||||
|
||||
# 处理相关边和节点
|
||||
related_edges = []
|
||||
related_node_uuids = set()
|
||||
|
||||
for edge in edges:
|
||||
if edge["source_node_uuid"] == entity_uuid:
|
||||
related_edges.append({
|
||||
"direction": "outgoing",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"target_node_uuid": edge["target_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["target_node_uuid"])
|
||||
else:
|
||||
related_edges.append({
|
||||
"direction": "incoming",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"source_node_uuid": edge["source_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["source_node_uuid"])
|
||||
|
||||
# 获取关联节点信息
|
||||
related_nodes = []
|
||||
for related_uuid in related_node_uuids:
|
||||
if related_uuid in node_map:
|
||||
related_node = node_map[related_uuid]
|
||||
related_nodes.append({
|
||||
"uuid": related_node["uuid"],
|
||||
"name": related_node["name"],
|
||||
"labels": related_node["labels"],
|
||||
"summary": related_node.get("summary", ""),
|
||||
})
|
||||
|
||||
return EntityNode(
|
||||
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
name=node.name or "",
|
||||
labels=node.labels or [],
|
||||
summary=node.summary or "",
|
||||
attributes=node.attributes or {},
|
||||
related_edges=related_edges,
|
||||
related_nodes=related_nodes,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_entities_by_type(
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_type: str,
|
||||
enrich_with_edges: bool = True
|
||||
) -> List[EntityNode]:
|
||||
"""
|
||||
获取指定类型的所有实体
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
entity_type: 实体类型(如 "Student", "PublicFigure" 等)
|
||||
enrich_with_edges: 是否获取相关边信息
|
||||
|
||||
Returns:
|
||||
实体列表
|
||||
"""
|
||||
result = self.filter_defined_entities(
|
||||
graph_id=graph_id,
|
||||
defined_entity_types=[entity_type],
|
||||
enrich_with_edges=enrich_with_edges
|
||||
)
|
||||
return result.entities
|
||||
|
||||
|
||||
Reference in New Issue
Block a user