123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- import os
- import traceback
- import requests
- import json
- from langchain_community.document_loaders import TextLoader
- from langchain.text_splitter import RecursiveCharacterTextSplitter
- import chromadb
- import time
- # 配置参数
- data_directory = "data"
- embed_api_url = "http://192.168.0.33:11434/api/embed"
- model_name = "deepseek-r1:14b" # 确认模型名称正确
- chromadb_persist_dir = "./chroma_db" # 持久化目录
- # 增强日志输出
- print(f"正在扫描目录: {os.path.abspath(data_directory)}")
- sql_files = []
- for root, _, files in os.walk(data_directory):
- for file in files:
- if file.endswith('.sql'):
- sql_files.append(os.path.join(root, file))
- print(f"发现 {len(sql_files)} 个SQL文件")
- # 加载文档并处理编码问题
- documents = []
- for file_path in sql_files:
- try:
- loader = TextLoader(file_path, encoding='utf-8')
- documents.extend(loader.load())
- print(f"已加载: {file_path}")
- except UnicodeDecodeError:
- try:
- loader = TextLoader(file_path, encoding='gbk')
- documents.extend(loader.load())
- print(f"已加载(GBK): {file_path}")
- except Exception as e:
- print(f"文件加载失败 {file_path}: {str(e)}")
- except Exception as e:
- print(f"加载文件异常 {file_path}: {str(e)}")
- print(f"共加载 {len(documents)} 个文档")
- # 优化文本分割(适合SQL)
- text_splitter = RecursiveCharacterTextSplitter(
- chunk_size=10000,
- chunk_overlap=20,
- separators=["\n\n", "\n", ";", " ", ""] # 按SQL语句分割
- )
- texts = text_splitter.split_documents(documents)
- print(f"分割后得到 {len(texts)} 个文本块")
- # 嵌入生成函数(增强重试机制)
- def get_embeddings(texts, max_retries=3):
- headers = {"Content-Type": "application/json"}
- valid_texts = []
- embeddings = []
- for idx, text in enumerate(texts):
- content = text.page_content.strip()
- if not content:
- continue
- for attempt in range(max_retries):
- try:
- response = requests.post(
- embed_api_url,
- json={"model": model_name, "input": content},
- headers=headers,
- timeout=600
- )
- if response.ok:
- embedding = response.json().get("embeddings")
- if isinstance(embedding, list) and len(embedding) > 0:
- # 确保 embedding 是符合要求的格式
- if isinstance(embedding[0], list):
- embedding = embedding[0]
- embeddings.append(embedding)
- valid_texts.append(text)
- print(f"✓ 文本块 {idx + 1}/{len(texts)} 嵌入成功")
- break
- else:
- print(f"× 文本块 {idx + 1} 无效响应: {response.text}")
- else:
- print(f"× 文本块 {idx + 1} 请求失败[{response.status_code}]: {response.text}")
- except Exception as e:
- print(f"× 文本块 {idx + 1} 异常: {str(e)}")
- if attempt < max_retries - 1:
- time.sleep(2 ** attempt) # 指数退避
- else:
- print(f"× 文本块 {idx + 1} 超过最大重试次数")
- print(f"成功生成 {len(embeddings)}/{len(texts)} 个嵌入向量")
- return valid_texts, embeddings
- try:
- # 获取嵌入向量
- print("\n正在生成嵌入向量...")
- valid_texts, embeddings = get_embeddings(texts)
- # 创建持久化客户端
- client = chromadb.PersistentClient(path=chromadb_persist_dir)
- collection = client.get_or_create_collection(
- name="sql_knowledge",
- metadata={"hnsw:space": "cosine"}
- )
- # 准备数据
- ids = [str(i) for i in range(len(valid_texts))]
- documents = [t.page_content for t in valid_texts]
- # 批量插入(分块处理防止内存问题)
- batch_size = 100
- for i in range(0, len(ids), batch_size):
- batch_ids = ids[i:i + batch_size]
- batch_embeddings = embeddings[i:i + batch_size]
- batch_docs = documents[i:i + batch_size]
- # 确保 batch_embeddings 是符合要求的格式
- batch_embeddings = [list(map(float, e)) for e in batch_embeddings]
- collection.upsert(
- ids=batch_ids,
- embeddings=batch_embeddings,
- documents=batch_docs
- )
- print(f"已插入 {i + len(batch_ids)}/{len(ids)} 条数据")
- print("知识库构建完成!持久化存储于:", os.path.abspath(chromadb_persist_dir))
- except Exception as e:
- print(f"主流程异常: {str(e)}")
- traceback.print_exc()
|