初始化
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from .app import app
|
||||
|
||||
__all__ = ["app"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,56 @@
|
||||
import asyncio
|
||||
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.messages import MultiModalMessage, TextMessage
|
||||
from autogen_core import Image
|
||||
from autogen_core.models import ModelFamily
|
||||
from autogen_ext.models.ollama import OllamaChatCompletionClient
|
||||
|
||||
from . import config
|
||||
from .mcp_tools import load_mcp_tools
|
||||
|
||||
|
||||
class AvatarAgentService:
|
||||
def __init__(self) -> None:
|
||||
self._model_client = OllamaChatCompletionClient(
|
||||
model=config.OLLAMA_MODEL,
|
||||
model_info={
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
)
|
||||
self._agent: AssistantAgent | None = None
|
||||
self._agent_lock = asyncio.Lock()
|
||||
|
||||
async def _create_agent(self) -> AssistantAgent:
|
||||
tools = await load_mcp_tools()
|
||||
return AssistantAgent(
|
||||
name="avatar",
|
||||
model_client=self._model_client,
|
||||
system_message=config.SYSTEM_MESSAGE,
|
||||
tools=tools or None,
|
||||
reflect_on_tool_use=bool(tools),
|
||||
)
|
||||
|
||||
async def _get_agent(self) -> AssistantAgent:
|
||||
if self._agent is not None:
|
||||
return self._agent
|
||||
async with self._agent_lock:
|
||||
if self._agent is None:
|
||||
self._agent = await self._create_agent()
|
||||
return self._agent
|
||||
|
||||
async def reply(self, user_text: str, image_b64: str) -> str:
|
||||
agent = await self._get_agent()
|
||||
user_image = Image.from_base64(image_b64)
|
||||
multimodal_task = MultiModalMessage(source="user", content=[user_text, user_image])
|
||||
|
||||
ai_response = ""
|
||||
async for message in agent.run_stream(task=multimodal_task):
|
||||
if isinstance(message, TextMessage) and message.source == "avatar":
|
||||
ai_response = message.content
|
||||
|
||||
return ai_response
|
||||
@@ -0,0 +1,6 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .ws import router as ws_router
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(ws_router)
|
||||
@@ -0,0 +1,44 @@
|
||||
import os
|
||||
import shlex
|
||||
|
||||
|
||||
def _env_bool(name: str, default: bool) -> bool:
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
return default
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _env_args(name: str, default: str = "") -> list[str]:
|
||||
value = os.getenv(name, default)
|
||||
if not value.strip():
|
||||
return []
|
||||
return shlex.split(value)
|
||||
|
||||
|
||||
SYSTEM_MESSAGE = (
|
||||
"你是一个友好、幽默的AI虚拟主播。你可以看到用户摄像头传来的画面,也能听到他们的话。"
|
||||
"请用简短、自然、热情的中文口语回答,每次回答控制在两三句话以内,不要输出任何 Markdown 格式。"
|
||||
"当用户询问实时天气、最新新闻或网页信息时,优先使用可用工具先查询再回答。"
|
||||
)
|
||||
|
||||
WHISPER_MODEL_NAME = "base"
|
||||
WHISPER_DEVICE = "cpu"
|
||||
WHISPER_COMPUTE_TYPE = "int8"
|
||||
WHISPER_LANGUAGE = "zh"
|
||||
WHISPER_BEAM_SIZE = 5
|
||||
|
||||
TTS_VOICE = "zh-CN-XiaoxiaoNeural"
|
||||
OLLAMA_MODEL = "qwen3-vl:latest"
|
||||
|
||||
SERVER_HOST = "0.0.0.0"
|
||||
SERVER_PORT = 8000
|
||||
|
||||
ENABLE_MCP_TOOLS = _env_bool("ENABLE_MCP_TOOLS", True)
|
||||
MCP_SERVER_READ_TIMEOUT_SECONDS = float(os.getenv("MCP_SERVER_READ_TIMEOUT_SECONDS", "30"))
|
||||
|
||||
MCP_WEATHER_SERVER_COMMAND = os.getenv("MCP_WEATHER_SERVER_COMMAND", "")
|
||||
MCP_WEATHER_SERVER_ARGS = _env_args("MCP_WEATHER_SERVER_ARGS")
|
||||
|
||||
MCP_WEBSEARCH_SERVER_COMMAND = os.getenv("MCP_WEBSEARCH_SERVER_COMMAND", "")
|
||||
MCP_WEBSEARCH_SERVER_ARGS = _env_args("MCP_WEBSEARCH_SERVER_ARGS")
|
||||
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from autogen_ext.tools.mcp import StdioServerParams, mcp_server_tools
|
||||
|
||||
from . import config
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MCPServerConfig:
|
||||
name: str
|
||||
command: str
|
||||
args: list[str]
|
||||
|
||||
|
||||
def _configured_servers() -> list[MCPServerConfig]:
|
||||
if not config.ENABLE_MCP_TOOLS:
|
||||
return []
|
||||
|
||||
servers: list[MCPServerConfig] = []
|
||||
if config.MCP_WEATHER_SERVER_COMMAND:
|
||||
servers.append(
|
||||
MCPServerConfig(
|
||||
name="weather",
|
||||
command=config.MCP_WEATHER_SERVER_COMMAND,
|
||||
args=config.MCP_WEATHER_SERVER_ARGS,
|
||||
)
|
||||
)
|
||||
if config.MCP_WEBSEARCH_SERVER_COMMAND:
|
||||
servers.append(
|
||||
MCPServerConfig(
|
||||
name="websearch",
|
||||
command=config.MCP_WEBSEARCH_SERVER_COMMAND,
|
||||
args=config.MCP_WEBSEARCH_SERVER_ARGS,
|
||||
)
|
||||
)
|
||||
return servers
|
||||
|
||||
|
||||
async def load_mcp_tools() -> list[Any]:
|
||||
configured_servers = _configured_servers()
|
||||
if not configured_servers:
|
||||
print("ℹ️ MCP 工具未配置,跳过加载。")
|
||||
return []
|
||||
|
||||
loaded_tools: list[Any] = []
|
||||
tool_names: set[str] = set()
|
||||
|
||||
for server in configured_servers:
|
||||
params = StdioServerParams(
|
||||
command=server.command,
|
||||
args=server.args,
|
||||
read_timeout_seconds=config.MCP_SERVER_READ_TIMEOUT_SECONDS,
|
||||
)
|
||||
try:
|
||||
server_tools = await mcp_server_tools(params)
|
||||
for tool in server_tools:
|
||||
if tool.name in tool_names:
|
||||
print(f"⚠️ MCP 工具重名,已跳过: {tool.name}")
|
||||
continue
|
||||
loaded_tools.append(tool)
|
||||
tool_names.add(tool.name)
|
||||
print(f"✅ MCP 服务已加载: {server.name} ({len(server_tools)} tools)")
|
||||
except Exception as exc:
|
||||
print(f"⚠️ MCP 服务加载失败: {server.name}, error={exc}")
|
||||
|
||||
if loaded_tools:
|
||||
print(f"✅ MCP 工具总数: {len(loaded_tools)}")
|
||||
else:
|
||||
print("ℹ️ 未加载到任何 MCP 工具。")
|
||||
return loaded_tools
|
||||
@@ -0,0 +1,35 @@
|
||||
import base64
|
||||
|
||||
import edge_tts
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
from . import config
|
||||
|
||||
|
||||
class SpeechService:
|
||||
def __init__(self) -> None:
|
||||
print("⏳ 正在加载本地语音识别模型 (首次启动可能需要下载)...")
|
||||
self._whisper_model = WhisperModel(
|
||||
config.WHISPER_MODEL_NAME,
|
||||
device=config.WHISPER_DEVICE,
|
||||
compute_type=config.WHISPER_COMPUTE_TYPE,
|
||||
)
|
||||
print("✅ 本地语音模型加载完毕!")
|
||||
|
||||
def transcribe(self, audio_path: str) -> str:
|
||||
segments, _ = self._whisper_model.transcribe(
|
||||
audio_path,
|
||||
beam_size=config.WHISPER_BEAM_SIZE,
|
||||
language=config.WHISPER_LANGUAGE,
|
||||
)
|
||||
return "".join(segment.text for segment in segments)
|
||||
|
||||
async def synthesize_audio_data_url(self, text: str) -> str:
|
||||
communicate = edge_tts.Communicate(text, config.TTS_VOICE)
|
||||
audio_data = b""
|
||||
async for chunk in communicate.stream():
|
||||
if chunk["type"] == "audio":
|
||||
audio_data += chunk["data"]
|
||||
|
||||
audio_b64 = base64.b64encode(audio_data).decode("utf-8")
|
||||
return f"data:audio/mp3;base64,{audio_b64}"
|
||||
@@ -0,0 +1,67 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from .agent_service import AvatarAgentService
|
||||
from .speech import SpeechService
|
||||
from .ws_messages import send_audio_message, send_text_message
|
||||
|
||||
router = APIRouter()
|
||||
speech_service = SpeechService()
|
||||
agent_service = AvatarAgentService()
|
||||
|
||||
|
||||
def _save_audio_to_temp_file(audio_b64: str) -> str:
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".webm")
|
||||
try:
|
||||
temp_file.write(base64.b64decode(audio_b64))
|
||||
return temp_file.name
|
||||
finally:
|
||||
temp_file.close()
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
print("✅ WebSocket 连接成功!准备就绪。")
|
||||
|
||||
try:
|
||||
while True:
|
||||
message_text = await websocket.receive_text()
|
||||
data = json.loads(message_text)
|
||||
|
||||
if data.get("type") != "user_input":
|
||||
continue
|
||||
|
||||
audio_b64 = data["audio"].split(",")[-1]
|
||||
image_b64 = data["image"].split(",")[-1]
|
||||
|
||||
audio_path = _save_audio_to_temp_file(audio_b64)
|
||||
try:
|
||||
await send_text_message(websocket, "<i>[👂 正在辨识语音...]</i><br>")
|
||||
user_text = speech_service.transcribe(audio_path)
|
||||
finally:
|
||||
if os.path.exists(audio_path):
|
||||
os.remove(audio_path)
|
||||
|
||||
if not user_text.strip():
|
||||
await send_text_message(websocket, "<i>[没听清你说什么...]</i><br>")
|
||||
continue
|
||||
|
||||
await send_text_message(websocket, f"<b>你说:</b>{user_text}<br>")
|
||||
await send_text_message(websocket, "<i>[🧠 正在看图思考...]</i><br>")
|
||||
|
||||
ai_response = await agent_service.reply(user_text, image_b64)
|
||||
await send_text_message(websocket, f"<b>AI主播:</b>{ai_response}<br><br>")
|
||||
|
||||
await send_text_message(websocket, "<i>[🗣️ 正在生成语音...]</i><br>")
|
||||
audio_data_url = await speech_service.synthesize_audio_data_url(ai_response)
|
||||
await send_audio_message(websocket, audio_data_url)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
print("❌ 前端页面已关闭或断开连接")
|
||||
except Exception as exc:
|
||||
print(f"⚠️ 发生错误: {exc}")
|
||||
@@ -0,0 +1,18 @@
|
||||
import json
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
async def send_text_message(websocket: WebSocket, content: str) -> None:
|
||||
await websocket.send_text(json.dumps({"type": "text", "content": content}))
|
||||
|
||||
|
||||
async def send_audio_message(websocket: WebSocket, audio_data_url: str) -> None:
|
||||
await websocket.send_text(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "audio",
|
||||
"content": audio_data_url,
|
||||
}
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user