build_knowledge_base.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import os
  2. import traceback
  3. import requests
  4. import json
  5. from langchain_community.document_loaders import TextLoader
  6. from langchain.text_splitter import RecursiveCharacterTextSplitter
  7. import chromadb
  8. import time
  9. # 配置参数
  10. data_directory = "data"
  11. embed_api_url = "http://192.168.0.33:11434/api/embed"
  12. model_name = "deepseek-r1:14b" # 确认模型名称正确
  13. chromadb_persist_dir = "./chroma_db" # 持久化目录
  14. # 增强日志输出
  15. print(f"正在扫描目录: {os.path.abspath(data_directory)}")
  16. sql_files = []
  17. for root, _, files in os.walk(data_directory):
  18. for file in files:
  19. if file.endswith('.sql'):
  20. sql_files.append(os.path.join(root, file))
  21. print(f"发现 {len(sql_files)} 个SQL文件")
  22. # 加载文档并处理编码问题
  23. documents = []
  24. for file_path in sql_files:
  25. try:
  26. loader = TextLoader(file_path, encoding='utf-8')
  27. documents.extend(loader.load())
  28. print(f"已加载: {file_path}")
  29. except UnicodeDecodeError:
  30. try:
  31. loader = TextLoader(file_path, encoding='gbk')
  32. documents.extend(loader.load())
  33. print(f"已加载(GBK): {file_path}")
  34. except Exception as e:
  35. print(f"文件加载失败 {file_path}: {str(e)}")
  36. except Exception as e:
  37. print(f"加载文件异常 {file_path}: {str(e)}")
  38. print(f"共加载 {len(documents)} 个文档")
  39. # 优化文本分割(适合SQL)
  40. text_splitter = RecursiveCharacterTextSplitter(
  41. chunk_size=10000,
  42. chunk_overlap=20,
  43. separators=["\n\n", "\n", ";", " ", ""] # 按SQL语句分割
  44. )
  45. texts = text_splitter.split_documents(documents)
  46. print(f"分割后得到 {len(texts)} 个文本块")
  47. # 嵌入生成函数(增强重试机制)
  48. def get_embeddings(texts, max_retries=3):
  49. headers = {"Content-Type": "application/json"}
  50. valid_texts = []
  51. embeddings = []
  52. for idx, text in enumerate(texts):
  53. content = text.page_content.strip()
  54. if not content:
  55. continue
  56. for attempt in range(max_retries):
  57. try:
  58. response = requests.post(
  59. embed_api_url,
  60. json={"model": model_name, "input": content},
  61. headers=headers,
  62. timeout=600
  63. )
  64. if response.ok:
  65. embedding = response.json().get("embeddings")
  66. if isinstance(embedding, list) and len(embedding) > 0:
  67. # 确保 embedding 是符合要求的格式
  68. if isinstance(embedding[0], list):
  69. embedding = embedding[0]
  70. embeddings.append(embedding)
  71. valid_texts.append(text)
  72. print(f"✓ 文本块 {idx + 1}/{len(texts)} 嵌入成功")
  73. break
  74. else:
  75. print(f"× 文本块 {idx + 1} 无效响应: {response.text}")
  76. else:
  77. print(f"× 文本块 {idx + 1} 请求失败[{response.status_code}]: {response.text}")
  78. except Exception as e:
  79. print(f"× 文本块 {idx + 1} 异常: {str(e)}")
  80. if attempt < max_retries - 1:
  81. time.sleep(2 ** attempt) # 指数退避
  82. else:
  83. print(f"× 文本块 {idx + 1} 超过最大重试次数")
  84. print(f"成功生成 {len(embeddings)}/{len(texts)} 个嵌入向量")
  85. return valid_texts, embeddings
  86. try:
  87. # 获取嵌入向量
  88. print("\n正在生成嵌入向量...")
  89. valid_texts, embeddings = get_embeddings(texts)
  90. # 创建持久化客户端
  91. client = chromadb.PersistentClient(path=chromadb_persist_dir)
  92. collection = client.get_or_create_collection(
  93. name="sql_knowledge",
  94. metadata={"hnsw:space": "cosine"}
  95. )
  96. # 准备数据
  97. ids = [str(i) for i in range(len(valid_texts))]
  98. documents = [t.page_content for t in valid_texts]
  99. # 批量插入(分块处理防止内存问题)
  100. batch_size = 100
  101. for i in range(0, len(ids), batch_size):
  102. batch_ids = ids[i:i + batch_size]
  103. batch_embeddings = embeddings[i:i + batch_size]
  104. batch_docs = documents[i:i + batch_size]
  105. # 确保 batch_embeddings 是符合要求的格式
  106. batch_embeddings = [list(map(float, e)) for e in batch_embeddings]
  107. collection.upsert(
  108. ids=batch_ids,
  109. embeddings=batch_embeddings,
  110. documents=batch_docs
  111. )
  112. print(f"已插入 {i + len(batch_ids)}/{len(ids)} 条数据")
  113. print("知识库构建完成!持久化存储于:", os.path.abspath(chromadb_persist_dir))
  114. except Exception as e:
  115. print(f"主流程异常: {str(e)}")
  116. traceback.print_exc()