Add initial implementation of txt2graph tool for knowledge graph generation
- Created a new Streamlit application for visualizing knowledge graphs. - Implemented text extraction from PDF, Markdown, and TXT files. - Developed graph building logic using Zep Cloud API. - Added support for custom entity types and relationships. - Included interactive HTML visualization for generated graphs. - Updated .gitignore to include new directories and files. - Added example environment configuration file (.env.example) for API key setup. - Created README.md with installation and usage instructions. - Introduced various utility scripts and styles for enhanced functionality.
This commit is contained in:
415
txt2graph/graph_builder.py
Normal file
415
txt2graph/graph_builder.py
Normal file
@@ -0,0 +1,415 @@
|
||||
"""
|
||||
Zep图谱构建模块
|
||||
负责与Zep云服务交互,构建知识图谱
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
|
||||
|
||||
from ontology import ENTITY_TYPES, EDGE_TYPES
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphNode:
|
||||
"""图节点数据结构"""
|
||||
uuid: str
|
||||
name: str
|
||||
summary: str
|
||||
labels: list[str]
|
||||
attributes: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphEdge:
|
||||
"""图边数据结构"""
|
||||
uuid: str
|
||||
name: str
|
||||
fact: str
|
||||
source_node_uuid: str
|
||||
target_node_uuid: str
|
||||
attributes: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphData:
|
||||
"""完整图数据"""
|
||||
graph_id: str
|
||||
nodes: list[GraphNode]
|
||||
edges: list[GraphEdge]
|
||||
|
||||
|
||||
class ZepGraphBuilder:
|
||||
"""Zep知识图谱构建器"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""
|
||||
初始化图谱构建器
|
||||
|
||||
Args:
|
||||
api_key: Zep API密钥,如果不提供则从环境变量ZEP_API_KEY读取
|
||||
"""
|
||||
self.api_key = api_key or os.environ.get("ZEP_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("需要提供ZEP_API_KEY,可以通过参数传入或设置环境变量")
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
|
||||
def create_graph(self, graph_id: Optional[str] = None, name: str = "Knowledge Graph") -> str:
|
||||
"""
|
||||
创建新的图谱
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID,如果不提供则自动生成
|
||||
name: 图谱名称
|
||||
|
||||
Returns:
|
||||
图谱ID
|
||||
"""
|
||||
if graph_id is None:
|
||||
graph_id = f"graph_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
self.client.graph.create(
|
||||
graph_id=graph_id,
|
||||
name=name,
|
||||
description="Knowledge graph generated by txt2graph"
|
||||
)
|
||||
|
||||
return graph_id
|
||||
|
||||
def set_ontology(self, graph_id: str):
|
||||
"""
|
||||
为图谱设置自定义本体(实体和边类型)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
"""
|
||||
# 构建边类型的源目标映射
|
||||
edge_definitions = {}
|
||||
|
||||
# WORKS_FOR: Person -> Organization/Company
|
||||
edge_definitions["WORKS_FOR"] = (
|
||||
EDGE_TYPES["WORKS_FOR"],
|
||||
[
|
||||
EntityEdgeSourceTarget(source="Person", target="Organization"),
|
||||
EntityEdgeSourceTarget(source="Person", target="Company"),
|
||||
]
|
||||
)
|
||||
|
||||
# LOCATED_IN: 多种实体 -> Location
|
||||
edge_definitions["LOCATED_IN"] = (
|
||||
EDGE_TYPES["LOCATED_IN"],
|
||||
[
|
||||
EntityEdgeSourceTarget(source="Person", target="Location"),
|
||||
EntityEdgeSourceTarget(source="Organization", target="Location"),
|
||||
EntityEdgeSourceTarget(source="Company", target="Location"),
|
||||
EntityEdgeSourceTarget(source="Event", target="Location"),
|
||||
]
|
||||
)
|
||||
|
||||
# PART_OF: Organization -> Organization, Company -> Company
|
||||
edge_definitions["PART_OF"] = (
|
||||
EDGE_TYPES["PART_OF"],
|
||||
[
|
||||
EntityEdgeSourceTarget(source="Organization", target="Organization"),
|
||||
EntityEdgeSourceTarget(source="Company", target="Company"),
|
||||
]
|
||||
)
|
||||
|
||||
# PRODUCES: Company -> Product
|
||||
edge_definitions["PRODUCES"] = (
|
||||
EDGE_TYPES["PRODUCES"],
|
||||
[
|
||||
EntityEdgeSourceTarget(source="Company", target="Product"),
|
||||
EntityEdgeSourceTarget(source="Organization", target="Product"),
|
||||
]
|
||||
)
|
||||
|
||||
# PARTICIPATES_IN: Person/Organization/Company -> Event
|
||||
edge_definitions["PARTICIPATES_IN"] = (
|
||||
EDGE_TYPES["PARTICIPATES_IN"],
|
||||
[
|
||||
EntityEdgeSourceTarget(source="Person", target="Event"),
|
||||
EntityEdgeSourceTarget(source="Organization", target="Event"),
|
||||
EntityEdgeSourceTarget(source="Company", target="Event"),
|
||||
]
|
||||
)
|
||||
|
||||
# COLLABORATES: 各种实体之间的合作
|
||||
edge_definitions["COLLABORATES"] = (
|
||||
EDGE_TYPES["COLLABORATES"],
|
||||
[
|
||||
EntityEdgeSourceTarget(source="Person", target="Person"),
|
||||
EntityEdgeSourceTarget(source="Company", target="Company"),
|
||||
EntityEdgeSourceTarget(source="Organization", target="Organization"),
|
||||
EntityEdgeSourceTarget(source="Company", target="Organization"),
|
||||
]
|
||||
)
|
||||
|
||||
# COMPETES: 公司之间的竞争
|
||||
edge_definitions["COMPETES"] = (
|
||||
EDGE_TYPES["COMPETES"],
|
||||
[
|
||||
EntityEdgeSourceTarget(source="Company", target="Company"),
|
||||
]
|
||||
)
|
||||
|
||||
# REPORTS: Media报道相关实体
|
||||
edge_definitions["REPORTS"] = (
|
||||
EDGE_TYPES["REPORTS"],
|
||||
[
|
||||
EntityEdgeSourceTarget(source="Media", target="Person"),
|
||||
EntityEdgeSourceTarget(source="Media", target="Company"),
|
||||
EntityEdgeSourceTarget(source="Media", target="Organization"),
|
||||
EntityEdgeSourceTarget(source="Media", target="Event"),
|
||||
]
|
||||
)
|
||||
|
||||
# 设置本体
|
||||
self.client.graph.set_ontology(
|
||||
graph_ids=[graph_id],
|
||||
entities=ENTITY_TYPES,
|
||||
edges=edge_definitions,
|
||||
)
|
||||
|
||||
def add_text_to_graph(
|
||||
self,
|
||||
graph_id: str,
|
||||
text_chunks: list[str],
|
||||
batch_size: int = 3,
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
将文本块分批添加到图谱中
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
text_chunks: 文本块列表
|
||||
batch_size: 每批发送的块数量
|
||||
progress_callback: 进度回调函数
|
||||
|
||||
Returns:
|
||||
任务ID列表
|
||||
"""
|
||||
task_ids = []
|
||||
total_chunks = len(text_chunks)
|
||||
|
||||
# 分批处理
|
||||
for i in range(0, total_chunks, batch_size):
|
||||
batch_chunks = text_chunks[i:i + batch_size]
|
||||
batch_num = i // batch_size + 1
|
||||
total_batches = (total_chunks + batch_size - 1) // batch_size
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f"发送第 {batch_num}/{total_batches} 批数据 ({len(batch_chunks)} 块)...")
|
||||
|
||||
# 构建episode数据
|
||||
episodes = [
|
||||
EpisodeData(data=chunk, type="text")
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
|
||||
try:
|
||||
# 批量添加
|
||||
batch_result = self.client.graph.add_batch(
|
||||
graph_id=graph_id,
|
||||
episodes=episodes
|
||||
)
|
||||
|
||||
if batch_result and batch_result[0].task_id:
|
||||
task_ids.append(batch_result[0].task_id)
|
||||
|
||||
# 短暂等待,避免请求过快
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
if progress_callback:
|
||||
progress_callback(f"批次 {batch_num} 发送失败: {str(e)}")
|
||||
raise
|
||||
|
||||
return task_ids
|
||||
|
||||
def wait_for_tasks(
|
||||
self,
|
||||
task_ids: list[str],
|
||||
timeout: int = 600,
|
||||
progress_callback: Optional[Callable] = None
|
||||
):
|
||||
"""
|
||||
等待所有任务完成
|
||||
|
||||
Args:
|
||||
task_ids: 任务ID列表
|
||||
timeout: 超时时间(秒)
|
||||
progress_callback: 进度回调
|
||||
"""
|
||||
if not task_ids:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
pending_tasks = set(task_ids)
|
||||
completed_tasks = set()
|
||||
|
||||
while pending_tasks:
|
||||
if time.time() - start_time > timeout:
|
||||
if progress_callback:
|
||||
progress_callback(f"警告: 部分任务超时,已完成 {len(completed_tasks)}/{len(task_ids)}")
|
||||
break
|
||||
|
||||
for task_id in list(pending_tasks):
|
||||
try:
|
||||
task = self.client.task.get(task_id=task_id)
|
||||
|
||||
if task.status == "completed":
|
||||
pending_tasks.remove(task_id)
|
||||
completed_tasks.add(task_id)
|
||||
elif task.status == "failed":
|
||||
pending_tasks.remove(task_id)
|
||||
if progress_callback:
|
||||
progress_callback(f"任务失败: {task.error}")
|
||||
|
||||
except Exception as e:
|
||||
if progress_callback:
|
||||
progress_callback(f"检查任务状态出错: {str(e)}")
|
||||
|
||||
if pending_tasks:
|
||||
if progress_callback:
|
||||
elapsed = int(time.time() - start_time)
|
||||
progress_callback(f"等待处理中... 已完成 {len(completed_tasks)}/{len(task_ids)} ({elapsed}秒)")
|
||||
time.sleep(3)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f"所有任务处理完成: {len(completed_tasks)}/{len(task_ids)}")
|
||||
|
||||
def get_graph_data(self, graph_id: str) -> GraphData:
|
||||
"""
|
||||
获取图谱的完整数据
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
Returns:
|
||||
GraphData对象,包含所有节点和边
|
||||
"""
|
||||
# 获取所有节点
|
||||
raw_nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
|
||||
nodes = [
|
||||
GraphNode(
|
||||
uuid=node.uuid_,
|
||||
name=node.name,
|
||||
summary=node.summary or "",
|
||||
labels=node.labels or [],
|
||||
attributes=node.attributes or {}
|
||||
)
|
||||
for node in raw_nodes
|
||||
]
|
||||
|
||||
# 获取所有边
|
||||
raw_edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
|
||||
edges = [
|
||||
GraphEdge(
|
||||
uuid=edge.uuid_,
|
||||
name=edge.name or "",
|
||||
fact=edge.fact or "",
|
||||
source_node_uuid=edge.source_node_uuid,
|
||||
target_node_uuid=edge.target_node_uuid,
|
||||
attributes=edge.attributes or {}
|
||||
)
|
||||
for edge in raw_edges
|
||||
]
|
||||
|
||||
return GraphData(
|
||||
graph_id=graph_id,
|
||||
nodes=nodes,
|
||||
edges=edges
|
||||
)
|
||||
|
||||
def delete_graph(self, graph_id: str):
|
||||
"""删除图谱"""
|
||||
self.client.graph.delete(graph_id=graph_id)
|
||||
|
||||
|
||||
def build_graph_from_text(
|
||||
text: str,
|
||||
graph_name: str = "Knowledge Graph",
|
||||
api_key: Optional[str] = None,
|
||||
chunk_size: int = 2000,
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> GraphData:
|
||||
"""
|
||||
便捷函数:从文本构建知识图谱
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
graph_name: 图谱名称
|
||||
api_key: Zep API密钥
|
||||
chunk_size: 文本分块大小(默认2000字符)
|
||||
progress_callback: 进度回调
|
||||
|
||||
Returns:
|
||||
GraphData对象
|
||||
"""
|
||||
from text_extractor import split_text_into_chunks
|
||||
|
||||
builder = ZepGraphBuilder(api_key=api_key)
|
||||
|
||||
# 创建图谱
|
||||
graph_id = builder.create_graph(name=graph_name)
|
||||
if progress_callback:
|
||||
progress_callback(f"创建图谱: {graph_id}")
|
||||
|
||||
# 设置本体
|
||||
builder.set_ontology(graph_id)
|
||||
if progress_callback:
|
||||
progress_callback("设置实体类型...")
|
||||
|
||||
# 分块处理文本
|
||||
chunks = split_text_into_chunks(text, max_chunk_size=chunk_size)
|
||||
if progress_callback:
|
||||
progress_callback(f"文本分为 {len(chunks)} 个块")
|
||||
|
||||
# 分批添加到图谱
|
||||
task_ids = builder.add_text_to_graph(
|
||||
graph_id=graph_id,
|
||||
text_chunks=chunks,
|
||||
batch_size=3,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
# 等待所有任务完成
|
||||
if task_ids:
|
||||
builder.wait_for_tasks(task_ids, progress_callback=progress_callback)
|
||||
|
||||
# 获取并返回图数据
|
||||
return builder.get_graph_data(graph_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
test_text = """
|
||||
武汉大学是中国著名的高等学府,位于湖北省武汉市。
|
||||
该校的樱花季每年吸引大量游客。
|
||||
马化腾是腾讯公司的创始人,腾讯总部位于深圳。
|
||||
"""
|
||||
|
||||
result = build_graph_from_text(
|
||||
text=test_text,
|
||||
graph_name="测试图谱",
|
||||
progress_callback=print
|
||||
)
|
||||
|
||||
print(f"\n节点数: {len(result.nodes)}")
|
||||
for node in result.nodes:
|
||||
print(f" - {node.name} ({node.labels})")
|
||||
|
||||
print(f"\n边数: {len(result.edges)}")
|
||||
for edge in result.edges:
|
||||
print(f" - {edge.fact}")
|
||||
Reference in New Issue
Block a user