Add initial implementation of txt2graph tool for knowledge graph generation

- Created a new Streamlit application for visualizing knowledge graphs.
- Implemented text extraction from PDF, Markdown, and TXT files.
- Developed graph building logic using Zep Cloud API.
- Added support for custom entity types and relationships.
- Included interactive HTML visualization for generated graphs.
- Updated .gitignore to include new directories and files.
- Added example environment configuration file (.env.example) for API key setup.
- Created README.md with installation and usage instructions.
- Introduced various utility scripts and styles for enhanced functionality.
This commit is contained in:
666ghj
2025-11-28 14:07:42 +08:00
parent 38e3d05b1d
commit 9657061b26
21 changed files with 3115 additions and 1 deletions

497
txt2graph/app.py Normal file
View File

@@ -0,0 +1,497 @@
"""
txt2graph 可视化界面
基于Streamlit和PyVis实现知识图谱可视化
"""
import os
import tempfile
import streamlit as st
from pathlib import Path
from pyvis.network import Network
import streamlit.components.v1 as components
from dotenv import load_dotenv
load_dotenv()
from text_extractor import extract_text, split_text_into_chunks
from graph_builder import ZepGraphBuilder, GraphData
# 页面配置
st.set_page_config(
page_title="txt2graph - 知识图谱生成器",
page_icon="🕸️",
layout="wide",
initial_sidebar_state="expanded"
)
# 自定义CSS样式
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+SC:wght@400;500;700&family=JetBrains+Mono&display=swap');
.main {
font-family: 'Noto Sans SC', sans-serif;
}
.stTitle {
font-weight: 700 !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.stats-card {
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
border-radius: 12px;
padding: 20px;
margin: 10px 0;
border: 1px solid rgba(102, 126, 234, 0.3);
}
.stats-number {
font-size: 2.5rem;
font-weight: 700;
color: #667eea;
font-family: 'JetBrains Mono', monospace;
}
.stats-label {
font-size: 0.9rem;
color: #a0a0a0;
text-transform: uppercase;
letter-spacing: 1px;
}
.entity-tag {
display: inline-block;
padding: 4px 12px;
border-radius: 20px;
font-size: 0.8rem;
margin: 2px;
font-weight: 500;
}
.entity-Person { background: rgba(255, 107, 107, 0.2); color: #ff6b6b; border: 1px solid #ff6b6b; }
.entity-Company { background: rgba(78, 205, 196, 0.2); color: #4ecdc4; border: 1px solid #4ecdc4; }
.entity-Organization { background: rgba(69, 183, 209, 0.2); color: #45b7d1; border: 1px solid #45b7d1; }
.entity-Location { background: rgba(150, 206, 180, 0.2); color: #96ceb4; border: 1px solid #96ceb4; }
.entity-Product { background: rgba(255, 238, 173, 0.2); color: #ffeead; border: 1px solid #ffeead; }
.entity-Event { background: rgba(220, 198, 224, 0.2); color: #dcc6e0; border: 1px solid #dcc6e0; }
.entity-Media { background: rgba(255, 183, 77, 0.2); color: #ffb74d; border: 1px solid #ffb74d; }
.sidebar .stButton > button {
width: 100%;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 12px 24px;
border-radius: 8px;
font-weight: 600;
transition: all 0.3s ease;
}
.sidebar .stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 20px rgba(102, 126, 234, 0.4);
}
</style>
""", unsafe_allow_html=True)
# 实体类型对应的颜色
ENTITY_COLORS = {
"Person": "#ff6b6b",
"Company": "#4ecdc4",
"Organization": "#45b7d1",
"Location": "#96ceb4",
"Product": "#ffeead",
"Event": "#dcc6e0",
"Media": "#ffb74d",
}
def create_pyvis_graph(graph_data: GraphData) -> str:
"""
创建PyVis图并返回HTML
"""
# 创建网络图
net = Network(
height="700px",
width="100%",
bgcolor="#0e1117",
font_color="white",
directed=True,
select_menu=True,
filter_menu=True,
)
# 配置物理引擎
net.set_options("""
{
"nodes": {
"font": {
"size": 14,
"face": "Noto Sans SC, Arial"
},
"borderWidth": 2,
"shadow": true
},
"edges": {
"color": {
"inherit": false,
"color": "#555555",
"highlight": "#667eea"
},
"arrows": {
"to": {
"enabled": true,
"scaleFactor": 0.5
}
},
"smooth": {
"type": "continuous",
"roundness": 0.2
},
"font": {
"size": 10,
"color": "#888888",
"face": "Noto Sans SC, Arial"
}
},
"physics": {
"enabled": true,
"barnesHut": {
"gravitationalConstant": -5000,
"centralGravity": 0.3,
"springLength": 150,
"springConstant": 0.04,
"damping": 0.09
},
"stabilization": {
"enabled": true,
"iterations": 200
}
},
"interaction": {
"hover": true,
"tooltipDelay": 100,
"navigationButtons": true,
"keyboard": true
}
}
""")
# 构建节点UUID到名称的映射
node_map = {node.uuid: node for node in graph_data.nodes}
# 添加节点
for node in graph_data.nodes:
# 确定节点类型和颜色
node_type = node.labels[0] if node.labels else "Unknown"
color = ENTITY_COLORS.get(node_type, "#888888")
# 构建工具提示
title = f"<b>{node.name}</b><br>"
title += f"<i>类型: {node_type}</i><br><br>"
if node.summary:
title += f"{node.summary[:200]}{'...' if len(node.summary) > 200 else ''}"
# 根据节点类型调整大小
size = 25 if node_type == "Person" else 30 if node_type in ["Company", "Organization"] else 20
net.add_node(
node.uuid,
label=node.name,
title=title,
color=color,
size=size,
shape="dot",
)
# 添加边
for edge in graph_data.edges:
if edge.source_node_uuid in node_map and edge.target_node_uuid in node_map:
# 构建边的工具提示
title = edge.fact if edge.fact else edge.name
net.add_edge(
edge.source_node_uuid,
edge.target_node_uuid,
title=title,
label=edge.name[:20] if edge.name else "",
)
# 生成HTML
with tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False, encoding='utf-8') as f:
net.save_graph(f.name)
with open(f.name, 'r', encoding='utf-8') as html_file:
html_content = html_file.read()
os.unlink(f.name)
return html_content
def display_stats(graph_data: GraphData):
"""显示图谱统计信息"""
col1, col2, col3 = st.columns(3)
with col1:
st.markdown(f"""
<div class="stats-card">
<div class="stats-number">{len(graph_data.nodes)}</div>
<div class="stats-label">实体节点</div>
</div>
""", unsafe_allow_html=True)
with col2:
st.markdown(f"""
<div class="stats-card">
<div class="stats-number">{len(graph_data.edges)}</div>
<div class="stats-label">关系边</div>
</div>
""", unsafe_allow_html=True)
# 统计实体类型分布
type_counts = {}
for node in graph_data.nodes:
node_type = node.labels[0] if node.labels else "Unknown"
type_counts[node_type] = type_counts.get(node_type, 0) + 1
with col3:
st.markdown(f"""
<div class="stats-card">
<div class="stats-number">{len(type_counts)}</div>
<div class="stats-label">实体类型</div>
</div>
""", unsafe_allow_html=True)
def display_entity_list(graph_data: GraphData):
"""显示实体列表"""
st.subheader("实体列表")
# 按类型分组
entities_by_type = {}
for node in graph_data.nodes:
node_type = node.labels[0] if node.labels else "Unknown"
if node_type not in entities_by_type:
entities_by_type[node_type] = []
entities_by_type[node_type].append(node)
# 创建标签页
if entities_by_type:
tabs = st.tabs(list(entities_by_type.keys()))
for tab, (entity_type, entities) in zip(tabs, entities_by_type.items()):
with tab:
for entity in entities:
with st.expander(f"{entity.name}", expanded=False):
if entity.summary:
st.write(entity.summary)
if entity.attributes:
st.json(entity.attributes)
def main():
# 标题
st.title("txt2graph")
st.markdown("*将文本转化为知识图谱*")
# 侧边栏
with st.sidebar:
st.header("配置")
# API Key
api_key = st.text_input(
"Zep API Key",
type="password",
value=os.environ.get("ZEP_API_KEY", ""),
help="从 https://app.getzep.com 获取API Key"
)
if api_key:
os.environ["ZEP_API_KEY"] = api_key
st.divider()
# 文件上传
st.header("上传文件")
uploaded_file = st.file_uploader(
"支持 .txt, .md, .pdf 文件",
type=["txt", "md", "pdf"],
help="上传要转换为知识图谱的文本文件"
)
# 或者直接输入文本
st.divider()
st.header("或直接输入文本")
text_input = st.text_area(
"输入文本内容",
height=150,
placeholder="在此输入或粘贴文本..."
)
st.divider()
# 高级设置
with st.expander("高级设置"):
chunk_size = st.slider(
"文本分块大小",
min_value=500,
max_value=4000,
value=2000,
step=500,
help="较小的块处理更稳定,较大的块包含更多上下文"
)
graph_name = st.text_input(
"图谱名称",
value="Knowledge Graph",
help="为生成的图谱命名"
)
st.divider()
# 生成按钮
generate_btn = st.button("生成知识图谱", type="primary", use_container_width=True)
# 主内容区
if "graph_data" not in st.session_state:
st.session_state.graph_data = None
if generate_btn:
if not api_key:
st.error("请先配置 Zep API Key")
return
# 获取文本内容
text_content = None
if uploaded_file:
with st.spinner("正在提取文本..."):
# 保存上传的文件到临时位置
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp:
tmp.write(uploaded_file.getvalue())
tmp_path = tmp.name
try:
text_content = extract_text(tmp_path)
finally:
os.unlink(tmp_path)
elif text_input:
text_content = text_input
else:
st.warning("请上传文件或输入文本")
return
if text_content:
st.info(f"提取了 {len(text_content)} 个字符的文本")
# 进度显示
progress_bar = st.progress(0)
status_text = st.empty()
try:
# 创建图谱构建器
builder = ZepGraphBuilder(api_key=api_key)
# 创建图谱
status_text.text("创建图谱...")
progress_bar.progress(10)
graph_id = builder.create_graph(name=graph_name)
# 设置本体
status_text.text("配置实体类型...")
progress_bar.progress(20)
builder.set_ontology(graph_id)
# 分块
status_text.text("分割文本...")
progress_bar.progress(30)
chunks = split_text_into_chunks(text_content, max_chunk_size=chunk_size)
st.info(f"文本已分为 {len(chunks)} 个块")
# 添加到图谱
status_text.text("正在发送数据到Zep...")
progress_bar.progress(40)
def update_progress(msg):
status_text.text(msg)
# 分批发送数据
task_ids = builder.add_text_to_graph(
graph_id=graph_id,
text_chunks=chunks,
batch_size=3,
progress_callback=update_progress
)
# 等待处理完成
progress_bar.progress(60)
status_text.text("等待Zep处理数据...")
if task_ids:
builder.wait_for_tasks(
task_ids,
timeout=600,
progress_callback=update_progress
)
# 获取图数据
status_text.text("获取图谱数据...")
progress_bar.progress(90)
st.session_state.graph_data = builder.get_graph_data(graph_id)
st.session_state.graph_id = graph_id
progress_bar.progress(100)
status_text.text("完成!")
st.success(f"知识图谱生成成功! Graph ID: {graph_id}")
except Exception as e:
st.error(f"生成图谱时出错: {str(e)}")
import traceback
st.code(traceback.format_exc())
# 显示图谱
if st.session_state.graph_data:
graph_data = st.session_state.graph_data
# 统计信息
display_stats(graph_data)
st.divider()
# 图谱可视化
st.subheader("知识图谱可视化")
if graph_data.nodes:
with st.spinner("渲染图谱..."):
html_content = create_pyvis_graph(graph_data)
components.html(html_content, height=750, scrolling=True)
else:
st.warning("图谱中没有节点")
st.divider()
# 实体列表
col1, col2 = st.columns([1, 1])
with col1:
display_entity_list(graph_data)
with col2:
st.subheader("关系列表")
if graph_data.edges:
for edge in graph_data.edges[:50]: # 只显示前50条
st.markdown(f"- **{edge.fact}**" if edge.fact else f"- {edge.name}")
if len(graph_data.edges) > 50:
st.caption(f"...还有 {len(graph_data.edges) - 50} 条关系")
else:
st.info("暂无关系数据")
if __name__ == "__main__":
main()