فهرست منبع

AI大模型智能体探索项目初始化

fangpy 4 ماه پیش
کامیت
e5d4635b7f
5فایلهای تغییر یافته به همراه415 افزوده شده و 0 حذف شده
  1. 180 0
      ask_questions.py
  2. 134 0
      build_knowledge_base.py
  3. 82 0
      data/cx_aps.sql
  4. 16 0
      main.py
  5. 3 0
      requirements.txt

+ 180 - 0
ask_questions.py

@@ -0,0 +1,180 @@
+import os
+import traceback
+import requests
+import json
+import chromadb
+import time
+
+import mysql.connector
+from mysql.connector import Error
+
+# 配置参数
+embed_api_url = "http://192.168.0.33:11434/api/embed"
+chat_api_url = "http://192.168.0.33:11434/api/chat"
+model_name = "deepseek-r1:14b"  # 确认模型名称正确
+chromadb_persist_dir = "./chroma_db"  # 持久化目录
+
+# 加载持久化的向量数据库
+client = chromadb.PersistentClient(path=chromadb_persist_dir)
+collection = client.get_collection(name="sql_knowledge")
+
+
+# 增强问答函数
+def ask_question(question: str, top_k: int = 3, max_retries=3):
+    # 生成问题嵌入
+    for attempt in range(max_retries):
+        try:
+            response = requests.post(
+                embed_api_url,
+                json={"model": model_name, "input": question},
+                timeout=120
+            )
+            if not response.ok:
+                print(f"问题嵌入失败[{response.status_code}]: {response.text}")
+                continue
+            query_embedding = response.json()["embeddings"]
+            # 确保 query_embedding 是一维的浮点数列表
+            if isinstance(query_embedding[0], list):
+                query_embedding = query_embedding[0]
+            query_embedding = list(map(float, query_embedding))
+            break
+        except Exception as e:
+            print(f"问题嵌入异常: {str(e)}")
+            if attempt < max_retries - 1:
+                time.sleep(2 ** attempt)  # 指数退避
+    else:
+        print("问题嵌入超过最大重试次数")
+        return None
+
+    # 向量检索
+    try:
+        results = collection.query(
+            query_embeddings=[query_embedding],
+            n_results=top_k,
+            include=["documents", "distances"]
+        )
+        context = "\n\n".join(results["documents"][0])
+        print("检索到相关上下文:")
+        print(context[:500] + "...")  # 打印前500字符避免刷屏
+    except Exception as e:
+        print(f"向量检索失败: {str(e)}")
+        return None
+
+    # 生成回答(流式模式)
+    for attempt in range(max_retries):
+        try:
+            data = {
+                "model": model_name,
+                "messages": [
+                    {
+                        "role": "user",
+                        "content": f"基于以下json格式表结构知识库上下文回答,如果不知道就说未知:\n{context}\n\n问题:{question}"
+                    }
+                ],
+                "stream": True,  # 开启流式模式
+                "options": {"temperature": 0.3}
+            }
+            print(f"调用 /api/chat 的请求数据: {data}")
+            response = requests.post(
+                chat_api_url,
+                json=data,
+                stream=True,  # 开启流式响应
+                timeout=600
+            )
+            if response.status_code == 200:
+                full_answer = ""
+                print("回答内容:")
+                for line in response.iter_lines():
+                    if line:
+                        try:
+                            chunk = line.decode('utf-8')
+                            # print(f"原始响应行: {chunk}")  # 添加日志输出
+                            chunk = chunk.strip()
+                            # if chunk.startswith('data:'):
+                            # chunk = chunk[5:]
+                            if chunk == "[DONE]":
+                                break
+                            try:
+                                chunk_data = json.loads(chunk)
+                            except json.JSONDecodeError:
+                                print(f"无法解析响应行: {chunk}")
+                                continue
+                            answer_chunk = chunk_data.get("message", {}).get("content", "")
+                            full_answer += answer_chunk
+                            print(answer_chunk, end='', flush=True)
+                        except Exception as e:
+                            print(f"处理流式响应异常: {str(e)}")
+                print()
+                return full_answer
+            else:
+                print(f"生成回答失败,状态码: {response.status_code},错误信息: {response.text}")
+        except requests.Timeout:
+            print("请求聊天 API 超时,请检查 API 服务是否正常。")
+        except Exception as e:
+            print(f"生成回答异常: {str(e)}")
+        if attempt < max_retries - 1:
+            time.sleep(2 ** attempt)  # 指数退避
+    print("生成回答超过最大重试次数")
+    return None
+
+
+# 根据SQL语句查询数据并格式化
+def query_bysql(sqlstr: str):
+    try:
+        # 建立数据库连接
+        connection = mysql.connector.connect(
+            host="192.168.0.31",  # 数据库主机地址
+            user="root",  # 数据库用户名
+            password="Irongwei@1",  # 数据库密码
+            database="cx_aps"  # 数据库名称
+        )
+
+        if connection.is_connected():
+            print("成功连接到MySQL数据库")
+
+            # 创建游标对象
+            cursor = connection.cursor()
+
+            # 执行SQL查询
+            cursor.execute("SELECT VERSION()")
+            version = cursor.fetchone()
+            print(f"MySQL数据库版本: {version}")
+
+            # 示例:查询数据
+            print("select * from (" + sqlstr + ") a where a.DELETED='0'")
+            cursor.execute("select * from ("+sqlstr+") a where a.DELETED='0'")
+
+            rows = cursor.fetchall()
+            for row in rows:
+                print(row)
+
+    except Error as e:
+        print(f"连接错误: {e}")
+    finally:
+        # 关闭连接
+        if 'connection' in locals() and connection.is_connected():
+            cursor.close()
+            connection.close()
+            print("数据库连接已关闭")
+
+
+# 测试问答
+test_questions = [
+    "帮我查询坯料计划表预计来料日期是2025年4月份的数据",
+
+]
+
+for q in test_questions:
+    print(f"\n问题:{q}")
+    start = time.time()
+    answer = ask_question(q+"。只需要生成SQL不需要解释")
+    print("-------sql start-------")
+    print(answer[answer.index("```sql"):])
+    sqlstr = answer[answer.index("```sql"):].replace("```sql","")
+    sqlstr = sqlstr.replace("```", "")
+    sqlstr = sqlstr.replace(";", "")
+    print("-------sql 处理后-------")
+    print(sqlstr)
+    print("-------sql end-------")
+    query_bysql(sqlstr)
+    print(f"回答耗时({time.time() - start:.1f}s)")

+ 134 - 0
build_knowledge_base.py

@@ -0,0 +1,134 @@
+import os
+import traceback
+import requests
+import json
+from langchain_community.document_loaders import TextLoader
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+import chromadb
+import time
+
+# 配置参数
+data_directory = "data"
+embed_api_url = "http://192.168.0.33:11434/api/embed"
+model_name = "deepseek-r1:14b"  # 确认模型名称正确
+chromadb_persist_dir = "./chroma_db"  # 持久化目录
+
+# 增强日志输出
+print(f"正在扫描目录: {os.path.abspath(data_directory)}")
+sql_files = []
+for root, _, files in os.walk(data_directory):
+    for file in files:
+        if file.endswith('.sql'):
+            sql_files.append(os.path.join(root, file))
+print(f"发现 {len(sql_files)} 个SQL文件")
+
+# 加载文档并处理编码问题
+documents = []
+for file_path in sql_files:
+    try:
+        loader = TextLoader(file_path, encoding='utf-8')
+        documents.extend(loader.load())
+        print(f"已加载: {file_path}")
+    except UnicodeDecodeError:
+        try:
+            loader = TextLoader(file_path, encoding='gbk')
+            documents.extend(loader.load())
+            print(f"已加载(GBK): {file_path}")
+        except Exception as e:
+            print(f"文件加载失败 {file_path}: {str(e)}")
+    except Exception as e:
+        print(f"加载文件异常 {file_path}: {str(e)}")
+
+print(f"共加载 {len(documents)} 个文档")
+
+# 优化文本分割(适合SQL)
+text_splitter = RecursiveCharacterTextSplitter(
+    chunk_size=10000,
+    chunk_overlap=20,
+    separators=["\n\n", "\n", ";", " ", ""]  # 按SQL语句分割
+)
+texts = text_splitter.split_documents(documents)
+print(f"分割后得到 {len(texts)} 个文本块")
+
+# 嵌入生成函数(增强重试机制)
+def get_embeddings(texts, max_retries=3):
+    headers = {"Content-Type": "application/json"}
+    valid_texts = []
+    embeddings = []
+
+    for idx, text in enumerate(texts):
+        content = text.page_content.strip()
+        if not content:
+            continue
+
+        for attempt in range(max_retries):
+            try:
+                response = requests.post(
+                    embed_api_url,
+                    json={"model": model_name, "input": content},
+                    headers=headers,
+                    timeout=600
+                )
+                if response.ok:
+                    embedding = response.json().get("embeddings")
+                    if isinstance(embedding, list) and len(embedding) > 0:
+                        # 确保 embedding 是符合要求的格式
+                        if isinstance(embedding[0], list):
+                            embedding = embedding[0]
+                        embeddings.append(embedding)
+                        valid_texts.append(text)
+                        print(f"✓ 文本块 {idx + 1}/{len(texts)} 嵌入成功")
+                        break
+                    else:
+                        print(f"× 文本块 {idx + 1} 无效响应: {response.text}")
+                else:
+                    print(f"× 文本块 {idx + 1} 请求失败[{response.status_code}]: {response.text}")
+            except Exception as e:
+                print(f"× 文本块 {idx + 1} 异常: {str(e)}")
+
+            if attempt < max_retries - 1:
+                time.sleep(2 ** attempt)  # 指数退避
+        else:
+            print(f"× 文本块 {idx + 1} 超过最大重试次数")
+
+    print(f"成功生成 {len(embeddings)}/{len(texts)} 个嵌入向量")
+    return valid_texts, embeddings
+
+try:
+    # 获取嵌入向量
+    print("\n正在生成嵌入向量...")
+    valid_texts, embeddings = get_embeddings(texts)
+
+    # 创建持久化客户端
+    client = chromadb.PersistentClient(path=chromadb_persist_dir)
+    collection = client.get_or_create_collection(
+        name="sql_knowledge",
+        metadata={"hnsw:space": "cosine"}
+    )
+
+    # 准备数据
+    ids = [str(i) for i in range(len(valid_texts))]
+    documents = [t.page_content for t in valid_texts]
+
+    # 批量插入(分块处理防止内存问题)
+    batch_size = 100
+    for i in range(0, len(ids), batch_size):
+        batch_ids = ids[i:i + batch_size]
+        batch_embeddings = embeddings[i:i + batch_size]
+        batch_docs = documents[i:i + batch_size]
+
+        # 确保 batch_embeddings 是符合要求的格式
+        batch_embeddings = [list(map(float, e)) for e in batch_embeddings]
+
+        collection.upsert(
+            ids=batch_ids,
+            embeddings=batch_embeddings,
+            documents=batch_docs
+        )
+        print(f"已插入 {i + len(batch_ids)}/{len(ids)} 条数据")
+
+    print("知识库构建完成!持久化存储于:", os.path.abspath(chromadb_persist_dir))
+
+except Exception as e:
+    print(f"主流程异常: {str(e)}")
+    traceback.print_exc()

+ 82 - 0
data/cx_aps.sql

@@ -0,0 +1,82 @@
+{
+	"tables": [
+		{
+			"name": "aps_blank_order",
+			"note": "生产订单_坯料计划",
+			"columns": [
+				{
+					"name": "ID",
+					"type": "varchar",
+					"precision": "36",
+					"note": "主键"
+				},
+				{
+					"name": "PRODUCTIONORDERID",
+					"type": "varchar",
+					"precision": "36",
+					"note": "生产订单ID"
+				},
+				{
+					"name": "BLANKNUMBER",
+					"type": "varchar",
+					"precision": "36",
+					"note": "坯料计划编号"
+				},
+				{
+					"name": "MATERIALNAME",
+					"type": "varchar",
+					"precision": "50",
+					"note": "物料名称"
+				},
+				{
+					"name": "MATERIALCODE",
+					"type": "varchar",
+					"precision": "50",
+					"note": "物料编码"
+				},
+				{
+					"name": "ALLOY",
+					"type": "varchar",
+					"precision": "50",
+					"note": "合金"
+				},
+				{
+					"name": "ALLOYSTATUS",
+					"type": "varchar",
+					"precision": "50",
+					"note": "合金状态"
+				},
+				{
+					"name": "PLANHAVEMATERIALDATE",
+					"type": "datetime",
+					"precision": "0",
+					"note": "预计来料日期"
+				},
+				{
+					"name": "MAXHEATROLL",
+					"type": "int",
+					"precision": "11",
+					"note": "最大装炉卷数"
+				},
+				{
+					"name": "EXPECTEDDAYS",
+					"type": "varchar",
+					"int": "11",
+					"note": "期望交货天数"
+				},
+				{
+					"name": "ROLLNUM",
+					"type": "int",
+					"precision": "11",
+					"note": "卷数"
+				},
+				{
+					"name": "PROWIDTH",
+					"type": "decimal",
+					"precision": "20, 3",
+					"note": "宽度(mm)"
+				}
+			]
+		}
+	]
+}

+ 16 - 0
main.py

@@ -0,0 +1,16 @@
+# This is a sample Python script.
+
+# Press Shift+F10 to execute it or replace it with your code.
+# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
+
+
+def print_hi(name):
+    # Use a breakpoint in the code line below to debug your script.
+    print(f'Hi, {name}')  # Press Ctrl+F8 to toggle the breakpoint.
+
+
+# Press the green button in the gutter to run the script.
+if __name__ == '__main__':
+    print_hi('PyCharm')
+
+# See PyCharm help at https://www.jetbrains.com/help/pycharm/

+ 3 - 0
requirements.txt

@@ -0,0 +1,3 @@
+langchain-community
+chromadb
+requests