feat(docker): containerize application and add TTS integration

This commit is contained in:
gameloader
2025-10-07 22:52:04 +08:00
parent 45470fd13d
commit 315dbfed90
19 changed files with 878 additions and 80 deletions

147
haystack_rag/main.py Normal file
View File

@ -0,0 +1,147 @@
# main.py
import sys
from haystack import Document
# 需要 OpenAIDocumentEmbedder 来嵌入要写入的文档
from haystack.components.embedders import OpenAIDocumentEmbedder
from haystack.utils import Secret
# 导入所需的配置和构建函数
from config import (
DEFAULT_USER_ID,
OPENAI_API_KEY_FROM_CONFIG,
OPENAI_API_BASE_URL_CONFIG,
OPENAI_EMBEDDING_MODEL,
OPENAI_EMBEDDING_KEY,
OPENAI_EMBEDDING_BASE,
)
from .rag_pipeline import build_rag_pipeline # 构建 RAG 查询管道
# 辅助函数:初始化 Document Embedder (与 embedding.py 中的类似)
def initialize_document_embedder() -> OpenAIDocumentEmbedder:
"""初始化用于嵌入文档的 OpenAIDocumentEmbedder。"""
if not OPENAI_API_KEY_FROM_CONFIG or "YOUR_API_KEY" in OPENAI_API_KEY_FROM_CONFIG:
print("警告: OpenAI API Key 未在 config.py 中有效配置。")
# raise ValueError("OpenAI API Key not configured correctly in config.py")
print(f"Initializing OpenAI Document Embedder with model: {OPENAI_EMBEDDING_MODEL}")
if OPENAI_API_BASE_URL_CONFIG:
print(f"Using custom API base URL from config: {OPENAI_API_BASE_URL_CONFIG}")
else:
print("Using default OpenAI API base URL (None specified in config).")
document_embedder = OpenAIDocumentEmbedder(
api_key=Secret.from_token(OPENAI_EMBEDDING_KEY),
model=OPENAI_EMBEDDING_MODEL,
api_base_url=OPENAI_EMBEDDING_BASE,
# meta_fields_to_embed=["name"] # 如果需要嵌入元数据字段
# embedding_batch_size=10 # 可以调整批处理大小
)
print("OpenAI Document Embedder initialized.")
return document_embedder
def run_chat_session(user_id: str):
"""
运行 RAG 聊天会话主循环。
每次用户输入时,先将其嵌入并添加到 Milvus然后运行 RAG 管道生成回复。
"""
print(f"--- Starting Chat Session for User: {user_id} ---")
# 构建 RAG 查询管道和获取 DocumentStore 实例
rag_query_pipeline, document_store = build_rag_pipeline(user_id)
# 初始化用于写入用户输入的 Document Embedder
document_embedder = initialize_document_embedder()
print("\nChatbot is ready! Type your questions or 'exit' to quit.")
# 打印使用的模型信息
try:
pass
# print(f"Using LLM: {rag_query_pipeline.get_component('generator').model}")
# 注意 RAG pipeline 中 query embedder 的名字是 'text_embedder'
# print(f"Using Query Embedder: {rag_query_pipeline.get_component('text_embedder').model}")
# print(f"Using Document Embedder (for writing): {document_embedder.model}")
except Exception as e:
print(f"Warning: Could not get component model names - {e}")
while True:
try:
query = input(f"[{user_id}] You: ")
if query.lower() == "exit":
print("Exiting chat session. Goodbye!")
break
if not query.strip():
continue
# --- 步骤 1: 嵌入用户输入并写入 Milvus ---
# print(f"[Workflow] Embedding user input as a document...")
# 将用户输入包装成 Haystack Document
user_doc_to_write = Document(content=query, meta={"user_id": user_id})
# 使用 OpenAIDocumentEmbedder 运行嵌入
# 它需要一个列表作为输入,即使只有一个文档
embedding_result = document_embedder.run([user_doc_to_write])
embedded_docs = embedding_result.get(
"documents", []
) # 获取带有嵌入的文档列表
if embedded_docs:
# print(f"[Workflow] Writing embedded document to Milvus for user {user_id}...")
# 将带有嵌入的文档写入 DocumentStore
document_store.write_documents(embedded_docs)
# print("[Workflow] Document written to Milvus.")
else:
print("[Workflow] Warning: Failed to embed document, skipping write.")
# 可以在这里添加错误处理或日志记录
# --- 步骤 2: 使用 RAG 查询管道生成回复 ---
# print("[Workflow] Running RAG query pipeline...")
# 准备 RAG 管道的输入
# text_embedder 需要原始查询文本
# prompt_builder 也需要原始查询文本(在模板中用作 {{query}}
pipeline_input = {
"text_embedder": {"text": query},
"prompt_builder": {"query": query},
}
# 运行 RAG 查询管道
results = rag_query_pipeline.run(pipeline_input)
# --- 步骤 3: 处理并打印结果 ---
# 根据文档示例,生成器的输出在 'generator' 组件的 'replies' 键中
if "llm" in results and results["llm"]["replies"]:
answer = results["llm"]["replies"][0]
# 尝试获取 token 使用量(可能在 meta 中)
total_tokens = "N/A"
try:
# meta 结构可能因版本或配置而异,需要检查确认
if (
"meta" in results["llm"]
and isinstance(results["llm"]["meta"], list)
and results["llm"]["meta"]
):
usage_info = results["llm"]["meta"][0].get("usage", {})
total_tokens = usage_info.get("total_tokens", "N/A")
except Exception:
pass # 忽略获取 token 的错误
print(f"Chatbot: {answer} (Tokens: {total_tokens})")
else:
print("Chatbot: Sorry, I couldn't generate an answer for that.")
print("Debug Info (Pipeline Results):", results) # 打印完整结果以供调试
except KeyboardInterrupt:
print("\nExiting chat session. Goodbye!")
break
except Exception as e:
print(f"\nAn error occurred: {e}")
import traceback
traceback.print_exc() # 打印详细的回溯信息
if __name__ == "__main__":
current_user_id = DEFAULT_USER_ID
run_chat_session(current_user_id)