feat(docker): containerize application and add TTS integration
This commit is contained in:
223
haystack_rag/api.py
Normal file
223
haystack_rag/api.py
Normal file
@ -0,0 +1,223 @@
|
||||
# app.py
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
from haystack import Document
|
||||
|
||||
# Import necessary components from the provided code
|
||||
from .data_handling import initialize_milvus_lite
|
||||
from .main import initialize_document_embedder
|
||||
from .retrieval import initialize_vector_retriever
|
||||
from .embedding import initialize_text_embedder
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(title="Document Embedding and Retrieval API")
|
||||
|
||||
|
||||
# Define request and response models
|
||||
class EmbedRequest(BaseModel):
|
||||
user_id: str
|
||||
content: str
|
||||
meta: Optional[dict] = {}
|
||||
|
||||
|
||||
class RetrieveRequest(BaseModel):
|
||||
user_id: str
|
||||
query: str
|
||||
|
||||
|
||||
class DocumentResponse(BaseModel):
|
||||
content: str
|
||||
score: Optional[float] = None
|
||||
meta: Optional[dict] = {}
|
||||
|
||||
|
||||
class RetrieveResponse(BaseModel):
|
||||
documents: List[DocumentResponse]
|
||||
query: str
|
||||
answer: Optional[str] = None
|
||||
|
||||
|
||||
# Helper functions
|
||||
def get_document_embedder():
|
||||
return initialize_document_embedder()
|
||||
|
||||
|
||||
def get_document_store(user_id: str):
|
||||
return initialize_milvus_lite(user_id)
|
||||
|
||||
|
||||
@app.post("/embed", response_model=dict)
|
||||
async def embed_document(
|
||||
request: EmbedRequest, embedder=Depends(get_document_embedder)
|
||||
):
|
||||
"""
|
||||
Embed content and store it in a Milvus collection for the specified user.
|
||||
"""
|
||||
try:
|
||||
# Initialize document store for the user
|
||||
document_store = get_document_store(request.user_id)
|
||||
|
||||
# Create a document with user content
|
||||
meta = request.meta.copy()
|
||||
meta["user_id"] = request.user_id # Ensure user_id is in meta
|
||||
user_doc = Document(content=request.content, meta=meta)
|
||||
|
||||
# Embed the document
|
||||
logger.info(f"Embedding document for user {request.user_id}")
|
||||
embedding_result = embedder.run([user_doc])
|
||||
embedded_docs = embedding_result.get("documents", [])
|
||||
|
||||
if not embedded_docs:
|
||||
raise HTTPException(status_code=500, detail="Failed to embed document")
|
||||
|
||||
# Write to document store
|
||||
logger.info(f"Writing embedded document to Milvus for user {request.user_id}")
|
||||
document_store.write_documents(embedded_docs)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Document embedded and stored for user {request.user_id}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error embedding document: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error embedding document: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/retrieve", response_model=RetrieveResponse)
|
||||
async def retrieve_documents(request: RetrieveRequest):
|
||||
"""
|
||||
Retrieve similar documents for a user based on a query without LLM generation.
|
||||
Only retrieves documents using vector similarity.
|
||||
"""
|
||||
try:
|
||||
# Get document store for the user
|
||||
document_store = get_document_store(request.user_id)
|
||||
|
||||
# Initialize text embedder for query embedding
|
||||
text_embedder = initialize_text_embedder()
|
||||
|
||||
# Initialize retriever
|
||||
retriever = initialize_vector_retriever(document_store)
|
||||
|
||||
# Embed the query
|
||||
logger.info(f"Embedding query for user {request.user_id}: '{request.query}'")
|
||||
embedding_result = text_embedder.run(text=request.query)
|
||||
query_embedding = embedding_result.get("embedding")
|
||||
|
||||
if not query_embedding:
|
||||
raise HTTPException(status_code=500, detail="Failed to embed query")
|
||||
|
||||
# Retrieve similar documents
|
||||
logger.info(f"Retrieving documents for query: '{request.query}'")
|
||||
retriever_result = retriever.run(query_embedding=query_embedding)
|
||||
retrieved_docs = retriever_result.get("documents", [])
|
||||
|
||||
# Convert to response format
|
||||
documents = []
|
||||
for doc in retrieved_docs:
|
||||
documents.append(
|
||||
DocumentResponse(
|
||||
content=doc.content,
|
||||
score=doc.score if hasattr(doc, "score") else None,
|
||||
meta=doc.meta,
|
||||
)
|
||||
)
|
||||
|
||||
return RetrieveResponse(documents=documents, query=request.query, answer=None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving documents: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error retrieving documents: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# Add these imports at the top if not already there
|
||||
from typing import List, Optional, Literal
|
||||
|
||||
|
||||
# Define message models for the new endpoint
|
||||
class Message(BaseModel):
|
||||
content: str
|
||||
role: str
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class EmbedMessagesRequest(BaseModel):
|
||||
user_id: str
|
||||
messages: List[Message]
|
||||
meta: Optional[dict] = {}
|
||||
|
||||
|
||||
@app.post("/embed_messages", response_model=dict)
|
||||
async def embed_messages(
|
||||
request: EmbedMessagesRequest, embedder=Depends(get_document_embedder)
|
||||
):
|
||||
"""
|
||||
Process a messages array, extract content from user messages,
|
||||
concatenate them with newlines, then embed and store in Milvus.
|
||||
"""
|
||||
try:
|
||||
# Initialize document store for the user
|
||||
document_store = get_document_store(request.user_id)
|
||||
|
||||
# Filter messages to keep only those with role="user"
|
||||
user_messages = [msg for msg in request.messages if msg.role == "user"]
|
||||
|
||||
# Extract content from each user message
|
||||
user_contents = [msg.content for msg in user_messages]
|
||||
|
||||
# Join contents with newline character
|
||||
concatenated_content = "\n".join(user_contents)
|
||||
|
||||
if not concatenated_content.strip():
|
||||
return {
|
||||
"status": "warning",
|
||||
"message": "No user messages found or all user messages were empty",
|
||||
}
|
||||
|
||||
# Create a document with concatenated content
|
||||
meta = request.meta.copy()
|
||||
meta["user_id"] = request.user_id # Ensure user_id is in meta
|
||||
user_doc = Document(content=concatenated_content, meta=meta)
|
||||
|
||||
# Embed the document
|
||||
logger.info(f"Embedding concatenated user messages for user {request.user_id}")
|
||||
embedding_result = embedder.run([user_doc])
|
||||
embedded_docs = embedding_result.get("documents", [])
|
||||
|
||||
if not embedded_docs:
|
||||
raise HTTPException(status_code=500, detail="Failed to embed document")
|
||||
|
||||
# Write to document store
|
||||
logger.info(f"Writing embedded document to Milvus for user {request.user_id}")
|
||||
document_store.write_documents(embedded_docs)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"User messages embedded and stored for user {request.user_id}",
|
||||
"processed_messages_count": len(user_messages),
|
||||
"concatenated_length": len(concatenated_content),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error embedding messages: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error embedding messages: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7999)
|
Reference in New Issue
Block a user