Files
rag_chat/haystack_rag/rag_pipeline.py

196 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# rag_pipeline.py
from haystack import Pipeline
from haystack import Document # 导入 Document
from milvus_haystack import MilvusDocumentStore
from .data_handling import initialize_milvus_lite
from .embedding import initialize_text_embedder
from .retrieval import initialize_vector_retriever
from .llm_integration import initialize_llm_and_prompt_builder
from haystack.utils import Secret
def build_rag_pipeline(user_id: str) -> tuple[Pipeline, MilvusDocumentStore]:
"""
为指定用户构建并返回 RAG 查询管道和对应的 DocumentStore。
"""
print(f"\n--- Building RAG Pipeline for User: {user_id} ---")
# 1. 初始化该用户的 DocumentStore
document_store = initialize_milvus_lite(user_id)
# 2. 初始化共享组件(可以在应用启动时初始化一次,这里为简单起见每次都创建)
text_embedder = initialize_text_embedder()
vector_retriever = initialize_vector_retriever(document_store)
llm, prompt_builder = initialize_llm_and_prompt_builder()
# 3. 创建 Haystack Pipeline
rag_pipeline = Pipeline()
# 4. 向管道添加组件,并指定名称
rag_pipeline.add_component(instance=text_embedder, name="text_embedder")
rag_pipeline.add_component(instance=vector_retriever, name="retriever")
rag_pipeline.add_component(instance=prompt_builder, name="prompt_builder")
rag_pipeline.add_component(instance=llm, name="llm")
# 5. 连接管道组件
# - 将用户问题文本输入到 text_embedder
# - 将 text_embedder 输出的嵌入向量连接到 retriever 的查询嵌入输入
# - 将 retriever 输出的文档连接到 prompt_builder 的文档输入
# - 将用户问题文本也连接到 prompt_builder 的问题输入
# - 将 prompt_builder 输出的完整提示连接到 llm 的提示输入
rag_pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
rag_pipeline.connect("retriever.documents", "prompt_builder.documents")
rag_pipeline.connect("prompt_builder.prompt", "llm.prompt")
print("--- RAG Pipeline Built Successfully ---")
# 返回管道和文档存储实例,因为主程序需要用文档存储来写入数据
return rag_pipeline, document_store
# --- Corrected Test Block ---
if __name__ == "__main__":
import os # Needed for API Key check
# We need OpenAIDocumentEmbedder to index test documents
from haystack.components.embedders import OpenAIDocumentEmbedder
# Import necessary config for initializing the Document Embedder
from config import (
OPENAI_API_KEY_FROM_CONFIG,
OPENAI_API_BASE_URL_CONFIG,
OPENAI_EMBEDDING_MODEL,
)
# --- Configuration ---
test_user = "test_user"
test_query = "Haystack是什么"
# Sample documents to index for testing
docs_to_index = [
Document(
content="Haystack是一个用于构建 NLP 应用程序(如问答系统、语义搜索)的开源框架。",
meta={"user_id": test_user, "source": "test_doc_1"},
),
Document(
content="你可以使用 Haystack 连接不同的组件,如文档存储、检索器和生成器。",
meta={"user_id": test_user, "source": "test_doc_2"},
),
Document(
content="Milvus 是一个流行的向量数据库,常用于 RAG 系统中存储嵌入。",
meta={"user_id": test_user, "source": "test_doc_3"},
),
]
print(f"--- Running Test for RAG Pipeline (User: {test_user}) ---")
# --- 1. Check API Key Availability ---
# Pipeline execution requires OpenAI API calls
api_key_configured = (
OPENAI_API_KEY_FROM_CONFIG and "YOUR_API_KEY" not in OPENAI_API_KEY_FROM_CONFIG
)
if not api_key_configured:
print("\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("! WARNING: OpenAI API Key not configured in config.py. !")
print("! Skipping RAG pipeline test execution. !")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
exit() # Exit script if key is missing for test run
else:
print("\n[Test Setup] OpenAI API Key found in config.")
# --- 2. Build the RAG Pipeline and get the Document Store ---
# This function initializes the store (potentially dropping old data)
# and builds the *querying* pipeline.
try:
pipeline, store = build_rag_pipeline(test_user)
except Exception as e:
print(f"\nError building RAG pipeline: {e}")
import traceback
traceback.print_exc()
exit()
# --- 3. Index Test Documents (with embeddings) ---
print("\n[Test Setup] Initializing Document Embedder for indexing test data...")
try:
# Initialize the Document Embedder directly here for the test
document_embedder = OpenAIDocumentEmbedder(
api_key=Secret.from_token(OPENAI_API_KEY_FROM_CONFIG),
model=OPENAI_EMBEDDING_MODEL,
api_base_url=OPENAI_API_BASE_URL_CONFIG,
)
print("[Test Setup] Document Embedder initialized.")
print("[Test Setup] Embedding test documents...")
embedding_result = document_embedder.run(docs_to_index)
embedded_docs = embedding_result.get("documents", [])
if embedded_docs:
print(
f"[Test Setup] Writing {len(embedded_docs)} embedded documents to Milvus..."
)
store.write_documents(embedded_docs)
print("[Test Setup] Test documents written successfully.")
# Optional: Verify count
# print(f"[Test Setup] Document count in store: {store.count_documents()}")
documents_indexed = True
else:
print("[Test Setup] ERROR: Failed to embed test documents.")
documents_indexed = False
except Exception as e:
print(f"\nError during test data indexing: {e}")
import traceback
traceback.print_exc()
documents_indexed = False
# --- 4. Run the RAG Pipeline (if setup succeeded) ---
if documents_indexed:
print(f"\n[Test Run] Running RAG pipeline for query: '{test_query}'")
# Prepare input for the RAG pipeline instance built by build_rag_pipeline
pipeline_input = {
"text_embedder": {"text": test_query}, # Input for the query embedder
"prompt_builder": {
"query": test_query
}, # Input for the prompt builder template
}
try:
results = pipeline.run(pipeline_input)
print("\n[Test Run] Pipeline Results:")
# Process and print the generator's answer
if "llm" in results and results["llm"]["replies"]:
answer = results["llm"]["replies"][0]
print(f"\nGenerated Answer: {answer}")
else:
print("\n[Test Run] Could not extract answer from generator.")
print(
"Full Pipeline Output:", results
) # Print full output for debugging
except Exception as e:
print(f"\n[Test Run] Error running RAG pipeline: {e}")
import traceback
traceback.print_exc()
else:
print("\n[Test Run] Skipping RAG pipeline execution due to indexing failure.")
# --- 5. Cleanup Note ---
# Optional: Add instructions or commented-out code for cleaning up the test Milvus data
print(
f"\n[Test Cleanup] Test finished. Consider manually removing data in: ./milvus_user_data_openai_fixed/{test_user}"
)
# import shutil
# from pathlib import Path
# from config import MILVUS_PERSIST_BASE_DIR
# test_db_path = MILVUS_PERSIST_BASE_DIR / test_user
# if test_db_path.exists():
# print(f"\nAttempting to clean up test data at {test_db_path}...")
# # shutil.rmtree(test_db_path) # Use with caution
print("\n--- RAG Pipeline Test Complete ---")