feat(docker): containerize application and add TTS integration
This commit is contained in:
147
haystack_rag/main.py
Normal file
147
haystack_rag/main.py
Normal file
@ -0,0 +1,147 @@
|
||||
# main.py
|
||||
import sys
|
||||
from haystack import Document
|
||||
|
||||
# 需要 OpenAIDocumentEmbedder 来嵌入要写入的文档
|
||||
from haystack.components.embedders import OpenAIDocumentEmbedder
|
||||
from haystack.utils import Secret
|
||||
|
||||
# 导入所需的配置和构建函数
|
||||
from config import (
|
||||
DEFAULT_USER_ID,
|
||||
OPENAI_API_KEY_FROM_CONFIG,
|
||||
OPENAI_API_BASE_URL_CONFIG,
|
||||
OPENAI_EMBEDDING_MODEL,
|
||||
OPENAI_EMBEDDING_KEY,
|
||||
OPENAI_EMBEDDING_BASE,
|
||||
)
|
||||
from .rag_pipeline import build_rag_pipeline # 构建 RAG 查询管道
|
||||
|
||||
|
||||
# 辅助函数:初始化 Document Embedder (与 embedding.py 中的类似)
|
||||
def initialize_document_embedder() -> OpenAIDocumentEmbedder:
|
||||
"""初始化用于嵌入文档的 OpenAIDocumentEmbedder。"""
|
||||
if not OPENAI_API_KEY_FROM_CONFIG or "YOUR_API_KEY" in OPENAI_API_KEY_FROM_CONFIG:
|
||||
print("警告: OpenAI API Key 未在 config.py 中有效配置。")
|
||||
# raise ValueError("OpenAI API Key not configured correctly in config.py")
|
||||
|
||||
print(f"Initializing OpenAI Document Embedder with model: {OPENAI_EMBEDDING_MODEL}")
|
||||
if OPENAI_API_BASE_URL_CONFIG:
|
||||
print(f"Using custom API base URL from config: {OPENAI_API_BASE_URL_CONFIG}")
|
||||
else:
|
||||
print("Using default OpenAI API base URL (None specified in config).")
|
||||
|
||||
document_embedder = OpenAIDocumentEmbedder(
|
||||
api_key=Secret.from_token(OPENAI_EMBEDDING_KEY),
|
||||
model=OPENAI_EMBEDDING_MODEL,
|
||||
api_base_url=OPENAI_EMBEDDING_BASE,
|
||||
# meta_fields_to_embed=["name"] # 如果需要嵌入元数据字段
|
||||
# embedding_batch_size=10 # 可以调整批处理大小
|
||||
)
|
||||
print("OpenAI Document Embedder initialized.")
|
||||
return document_embedder
|
||||
|
||||
|
||||
def run_chat_session(user_id: str):
|
||||
"""
|
||||
运行 RAG 聊天会话主循环。
|
||||
每次用户输入时,先将其嵌入并添加到 Milvus,然后运行 RAG 管道生成回复。
|
||||
"""
|
||||
print(f"--- Starting Chat Session for User: {user_id} ---")
|
||||
|
||||
# 构建 RAG 查询管道和获取 DocumentStore 实例
|
||||
rag_query_pipeline, document_store = build_rag_pipeline(user_id)
|
||||
|
||||
# 初始化用于写入用户输入的 Document Embedder
|
||||
document_embedder = initialize_document_embedder()
|
||||
|
||||
print("\nChatbot is ready! Type your questions or 'exit' to quit.")
|
||||
# 打印使用的模型信息
|
||||
try:
|
||||
pass
|
||||
# print(f"Using LLM: {rag_query_pipeline.get_component('generator').model}")
|
||||
# 注意 RAG pipeline 中 query embedder 的名字是 'text_embedder'
|
||||
# print(f"Using Query Embedder: {rag_query_pipeline.get_component('text_embedder').model}")
|
||||
# print(f"Using Document Embedder (for writing): {document_embedder.model}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not get component model names - {e}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input(f"[{user_id}] You: ")
|
||||
if query.lower() == "exit":
|
||||
print("Exiting chat session. Goodbye!")
|
||||
break
|
||||
if not query.strip():
|
||||
continue
|
||||
|
||||
# --- 步骤 1: 嵌入用户输入并写入 Milvus ---
|
||||
# print(f"[Workflow] Embedding user input as a document...")
|
||||
# 将用户输入包装成 Haystack Document
|
||||
user_doc_to_write = Document(content=query, meta={"user_id": user_id})
|
||||
|
||||
# 使用 OpenAIDocumentEmbedder 运行嵌入
|
||||
# 它需要一个列表作为输入,即使只有一个文档
|
||||
embedding_result = document_embedder.run([user_doc_to_write])
|
||||
embedded_docs = embedding_result.get(
|
||||
"documents", []
|
||||
) # 获取带有嵌入的文档列表
|
||||
|
||||
if embedded_docs:
|
||||
# print(f"[Workflow] Writing embedded document to Milvus for user {user_id}...")
|
||||
# 将带有嵌入的文档写入 DocumentStore
|
||||
document_store.write_documents(embedded_docs)
|
||||
# print("[Workflow] Document written to Milvus.")
|
||||
else:
|
||||
print("[Workflow] Warning: Failed to embed document, skipping write.")
|
||||
# 可以在这里添加错误处理或日志记录
|
||||
|
||||
# --- 步骤 2: 使用 RAG 查询管道生成回复 ---
|
||||
# print("[Workflow] Running RAG query pipeline...")
|
||||
# 准备 RAG 管道的输入
|
||||
# text_embedder 需要原始查询文本
|
||||
# prompt_builder 也需要原始查询文本(在模板中用作 {{query}})
|
||||
pipeline_input = {
|
||||
"text_embedder": {"text": query},
|
||||
"prompt_builder": {"query": query},
|
||||
}
|
||||
|
||||
# 运行 RAG 查询管道
|
||||
results = rag_query_pipeline.run(pipeline_input)
|
||||
|
||||
# --- 步骤 3: 处理并打印结果 ---
|
||||
# 根据文档示例,生成器的输出在 'generator' 组件的 'replies' 键中
|
||||
if "llm" in results and results["llm"]["replies"]:
|
||||
answer = results["llm"]["replies"][0]
|
||||
# 尝试获取 token 使用量(可能在 meta 中)
|
||||
total_tokens = "N/A"
|
||||
try:
|
||||
# meta 结构可能因版本或配置而异,需要检查确认
|
||||
if (
|
||||
"meta" in results["llm"]
|
||||
and isinstance(results["llm"]["meta"], list)
|
||||
and results["llm"]["meta"]
|
||||
):
|
||||
usage_info = results["llm"]["meta"][0].get("usage", {})
|
||||
total_tokens = usage_info.get("total_tokens", "N/A")
|
||||
except Exception:
|
||||
pass # 忽略获取 token 的错误
|
||||
|
||||
print(f"Chatbot: {answer} (Tokens: {total_tokens})")
|
||||
else:
|
||||
print("Chatbot: Sorry, I couldn't generate an answer for that.")
|
||||
print("Debug Info (Pipeline Results):", results) # 打印完整结果以供调试
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting chat session. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"\nAn error occurred: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc() # 打印详细的回溯信息
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_user_id = DEFAULT_USER_ID
|
||||
run_chat_session(current_user_id)
|
Reference in New Issue
Block a user