ask_questions.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import os
  2. import traceback
  3. import requests
  4. import json
  5. import chromadb
  6. import time
  7. import mysql.connector
  8. from mysql.connector import Error
  9. # 配置参数
  10. embed_api_url = "http://192.168.0.33:11434/api/embed"
  11. chat_api_url = "http://192.168.0.33:11434/api/chat"
  12. model_name = "deepseek-r1:14b" # 确认模型名称正确
  13. chromadb_persist_dir = "./chroma_db" # 持久化目录
  14. # 加载持久化的向量数据库
  15. client = chromadb.PersistentClient(path=chromadb_persist_dir)
  16. collection = client.get_collection(name="sql_knowledge")
  17. # 增强问答函数
  18. def ask_question(question: str, top_k: int = 3, max_retries=3):
  19. # 生成问题嵌入
  20. for attempt in range(max_retries):
  21. try:
  22. response = requests.post(
  23. embed_api_url,
  24. json={"model": model_name, "input": question},
  25. timeout=120
  26. )
  27. if not response.ok:
  28. print(f"问题嵌入失败[{response.status_code}]: {response.text}")
  29. continue
  30. query_embedding = response.json()["embeddings"]
  31. # 确保 query_embedding 是一维的浮点数列表
  32. if isinstance(query_embedding[0], list):
  33. query_embedding = query_embedding[0]
  34. query_embedding = list(map(float, query_embedding))
  35. break
  36. except Exception as e:
  37. print(f"问题嵌入异常: {str(e)}")
  38. if attempt < max_retries - 1:
  39. time.sleep(2 ** attempt) # 指数退避
  40. else:
  41. print("问题嵌入超过最大重试次数")
  42. return None
  43. # 向量检索
  44. try:
  45. results = collection.query(
  46. query_embeddings=[query_embedding],
  47. n_results=top_k,
  48. include=["documents", "distances"]
  49. )
  50. context = "\n\n".join(results["documents"][0])
  51. print("检索到相关上下文:")
  52. print(context[:500] + "...") # 打印前500字符避免刷屏
  53. except Exception as e:
  54. print(f"向量检索失败: {str(e)}")
  55. return None
  56. # 生成回答(流式模式)
  57. for attempt in range(max_retries):
  58. try:
  59. data = {
  60. "model": model_name,
  61. "messages": [
  62. {
  63. "role": "user",
  64. "content": f"基于以下json格式表结构知识库上下文回答,如果不知道就说未知:\n{context}\n\n问题:{question}"
  65. }
  66. ],
  67. "stream": True, # 开启流式模式
  68. "options": {"temperature": 0.3}
  69. }
  70. print(f"调用 /api/chat 的请求数据: {data}")
  71. response = requests.post(
  72. chat_api_url,
  73. json=data,
  74. stream=True, # 开启流式响应
  75. timeout=600
  76. )
  77. if response.status_code == 200:
  78. full_answer = ""
  79. print("回答内容:")
  80. for line in response.iter_lines():
  81. if line:
  82. try:
  83. chunk = line.decode('utf-8')
  84. # print(f"原始响应行: {chunk}") # 添加日志输出
  85. chunk = chunk.strip()
  86. # if chunk.startswith('data:'):
  87. # chunk = chunk[5:]
  88. if chunk == "[DONE]":
  89. break
  90. try:
  91. chunk_data = json.loads(chunk)
  92. except json.JSONDecodeError:
  93. print(f"无法解析响应行: {chunk}")
  94. continue
  95. answer_chunk = chunk_data.get("message", {}).get("content", "")
  96. full_answer += answer_chunk
  97. print(answer_chunk, end='', flush=True)
  98. except Exception as e:
  99. print(f"处理流式响应异常: {str(e)}")
  100. print()
  101. return full_answer
  102. else:
  103. print(f"生成回答失败,状态码: {response.status_code},错误信息: {response.text}")
  104. except requests.Timeout:
  105. print("请求聊天 API 超时,请检查 API 服务是否正常。")
  106. except Exception as e:
  107. print(f"生成回答异常: {str(e)}")
  108. if attempt < max_retries - 1:
  109. time.sleep(2 ** attempt) # 指数退避
  110. print("生成回答超过最大重试次数")
  111. return None
  112. # 根据SQL语句查询数据并格式化
  113. def query_bysql(sqlstr: str):
  114. try:
  115. # 建立数据库连接
  116. connection = mysql.connector.connect(
  117. host="192.168.0.31", # 数据库主机地址
  118. user="root", # 数据库用户名
  119. password="Irongwei@1", # 数据库密码
  120. database="cx_aps" # 数据库名称
  121. )
  122. if connection.is_connected():
  123. print("成功连接到MySQL数据库")
  124. # 创建游标对象
  125. cursor = connection.cursor()
  126. # 执行SQL查询
  127. cursor.execute("SELECT VERSION()")
  128. version = cursor.fetchone()
  129. print(f"MySQL数据库版本: {version}")
  130. # 示例:查询数据
  131. print("select * from (" + sqlstr + ") a where a.DELETED='0'")
  132. cursor.execute("select * from ("+sqlstr+") a where a.DELETED='0'")
  133. rows = cursor.fetchall()
  134. for row in rows:
  135. print(row)
  136. except Error as e:
  137. print(f"连接错误: {e}")
  138. finally:
  139. # 关闭连接
  140. if 'connection' in locals() and connection.is_connected():
  141. cursor.close()
  142. connection.close()
  143. print("数据库连接已关闭")
  144. # 测试问答
  145. test_questions = [
  146. "帮我查询坯料计划表预计来料日期是2025年4月份的数据",
  147. ]
  148. for q in test_questions:
  149. print(f"\n问题:{q}")
  150. start = time.time()
  151. answer = ask_question(q+"。只需要生成SQL不需要解释")
  152. print("-------sql start-------")
  153. print(answer[answer.index("```sql"):])
  154. sqlstr = answer[answer.index("```sql"):].replace("```sql","")
  155. sqlstr = sqlstr.replace("```", "")
  156. sqlstr = sqlstr.replace(";", "")
  157. print("-------sql 处理后-------")
  158. print(sqlstr)
  159. print("-------sql end-------")
  160. query_bysql(sqlstr)
  161. print(f"回答耗时({time.time() - start:.1f}s)")