Add initial implementation of txt2graph tool for knowledge graph generation

- Created a new Streamlit application for visualizing knowledge graphs.
- Implemented text extraction from PDF, Markdown, and TXT files.
- Developed graph building logic using Zep Cloud API.
- Added support for custom entity types and relationships.
- Included interactive HTML visualization for generated graphs.
- Updated .gitignore to include new directories and files.
- Added example environment configuration file (.env.example) for API key setup.
- Created README.md with installation and usage instructions.
- Introduced various utility scripts and styles for enhanced functionality.
This commit is contained in:
666ghj
2025-11-28 14:07:42 +08:00
parent 38e3d05b1d
commit 9657061b26
21 changed files with 3115 additions and 1 deletions

415
txt2graph/graph_builder.py Normal file
View File

@@ -0,0 +1,415 @@
"""
Zep图谱构建模块
负责与Zep云服务交互构建知识图谱
"""
import os
import time
import uuid
from typing import Optional, Callable
from dataclasses import dataclass
from zep_cloud.client import Zep
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
from ontology import ENTITY_TYPES, EDGE_TYPES
@dataclass
class GraphNode:
"""图节点数据结构"""
uuid: str
name: str
summary: str
labels: list[str]
attributes: dict
@dataclass
class GraphEdge:
"""图边数据结构"""
uuid: str
name: str
fact: str
source_node_uuid: str
target_node_uuid: str
attributes: dict
@dataclass
class GraphData:
"""完整图数据"""
graph_id: str
nodes: list[GraphNode]
edges: list[GraphEdge]
class ZepGraphBuilder:
"""Zep知识图谱构建器"""
def __init__(self, api_key: Optional[str] = None):
"""
初始化图谱构建器
Args:
api_key: Zep API密钥如果不提供则从环境变量ZEP_API_KEY读取
"""
self.api_key = api_key or os.environ.get("ZEP_API_KEY")
if not self.api_key:
raise ValueError("需要提供ZEP_API_KEY可以通过参数传入或设置环境变量")
self.client = Zep(api_key=self.api_key)
def create_graph(self, graph_id: Optional[str] = None, name: str = "Knowledge Graph") -> str:
"""
创建新的图谱
Args:
graph_id: 图谱ID如果不提供则自动生成
name: 图谱名称
Returns:
图谱ID
"""
if graph_id is None:
graph_id = f"graph_{uuid.uuid4().hex[:16]}"
self.client.graph.create(
graph_id=graph_id,
name=name,
description="Knowledge graph generated by txt2graph"
)
return graph_id
def set_ontology(self, graph_id: str):
"""
为图谱设置自定义本体(实体和边类型)
Args:
graph_id: 图谱ID
"""
# 构建边类型的源目标映射
edge_definitions = {}
# WORKS_FOR: Person -> Organization/Company
edge_definitions["WORKS_FOR"] = (
EDGE_TYPES["WORKS_FOR"],
[
EntityEdgeSourceTarget(source="Person", target="Organization"),
EntityEdgeSourceTarget(source="Person", target="Company"),
]
)
# LOCATED_IN: 多种实体 -> Location
edge_definitions["LOCATED_IN"] = (
EDGE_TYPES["LOCATED_IN"],
[
EntityEdgeSourceTarget(source="Person", target="Location"),
EntityEdgeSourceTarget(source="Organization", target="Location"),
EntityEdgeSourceTarget(source="Company", target="Location"),
EntityEdgeSourceTarget(source="Event", target="Location"),
]
)
# PART_OF: Organization -> Organization, Company -> Company
edge_definitions["PART_OF"] = (
EDGE_TYPES["PART_OF"],
[
EntityEdgeSourceTarget(source="Organization", target="Organization"),
EntityEdgeSourceTarget(source="Company", target="Company"),
]
)
# PRODUCES: Company -> Product
edge_definitions["PRODUCES"] = (
EDGE_TYPES["PRODUCES"],
[
EntityEdgeSourceTarget(source="Company", target="Product"),
EntityEdgeSourceTarget(source="Organization", target="Product"),
]
)
# PARTICIPATES_IN: Person/Organization/Company -> Event
edge_definitions["PARTICIPATES_IN"] = (
EDGE_TYPES["PARTICIPATES_IN"],
[
EntityEdgeSourceTarget(source="Person", target="Event"),
EntityEdgeSourceTarget(source="Organization", target="Event"),
EntityEdgeSourceTarget(source="Company", target="Event"),
]
)
# COLLABORATES: 各种实体之间的合作
edge_definitions["COLLABORATES"] = (
EDGE_TYPES["COLLABORATES"],
[
EntityEdgeSourceTarget(source="Person", target="Person"),
EntityEdgeSourceTarget(source="Company", target="Company"),
EntityEdgeSourceTarget(source="Organization", target="Organization"),
EntityEdgeSourceTarget(source="Company", target="Organization"),
]
)
# COMPETES: 公司之间的竞争
edge_definitions["COMPETES"] = (
EDGE_TYPES["COMPETES"],
[
EntityEdgeSourceTarget(source="Company", target="Company"),
]
)
# REPORTS: Media报道相关实体
edge_definitions["REPORTS"] = (
EDGE_TYPES["REPORTS"],
[
EntityEdgeSourceTarget(source="Media", target="Person"),
EntityEdgeSourceTarget(source="Media", target="Company"),
EntityEdgeSourceTarget(source="Media", target="Organization"),
EntityEdgeSourceTarget(source="Media", target="Event"),
]
)
# 设置本体
self.client.graph.set_ontology(
graph_ids=[graph_id],
entities=ENTITY_TYPES,
edges=edge_definitions,
)
def add_text_to_graph(
self,
graph_id: str,
text_chunks: list[str],
batch_size: int = 3,
progress_callback: Optional[Callable] = None
) -> list[str]:
"""
将文本块分批添加到图谱中
Args:
graph_id: 图谱ID
text_chunks: 文本块列表
batch_size: 每批发送的块数量
progress_callback: 进度回调函数
Returns:
任务ID列表
"""
task_ids = []
total_chunks = len(text_chunks)
# 分批处理
for i in range(0, total_chunks, batch_size):
batch_chunks = text_chunks[i:i + batch_size]
batch_num = i // batch_size + 1
total_batches = (total_chunks + batch_size - 1) // batch_size
if progress_callback:
progress_callback(f"发送第 {batch_num}/{total_batches} 批数据 ({len(batch_chunks)} 块)...")
# 构建episode数据
episodes = [
EpisodeData(data=chunk, type="text")
for chunk in batch_chunks
]
try:
# 批量添加
batch_result = self.client.graph.add_batch(
graph_id=graph_id,
episodes=episodes
)
if batch_result and batch_result[0].task_id:
task_ids.append(batch_result[0].task_id)
# 短暂等待,避免请求过快
time.sleep(1)
except Exception as e:
if progress_callback:
progress_callback(f"批次 {batch_num} 发送失败: {str(e)}")
raise
return task_ids
def wait_for_tasks(
self,
task_ids: list[str],
timeout: int = 600,
progress_callback: Optional[Callable] = None
):
"""
等待所有任务完成
Args:
task_ids: 任务ID列表
timeout: 超时时间(秒)
progress_callback: 进度回调
"""
if not task_ids:
return
start_time = time.time()
pending_tasks = set(task_ids)
completed_tasks = set()
while pending_tasks:
if time.time() - start_time > timeout:
if progress_callback:
progress_callback(f"警告: 部分任务超时,已完成 {len(completed_tasks)}/{len(task_ids)}")
break
for task_id in list(pending_tasks):
try:
task = self.client.task.get(task_id=task_id)
if task.status == "completed":
pending_tasks.remove(task_id)
completed_tasks.add(task_id)
elif task.status == "failed":
pending_tasks.remove(task_id)
if progress_callback:
progress_callback(f"任务失败: {task.error}")
except Exception as e:
if progress_callback:
progress_callback(f"检查任务状态出错: {str(e)}")
if pending_tasks:
if progress_callback:
elapsed = int(time.time() - start_time)
progress_callback(f"等待处理中... 已完成 {len(completed_tasks)}/{len(task_ids)} ({elapsed}秒)")
time.sleep(3)
if progress_callback:
progress_callback(f"所有任务处理完成: {len(completed_tasks)}/{len(task_ids)}")
def get_graph_data(self, graph_id: str) -> GraphData:
"""
获取图谱的完整数据
Args:
graph_id: 图谱ID
Returns:
GraphData对象包含所有节点和边
"""
# 获取所有节点
raw_nodes = self.client.graph.node.get_by_graph_id(graph_id=graph_id)
nodes = [
GraphNode(
uuid=node.uuid_,
name=node.name,
summary=node.summary or "",
labels=node.labels or [],
attributes=node.attributes or {}
)
for node in raw_nodes
]
# 获取所有边
raw_edges = self.client.graph.edge.get_by_graph_id(graph_id=graph_id)
edges = [
GraphEdge(
uuid=edge.uuid_,
name=edge.name or "",
fact=edge.fact or "",
source_node_uuid=edge.source_node_uuid,
target_node_uuid=edge.target_node_uuid,
attributes=edge.attributes or {}
)
for edge in raw_edges
]
return GraphData(
graph_id=graph_id,
nodes=nodes,
edges=edges
)
def delete_graph(self, graph_id: str):
"""删除图谱"""
self.client.graph.delete(graph_id=graph_id)
def build_graph_from_text(
text: str,
graph_name: str = "Knowledge Graph",
api_key: Optional[str] = None,
chunk_size: int = 2000,
progress_callback: Optional[Callable] = None
) -> GraphData:
"""
便捷函数:从文本构建知识图谱
Args:
text: 输入文本
graph_name: 图谱名称
api_key: Zep API密钥
chunk_size: 文本分块大小默认2000字符
progress_callback: 进度回调
Returns:
GraphData对象
"""
from text_extractor import split_text_into_chunks
builder = ZepGraphBuilder(api_key=api_key)
# 创建图谱
graph_id = builder.create_graph(name=graph_name)
if progress_callback:
progress_callback(f"创建图谱: {graph_id}")
# 设置本体
builder.set_ontology(graph_id)
if progress_callback:
progress_callback("设置实体类型...")
# 分块处理文本
chunks = split_text_into_chunks(text, max_chunk_size=chunk_size)
if progress_callback:
progress_callback(f"文本分为 {len(chunks)} 个块")
# 分批添加到图谱
task_ids = builder.add_text_to_graph(
graph_id=graph_id,
text_chunks=chunks,
batch_size=3,
progress_callback=progress_callback
)
# 等待所有任务完成
if task_ids:
builder.wait_for_tasks(task_ids, progress_callback=progress_callback)
# 获取并返回图数据
return builder.get_graph_data(graph_id)
if __name__ == "__main__":
# 测试
from dotenv import load_dotenv
load_dotenv()
test_text = """
武汉大学是中国著名的高等学府,位于湖北省武汉市。
该校的樱花季每年吸引大量游客。
马化腾是腾讯公司的创始人,腾讯总部位于深圳。
"""
result = build_graph_from_text(
text=test_text,
graph_name="测试图谱",
progress_callback=print
)
print(f"\n节点数: {len(result.nodes)}")
for node in result.nodes:
print(f" - {node.name} ({node.labels})")
print(f"\n边数: {len(result.edges)}")
for edge in result.edges:
print(f" - {edge.fact}")