From 85bcbe45295c237ba80a2c004ddb3b67c81393d6 Mon Sep 17 00:00:00 2001 From: JiajunLI Date: Wed, 4 Mar 2026 15:35:57 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=A8=A1=E5=9D=97=E8=A7=A3?= =?UTF-8?q?=E8=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brain.py | 85 ++++++++++++++ config.py | 16 +++ main.py | 298 +++++------------------------------------------ profile_store.py | 39 +++++++ voice_io.py | 128 ++++++++++++++++++++ 5 files changed, 298 insertions(+), 268 deletions(-) create mode 100644 brain.py create mode 100644 config.py create mode 100644 profile_store.py create mode 100644 voice_io.py diff --git a/brain.py b/brain.py new file mode 100644 index 0000000..e00157f --- /dev/null +++ b/brain.py @@ -0,0 +1,85 @@ +import sys +from typing import Annotated + +from autogen_agentchat.agents import AssistantAgent +from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_ext.models.openai import _openai_client as openai_client_module +from autogen_ext.tools.mcp import StdioServerParams, mcp_server_tools + +from config import MODEL_API_KEY, MODEL_BASE_URL, MODEL_NAME + + +def patch_autogen_tool_schema_for_vllm() -> None: + """ + vLLM 目前会对 OpenAI 工具定义中的 `strict` 字段告警(即便 strict=False)。 + 这里做最小补丁:保留工具定义,移除该字段,避免无意义警告。 + """ + if getattr(openai_client_module.convert_tools, "_strict_removed_patch", False): + return + + original_convert_tools = openai_client_module.convert_tools + + def convert_tools_without_strict(tools): + converted = original_convert_tools(tools) + for tool in converted: + fn = tool.get("function") + if isinstance(fn, dict): + fn.pop("strict", None) + return converted + + convert_tools_without_strict._strict_removed_patch = True + openai_client_module.convert_tools = convert_tools_without_strict + + +async def set_expression( + expression: Annotated[str, "机器人要展示的表情,如:开心、疑惑、难过、待机"], + intensity: Annotated[int, "表情强度 1-10"] = 5, +) -> str: + """[模拟面部] 控制机器人头部的表情展示。""" + print(f"\n>>>>>> 🤖 表情更新: 【{expression}】 (强度: {intensity}/10) <<<<<<") + return f"已切换到【{expression}】表情。" + + +async def create_brain() -> tuple[AssistantAgent, OpenAIChatCompletionClient]: + patch_autogen_tool_schema_for_vllm() + + # 连接 MCP Server,只取写入类工具(读档案由 Python 层直接处理) + server_params = StdioServerParams( + command=sys.executable, + args=["robot_mcp_server.py"], + ) + all_mcp_tools = await mcp_server_tools(server_params) + mcp_tools = [t for t in all_mcp_tools if getattr(t, "name", "") != "get_user_profile"] + + model_client = OpenAIChatCompletionClient( + model=MODEL_NAME, + base_url=MODEL_BASE_URL, + api_key=MODEL_API_KEY, + model_info={ + "vision": True, + "function_calling": True, + "json_output": True, + # 关闭 structured_output,避免向 vLLM 发送 strict 字段告警。 + "structured_output": False, + "family": "unknown", + "context_length": 32000, + }, + ) + + brain = AssistantAgent( + name="RobotBrain", + model_client=model_client, + tools=[set_expression] + mcp_tools, + reflect_on_tool_use=True, + system_message="""你是一个有情感的机器人伙伴。每条消息包含: +[用户档案] 历史记录(已由系统预加载) +[视觉状态] 当前摄像头观察 +[用户说] 用户这句话 + +回应规则: +1. 同一轮内调用:set_expression(表情)+ 所有需要的查询工具(get_weather/get_location/web_search)+ 需要的用户信息工具(upsert_user/set_preference) +2. 工具执行完毕后,用简短、温暖、自然的语言直接回答用户——这段文字就是你的语音输出。 +3. 不要说"我去查一下"之类的过渡语,直接完成任务并给出结果。""", + ) + return brain, model_client + diff --git a/config.py b/config.py new file mode 100644 index 0000000..8b09d0a --- /dev/null +++ b/config.py @@ -0,0 +1,16 @@ +import os +from pathlib import Path + +BASE_DIR = Path(__file__).resolve().parent +USER_DB_PATH = BASE_DIR / "users.db" + +MODEL_CALL_TIMEOUT_SECONDS = 45 +ASR_LANGUAGE = "zh-CN" + +MODEL_NAME = os.getenv("VLM_MODEL", "Qwen/Qwen3-VL-8B-Instruct") +MODEL_BASE_URL = os.getenv("VLM_BASE_URL", "http://220.248.114.28:8000/v1") +MODEL_API_KEY = os.getenv("VLM_API_KEY", "EMPTY") + +# edge-tts Yunxi 音色 +TTS_VOICE = os.getenv("TTS_VOICE", "zh-CN-YunxiNeural") + diff --git a/main.py b/main.py index a23bd3d..09cb642 100644 --- a/main.py +++ b/main.py @@ -1,250 +1,24 @@ import asyncio -import json -import os -import shutil -import sqlite3 -import subprocess -import sys -import tempfile -from pathlib import Path -from typing import Annotated -from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core import CancellationToken -from autogen_ext.models.openai import OpenAIChatCompletionClient -from autogen_ext.models.openai import _openai_client as openai_client_module -from autogen_ext.tools.mcp import StdioServerParams, mcp_server_tools -try: - import speech_recognition as sr -except ImportError: - sr = None - -try: - import edge_tts -except ImportError: - edge_tts = None - -BASE_DIR = Path(__file__).resolve().parent -USER_DB_PATH = BASE_DIR / "users.db" -MODEL_CALL_TIMEOUT_SECONDS = 45 -ASR_LANGUAGE = "zh-CN" -MODEL_NAME = os.getenv("VLM_MODEL", "Qwen/Qwen3-VL-8B-Instruct") -MODEL_BASE_URL = os.getenv("VLM_BASE_URL", "http://220.248.114.28:8000/v1") -MODEL_API_KEY = os.getenv("VLM_API_KEY", "EMPTY") -TTS_VOICE = os.getenv("TTS_VOICE", "zh-CN-YunxiNeural") - -# --- 第一部分:本地工具(面部 + 语音,以后接硬件)--- +from brain import create_brain +from config import MODEL_BASE_URL, MODEL_CALL_TIMEOUT_SECONDS, MODEL_NAME +from profile_store import load_user_profile +from voice_io import ( + async_console_input, + async_speak, + find_audio_player, + get_user_input, + has_asr, + has_tts, +) -def _patch_autogen_tool_schema_for_vllm() -> None: - """ - vLLM 目前会对 OpenAI 工具定义中的 `strict` 字段告警(即便 strict=False)。 - 这里做最小补丁:保留工具定义,移除该字段,避免无意义警告。 - """ - if getattr(openai_client_module.convert_tools, "_strict_removed_patch", False): - return +async def start_simulated_head() -> None: + brain, model_client = await create_brain() - original_convert_tools = openai_client_module.convert_tools - - def convert_tools_without_strict(tools): - converted = original_convert_tools(tools) - for tool in converted: - fn = tool.get("function") - if isinstance(fn, dict): - fn.pop("strict", None) - return converted - - convert_tools_without_strict._strict_removed_patch = True - openai_client_module.convert_tools = convert_tools_without_strict - - -async def _async_console_input(prompt: str) -> str: - """在线程中执行阻塞 input,避免阻塞事件循环。""" - return await asyncio.to_thread(input, prompt) - - -def _find_audio_player() -> list[str] | None: - """查找可用播放器,优先 ffplay。""" - if shutil.which("ffplay"): - return ["ffplay", "-nodisp", "-autoexit", "-loglevel", "error"] - if shutil.which("mpg123"): - return ["mpg123", "-q"] - if shutil.which("afplay"): - return ["afplay"] - return None - - -def _play_audio_file_blocking(audio_path: str, player_cmd: list[str]) -> bool: - """阻塞播放音频文件。""" - try: - subprocess.run( - [*player_cmd, audio_path], - check=True, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - return True - except Exception: - return False - - -async def _async_speak(text: str) -> bool: - """使用 edge-tts 生成 Yunxi 语音并播放。""" - if not text or edge_tts is None: - return False - - player_cmd = _find_audio_player() - if player_cmd is None: - return False - - with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as fp: - audio_path = fp.name - try: - communicate = edge_tts.Communicate(text=text, voice=TTS_VOICE) - await communicate.save(audio_path) - return await asyncio.to_thread(_play_audio_file_blocking, audio_path, player_cmd) - except Exception: - return False - finally: - try: - Path(audio_path).unlink(missing_ok=True) - except Exception: - pass - - -def _listen_once_blocking( - language: str = ASR_LANGUAGE, - timeout: int = 8, - phrase_time_limit: int = 20, -) -> str: - """阻塞式麦克风识别,返回识别文本。""" - if sr is None: - raise RuntimeError("缺少 speech_recognition 依赖") - - recognizer = sr.Recognizer() - with sr.Microphone(sample_rate=16000) as source: - print(">>>>>> 🎤 请说话... <<<<<<") - recognizer.adjust_for_ambient_noise(source, duration=0.4) - audio = recognizer.listen( - source, - timeout=timeout, - phrase_time_limit=phrase_time_limit, - ) - return recognizer.recognize_google(audio, language=language).strip() - - -async def _async_listen_once() -> str: - """在线程中执行语音识别,避免阻塞事件循环。""" - return await asyncio.to_thread(_listen_once_blocking) - - -async def _get_user_input(io_mode: str) -> str: - """ - 统一用户输入入口: - - text: 纯文本输入 - - voice: 回车后语音输入,也允许直接键入文字 - """ - if io_mode == "text": - return (await _async_console_input("你说: ")).strip() - - typed = (await _async_console_input("你说(回车=语音, 直接输入=文本): ")).strip() - if typed: - return typed - - try: - spoken = await _async_listen_once() - except Exception as e: - print(f">>>>>> ⚠️ 语音识别失败:{e} <<<<<<\n") - return "" - - if spoken: - print(f"[语音识别]: {spoken}") - return spoken - - -async def set_expression( - expression: Annotated[str, "机器人要展示的表情,如:开心、疑惑、难过、待机"], - intensity: Annotated[int, "表情强度 1-10"] = 5 -) -> str: - """[模拟面部] 控制机器人头部的表情展示。""" - print(f"\n>>>>>> 🤖 表情更新: 【{expression}】 (强度: {intensity}/10) <<<<<<") - return f"已切换到【{expression}】表情。" - -# --- 第二部分:直接读取用户档案(不经过 MCP,避免多轮工具调用)--- - -def _load_user_profile(user_name: str, db_path: str | Path = USER_DB_PATH) -> str: - """在 Python 层直接读档案,注入到消息上下文,模型无需主动调用 get_user_profile。""" - try: - with sqlite3.connect(db_path) as conn: - conn.row_factory = sqlite3.Row - user = conn.execute( - "SELECT * FROM users WHERE name = ?", (user_name,) - ).fetchone() - if not user: - return f"用户 {user_name} 尚无历史记录,这是第一次见面。" - prefs = conn.execute( - "SELECT category, content FROM preferences WHERE user_name = ?", - (user_name,) - ).fetchall() - conn.execute( - "UPDATE users SET last_seen = datetime('now') WHERE name = ?", - (user_name,) - ) - return json.dumps({ - "基本信息": {"姓名": user["name"], "年龄": user["age"], "上次见面": user["last_seen"]}, - "偏好习惯": {p["category"]: p["content"] for p in prefs}, - }, ensure_ascii=False) - except Exception as e: - return f"档案读取失败({e}),当作第一次见面。" - -# --- 第三部分:启动大脑 --- - -async def start_simulated_head(): - _patch_autogen_tool_schema_for_vllm() - - # 连接 MCP Server,只取写入类工具(读档案由 Python 层直接处理) - server_params = StdioServerParams( - command=sys.executable, - args=["robot_mcp_server.py"], - ) - all_mcp_tools = await mcp_server_tools(server_params) - # 过滤掉 get_user_profile,模型无需主动调用它 - mcp_tools = [t for t in all_mcp_tools if getattr(t, "name", "") != "get_user_profile"] - - model_client = OpenAIChatCompletionClient( - model=MODEL_NAME, - base_url=MODEL_BASE_URL, - api_key=MODEL_API_KEY, - model_info={ - "vision": True, - "function_calling": True, - "json_output": True, - # 关闭 structured_output,避免向 vLLM 发送 strict 字段告警。 - "structured_output": False, - "family": "unknown", - "context_length": 32000, - } - ) - - brain = AssistantAgent( - name="RobotBrain", - model_client=model_client, - tools=[set_expression] + mcp_tools, - reflect_on_tool_use=True, - system_message="""你是一个有情感的机器人伙伴。每条消息包含: -[用户档案] 历史记录(已由系统预加载) -[视觉状态] 当前摄像头观察 -[用户说] 用户这句话 - -回应规则: -1. 同一轮内调用:set_expression(表情)+ 所有需要的查询工具(get_weather/get_location/web_search)+ 需要的用户信息工具(upsert_user/set_preference) -2. 工具执行完毕后,用简短、温暖、自然的语言直接回答用户——这段文字就是你的语音输出。 -3. 不要说"我去查一下"之类的过渡语,直接完成任务并给出结果。""", - ) - - # --- 第四部分:交互循环 --- print("=" * 50) print(" 机器人已上线!输入 'quit' 退出") print(f" 模型: {MODEL_NAME}") @@ -252,22 +26,17 @@ async def start_simulated_head(): print("=" * 50) try: - user_name = (await _async_console_input("请输入你的名字: ")).strip() or "用户" + user_name = (await async_console_input("请输入你的名字: ")).strip() or "用户" except (EOFError, KeyboardInterrupt): print("\n机器人下线,再见!") return - has_asr = sr is not None - has_tts = edge_tts is not None - if has_asr and has_tts: - mode_tip = "voice" - else: - mode_tip = "text" + asr_ready = has_asr() + tts_ready = has_tts() + mode_tip = "voice" if (asr_ready and tts_ready) else "text" try: io_mode = ( - await _async_console_input( - f"输入模式 voice/text(默认 {mode_tip}): " - ) + await async_console_input(f"输入模式 voice/text(默认 {mode_tip}): ") ).strip().lower() or mode_tip except (EOFError, KeyboardInterrupt): print("\n机器人下线,再见!") @@ -275,60 +44,53 @@ async def start_simulated_head(): if io_mode not in ("voice", "text"): io_mode = mode_tip - if io_mode == "voice" and not has_asr: + if io_mode == "voice" and not asr_ready: print(">>>>>> ⚠️ 未安装 speech_recognition,已降级为文本输入。 <<<<<<") io_mode = "text" - if io_mode == "voice" and not has_tts: + if io_mode == "voice" and not tts_ready: print(">>>>>> ⚠️ 未安装 edge-tts,将仅文本输出,不播报语音。 <<<<<<") - if io_mode == "voice" and has_tts and _find_audio_player() is None: + if io_mode == "voice" and tts_ready and find_audio_player() is None: print(">>>>>> ⚠️ 未检测到播放器(ffplay/mpg123/afplay),将仅文本输出。 <<<<<<") print( "\n[语音依赖状态] " - f"ASR={'ok' if has_asr else 'missing'}, " - f"TTS={'ok' if has_tts else 'missing'}" + f"ASR={'ok' if asr_ready else 'missing'}, " + f"TTS={'ok' if tts_ready else 'missing'}" ) - if not has_asr or not has_tts: + if not asr_ready or not tts_ready: print("可安装: pip install SpeechRecognition pyaudio edge-tts") visual_context = "视觉输入:用户坐在电脑前,表情平静,看着屏幕。" - print(f"\n[当前视觉状态]: {visual_context}") print("提示:输入 'v <描述>' 可以更新视觉状态,例如: v 用户在笑\n") - history = [] + history: list[TextMessage] = [] try: while True: try: - user_input = await _get_user_input(io_mode) + user_input = await get_user_input(io_mode) except (EOFError, KeyboardInterrupt): print("\n机器人下线,再见!") break if not user_input: continue - if user_input.lower() in ("quit", "exit", "退出"): print("机器人下线,再见!") break - if user_input.lower().startswith("v "): visual_context = f"视觉输入:{user_input[2:].strip()}。" print(f"[视觉状态已更新]: {visual_context}\n") continue - # Python 层直接读取档案并注入消息,模型无需发起额外工具调用 - profile = _load_user_profile(user_name) + profile = load_user_profile(user_name) combined_input = ( f"[用户档案]\n{profile}\n\n" f"[视觉状态] {visual_context}\n" f"[用户说] {user_input}" ) history.append(TextMessage(content=combined_input, source="user")) - - # 只保留最近 6 条消息(3轮对话),防止超出 token 上限 - # 用户档案每轮从数据库重新注入,不依赖长历史 if len(history) > 6: history = history[-6:] @@ -344,19 +106,19 @@ async def start_simulated_head(): print(f">>>>>> ⚠️ 本轮处理失败:{e} <<<<<<\n") continue - # 模型的文字回复就是语音输出(reflect_on_tool_use=True 保证这里是 TextMessage) speech = response.chat_message.content if speech and isinstance(speech, str): print(f">>>>>> 🔊 机器人说: {speech} <<<<<<\n") if io_mode == "voice": - spoken_ok = await _async_speak(speech) + spoken_ok = await async_speak(speech) if not spoken_ok: print(">>>>>> ⚠️ TTS 不可用,当前仅文本输出。 <<<<<<\n") - # 只把最终回复加入历史,inner_messages 是事件对象不能序列化回模型 history.append(response.chat_message) finally: model_client.close() + if __name__ == "__main__": asyncio.run(start_simulated_head()) + diff --git a/profile_store.py b/profile_store.py new file mode 100644 index 0000000..549a213 --- /dev/null +++ b/profile_store.py @@ -0,0 +1,39 @@ +import json +import sqlite3 +from pathlib import Path + +from config import USER_DB_PATH + + +def load_user_profile(user_name: str, db_path: str | Path = USER_DB_PATH) -> str: + """在 Python 层直接读档案,注入到消息上下文,模型无需主动调用 get_user_profile。""" + try: + with sqlite3.connect(db_path) as conn: + conn.row_factory = sqlite3.Row + user = conn.execute( + "SELECT * FROM users WHERE name = ?", (user_name,) + ).fetchone() + if not user: + return f"用户 {user_name} 尚无历史记录,这是第一次见面。" + prefs = conn.execute( + "SELECT category, content FROM preferences WHERE user_name = ?", + (user_name,) + ).fetchall() + conn.execute( + "UPDATE users SET last_seen = datetime('now') WHERE name = ?", + (user_name,) + ) + return json.dumps( + { + "基本信息": { + "姓名": user["name"], + "年龄": user["age"], + "上次见面": user["last_seen"], + }, + "偏好习惯": {p["category"]: p["content"] for p in prefs}, + }, + ensure_ascii=False, + ) + except Exception as e: + return f"档案读取失败({e}),当作第一次见面。" + diff --git a/voice_io.py b/voice_io.py new file mode 100644 index 0000000..1a59767 --- /dev/null +++ b/voice_io.py @@ -0,0 +1,128 @@ +import asyncio +import shutil +import subprocess +import tempfile +from pathlib import Path + +from config import ASR_LANGUAGE, TTS_VOICE + +try: + import speech_recognition as sr +except ImportError: + sr = None + +try: + import edge_tts +except ImportError: + edge_tts = None + + +async def async_console_input(prompt: str) -> str: + """在线程中执行阻塞 input,避免阻塞事件循环。""" + return await asyncio.to_thread(input, prompt) + + +def has_asr() -> bool: + return sr is not None + + +def has_tts() -> bool: + return edge_tts is not None + + +def find_audio_player() -> list[str] | None: + """查找可用播放器,优先 ffplay。""" + if shutil.which("ffplay"): + return ["ffplay", "-nodisp", "-autoexit", "-loglevel", "error"] + if shutil.which("mpg123"): + return ["mpg123", "-q"] + if shutil.which("afplay"): + return ["afplay"] + return None + + +def _play_audio_file_blocking(audio_path: str, player_cmd: list[str]) -> bool: + try: + subprocess.run( + [*player_cmd, audio_path], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + return True + except Exception: + return False + + +async def async_speak(text: str) -> bool: + """使用 edge-tts 生成 Yunxi 语音并播放。""" + if not text or edge_tts is None: + return False + + player_cmd = find_audio_player() + if player_cmd is None: + return False + + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as fp: + audio_path = fp.name + try: + communicate = edge_tts.Communicate(text=text, voice=TTS_VOICE) + await communicate.save(audio_path) + return await asyncio.to_thread(_play_audio_file_blocking, audio_path, player_cmd) + except Exception: + return False + finally: + try: + Path(audio_path).unlink(missing_ok=True) + except Exception: + pass + + +def _listen_once_blocking( + language: str = ASR_LANGUAGE, + timeout: int = 8, + phrase_time_limit: int = 20, +) -> str: + """阻塞式麦克风识别,返回识别文本。""" + if sr is None: + raise RuntimeError("缺少 speech_recognition 依赖") + + recognizer = sr.Recognizer() + with sr.Microphone(sample_rate=16000) as source: + print(">>>>>> 🎤 请说话... <<<<<<") + recognizer.adjust_for_ambient_noise(source, duration=0.4) + audio = recognizer.listen( + source, + timeout=timeout, + phrase_time_limit=phrase_time_limit, + ) + return recognizer.recognize_google(audio, language=language).strip() + + +async def _async_listen_once() -> str: + return await asyncio.to_thread(_listen_once_blocking) + + +async def get_user_input(io_mode: str) -> str: + """ + 统一用户输入入口: + - text: 纯文本输入 + - voice: 回车后语音输入,也允许直接键入文字 + """ + if io_mode == "text": + return (await async_console_input("你说: ")).strip() + + typed = (await async_console_input("你说(回车=语音, 直接输入=文本): ")).strip() + if typed: + return typed + + try: + spoken = await _async_listen_once() + except Exception as e: + print(f">>>>>> ⚠️ 语音识别失败:{e} <<<<<<\n") + return "" + + if spoken: + print(f"[语音识别]: {spoken}") + return spoken +