初始化
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -1,5 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
from autogen_agentchat.agents import AssistantAgent
|
from autogen_agentchat.agents import AssistantAgent
|
||||||
from autogen_agentchat.messages import MultiModalMessage, TextMessage
|
from autogen_agentchat.messages import MultiModalMessage, TextMessage
|
||||||
from autogen_core import Image
|
from autogen_core import Image
|
||||||
@@ -7,12 +5,11 @@ from autogen_core.models import ModelFamily
|
|||||||
from autogen_ext.models.ollama import OllamaChatCompletionClient
|
from autogen_ext.models.ollama import OllamaChatCompletionClient
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
from .mcp_tools import load_mcp_tools
|
|
||||||
|
|
||||||
|
|
||||||
class AvatarAgentService:
|
class AvatarAgentService:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._model_client = OllamaChatCompletionClient(
|
model_client = OllamaChatCompletionClient(
|
||||||
model=config.OLLAMA_MODEL,
|
model=config.OLLAMA_MODEL,
|
||||||
model_info={
|
model_info={
|
||||||
"vision": True,
|
"vision": True,
|
||||||
@@ -22,34 +19,18 @@ class AvatarAgentService:
|
|||||||
"structured_output": True,
|
"structured_output": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self._agent: AssistantAgent | None = None
|
self._agent = AssistantAgent(
|
||||||
self._agent_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
async def _create_agent(self) -> AssistantAgent:
|
|
||||||
tools = await load_mcp_tools()
|
|
||||||
return AssistantAgent(
|
|
||||||
name="avatar",
|
name="avatar",
|
||||||
model_client=self._model_client,
|
model_client=model_client,
|
||||||
system_message=config.SYSTEM_MESSAGE,
|
system_message=config.SYSTEM_MESSAGE,
|
||||||
tools=tools or None,
|
|
||||||
reflect_on_tool_use=bool(tools),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_agent(self) -> AssistantAgent:
|
|
||||||
if self._agent is not None:
|
|
||||||
return self._agent
|
|
||||||
async with self._agent_lock:
|
|
||||||
if self._agent is None:
|
|
||||||
self._agent = await self._create_agent()
|
|
||||||
return self._agent
|
|
||||||
|
|
||||||
async def reply(self, user_text: str, image_b64: str) -> str:
|
async def reply(self, user_text: str, image_b64: str) -> str:
|
||||||
agent = await self._get_agent()
|
|
||||||
user_image = Image.from_base64(image_b64)
|
user_image = Image.from_base64(image_b64)
|
||||||
multimodal_task = MultiModalMessage(source="user", content=[user_text, user_image])
|
multimodal_task = MultiModalMessage(source="user", content=[user_text, user_image])
|
||||||
|
|
||||||
ai_response = ""
|
ai_response = ""
|
||||||
async for message in agent.run_stream(task=multimodal_task):
|
async for message in self._agent.run_stream(task=multimodal_task):
|
||||||
if isinstance(message, TextMessage) and message.source == "avatar":
|
if isinstance(message, TextMessage) and message.source == "avatar":
|
||||||
ai_response = message.content
|
ai_response = message.content
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +1,6 @@
|
|||||||
import os
|
|
||||||
import shlex
|
|
||||||
|
|
||||||
|
|
||||||
def _env_bool(name: str, default: bool) -> bool:
|
|
||||||
value = os.getenv(name)
|
|
||||||
if value is None:
|
|
||||||
return default
|
|
||||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
|
||||||
|
|
||||||
|
|
||||||
def _env_args(name: str, default: str = "") -> list[str]:
|
|
||||||
value = os.getenv(name, default)
|
|
||||||
if not value.strip():
|
|
||||||
return []
|
|
||||||
return shlex.split(value)
|
|
||||||
|
|
||||||
|
|
||||||
SYSTEM_MESSAGE = (
|
SYSTEM_MESSAGE = (
|
||||||
"你是一个友好、幽默的AI虚拟主播。你可以看到用户摄像头传来的画面,也能听到他们的话。"
|
"你是一个友好、幽默的AI虚拟主播。你可以看到用户摄像头传来的画面,也能听到他们的话。"
|
||||||
"请用简短、自然、热情的中文口语回答,每次回答控制在两三句话以内,不要输出任何 Markdown 格式。"
|
"请用简短、自然、热情的中文口语回答,每次回答控制在两三句话以内,不要输出任何 Markdown 格式。"
|
||||||
"当用户询问实时天气、最新新闻或网页信息时,优先使用可用工具先查询再回答。"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
WHISPER_MODEL_NAME = "base"
|
WHISPER_MODEL_NAME = "base"
|
||||||
@@ -33,12 +14,3 @@ OLLAMA_MODEL = "qwen3-vl:latest"
|
|||||||
|
|
||||||
SERVER_HOST = "0.0.0.0"
|
SERVER_HOST = "0.0.0.0"
|
||||||
SERVER_PORT = 8000
|
SERVER_PORT = 8000
|
||||||
|
|
||||||
ENABLE_MCP_TOOLS = _env_bool("ENABLE_MCP_TOOLS", True)
|
|
||||||
MCP_SERVER_READ_TIMEOUT_SECONDS = float(os.getenv("MCP_SERVER_READ_TIMEOUT_SECONDS", "30"))
|
|
||||||
|
|
||||||
MCP_WEATHER_SERVER_COMMAND = os.getenv("MCP_WEATHER_SERVER_COMMAND", "")
|
|
||||||
MCP_WEATHER_SERVER_ARGS = _env_args("MCP_WEATHER_SERVER_ARGS")
|
|
||||||
|
|
||||||
MCP_WEBSEARCH_SERVER_COMMAND = os.getenv("MCP_WEBSEARCH_SERVER_COMMAND", "")
|
|
||||||
MCP_WEBSEARCH_SERVER_ARGS = _env_args("MCP_WEBSEARCH_SERVER_ARGS")
|
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from autogen_ext.tools.mcp import StdioServerParams, mcp_server_tools
|
|
||||||
|
|
||||||
from . import config
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class MCPServerConfig:
|
|
||||||
name: str
|
|
||||||
command: str
|
|
||||||
args: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
def _configured_servers() -> list[MCPServerConfig]:
|
|
||||||
if not config.ENABLE_MCP_TOOLS:
|
|
||||||
return []
|
|
||||||
|
|
||||||
servers: list[MCPServerConfig] = []
|
|
||||||
if config.MCP_WEATHER_SERVER_COMMAND:
|
|
||||||
servers.append(
|
|
||||||
MCPServerConfig(
|
|
||||||
name="weather",
|
|
||||||
command=config.MCP_WEATHER_SERVER_COMMAND,
|
|
||||||
args=config.MCP_WEATHER_SERVER_ARGS,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if config.MCP_WEBSEARCH_SERVER_COMMAND:
|
|
||||||
servers.append(
|
|
||||||
MCPServerConfig(
|
|
||||||
name="websearch",
|
|
||||||
command=config.MCP_WEBSEARCH_SERVER_COMMAND,
|
|
||||||
args=config.MCP_WEBSEARCH_SERVER_ARGS,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return servers
|
|
||||||
|
|
||||||
|
|
||||||
async def load_mcp_tools() -> list[Any]:
|
|
||||||
configured_servers = _configured_servers()
|
|
||||||
if not configured_servers:
|
|
||||||
print("ℹ️ MCP 工具未配置,跳过加载。")
|
|
||||||
return []
|
|
||||||
|
|
||||||
loaded_tools: list[Any] = []
|
|
||||||
tool_names: set[str] = set()
|
|
||||||
|
|
||||||
for server in configured_servers:
|
|
||||||
params = StdioServerParams(
|
|
||||||
command=server.command,
|
|
||||||
args=server.args,
|
|
||||||
read_timeout_seconds=config.MCP_SERVER_READ_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
server_tools = await mcp_server_tools(params)
|
|
||||||
for tool in server_tools:
|
|
||||||
if tool.name in tool_names:
|
|
||||||
print(f"⚠️ MCP 工具重名,已跳过: {tool.name}")
|
|
||||||
continue
|
|
||||||
loaded_tools.append(tool)
|
|
||||||
tool_names.add(tool.name)
|
|
||||||
print(f"✅ MCP 服务已加载: {server.name} ({len(server_tools)} tools)")
|
|
||||||
except Exception as exc:
|
|
||||||
print(f"⚠️ MCP 服务加载失败: {server.name}, error={exc}")
|
|
||||||
|
|
||||||
if loaded_tools:
|
|
||||||
print(f"✅ MCP 工具总数: {len(loaded_tools)}")
|
|
||||||
else:
|
|
||||||
print("ℹ️ 未加载到任何 MCP 工具。")
|
|
||||||
return loaded_tools
|
|
||||||
Reference in New Issue
Block a user