148 lines
6.1 KiB
Python
148 lines
6.1 KiB
Python
# main.py
|
||
import sys
|
||
from haystack import Document
|
||
|
||
# 需要 OpenAIDocumentEmbedder 来嵌入要写入的文档
|
||
from haystack.components.embedders import OpenAIDocumentEmbedder
|
||
from haystack.utils import Secret
|
||
|
||
# 导入所需的配置和构建函数
|
||
from config import (
|
||
DEFAULT_USER_ID,
|
||
OPENAI_API_KEY_FROM_CONFIG,
|
||
OPENAI_API_BASE_URL_CONFIG,
|
||
OPENAI_EMBEDDING_MODEL,
|
||
OPENAI_EMBEDDING_KEY,
|
||
OPENAI_EMBEDDING_BASE,
|
||
)
|
||
from .rag_pipeline import build_rag_pipeline # 构建 RAG 查询管道
|
||
|
||
|
||
# 辅助函数:初始化 Document Embedder (与 embedding.py 中的类似)
|
||
def initialize_document_embedder() -> OpenAIDocumentEmbedder:
|
||
"""初始化用于嵌入文档的 OpenAIDocumentEmbedder。"""
|
||
if not OPENAI_API_KEY_FROM_CONFIG or "YOUR_API_KEY" in OPENAI_API_KEY_FROM_CONFIG:
|
||
print("警告: OpenAI API Key 未在 config.py 中有效配置。")
|
||
# raise ValueError("OpenAI API Key not configured correctly in config.py")
|
||
|
||
print(f"Initializing OpenAI Document Embedder with model: {OPENAI_EMBEDDING_MODEL}")
|
||
if OPENAI_API_BASE_URL_CONFIG:
|
||
print(f"Using custom API base URL from config: {OPENAI_API_BASE_URL_CONFIG}")
|
||
else:
|
||
print("Using default OpenAI API base URL (None specified in config).")
|
||
|
||
document_embedder = OpenAIDocumentEmbedder(
|
||
api_key=Secret.from_token(OPENAI_EMBEDDING_KEY),
|
||
model=OPENAI_EMBEDDING_MODEL,
|
||
api_base_url=OPENAI_EMBEDDING_BASE,
|
||
# meta_fields_to_embed=["name"] # 如果需要嵌入元数据字段
|
||
# embedding_batch_size=10 # 可以调整批处理大小
|
||
)
|
||
print("OpenAI Document Embedder initialized.")
|
||
return document_embedder
|
||
|
||
|
||
def run_chat_session(user_id: str):
|
||
"""
|
||
运行 RAG 聊天会话主循环。
|
||
每次用户输入时,先将其嵌入并添加到 Milvus,然后运行 RAG 管道生成回复。
|
||
"""
|
||
print(f"--- Starting Chat Session for User: {user_id} ---")
|
||
|
||
# 构建 RAG 查询管道和获取 DocumentStore 实例
|
||
rag_query_pipeline, document_store = build_rag_pipeline(user_id)
|
||
|
||
# 初始化用于写入用户输入的 Document Embedder
|
||
document_embedder = initialize_document_embedder()
|
||
|
||
print("\nChatbot is ready! Type your questions or 'exit' to quit.")
|
||
# 打印使用的模型信息
|
||
try:
|
||
pass
|
||
# print(f"Using LLM: {rag_query_pipeline.get_component('generator').model}")
|
||
# 注意 RAG pipeline 中 query embedder 的名字是 'text_embedder'
|
||
# print(f"Using Query Embedder: {rag_query_pipeline.get_component('text_embedder').model}")
|
||
# print(f"Using Document Embedder (for writing): {document_embedder.model}")
|
||
except Exception as e:
|
||
print(f"Warning: Could not get component model names - {e}")
|
||
|
||
while True:
|
||
try:
|
||
query = input(f"[{user_id}] You: ")
|
||
if query.lower() == "exit":
|
||
print("Exiting chat session. Goodbye!")
|
||
break
|
||
if not query.strip():
|
||
continue
|
||
|
||
# --- 步骤 1: 嵌入用户输入并写入 Milvus ---
|
||
# print(f"[Workflow] Embedding user input as a document...")
|
||
# 将用户输入包装成 Haystack Document
|
||
user_doc_to_write = Document(content=query, meta={"user_id": user_id})
|
||
|
||
# 使用 OpenAIDocumentEmbedder 运行嵌入
|
||
# 它需要一个列表作为输入,即使只有一个文档
|
||
embedding_result = document_embedder.run([user_doc_to_write])
|
||
embedded_docs = embedding_result.get(
|
||
"documents", []
|
||
) # 获取带有嵌入的文档列表
|
||
|
||
if embedded_docs:
|
||
# print(f"[Workflow] Writing embedded document to Milvus for user {user_id}...")
|
||
# 将带有嵌入的文档写入 DocumentStore
|
||
document_store.write_documents(embedded_docs)
|
||
# print("[Workflow] Document written to Milvus.")
|
||
else:
|
||
print("[Workflow] Warning: Failed to embed document, skipping write.")
|
||
# 可以在这里添加错误处理或日志记录
|
||
|
||
# --- 步骤 2: 使用 RAG 查询管道生成回复 ---
|
||
# print("[Workflow] Running RAG query pipeline...")
|
||
# 准备 RAG 管道的输入
|
||
# text_embedder 需要原始查询文本
|
||
# prompt_builder 也需要原始查询文本(在模板中用作 {{query}})
|
||
pipeline_input = {
|
||
"text_embedder": {"text": query},
|
||
"prompt_builder": {"query": query},
|
||
}
|
||
|
||
# 运行 RAG 查询管道
|
||
results = rag_query_pipeline.run(pipeline_input)
|
||
|
||
# --- 步骤 3: 处理并打印结果 ---
|
||
# 根据文档示例,生成器的输出在 'generator' 组件的 'replies' 键中
|
||
if "llm" in results and results["llm"]["replies"]:
|
||
answer = results["llm"]["replies"][0]
|
||
# 尝试获取 token 使用量(可能在 meta 中)
|
||
total_tokens = "N/A"
|
||
try:
|
||
# meta 结构可能因版本或配置而异,需要检查确认
|
||
if (
|
||
"meta" in results["llm"]
|
||
and isinstance(results["llm"]["meta"], list)
|
||
and results["llm"]["meta"]
|
||
):
|
||
usage_info = results["llm"]["meta"][0].get("usage", {})
|
||
total_tokens = usage_info.get("total_tokens", "N/A")
|
||
except Exception:
|
||
pass # 忽略获取 token 的错误
|
||
|
||
print(f"Chatbot: {answer} (Tokens: {total_tokens})")
|
||
else:
|
||
print("Chatbot: Sorry, I couldn't generate an answer for that.")
|
||
print("Debug Info (Pipeline Results):", results) # 打印完整结果以供调试
|
||
|
||
except KeyboardInterrupt:
|
||
print("\nExiting chat session. Goodbye!")
|
||
break
|
||
except Exception as e:
|
||
print(f"\nAn error occurred: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc() # 打印详细的回溯信息
|
||
|
||
|
||
if __name__ == "__main__":
|
||
current_user_id = DEFAULT_USER_ID
|
||
run_chat_session(current_user_id)
|