Introduce Project ID for context management, finalizing the stateful API pipeline from file submission to graph construction.
This commit is contained in:
9
backend/app/utils/__init__.py
Normal file
9
backend/app/utils/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
工具模块
|
||||
"""
|
||||
|
||||
from .file_parser import FileParser
|
||||
from .llm_client import LLMClient
|
||||
|
||||
__all__ = ['FileParser', 'LLMClient']
|
||||
|
||||
141
backend/app/utils/file_parser.py
Normal file
141
backend/app/utils/file_parser.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
文件解析工具
|
||||
支持PDF、Markdown、TXT文件的文本提取
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class FileParser:
|
||||
"""文件解析器"""
|
||||
|
||||
SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'}
|
||||
|
||||
@classmethod
|
||||
def extract_text(cls, file_path: str) -> str:
|
||||
"""
|
||||
从文件中提取文本
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
suffix = path.suffix.lower()
|
||||
|
||||
if suffix not in cls.SUPPORTED_EXTENSIONS:
|
||||
raise ValueError(f"不支持的文件格式: {suffix}")
|
||||
|
||||
if suffix == '.pdf':
|
||||
return cls._extract_from_pdf(file_path)
|
||||
elif suffix in {'.md', '.markdown'}:
|
||||
return cls._extract_from_md(file_path)
|
||||
elif suffix == '.txt':
|
||||
return cls._extract_from_txt(file_path)
|
||||
|
||||
raise ValueError(f"无法处理的文件格式: {suffix}")
|
||||
|
||||
@staticmethod
|
||||
def _extract_from_pdf(file_path: str) -> str:
|
||||
"""从PDF提取文本"""
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
except ImportError:
|
||||
raise ImportError("需要安装PyMuPDF: pip install PyMuPDF")
|
||||
|
||||
text_parts = []
|
||||
with fitz.open(file_path) as doc:
|
||||
for page in doc:
|
||||
text = page.get_text()
|
||||
if text.strip():
|
||||
text_parts.append(text)
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
@staticmethod
|
||||
def _extract_from_md(file_path: str) -> str:
|
||||
"""从Markdown提取文本"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
@staticmethod
|
||||
def _extract_from_txt(file_path: str) -> str:
|
||||
"""从TXT提取文本"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
@classmethod
|
||||
def extract_from_multiple(cls, file_paths: List[str]) -> str:
|
||||
"""
|
||||
从多个文件提取文本并合并
|
||||
|
||||
Args:
|
||||
file_paths: 文件路径列表
|
||||
|
||||
Returns:
|
||||
合并后的文本
|
||||
"""
|
||||
all_texts = []
|
||||
|
||||
for i, file_path in enumerate(file_paths, 1):
|
||||
try:
|
||||
text = cls.extract_text(file_path)
|
||||
filename = Path(file_path).name
|
||||
all_texts.append(f"=== 文档 {i}: {filename} ===\n{text}")
|
||||
except Exception as e:
|
||||
all_texts.append(f"=== 文档 {i}: {file_path} (提取失败: {str(e)}) ===")
|
||||
|
||||
return "\n\n".join(all_texts)
|
||||
|
||||
|
||||
def split_text_into_chunks(
|
||||
text: str,
|
||||
chunk_size: int = 500,
|
||||
overlap: int = 50
|
||||
) -> List[str]:
|
||||
"""
|
||||
将文本分割成小块
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
chunk_size: 每块的字符数
|
||||
overlap: 重叠字符数
|
||||
|
||||
Returns:
|
||||
文本块列表
|
||||
"""
|
||||
if len(text) <= chunk_size:
|
||||
return [text] if text.strip() else []
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
|
||||
# 尝试在句子边界处分割
|
||||
if end < len(text):
|
||||
# 查找最近的句子结束符
|
||||
for sep in ['。', '!', '?', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']:
|
||||
last_sep = text[start:end].rfind(sep)
|
||||
if last_sep != -1 and last_sep > chunk_size * 0.3:
|
||||
end = start + last_sep + len(sep)
|
||||
break
|
||||
|
||||
chunk = text[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
# 下一个块从重叠位置开始
|
||||
start = end - overlap if end < len(text) else len(text)
|
||||
|
||||
return chunks
|
||||
|
||||
91
backend/app/utils/llm_client.py
Normal file
91
backend/app/utils/llm_client.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
LLM客户端封装
|
||||
统一使用OpenAI格式调用
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Dict, Any, List
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""LLM客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
model: Optional[str] = None
|
||||
):
|
||||
self.api_key = api_key or Config.LLM_API_KEY
|
||||
self.base_url = base_url or Config.LLM_BASE_URL
|
||||
self.model = model or Config.LLM_MODEL_NAME
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
response_format: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""
|
||||
发送聊天请求
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
response_format: 响应格式(如JSON模式)
|
||||
|
||||
Returns:
|
||||
模型响应文本
|
||||
"""
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def chat_json(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 4096
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送聊天请求并返回JSON
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
解析后的JSON对象
|
||||
"""
|
||||
response = self.chat(
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
|
||||
return json.loads(response)
|
||||
|
||||
107
backend/app/utils/logger.py
Normal file
107
backend/app/utils/logger.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
日志配置模块
|
||||
提供统一的日志管理,同时输出到控制台和文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
|
||||
# 日志目录
|
||||
LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs')
|
||||
|
||||
|
||||
def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger:
|
||||
"""
|
||||
设置日志器
|
||||
|
||||
Args:
|
||||
name: 日志器名称
|
||||
level: 日志级别
|
||||
|
||||
Returns:
|
||||
配置好的日志器
|
||||
"""
|
||||
# 确保日志目录存在
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
|
||||
# 创建日志器
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
# 如果已经有处理器,不重复添加
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
# 日志格式
|
||||
detailed_formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
simple_formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(levelname)s: %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
|
||||
# 1. 文件处理器 - 详细日志(按日期命名,带轮转)
|
||||
log_filename = datetime.now().strftime('%Y-%m-%d') + '.log'
|
||||
file_handler = RotatingFileHandler(
|
||||
os.path.join(LOG_DIR, log_filename),
|
||||
maxBytes=10 * 1024 * 1024, # 10MB
|
||||
backupCount=5,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(detailed_formatter)
|
||||
|
||||
# 2. 控制台处理器 - 简洁日志(INFO及以上)
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(simple_formatter)
|
||||
|
||||
# 添加处理器
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name: str = 'mirofish') -> logging.Logger:
|
||||
"""
|
||||
获取日志器(如果不存在则创建)
|
||||
|
||||
Args:
|
||||
name: 日志器名称
|
||||
|
||||
Returns:
|
||||
日志器实例
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
if not logger.handlers:
|
||||
return setup_logger(name)
|
||||
return logger
|
||||
|
||||
|
||||
# 创建默认日志器
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# 便捷方法
|
||||
def debug(msg, *args, **kwargs):
|
||||
logger.debug(msg, *args, **kwargs)
|
||||
|
||||
def info(msg, *args, **kwargs):
|
||||
logger.info(msg, *args, **kwargs)
|
||||
|
||||
def warning(msg, *args, **kwargs):
|
||||
logger.warning(msg, *args, **kwargs)
|
||||
|
||||
def error(msg, *args, **kwargs):
|
||||
logger.error(msg, *args, **kwargs)
|
||||
|
||||
def critical(msg, *args, **kwargs):
|
||||
logger.critical(msg, *args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user