diff --git a/api.py b/api.py index c3898b9..8ffa92d 100644 --- a/api.py +++ b/api.py @@ -142,6 +142,81 @@ async def retrieve_documents(request: RetrieveRequest): ) +# Add these imports at the top if not already there +from typing import List, Optional, Literal + + +# Define message models for the new endpoint +class Message(BaseModel): + content: str + role: str + name: Optional[str] = None + + +class EmbedMessagesRequest(BaseModel): + user_id: str + messages: List[Message] + meta: Optional[dict] = {} + + +@app.post("/embed_messages", response_model=dict) +async def embed_messages( + request: EmbedMessagesRequest, embedder=Depends(get_document_embedder) +): + """ + Process a messages array, extract content from user messages, + concatenate them with newlines, then embed and store in Milvus. + """ + try: + # Initialize document store for the user + document_store = get_document_store(request.user_id) + + # Filter messages to keep only those with role="user" + user_messages = [msg for msg in request.messages if msg.role == "user"] + + # Extract content from each user message + user_contents = [msg.content for msg in user_messages] + + # Join contents with newline character + concatenated_content = "\n".join(user_contents) + + if not concatenated_content.strip(): + return { + "status": "warning", + "message": "No user messages found or all user messages were empty", + } + + # Create a document with concatenated content + meta = request.meta.copy() + meta["user_id"] = request.user_id # Ensure user_id is in meta + user_doc = Document(content=concatenated_content, meta=meta) + + # Embed the document + logger.info(f"Embedding concatenated user messages for user {request.user_id}") + embedding_result = embedder.run([user_doc]) + embedded_docs = embedding_result.get("documents", []) + + if not embedded_docs: + raise HTTPException(status_code=500, detail="Failed to embed document") + + # Write to document store + logger.info(f"Writing embedded document to Milvus for user {request.user_id}") + document_store.write_documents(embedded_docs) + + return { + "status": "success", + "message": f"User messages embedded and stored for user {request.user_id}", + "processed_messages_count": len(user_messages), + "concatenated_length": len(concatenated_content), + } + + except Exception as e: + logger.error(f"Error embedding messages: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Error embedding messages: {str(e)}" + ) + + if __name__ == "__main__": import uvicorn