# app.py from fastapi import FastAPI, HTTPException, Depends from pydantic import BaseModel from typing import List, Optional import logging from haystack import Document # Import necessary components from the provided code from .data_handling import initialize_milvus_lite from .main import initialize_document_embedder from .retrieval import initialize_vector_retriever from .embedding import initialize_text_embedder # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI(title="Document Embedding and Retrieval API") # Define request and response models class EmbedRequest(BaseModel): user_id: str content: str meta: Optional[dict] = {} class RetrieveRequest(BaseModel): user_id: str query: str class DocumentResponse(BaseModel): content: str score: Optional[float] = None meta: Optional[dict] = {} class RetrieveResponse(BaseModel): documents: List[DocumentResponse] query: str answer: Optional[str] = None # Helper functions def get_document_embedder(): return initialize_document_embedder() def get_document_store(user_id: str): return initialize_milvus_lite(user_id) @app.post("/embed", response_model=dict) async def embed_document( request: EmbedRequest, embedder=Depends(get_document_embedder) ): """ Embed content and store it in a Milvus collection for the specified user. """ try: # Initialize document store for the user document_store = get_document_store(request.user_id) # Create a document with user content meta = request.meta.copy() meta["user_id"] = request.user_id # Ensure user_id is in meta user_doc = Document(content=request.content, meta=meta) # Embed the document logger.info(f"Embedding document for user {request.user_id}") embedding_result = embedder.run([user_doc]) embedded_docs = embedding_result.get("documents", []) if not embedded_docs: raise HTTPException(status_code=500, detail="Failed to embed document") # Write to document store logger.info(f"Writing embedded document to Milvus for user {request.user_id}") document_store.write_documents(embedded_docs) return { "status": "success", "message": f"Document embedded and stored for user {request.user_id}", } except Exception as e: logger.error(f"Error embedding document: {str(e)}") raise HTTPException( status_code=500, detail=f"Error embedding document: {str(e)}" ) @app.post("/retrieve", response_model=RetrieveResponse) async def retrieve_documents(request: RetrieveRequest): """ Retrieve similar documents for a user based on a query without LLM generation. Only retrieves documents using vector similarity. """ try: # Get document store for the user document_store = get_document_store(request.user_id) # Initialize text embedder for query embedding text_embedder = initialize_text_embedder() # Initialize retriever retriever = initialize_vector_retriever(document_store) # Embed the query logger.info(f"Embedding query for user {request.user_id}: '{request.query}'") embedding_result = text_embedder.run(text=request.query) query_embedding = embedding_result.get("embedding") if not query_embedding: raise HTTPException(status_code=500, detail="Failed to embed query") # Retrieve similar documents logger.info(f"Retrieving documents for query: '{request.query}'") retriever_result = retriever.run(query_embedding=query_embedding) retrieved_docs = retriever_result.get("documents", []) # Convert to response format documents = [] for doc in retrieved_docs: documents.append( DocumentResponse( content=doc.content, score=doc.score if hasattr(doc, "score") else None, meta=doc.meta, ) ) return RetrieveResponse(documents=documents, query=request.query, answer=None) except Exception as e: logger.error(f"Error retrieving documents: {str(e)}") raise HTTPException( status_code=500, detail=f"Error retrieving documents: {str(e)}" ) # Add these imports at the top if not already there from typing import List, Optional, Literal # Define message models for the new endpoint class Message(BaseModel): content: str role: str name: Optional[str] = None class EmbedMessagesRequest(BaseModel): user_id: str messages: List[Message] meta: Optional[dict] = {} @app.post("/embed_messages", response_model=dict) async def embed_messages( request: EmbedMessagesRequest, embedder=Depends(get_document_embedder) ): """ Process a messages array, extract content from user messages, concatenate them with newlines, then embed and store in Milvus. """ try: # Initialize document store for the user document_store = get_document_store(request.user_id) # Filter messages to keep only those with role="user" user_messages = [msg for msg in request.messages if msg.role == "user"] # Extract content from each user message user_contents = [msg.content for msg in user_messages] # Join contents with newline character concatenated_content = "\n".join(user_contents) if not concatenated_content.strip(): return { "status": "warning", "message": "No user messages found or all user messages were empty", } # Create a document with concatenated content meta = request.meta.copy() meta["user_id"] = request.user_id # Ensure user_id is in meta user_doc = Document(content=concatenated_content, meta=meta) # Embed the document logger.info(f"Embedding concatenated user messages for user {request.user_id}") embedding_result = embedder.run([user_doc]) embedded_docs = embedding_result.get("documents", []) if not embedded_docs: raise HTTPException(status_code=500, detail="Failed to embed document") # Write to document store logger.info(f"Writing embedded document to Milvus for user {request.user_id}") document_store.write_documents(embedded_docs) return { "status": "success", "message": f"User messages embedded and stored for user {request.user_id}", "processed_messages_count": len(user_messages), "concatenated_length": len(concatenated_content), } except Exception as e: logger.error(f"Error embedding messages: {str(e)}") raise HTTPException( status_code=500, detail=f"Error embedding messages: {str(e)}" ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7999)