|
@@ -0,0 +1,180 @@
|
|
|
+import os
|
|
|
+import traceback
|
|
|
+import requests
|
|
|
+import json
|
|
|
+import chromadb
|
|
|
+import time
|
|
|
+
|
|
|
+import mysql.connector
|
|
|
+from mysql.connector import Error
|
|
|
+
|
|
|
+# 配置参数
|
|
|
+embed_api_url = "http://192.168.0.33:11434/api/embed"
|
|
|
+chat_api_url = "http://192.168.0.33:11434/api/chat"
|
|
|
+model_name = "deepseek-r1:14b" # 确认模型名称正确
|
|
|
+chromadb_persist_dir = "./chroma_db" # 持久化目录
|
|
|
+
|
|
|
+# 加载持久化的向量数据库
|
|
|
+client = chromadb.PersistentClient(path=chromadb_persist_dir)
|
|
|
+collection = client.get_collection(name="sql_knowledge")
|
|
|
+
|
|
|
+
|
|
|
+# 增强问答函数
|
|
|
+def ask_question(question: str, top_k: int = 3, max_retries=3):
|
|
|
+ # 生成问题嵌入
|
|
|
+ for attempt in range(max_retries):
|
|
|
+ try:
|
|
|
+ response = requests.post(
|
|
|
+ embed_api_url,
|
|
|
+ json={"model": model_name, "input": question},
|
|
|
+ timeout=120
|
|
|
+ )
|
|
|
+ if not response.ok:
|
|
|
+ print(f"问题嵌入失败[{response.status_code}]: {response.text}")
|
|
|
+ continue
|
|
|
+ query_embedding = response.json()["embeddings"]
|
|
|
+ # 确保 query_embedding 是一维的浮点数列表
|
|
|
+ if isinstance(query_embedding[0], list):
|
|
|
+ query_embedding = query_embedding[0]
|
|
|
+ query_embedding = list(map(float, query_embedding))
|
|
|
+ break
|
|
|
+ except Exception as e:
|
|
|
+ print(f"问题嵌入异常: {str(e)}")
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ time.sleep(2 ** attempt) # 指数退避
|
|
|
+ else:
|
|
|
+ print("问题嵌入超过最大重试次数")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 向量检索
|
|
|
+ try:
|
|
|
+ results = collection.query(
|
|
|
+ query_embeddings=[query_embedding],
|
|
|
+ n_results=top_k,
|
|
|
+ include=["documents", "distances"]
|
|
|
+ )
|
|
|
+ context = "\n\n".join(results["documents"][0])
|
|
|
+ print("检索到相关上下文:")
|
|
|
+ print(context[:500] + "...") # 打印前500字符避免刷屏
|
|
|
+ except Exception as e:
|
|
|
+ print(f"向量检索失败: {str(e)}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 生成回答(流式模式)
|
|
|
+ for attempt in range(max_retries):
|
|
|
+ try:
|
|
|
+ data = {
|
|
|
+ "model": model_name,
|
|
|
+ "messages": [
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": f"基于以下json格式表结构知识库上下文回答,如果不知道就说未知:\n{context}\n\n问题:{question}"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "stream": True, # 开启流式模式
|
|
|
+ "options": {"temperature": 0.3}
|
|
|
+ }
|
|
|
+ print(f"调用 /api/chat 的请求数据: {data}")
|
|
|
+ response = requests.post(
|
|
|
+ chat_api_url,
|
|
|
+ json=data,
|
|
|
+ stream=True, # 开启流式响应
|
|
|
+ timeout=600
|
|
|
+ )
|
|
|
+ if response.status_code == 200:
|
|
|
+ full_answer = ""
|
|
|
+ print("回答内容:")
|
|
|
+ for line in response.iter_lines():
|
|
|
+ if line:
|
|
|
+ try:
|
|
|
+ chunk = line.decode('utf-8')
|
|
|
+ # print(f"原始响应行: {chunk}") # 添加日志输出
|
|
|
+ chunk = chunk.strip()
|
|
|
+ # if chunk.startswith('data:'):
|
|
|
+ # chunk = chunk[5:]
|
|
|
+ if chunk == "[DONE]":
|
|
|
+ break
|
|
|
+ try:
|
|
|
+ chunk_data = json.loads(chunk)
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ print(f"无法解析响应行: {chunk}")
|
|
|
+ continue
|
|
|
+ answer_chunk = chunk_data.get("message", {}).get("content", "")
|
|
|
+ full_answer += answer_chunk
|
|
|
+ print(answer_chunk, end='', flush=True)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"处理流式响应异常: {str(e)}")
|
|
|
+ print()
|
|
|
+ return full_answer
|
|
|
+ else:
|
|
|
+ print(f"生成回答失败,状态码: {response.status_code},错误信息: {response.text}")
|
|
|
+ except requests.Timeout:
|
|
|
+ print("请求聊天 API 超时,请检查 API 服务是否正常。")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"生成回答异常: {str(e)}")
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ time.sleep(2 ** attempt) # 指数退避
|
|
|
+ print("生成回答超过最大重试次数")
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+# 根据SQL语句查询数据并格式化
|
|
|
+def query_bysql(sqlstr: str):
|
|
|
+ try:
|
|
|
+ # 建立数据库连接
|
|
|
+ connection = mysql.connector.connect(
|
|
|
+ host="192.168.0.31", # 数据库主机地址
|
|
|
+ user="root", # 数据库用户名
|
|
|
+ password="Irongwei@1", # 数据库密码
|
|
|
+ database="cx_aps" # 数据库名称
|
|
|
+ )
|
|
|
+
|
|
|
+ if connection.is_connected():
|
|
|
+ print("成功连接到MySQL数据库")
|
|
|
+
|
|
|
+ # 创建游标对象
|
|
|
+ cursor = connection.cursor()
|
|
|
+
|
|
|
+ # 执行SQL查询
|
|
|
+ cursor.execute("SELECT VERSION()")
|
|
|
+ version = cursor.fetchone()
|
|
|
+ print(f"MySQL数据库版本: {version}")
|
|
|
+
|
|
|
+ # 示例:查询数据
|
|
|
+ print("select * from (" + sqlstr + ") a where a.DELETED='0'")
|
|
|
+ cursor.execute("select * from ("+sqlstr+") a where a.DELETED='0'")
|
|
|
+
|
|
|
+ rows = cursor.fetchall()
|
|
|
+ for row in rows:
|
|
|
+ print(row)
|
|
|
+
|
|
|
+ except Error as e:
|
|
|
+ print(f"连接错误: {e}")
|
|
|
+ finally:
|
|
|
+ # 关闭连接
|
|
|
+ if 'connection' in locals() and connection.is_connected():
|
|
|
+ cursor.close()
|
|
|
+ connection.close()
|
|
|
+ print("数据库连接已关闭")
|
|
|
+
|
|
|
+
|
|
|
+# 测试问答
|
|
|
+test_questions = [
|
|
|
+ "帮我查询坯料计划表预计来料日期是2025年4月份的数据",
|
|
|
+
|
|
|
+]
|
|
|
+
|
|
|
+for q in test_questions:
|
|
|
+ print(f"\n问题:{q}")
|
|
|
+ start = time.time()
|
|
|
+ answer = ask_question(q+"。只需要生成SQL不需要解释")
|
|
|
+ print("-------sql start-------")
|
|
|
+ print(answer[answer.index("```sql"):])
|
|
|
+ sqlstr = answer[answer.index("```sql"):].replace("```sql","")
|
|
|
+ sqlstr = sqlstr.replace("```", "")
|
|
|
+ sqlstr = sqlstr.replace(";", "")
|
|
|
+ print("-------sql 处理后-------")
|
|
|
+ print(sqlstr)
|
|
|
+ print("-------sql end-------")
|
|
|
+ query_bysql(sqlstr)
|
|
|
+ print(f"回答耗时({time.time() - start:.1f}s)")
|