Files
rag_chat/haystack_rag/main.py

148 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# main.py
import sys
from haystack import Document
# 需要 OpenAIDocumentEmbedder 来嵌入要写入的文档
from haystack.components.embedders import OpenAIDocumentEmbedder
from haystack.utils import Secret
# 导入所需的配置和构建函数
from config import (
DEFAULT_USER_ID,
OPENAI_API_KEY_FROM_CONFIG,
OPENAI_API_BASE_URL_CONFIG,
OPENAI_EMBEDDING_MODEL,
OPENAI_EMBEDDING_KEY,
OPENAI_EMBEDDING_BASE,
)
from .rag_pipeline import build_rag_pipeline # 构建 RAG 查询管道
# 辅助函数:初始化 Document Embedder (与 embedding.py 中的类似)
def initialize_document_embedder() -> OpenAIDocumentEmbedder:
"""初始化用于嵌入文档的 OpenAIDocumentEmbedder。"""
if not OPENAI_API_KEY_FROM_CONFIG or "YOUR_API_KEY" in OPENAI_API_KEY_FROM_CONFIG:
print("警告: OpenAI API Key 未在 config.py 中有效配置。")
# raise ValueError("OpenAI API Key not configured correctly in config.py")
print(f"Initializing OpenAI Document Embedder with model: {OPENAI_EMBEDDING_MODEL}")
if OPENAI_API_BASE_URL_CONFIG:
print(f"Using custom API base URL from config: {OPENAI_API_BASE_URL_CONFIG}")
else:
print("Using default OpenAI API base URL (None specified in config).")
document_embedder = OpenAIDocumentEmbedder(
api_key=Secret.from_token(OPENAI_EMBEDDING_KEY),
model=OPENAI_EMBEDDING_MODEL,
api_base_url=OPENAI_EMBEDDING_BASE,
# meta_fields_to_embed=["name"] # 如果需要嵌入元数据字段
# embedding_batch_size=10 # 可以调整批处理大小
)
print("OpenAI Document Embedder initialized.")
return document_embedder
def run_chat_session(user_id: str):
"""
运行 RAG 聊天会话主循环。
每次用户输入时,先将其嵌入并添加到 Milvus然后运行 RAG 管道生成回复。
"""
print(f"--- Starting Chat Session for User: {user_id} ---")
# 构建 RAG 查询管道和获取 DocumentStore 实例
rag_query_pipeline, document_store = build_rag_pipeline(user_id)
# 初始化用于写入用户输入的 Document Embedder
document_embedder = initialize_document_embedder()
print("\nChatbot is ready! Type your questions or 'exit' to quit.")
# 打印使用的模型信息
try:
pass
# print(f"Using LLM: {rag_query_pipeline.get_component('generator').model}")
# 注意 RAG pipeline 中 query embedder 的名字是 'text_embedder'
# print(f"Using Query Embedder: {rag_query_pipeline.get_component('text_embedder').model}")
# print(f"Using Document Embedder (for writing): {document_embedder.model}")
except Exception as e:
print(f"Warning: Could not get component model names - {e}")
while True:
try:
query = input(f"[{user_id}] You: ")
if query.lower() == "exit":
print("Exiting chat session. Goodbye!")
break
if not query.strip():
continue
# --- 步骤 1: 嵌入用户输入并写入 Milvus ---
# print(f"[Workflow] Embedding user input as a document...")
# 将用户输入包装成 Haystack Document
user_doc_to_write = Document(content=query, meta={"user_id": user_id})
# 使用 OpenAIDocumentEmbedder 运行嵌入
# 它需要一个列表作为输入,即使只有一个文档
embedding_result = document_embedder.run([user_doc_to_write])
embedded_docs = embedding_result.get(
"documents", []
) # 获取带有嵌入的文档列表
if embedded_docs:
# print(f"[Workflow] Writing embedded document to Milvus for user {user_id}...")
# 将带有嵌入的文档写入 DocumentStore
document_store.write_documents(embedded_docs)
# print("[Workflow] Document written to Milvus.")
else:
print("[Workflow] Warning: Failed to embed document, skipping write.")
# 可以在这里添加错误处理或日志记录
# --- 步骤 2: 使用 RAG 查询管道生成回复 ---
# print("[Workflow] Running RAG query pipeline...")
# 准备 RAG 管道的输入
# text_embedder 需要原始查询文本
# prompt_builder 也需要原始查询文本(在模板中用作 {{query}}
pipeline_input = {
"text_embedder": {"text": query},
"prompt_builder": {"query": query},
}
# 运行 RAG 查询管道
results = rag_query_pipeline.run(pipeline_input)
# --- 步骤 3: 处理并打印结果 ---
# 根据文档示例,生成器的输出在 'generator' 组件的 'replies' 键中
if "llm" in results and results["llm"]["replies"]:
answer = results["llm"]["replies"][0]
# 尝试获取 token 使用量(可能在 meta 中)
total_tokens = "N/A"
try:
# meta 结构可能因版本或配置而异,需要检查确认
if (
"meta" in results["llm"]
and isinstance(results["llm"]["meta"], list)
and results["llm"]["meta"]
):
usage_info = results["llm"]["meta"][0].get("usage", {})
total_tokens = usage_info.get("total_tokens", "N/A")
except Exception:
pass # 忽略获取 token 的错误
print(f"Chatbot: {answer} (Tokens: {total_tokens})")
else:
print("Chatbot: Sorry, I couldn't generate an answer for that.")
print("Debug Info (Pipeline Results):", results) # 打印完整结果以供调试
except KeyboardInterrupt:
print("\nExiting chat session. Goodbye!")
break
except Exception as e:
print(f"\nAn error occurred: {e}")
import traceback
traceback.print_exc() # 打印详细的回溯信息
if __name__ == "__main__":
current_user_id = DEFAULT_USER_ID
run_chat_session(current_user_id)