add embed_messages api for only embed user messages
This commit is contained in:
75
api.py
75
api.py
@ -142,6 +142,81 @@ async def retrieve_documents(request: RetrieveRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Add these imports at the top if not already there
|
||||||
|
from typing import List, Optional, Literal
|
||||||
|
|
||||||
|
|
||||||
|
# Define message models for the new endpoint
|
||||||
|
class Message(BaseModel):
|
||||||
|
content: str
|
||||||
|
role: str
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedMessagesRequest(BaseModel):
|
||||||
|
user_id: str
|
||||||
|
messages: List[Message]
|
||||||
|
meta: Optional[dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/embed_messages", response_model=dict)
|
||||||
|
async def embed_messages(
|
||||||
|
request: EmbedMessagesRequest, embedder=Depends(get_document_embedder)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Process a messages array, extract content from user messages,
|
||||||
|
concatenate them with newlines, then embed and store in Milvus.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Initialize document store for the user
|
||||||
|
document_store = get_document_store(request.user_id)
|
||||||
|
|
||||||
|
# Filter messages to keep only those with role="user"
|
||||||
|
user_messages = [msg for msg in request.messages if msg.role == "user"]
|
||||||
|
|
||||||
|
# Extract content from each user message
|
||||||
|
user_contents = [msg.content for msg in user_messages]
|
||||||
|
|
||||||
|
# Join contents with newline character
|
||||||
|
concatenated_content = "\n".join(user_contents)
|
||||||
|
|
||||||
|
if not concatenated_content.strip():
|
||||||
|
return {
|
||||||
|
"status": "warning",
|
||||||
|
"message": "No user messages found or all user messages were empty",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a document with concatenated content
|
||||||
|
meta = request.meta.copy()
|
||||||
|
meta["user_id"] = request.user_id # Ensure user_id is in meta
|
||||||
|
user_doc = Document(content=concatenated_content, meta=meta)
|
||||||
|
|
||||||
|
# Embed the document
|
||||||
|
logger.info(f"Embedding concatenated user messages for user {request.user_id}")
|
||||||
|
embedding_result = embedder.run([user_doc])
|
||||||
|
embedded_docs = embedding_result.get("documents", [])
|
||||||
|
|
||||||
|
if not embedded_docs:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to embed document")
|
||||||
|
|
||||||
|
# Write to document store
|
||||||
|
logger.info(f"Writing embedded document to Milvus for user {request.user_id}")
|
||||||
|
document_store.write_documents(embedded_docs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"User messages embedded and stored for user {request.user_id}",
|
||||||
|
"processed_messages_count": len(user_messages),
|
||||||
|
"concatenated_length": len(concatenated_content),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error embedding messages: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Error embedding messages: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user