feat(graph): implement pagination for fetching nodes and edges; add utility functions for streamlined data retrieval
This commit is contained in:
@@ -15,6 +15,7 @@ from zep_cloud import EpisodeData, EntityEdgeSourceTarget
|
||||
|
||||
from ..config import Config
|
||||
from ..models.task import TaskManager, TaskStatus
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
from .text_processor import TextProcessor
|
||||
|
||||
|
||||
@@ -395,12 +396,12 @@ class GraphBuilderService:
|
||||
|
||||
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
||||
"""获取图谱信息"""
|
||||
# 获取节点
|
||||
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
# 获取边
|
||||
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
# 获取节点(分页)
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
# 获取边(分页)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
# 统计实体类型
|
||||
entity_types = set()
|
||||
for node in nodes:
|
||||
@@ -408,7 +409,7 @@ class GraphBuilderService:
|
||||
for label in node.labels:
|
||||
if label not in ["Entity", "Node"]:
|
||||
entity_types.add(label)
|
||||
|
||||
|
||||
return GraphInfo(
|
||||
graph_id=graph_id,
|
||||
node_count=len(nodes),
|
||||
@@ -426,9 +427,9 @@ class GraphBuilderService:
|
||||
Returns:
|
||||
包含nodes和edges的字典,包括时间信息、属性等详细数据
|
||||
"""
|
||||
nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||||
edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||||
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
# 创建节点映射用于获取节点名称
|
||||
node_map = {}
|
||||
for node in nodes:
|
||||
|
||||
@@ -11,6 +11,7 @@ from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
|
||||
logger = get_logger('mirofish.zep_entity_reader')
|
||||
|
||||
@@ -125,22 +126,18 @@ class ZepEntityReader:
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有节点(带重试机制)
|
||||
|
||||
获取图谱的所有节点(分页获取)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
|
||||
Returns:
|
||||
节点列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
|
||||
# 使用重试机制调用Zep API
|
||||
nodes = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
|
||||
operation_name=f"获取节点(graph={graph_id})"
|
||||
)
|
||||
|
||||
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
nodes_data.append({
|
||||
@@ -150,28 +147,24 @@ class ZepEntityReader:
|
||||
"summary": node.summary or "",
|
||||
"attributes": node.attributes or {},
|
||||
})
|
||||
|
||||
|
||||
logger.info(f"共获取 {len(nodes_data)} 个节点")
|
||||
return nodes_data
|
||||
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有边(带重试机制)
|
||||
|
||||
获取图谱的所有边(分页获取)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
|
||||
# 使用重试机制调用Zep API
|
||||
edges = self._call_with_retry(
|
||||
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
|
||||
operation_name=f"获取边(graph={graph_id})"
|
||||
)
|
||||
|
||||
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
edges_data.append({
|
||||
@@ -182,7 +175,7 @@ class ZepEntityReader:
|
||||
"target_node_uuid": edge.target_node_uuid,
|
||||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
|
||||
logger.info(f"共获取 {len(edges_data)} 条边")
|
||||
return edges_data
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from zep_cloud.client import Zep
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.llm_client import LLMClient
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
|
||||
logger = get_logger('mirofish.zep_tools')
|
||||
|
||||
@@ -648,71 +649,67 @@ class ZepToolsService:
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[NodeInfo]:
|
||||
"""
|
||||
获取图谱的所有节点
|
||||
|
||||
获取图谱的所有节点(分页获取)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
|
||||
Returns:
|
||||
节点列表
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
|
||||
nodes = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get_by_graph_id(graph_id=graph_id),
|
||||
operation_name=f"获取节点(graph={graph_id})"
|
||||
)
|
||||
|
||||
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
result = []
|
||||
for node in nodes:
|
||||
node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or ""
|
||||
result.append(NodeInfo(
|
||||
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
uuid=str(node_uuid) if node_uuid else "",
|
||||
name=node.name or "",
|
||||
labels=node.labels or [],
|
||||
summary=node.summary or "",
|
||||
attributes=node.attributes or {}
|
||||
))
|
||||
|
||||
|
||||
logger.info(f"获取到 {len(result)} 个节点")
|
||||
return result
|
||||
|
||||
|
||||
def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]:
|
||||
"""
|
||||
获取图谱的所有边(包含时间信息)
|
||||
|
||||
获取图谱的所有边(分页获取,包含时间信息)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
include_temporal: 是否包含时间信息(默认True)
|
||||
|
||||
|
||||
Returns:
|
||||
边列表(包含created_at, valid_at, invalid_at, expired_at)
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
|
||||
edges = self._call_with_retry(
|
||||
func=lambda: self.client.graph.edge.get_by_graph_id(graph_id=graph_id),
|
||||
operation_name=f"获取边(graph={graph_id})"
|
||||
)
|
||||
|
||||
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
result = []
|
||||
for edge in edges:
|
||||
edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or ""
|
||||
edge_info = EdgeInfo(
|
||||
uuid=getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||
uuid=str(edge_uuid) if edge_uuid else "",
|
||||
name=edge.name or "",
|
||||
fact=edge.fact or "",
|
||||
source_node_uuid=edge.source_node_uuid or "",
|
||||
target_node_uuid=edge.target_node_uuid or ""
|
||||
)
|
||||
|
||||
|
||||
# 添加时间信息
|
||||
if include_temporal:
|
||||
edge_info.created_at = getattr(edge, 'created_at', None)
|
||||
edge_info.valid_at = getattr(edge, 'valid_at', None)
|
||||
edge_info.invalid_at = getattr(edge, 'invalid_at', None)
|
||||
edge_info.expired_at = getattr(edge, 'expired_at', None)
|
||||
|
||||
|
||||
result.append(edge_info)
|
||||
|
||||
|
||||
logger.info(f"获取到 {len(result)} 条边")
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user