From c86e2458ef359ba9c548aedf373f80d4db848257 Mon Sep 17 00:00:00 2001 From: JiajunLI Date: Tue, 3 Mar 2026 17:20:54 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E8=81=94=E7=BD=91?= =?UTF-8?q?=E6=90=9C=E7=B4=A2=E3=80=81=E5=9C=B0=E7=90=86=E4=BD=8D=E7=BD=AE?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2=E3=80=81=E5=A4=A9=E6=B0=94=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 203 ++++++++++++++++++------- robot_mcp_server.py | 358 ++++++++++++++++++++++++++++++++++++++++++++ start_vllm.sh | 13 ++ 3 files changed, 523 insertions(+), 51 deletions(-) create mode 100644 robot_mcp_server.py create mode 100644 start_vllm.sh diff --git a/main.py b/main.py index e549a08..e227053 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,50 @@ import asyncio +import json +import sqlite3 +import sys +from pathlib import Path from typing import Annotated + from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core import CancellationToken from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_ext.models.openai import _openai_client as openai_client_module +from autogen_ext.tools.mcp import StdioServerParams, mcp_server_tools + +BASE_DIR = Path(__file__).resolve().parent +USER_DB_PATH = BASE_DIR / "users.db" +MODEL_CALL_TIMEOUT_SECONDS = 45 + +# --- 第一部分:本地工具(面部 + 语音,以后接硬件)--- + + +def _patch_autogen_tool_schema_for_vllm() -> None: + """ + vLLM 目前会对 OpenAI 工具定义中的 `strict` 字段告警(即便 strict=False)。 + 这里做最小补丁:保留工具定义,移除该字段,避免无意义警告。 + """ + if getattr(openai_client_module.convert_tools, "_strict_removed_patch", False): + return + + original_convert_tools = openai_client_module.convert_tools + + def convert_tools_without_strict(tools): + converted = original_convert_tools(tools) + for tool in converted: + fn = tool.get("function") + if isinstance(fn, dict): + fn.pop("strict", None) + return converted + + convert_tools_without_strict._strict_removed_patch = True + openai_client_module.convert_tools = convert_tools_without_strict + + +async def _async_console_input(prompt: str) -> str: + """在线程中执行阻塞 input,避免阻塞事件循环。""" + return await asyncio.to_thread(input, prompt) -# --- 第一部分:工具定义 --- -# 以后接上机器人时,把 print 替换成串口指令 / TTS 调用即可 async def set_expression( expression: Annotated[str, "机器人要展示的表情,如:开心、疑惑、难过、待机"], @@ -16,15 +54,47 @@ async def set_expression( print(f"\n>>>>>> 🤖 表情更新: 【{expression}】 (强度: {intensity}/10) <<<<<<") return f"已切换到【{expression}】表情。" -async def speak( - text: Annotated[str, "机器人要说的话,简短自然"] -) -> str: - """[模拟 TTS] 机器人开口说话。以后接 TTS 引擎播放语音。""" - print(f">>>>>> 🔊 机器人说: {text} <<<<<<\n") - return "语音已播放。" +# --- 第二部分:直接读取用户档案(不经过 MCP,避免多轮工具调用)--- + +def _load_user_profile(user_name: str, db_path: str | Path = USER_DB_PATH) -> str: + """在 Python 层直接读档案,注入到消息上下文,模型无需主动调用 get_user_profile。""" + try: + with sqlite3.connect(db_path) as conn: + conn.row_factory = sqlite3.Row + user = conn.execute( + "SELECT * FROM users WHERE name = ?", (user_name,) + ).fetchone() + if not user: + return f"用户 {user_name} 尚无历史记录,这是第一次见面。" + prefs = conn.execute( + "SELECT category, content FROM preferences WHERE user_name = ?", + (user_name,) + ).fetchall() + conn.execute( + "UPDATE users SET last_seen = datetime('now') WHERE name = ?", + (user_name,) + ) + return json.dumps({ + "基本信息": {"姓名": user["name"], "年龄": user["age"], "上次见面": user["last_seen"]}, + "偏好习惯": {p["category"]: p["content"] for p in prefs}, + }, ensure_ascii=False) + except Exception as e: + return f"档案读取失败({e}),当作第一次见面。" + +# --- 第三部分:启动大脑 --- -# --- 第二部分:配置 VLM 大脑 --- async def start_simulated_head(): + _patch_autogen_tool_schema_for_vllm() + + # 连接 MCP Server,只取写入类工具(读档案由 Python 层直接处理) + server_params = StdioServerParams( + command=sys.executable, + args=["robot_mcp_server.py"], + ) + all_mcp_tools = await mcp_server_tools(server_params) + # 过滤掉 get_user_profile,模型无需主动调用它 + mcp_tools = [t for t in all_mcp_tools if getattr(t, "name", "") != "get_user_profile"] + model_client = OpenAIChatCompletionClient( model="Qwen/Qwen3-VL-8B-Instruct", base_url="http://localhost:8000/v1", @@ -33,70 +103,101 @@ async def start_simulated_head(): "vision": True, "function_calling": True, "json_output": True, - "structured_output": True, + # 关闭 structured_output,避免向 vLLM 发送 strict 字段告警。 + "structured_output": False, "family": "unknown", - "context_length": 4096, + "context_length": 32000, } ) brain = AssistantAgent( name="RobotBrain", model_client=model_client, - tools=[set_expression, speak], - system_message="""你是一个有情感的机器人伙伴,能感知用户状态并进行语言交流。 -每次收到输入时,你必须: -1. 综合视觉信息和用户说的话,理解当前情境和用户的情绪/需求。 -2. 调用 set_expression 展示合适的表情。 -3. 调用 speak 用简短、温暖、自然的语言回应用户。 -回应风格:像和老朋友聊天,不要太正式,有点个性和幽默感。""" + tools=[set_expression] + mcp_tools, + reflect_on_tool_use=True, + system_message="""你是一个有情感的机器人伙伴。每条消息包含: +[用户档案] 历史记录(已由系统预加载) +[视觉状态] 当前摄像头观察 +[用户说] 用户这句话 + +回应规则: +1. 同一轮内调用:set_expression(表情)+ 所有需要的查询工具(get_weather/get_location/web_search)+ 需要的用户信息工具(upsert_user/set_preference) +2. 工具执行完毕后,用简短、温暖、自然的语言直接回答用户——这段文字就是你的语音输出。 +3. 不要说"我去查一下"之类的过渡语,直接完成任务并给出结果。""", ) - # --- 第三部分:交互循环 --- - # 模拟视觉上下文(真实项目中由摄像头实时提供) - visual_context = "视觉输入:用户坐在电脑前,表情平静,看着屏幕。" - + # --- 第四部分:交互循环 --- print("=" * 50) print(" 机器人已上线!输入 'quit' 退出") print("=" * 50) - print(f"[当前视觉状态]: {visual_context}") + + try: + user_name = (await _async_console_input("请输入你的名字: ")).strip() or "用户" + except (EOFError, KeyboardInterrupt): + print("\n机器人下线,再见!") + return + visual_context = "视觉输入:用户坐在电脑前,表情平静,看着屏幕。" + + print(f"\n[当前视觉状态]: {visual_context}") print("提示:输入 'v <描述>' 可以更新视觉状态,例如: v 用户在笑\n") - history = [] # 维护完整对话历史,让机器人记住上下文 + history = [] - while True: - try: - user_input = input("你说: ").strip() - except (EOFError, KeyboardInterrupt): - print("\n机器人下线,再见!") - break + try: + while True: + try: + user_input = (await _async_console_input("你说: ")).strip() + except (EOFError, KeyboardInterrupt): + print("\n机器人下线,再见!") + break - if not user_input: - continue + if not user_input: + continue - if user_input.lower() in ("quit", "exit", "退出"): - await brain.on_messages( - [*history, TextMessage(content=f"{visual_context}\n用户说:「再见」", source="user")], - CancellationToken() + if user_input.lower() in ("quit", "exit", "退出"): + print("机器人下线,再见!") + break + + if user_input.lower().startswith("v "): + visual_context = f"视觉输入:{user_input[2:].strip()}。" + print(f"[视觉状态已更新]: {visual_context}\n") + continue + + # Python 层直接读取档案并注入消息,模型无需发起额外工具调用 + profile = _load_user_profile(user_name) + combined_input = ( + f"[用户档案]\n{profile}\n\n" + f"[视觉状态] {visual_context}\n" + f"[用户说] {user_input}" ) - print("\n机器人下线,再见!") - break + history.append(TextMessage(content=combined_input, source="user")) - # 支持临时更新视觉状态 - if user_input.lower().startswith("v "): - visual_context = f"视觉输入:{user_input[2:].strip()}。" - print(f"[视觉状态已更新]: {visual_context}\n") - continue + # 只保留最近 6 条消息(3轮对话),防止超出 token 上限 + # 用户档案每轮从数据库重新注入,不依赖长历史 + if len(history) > 6: + history = history[-6:] - # 合并视觉 + 语言输入 - combined_input = f"{visual_context}\n用户说:「{user_input}」" - history.append(TextMessage(content=combined_input, source="user")) + try: + response = await asyncio.wait_for( + brain.on_messages(history, CancellationToken()), + timeout=MODEL_CALL_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + print(">>>>>> ⚠️ 请求超时,请稍后重试或简化问题。 <<<<<<\n") + continue + except Exception as e: + print(f">>>>>> ⚠️ 本轮处理失败:{e} <<<<<<\n") + continue - response = await brain.on_messages(history, CancellationToken()) + # 模型的文字回复就是语音输出(reflect_on_tool_use=True 保证这里是 TextMessage) + speech = response.chat_message.content + if speech and isinstance(speech, str): + print(f">>>>>> 🔊 机器人说: {speech} <<<<<<\n") - # 把本轮所有消息(工具调用、工具结果、最终回复)加入历史 - if response.inner_messages: - history.extend(response.inner_messages) - history.append(response.chat_message) + # 只把最终回复加入历史,inner_messages 是事件对象不能序列化回模型 + history.append(response.chat_message) + finally: + model_client.close() if __name__ == "__main__": asyncio.run(start_simulated_head()) diff --git a/robot_mcp_server.py b/robot_mcp_server.py new file mode 100644 index 0000000..d63c8a9 --- /dev/null +++ b/robot_mcp_server.py @@ -0,0 +1,358 @@ +""" +机器人用户档案 MCP Server +存储并维护用户基本信息、偏好习惯。 +""" +import json +import logging +import sqlite3 +from pathlib import Path + +# 压制 mcp 库的 INFO 日志,只保留 WARNING 及以上 +logging.basicConfig(level=logging.WARNING) + +from mcp.server.fastmcp import FastMCP + +DB_PATH = Path(__file__).parent / "users.db" +mcp = FastMCP("robot-user-db") + + +def _db_connect() -> sqlite3.Connection: + """统一数据库连接入口,确保启用外键约束。""" + conn = sqlite3.connect(DB_PATH) + conn.execute("PRAGMA foreign_keys = ON") + return conn + + +def _create_preferences_table(conn: sqlite3.Connection) -> None: + conn.execute(""" + CREATE TABLE preferences ( + user_name TEXT NOT NULL, + category TEXT NOT NULL, + content TEXT NOT NULL, + updated_at TEXT DEFAULT (datetime('now')), + PRIMARY KEY (user_name, category), + FOREIGN KEY (user_name) REFERENCES users(name) + ON UPDATE CASCADE + ON DELETE CASCADE + ) + """) + + +# --- 初始化数据库 --- +def _init_db(): + with _db_connect() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS users ( + name TEXT PRIMARY KEY, + age INTEGER, + created_at TEXT DEFAULT (datetime('now')), + last_seen TEXT + ) + """) + + preferences_exists = conn.execute( + "SELECT 1 FROM sqlite_master WHERE type='table' AND name='preferences'" + ).fetchone() is not None + + if not preferences_exists: + _create_preferences_table(conn) + + +_init_db() + + +# --- MCP 工具定义 --- + +@mcp.tool() +def get_user_profile(user_name: str) -> str: + """ + 获取用户档案(基本信息、所有偏好)。 + 在每次对话的第一轮调用,用于了解用户背景。 + """ + with _db_connect() as conn: + conn.row_factory = sqlite3.Row + + conn.execute( + "UPDATE users SET last_seen = datetime('now') WHERE name = ?", + (user_name,) + ) + user = conn.execute( + "SELECT * FROM users WHERE name = ?", (user_name,) + ).fetchone() + if not user: + return json.dumps( + {"found": False, "message": f"用户 {user_name} 尚无档案,这是第一次见面。"}, + ensure_ascii=False + ) + + prefs = conn.execute( + "SELECT category, content FROM preferences WHERE user_name = ?", + (user_name,) + ).fetchall() + + return json.dumps({ + "found": True, + "basic": { + "name": user["name"], + "age": user["age"], + "last_seen": user["last_seen"], + }, + "preferences": {p["category"]: p["content"] for p in prefs}, + }, ensure_ascii=False, indent=2) + + +@mcp.tool() +def upsert_user(user_name: str, age: int = None) -> str: + """ + 创建或更新用户基本信息。 + 当得知用户姓名、年龄等基本信息时调用。 + """ + with _db_connect() as conn: + existing = conn.execute( + "SELECT 1 FROM users WHERE name = ?", + (user_name,), + ).fetchone() is not None + conn.execute( + """INSERT INTO users (name, age, last_seen) + VALUES (?, ?, datetime('now')) + ON CONFLICT(name) DO UPDATE SET + age = COALESCE(excluded.age, users.age), + last_seen = datetime('now')""", + (user_name, age), + ) + if existing: + return f"已更新用户 {user_name} 的档案。" + return f"已为 {user_name} 创建新档案。" + + +@mcp.tool() +def set_preference(user_name: str, category: str, content: str) -> str: + """ + 更新用户的某项偏好或习惯(同一 category 只保留最新值)。 + category 示例:'话题喜好'、'沟通风格'、'工作习惯'、'忌讳'、'饮食偏好'。 + 在对话中发现新偏好时调用。 + """ + with _db_connect() as conn: + # 保证 users 中存在主档,满足 preferences 外键约束。 + conn.execute( + """INSERT INTO users (name, last_seen) + VALUES (?, datetime('now')) + ON CONFLICT(name) DO UPDATE SET last_seen = datetime('now')""", + (user_name,), + ) + conn.execute( + """INSERT INTO preferences (user_name, category, content) + VALUES (?, ?, ?) + ON CONFLICT(user_name, category) + DO UPDATE SET content = excluded.content, updated_at = datetime('now')""", + (user_name, category, content) + ) + return f"已更新 {user_name} 的偏好 [{category}]:{content}。" + +# ============================================================ +# 联网工具:定位 / 天气 / 搜索 +# ============================================================ + +import requests + +# WMO 天气代码 → 中文描述 +_WMO = { + 0: "晴天", 1: "基本晴朗", 2: "局部多云", 3: "阴天", + 45: "雾", 48: "冻雾", + 51: "轻微毛毛雨", 53: "毛毛雨", 55: "密集毛毛雨", + 61: "小雨", 63: "中雨", 65: "大雨", + 71: "小雪", 73: "中雪", 75: "大雪", 77: "冰粒", + 80: "阵雨", 81: "中等阵雨", 82: "强阵雨", + 85: "阵雪", 86: "强阵雪", + 95: "雷阵雨", 96: "伴有冰雹的雷阵雨", 99: "强雷阵雨", +} + + +def _lookup_location_cn(ip: str = None) -> dict | None: + """ + 使用国内 IP 归属地接口查询地理信息。 + 优先返回城市/省份;该接口不提供经纬度。 + """ + params = {"json": "true"} + if ip: + params["ip"] = ip + resp = requests.get( + "https://whois.pconline.com.cn/ipJson.jsp", + params=params, + timeout=8, + ) + # 该接口常见返回为 gbk 编码 + resp.encoding = resp.apparent_encoding or "gbk" + raw = resp.text.strip() + if raw.startswith("var returnCitySN"): + raw = raw.split("=", 1)[-1].strip().rstrip(";") + data = json.loads(raw) + + city = (data.get("city") or "").strip() + region = (data.get("pro") or "").strip() + ip_value = (data.get("ip") or ip or "").strip() + if not (city or region): + return None + return { + "city": city or region, + "region": region, + "country": "中国", + "lat": None, + "lon": None, + "ip": ip_value, + } + + +def _lookup_location_ipapi() -> dict | None: + """回退定位:使用 ip-api。""" + data = requests.get( + "http://ip-api.com/json/", + params={"lang": "zh-CN", "fields": "status,city,regionName,country,lat,lon,query"}, + timeout=8, + ).json() + if data.get("status") != "success": + return None + return { + "city": data.get("city") or "", + "region": data.get("regionName") or "", + "country": data.get("country") or "", + "lat": data.get("lat"), + "lon": data.get("lon"), + "ip": data.get("query") or "", + } + + +def _lookup_location() -> dict | None: + """统一定位入口:中国 IP 接口优先,失败回退 ip-api。""" + return _lookup_location_cn() or _lookup_location_ipapi() + + +def _geocode_city(city: str) -> tuple[float | None, float | None, str]: + """根据城市名查经纬度,供天气查询使用。""" + geo = requests.get( + "https://geocoding-api.open-meteo.com/v1/search", + params={"name": city, "count": 1, "language": "zh"}, + timeout=8, + ).json() + results = geo.get("results") + if not results: + return None, None, city + return ( + results[0].get("latitude"), + results[0].get("longitude"), + results[0].get("name", city), + ) + + +def _resolve_weather_target( + city: str | None, lat: float | None, lon: float | None +) -> tuple[float | None, float | None, str | None, str | None]: + """统一解析天气查询目标,减少重复分支。""" + auto_locate_error = "自动定位失败,请手动传入城市名。" + if lat is not None and lon is not None: + return lat, lon, city, None + + if city: + lat, lon, city = _geocode_city(city) + if lat is None or lon is None: + return None, None, city, f"找不到城市:{city}" + return lat, lon, city, None + + loc = _lookup_location() + if not loc: + return None, None, None, auto_locate_error + city = loc["city"] or city + lat, lon = loc.get("lat"), loc.get("lon") + if lat is None or lon is None: + if not city: + return None, None, None, auto_locate_error + lat, lon, city = _geocode_city(city) + if lat is None or lon is None: + return None, None, None, auto_locate_error + return lat, lon, city, None + + +@mcp.tool() +def get_location() -> str: + """ + 通过 IP 地址获取当前地理位置(城市、省份、国家、经纬度)。 + 在查询天气前,或需要了解用户所在城市时调用。 + """ + try: + loc = _lookup_location() + if not loc: + return "定位失败,请稍后再试。" + return json.dumps({ + "城市": loc["city"], + "省份": loc["region"], + "国家": loc["country"], + "纬度": loc["lat"], + "经度": loc["lon"], + "IP": loc["ip"], + }, ensure_ascii=False) + except Exception as e: + return f"定位失败:{e}" + + +@mcp.tool() +def get_weather(city: str = None, lat: float = None, lon: float = None) -> str: + """ + 获取实时天气信息。 + 可以传入城市名(city),或经纬度(lat/lon);若都不传则自动定位。 + 返回温度、天气状况、风速。 + """ + try: + lat, lon, city, err = _resolve_weather_target(city, lat, lon) + if err: + return err + + # 查询天气 + resp = requests.get( + "https://api.open-meteo.com/v1/forecast", + params={ + "latitude": lat, "longitude": lon, + "current": "temperature_2m,apparent_temperature,weather_code,wind_speed_10m,relative_humidity_2m", + "timezone": "auto", + }, + timeout=10, + ).json() + cur = resp.get("current", {}) + code = cur.get("weather_code", -1) + return json.dumps({ + "城市": city or f"{lat},{lon}", + "天气": _WMO.get(code, f"未知(code={code})"), + "温度": f"{cur.get('temperature_2m', '?')}°C", + "体感温度": f"{cur.get('apparent_temperature', '?')}°C", + "湿度": f"{cur.get('relative_humidity_2m', '?')}%", + "风速": f"{cur.get('wind_speed_10m', '?')} km/h", + }, ensure_ascii=False) + except Exception as e: + return f"天气查询失败:{e}" + + +@mcp.tool() +def web_search(query: str, max_results: int = 5) -> str: + """ + 联网搜索,获取实时信息(新闻、百科、价格等)。 + 返回最多 max_results 条结果(标题 + 摘要 + 链接)。 + """ + try: + query = query.strip() + if not query: + return "搜索关键词不能为空。" + max_results = max(1, min(max_results, 10)) + + from ddgs import DDGS + results = DDGS().text(query, max_results=max_results) + if not results: + return "搜索无结果。" + output = [] + for i, r in enumerate(results, 1): + output.append(f"{i}. {r['title']}\n {r['body'][:150]}\n {r['href']}") + return "\n\n".join(output) + except Exception as e: + return f"搜索失败:{e}" + + +if __name__ == "__main__": + mcp.run() diff --git a/start_vllm.sh b/start_vllm.sh new file mode 100644 index 0000000..d400932 --- /dev/null +++ b/start_vllm.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# 启动 vLLM 服务器脚本 +# 用法: bash start_vllm.sh + +python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen3-VL-8B-Instruct \ + --trust-remote-code \ + --port 8000 \ + --gpu-memory-utilization 0.85 \ + --max-model-len 32000 \ + --enable-auto-tool-choice \ + --tool-call-parser hermes \ + --uvicorn-log-level warning