import os import traceback import requests import json import chromadb import time import mysql.connector from mysql.connector import Error # 配置参数 embed_api_url = "http://192.168.0.33:11434/api/embed" chat_api_url = "http://192.168.0.33:11434/api/chat" model_name = "deepseek-r1:14b" # 确认模型名称正确 chromadb_persist_dir = "./chroma_db" # 持久化目录 # 加载持久化的向量数据库 client = chromadb.PersistentClient(path=chromadb_persist_dir) collection = client.get_collection(name="sql_knowledge") # 增强问答函数 def ask_question(question: str, top_k: int = 3, max_retries=3): # 生成问题嵌入 for attempt in range(max_retries): try: response = requests.post( embed_api_url, json={"model": model_name, "input": question}, timeout=120 ) if not response.ok: print(f"问题嵌入失败[{response.status_code}]: {response.text}") continue query_embedding = response.json()["embeddings"] # 确保 query_embedding 是一维的浮点数列表 if isinstance(query_embedding[0], list): query_embedding = query_embedding[0] query_embedding = list(map(float, query_embedding)) break except Exception as e: print(f"问题嵌入异常: {str(e)}") if attempt < max_retries - 1: time.sleep(2 ** attempt) # 指数退避 else: print("问题嵌入超过最大重试次数") return None # 向量检索 try: results = collection.query( query_embeddings=[query_embedding], n_results=top_k, include=["documents", "distances"] ) context = "\n\n".join(results["documents"][0]) print("检索到相关上下文:") print(context[:500] + "...") # 打印前500字符避免刷屏 except Exception as e: print(f"向量检索失败: {str(e)}") return None # 生成回答(流式模式) for attempt in range(max_retries): try: data = { "model": model_name, "messages": [ { "role": "user", "content": f"基于以下json格式表结构知识库上下文回答,如果不知道就说未知:\n{context}\n\n问题:{question}" } ], "stream": True, # 开启流式模式 "options": {"temperature": 0.3} } print(f"调用 /api/chat 的请求数据: {data}") response = requests.post( chat_api_url, json=data, stream=True, # 开启流式响应 timeout=600 ) if response.status_code == 200: full_answer = "" print("回答内容:") for line in response.iter_lines(): if line: try: chunk = line.decode('utf-8') # print(f"原始响应行: {chunk}") # 添加日志输出 chunk = chunk.strip() # if chunk.startswith('data:'): # chunk = chunk[5:] if chunk == "[DONE]": break try: chunk_data = json.loads(chunk) except json.JSONDecodeError: print(f"无法解析响应行: {chunk}") continue answer_chunk = chunk_data.get("message", {}).get("content", "") full_answer += answer_chunk print(answer_chunk, end='', flush=True) except Exception as e: print(f"处理流式响应异常: {str(e)}") print() return full_answer else: print(f"生成回答失败,状态码: {response.status_code},错误信息: {response.text}") except requests.Timeout: print("请求聊天 API 超时,请检查 API 服务是否正常。") except Exception as e: print(f"生成回答异常: {str(e)}") if attempt < max_retries - 1: time.sleep(2 ** attempt) # 指数退避 print("生成回答超过最大重试次数") return None # 根据SQL语句查询数据并格式化 def query_bysql(sqlstr: str): try: # 建立数据库连接 connection = mysql.connector.connect( host="192.168.0.31", # 数据库主机地址 user="root", # 数据库用户名 password="Irongwei@1", # 数据库密码 database="cx_aps" # 数据库名称 ) if connection.is_connected(): print("成功连接到MySQL数据库") # 创建游标对象 cursor = connection.cursor() # 执行SQL查询 cursor.execute("SELECT VERSION()") version = cursor.fetchone() print(f"MySQL数据库版本: {version}") # 示例:查询数据 print("select * from (" + sqlstr + ") a where a.DELETED='0'") cursor.execute("select * from ("+sqlstr+") a where a.DELETED='0'") rows = cursor.fetchall() for row in rows: print(row) except Error as e: print(f"连接错误: {e}") finally: # 关闭连接 if 'connection' in locals() and connection.is_connected(): cursor.close() connection.close() print("数据库连接已关闭") # 测试问答 test_questions = [ "帮我查询坯料计划表预计来料日期是2025年4月份的数据", ] for q in test_questions: print(f"\n问题:{q}") start = time.time() answer = ask_question(q+"。只需要生成SQL不需要解释") print("-------sql start-------") print(answer[answer.index("```sql"):]) sqlstr = answer[answer.index("```sql"):].replace("```sql","") sqlstr = sqlstr.replace("```", "") sqlstr = sqlstr.replace(";", "") print("-------sql 处理后-------") print(sqlstr) print("-------sql end-------") query_bysql(sqlstr) print(f"回答耗时({time.time() - start:.1f}s)")