import os import traceback import requests import json from langchain_community.document_loaders import TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter import chromadb import time # 配置参数 data_directory = "data" embed_api_url = "http://192.168.0.33:11434/api/embed" model_name = "deepseek-r1:14b" # 确认模型名称正确 chromadb_persist_dir = "./chroma_db" # 持久化目录 # 增强日志输出 print(f"正在扫描目录: {os.path.abspath(data_directory)}") sql_files = [] for root, _, files in os.walk(data_directory): for file in files: if file.endswith('.sql'): sql_files.append(os.path.join(root, file)) print(f"发现 {len(sql_files)} 个SQL文件") # 加载文档并处理编码问题 documents = [] for file_path in sql_files: try: loader = TextLoader(file_path, encoding='utf-8') documents.extend(loader.load()) print(f"已加载: {file_path}") except UnicodeDecodeError: try: loader = TextLoader(file_path, encoding='gbk') documents.extend(loader.load()) print(f"已加载(GBK): {file_path}") except Exception as e: print(f"文件加载失败 {file_path}: {str(e)}") except Exception as e: print(f"加载文件异常 {file_path}: {str(e)}") print(f"共加载 {len(documents)} 个文档") # 优化文本分割(适合SQL) text_splitter = RecursiveCharacterTextSplitter( chunk_size=10000, chunk_overlap=20, separators=["\n\n", "\n", ";", " ", ""] # 按SQL语句分割 ) texts = text_splitter.split_documents(documents) print(f"分割后得到 {len(texts)} 个文本块") # 嵌入生成函数(增强重试机制) def get_embeddings(texts, max_retries=3): headers = {"Content-Type": "application/json"} valid_texts = [] embeddings = [] for idx, text in enumerate(texts): content = text.page_content.strip() if not content: continue for attempt in range(max_retries): try: response = requests.post( embed_api_url, json={"model": model_name, "input": content}, headers=headers, timeout=600 ) if response.ok: embedding = response.json().get("embeddings") if isinstance(embedding, list) and len(embedding) > 0: # 确保 embedding 是符合要求的格式 if isinstance(embedding[0], list): embedding = embedding[0] embeddings.append(embedding) valid_texts.append(text) print(f"✓ 文本块 {idx + 1}/{len(texts)} 嵌入成功") break else: print(f"× 文本块 {idx + 1} 无效响应: {response.text}") else: print(f"× 文本块 {idx + 1} 请求失败[{response.status_code}]: {response.text}") except Exception as e: print(f"× 文本块 {idx + 1} 异常: {str(e)}") if attempt < max_retries - 1: time.sleep(2 ** attempt) # 指数退避 else: print(f"× 文本块 {idx + 1} 超过最大重试次数") print(f"成功生成 {len(embeddings)}/{len(texts)} 个嵌入向量") return valid_texts, embeddings try: # 获取嵌入向量 print("\n正在生成嵌入向量...") valid_texts, embeddings = get_embeddings(texts) # 创建持久化客户端 client = chromadb.PersistentClient(path=chromadb_persist_dir) collection = client.get_or_create_collection( name="sql_knowledge", metadata={"hnsw:space": "cosine"} ) # 准备数据 ids = [str(i) for i in range(len(valid_texts))] documents = [t.page_content for t in valid_texts] # 批量插入(分块处理防止内存问题) batch_size = 100 for i in range(0, len(ids), batch_size): batch_ids = ids[i:i + batch_size] batch_embeddings = embeddings[i:i + batch_size] batch_docs = documents[i:i + batch_size] # 确保 batch_embeddings 是符合要求的格式 batch_embeddings = [list(map(float, e)) for e in batch_embeddings] collection.upsert( ids=batch_ids, embeddings=batch_embeddings, documents=batch_docs ) print(f"已插入 {i + len(batch_ids)}/{len(ids)} 条数据") print("知识库构建完成!持久化存储于:", os.path.abspath(chromadb_persist_dir)) except Exception as e: print(f"主流程异常: {str(e)}") traceback.print_exc()