123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- import os
- import traceback
- import requests
- import json
- import chromadb
- import time
- import mysql.connector
- from mysql.connector import Error
- # 配置参数
- embed_api_url = "http://192.168.0.33:11434/api/embed"
- chat_api_url = "http://192.168.0.33:11434/api/chat"
- model_name = "deepseek-r1:14b" # 确认模型名称正确
- chromadb_persist_dir = "./chroma_db" # 持久化目录
- # 加载持久化的向量数据库
- client = chromadb.PersistentClient(path=chromadb_persist_dir)
- collection = client.get_collection(name="sql_knowledge")
- # 增强问答函数
- def ask_question(question: str, top_k: int = 3, max_retries=3):
- # 生成问题嵌入
- for attempt in range(max_retries):
- try:
- response = requests.post(
- embed_api_url,
- json={"model": model_name, "input": question},
- timeout=120
- )
- if not response.ok:
- print(f"问题嵌入失败[{response.status_code}]: {response.text}")
- continue
- query_embedding = response.json()["embeddings"]
- # 确保 query_embedding 是一维的浮点数列表
- if isinstance(query_embedding[0], list):
- query_embedding = query_embedding[0]
- query_embedding = list(map(float, query_embedding))
- break
- except Exception as e:
- print(f"问题嵌入异常: {str(e)}")
- if attempt < max_retries - 1:
- time.sleep(2 ** attempt) # 指数退避
- else:
- print("问题嵌入超过最大重试次数")
- return None
- # 向量检索
- try:
- results = collection.query(
- query_embeddings=[query_embedding],
- n_results=top_k,
- include=["documents", "distances"]
- )
- context = "\n\n".join(results["documents"][0])
- print("检索到相关上下文:")
- print(context[:500] + "...") # 打印前500字符避免刷屏
- except Exception as e:
- print(f"向量检索失败: {str(e)}")
- return None
- # 生成回答(流式模式)
- for attempt in range(max_retries):
- try:
- data = {
- "model": model_name,
- "messages": [
- {
- "role": "user",
- "content": f"基于以下json格式表结构知识库上下文回答,如果不知道就说未知:\n{context}\n\n问题:{question}"
- }
- ],
- "stream": True, # 开启流式模式
- "options": {"temperature": 0.3}
- }
- print(f"调用 /api/chat 的请求数据: {data}")
- response = requests.post(
- chat_api_url,
- json=data,
- stream=True, # 开启流式响应
- timeout=600
- )
- if response.status_code == 200:
- full_answer = ""
- print("回答内容:")
- for line in response.iter_lines():
- if line:
- try:
- chunk = line.decode('utf-8')
- # print(f"原始响应行: {chunk}") # 添加日志输出
- chunk = chunk.strip()
- # if chunk.startswith('data:'):
- # chunk = chunk[5:]
- if chunk == "[DONE]":
- break
- try:
- chunk_data = json.loads(chunk)
- except json.JSONDecodeError:
- print(f"无法解析响应行: {chunk}")
- continue
- answer_chunk = chunk_data.get("message", {}).get("content", "")
- full_answer += answer_chunk
- print(answer_chunk, end='', flush=True)
- except Exception as e:
- print(f"处理流式响应异常: {str(e)}")
- print()
- return full_answer
- else:
- print(f"生成回答失败,状态码: {response.status_code},错误信息: {response.text}")
- except requests.Timeout:
- print("请求聊天 API 超时,请检查 API 服务是否正常。")
- except Exception as e:
- print(f"生成回答异常: {str(e)}")
- if attempt < max_retries - 1:
- time.sleep(2 ** attempt) # 指数退避
- print("生成回答超过最大重试次数")
- return None
- # 根据SQL语句查询数据并格式化
- def query_bysql(sqlstr: str):
- try:
- # 建立数据库连接
- connection = mysql.connector.connect(
- host="192.168.0.31", # 数据库主机地址
- user="root", # 数据库用户名
- password="Irongwei@1", # 数据库密码
- database="cx_aps" # 数据库名称
- )
- if connection.is_connected():
- print("成功连接到MySQL数据库")
- # 创建游标对象
- cursor = connection.cursor()
- # 执行SQL查询
- cursor.execute("SELECT VERSION()")
- version = cursor.fetchone()
- print(f"MySQL数据库版本: {version}")
- # 示例:查询数据
- print("select * from (" + sqlstr + ") a where a.DELETED='0'")
- cursor.execute("select * from ("+sqlstr+") a where a.DELETED='0'")
- rows = cursor.fetchall()
- for row in rows:
- print(row)
- except Error as e:
- print(f"连接错误: {e}")
- finally:
- # 关闭连接
- if 'connection' in locals() and connection.is_connected():
- cursor.close()
- connection.close()
- print("数据库连接已关闭")
- # 测试问答
- test_questions = [
- "帮我查询坯料计划表预计来料日期是2025年4月份的数据",
- ]
- for q in test_questions:
- print(f"\n问题:{q}")
- start = time.time()
- answer = ask_question(q+"。只需要生成SQL不需要解释")
- print("-------sql start-------")
- print(answer[answer.index("```sql"):])
- sqlstr = answer[answer.index("```sql"):].replace("```sql","")
- sqlstr = sqlstr.replace("```", "")
- sqlstr = sqlstr.replace(";", "")
- print("-------sql 处理后-------")
- print(sqlstr)
- print("-------sql end-------")
- query_bysql(sqlstr)
- print(f"回答耗时({time.time() - start:.1f}s)")
|