import streamlit as st import os import cv2 import tempfile import torch # Import the main app class and config flags from motion_app import MotionComparisonApp from config import REALSENSE_AVAILABLE, PYGAME_AVAILABLE torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)] # Set page config at module level st.set_page_config(page_title="动作比较", page_icon="🏃", layout="wide") def main(): try: st.title("🏃 动作比较与姿态分析系统") st.markdown("---") # Initialize the app object in session state if 'app' not in st.session_state: st.session_state.app = MotionComparisonApp() app = st.session_state.app # --- Sidebar UI --- with st.sidebar: st.header("🎛️ 控制面板") # Display settings resolution_options = {"高清 (1280x800)": "high", "中等 (960x720)": "medium", "标准 (640x480)": "low"} selected_res = st.selectbox("显示分辨率", list(resolution_options.keys()), index=1) app.display_settings['resolution_mode'] = resolution_options[selected_res] st.markdown("---") # Video Source Selection video_source = st.radio("视频来源", ["预设视频", "上传视频"]) video_path = None if video_source == "预设视频": # Define preset video mapping (display name -> filename) preset_videos = { "六字诀": "liuzi.mp4", # Add more preset videos here: # "太极拳": "taiji.mp4", # "八段锦": "baduanjin.mp4", } preset_folder = "preset_videos" if os.path.exists(preset_folder): # Check which preset videos actually exist in the folder available_presets = {} for display_name, filename in preset_videos.items(): full_path = os.path.join(preset_folder, filename) if os.path.exists(full_path): available_presets[display_name] = filename if available_presets: selected_display_name = st.selectbox("选择预设视频", list(available_presets.keys())) selected_filename = available_presets[selected_display_name] video_path = os.path.join(preset_folder, selected_filename) st.success(f"✅ 已选择视频: {selected_display_name} ({selected_filename})") else: st.error("❌ preset_videos 文件夹中未找到预设视频文件") st.info(f"请确保以下文件存在: {', '.join(preset_videos.values())}") else: st.error("❌ preset_videos 文件夹不存在") else: uploaded_file = st.file_uploader("上传视频文件", type=['mp4', 'avi', 'mov', 'mkv']) if uploaded_file: with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file: tmp_file.write(uploaded_file.read()) video_path = tmp_file.name st.markdown("---") # System Initialization st.subheader("⚙️ 系统初始化") if st.button("🚀 初始化系统", use_container_width=True): with st.spinner("正在初始化检测器和摄像头..."): app.initialize_detector() app.initialize_camera() # System Status st.subheader("ℹ️ 系统状态") st.info(f"计算设备: {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}") st.info(f"摄像头: {'RealSense' if REALSENSE_AVAILABLE else 'USB 摄像头'}") st.info(f"音频: {'启用' if PYGAME_AVAILABLE else '禁用'}") # --- Main Page UI --- if video_path is not None: try: # 显示视频基本信息(紧凑显示) cap = cv2.VideoCapture(video_path) if cap.isOpened(): fps = cap.get(cv2.CAP_PROP_FPS) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = frame_count / fps if fps > 0 else 0 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) info_col1, info_col2, info_col3, info_col4 = st.columns(4) with info_col1: st.metric("⏱️ 时长", f"{duration:.1f}s") with info_col2: st.metric("🔄 帧率", f"{fps:.1f}") with info_col3: st.metric("📐 原始分辨率", f"{width}×{height}") with info_col4: current_res = app.get_display_resolution() st.metric("🖥️ 显示分辨率", f"{current_res[0]}×{current_res[1]}") cap.release() st.markdown("---") # 初始化session_state if 'show_preview' not in st.session_state: st.session_state.show_preview = False # 预览区域 if st.session_state.show_preview: st.markdown("### 📷 摄像头预览") # 初始化摄像头(如果还没有初始化) if not app.is_realsense_active and (app.webcam_cap is None or not app.webcam_cap.isOpened()): app.initialize_realsense() # 获取预览帧 preview_frame = app.get_camera_preview_frame() if preview_frame is not None: # 使用大尺寸预览 preview_col1, preview_col2, preview_col3 = st.columns([1, 3, 1]) with preview_col2: st.image(preview_frame, caption="摄像头预览 (已镜像)", use_container_width=True) # 显示提示信息 camera_type = "RealSense摄像头" if app.is_realsense_active else "USB摄像头" st.info(f"🎯 正在使用: {camera_type} - 请调整您的位置,确保全身在画面中清晰可见") # 预览控制按钮 preview_btn_col1, preview_btn_col2, preview_btn_col3 = st.columns(3) with preview_btn_col1: if st.button("🔄 刷新画面", key="refresh_preview", use_container_width=True): st.rerun() with preview_btn_col2: if st.button("⏹️ 停止预览", key="stop_preview", use_container_width=True): st.session_state.show_preview = False st.rerun() with preview_btn_col3: if st.button("🚀 开始比较", key="start_from_preview", use_container_width=True): if app.body_detector is None: st.error("⚠️ 请先初始化关键点检测器") else: st.session_state.show_preview = False app.start_comparison(video_path) else: st.error("❌ 无法获取摄像头画面,请检查摄像头连接") st.session_state.show_preview = False else: # 主控制按钮(大按钮) main_btn_col1, main_btn_col2, main_btn_col3 = st.columns(3) with main_btn_col1: if st.button("📷 预览摄像头", use_container_width=True, help="先预览摄像头调整位置"): if app.body_detector is None: st.error("⚠️ 请先在侧边栏初始化系统") else: st.session_state.show_preview = True st.rerun() with main_btn_col2: if st.button("🚀 直接开始比较", use_container_width=True, help="直接开始动作比较"): if app.body_detector is None: st.error("⚠️ 请先在侧边栏初始化系统") else: app.start_comparison(video_path) with main_btn_col3: st.markdown("") # 占位 except Exception as e: st.error(f"❌ 处理视频时出现错误: {str(e)}") else: # 无视频时显示说明 st.info("👈 请在左侧选择预设视频或上传标准动作视频") # 显示使用说明(优化布局) with st.expander("📖 使用说明", expanded=True): usage_col1, usage_col2 = st.columns(2) with usage_col1: st.markdown(""" #### 🚀 快速开始: 1. **选择视频**: 侧边栏选择预设视频或上传视频 2. **调整分辨率**: 根据设备性能选择合适分辨率 3. **初始化系统**: 点击"初始化系统"按钮 4. **预览摄像头**: 调整位置确保全身可见 5. **开始比较**: 跟随标准视频做动作 #### ⚙️ 分辨率建议: - **高性能设备**: 高清模式 (1280x800) - **一般设备**: 中等模式 (960x720) - **低配设备**: 标准模式 (640x480) """) with usage_col2: st.markdown(""" #### ✨ 主要特性: - 🎯 实时关键点检测和动作分析 - 📹 支持RealSense和USB摄像头 - 🔊 视频音频同步播放 - 📊 实时相似度图表分析 - 🖥️ 大屏幕优化显示 - ⚡ GPU加速支持 #### 📋 系统要求: - 摄像头设备及权限 - 充足的光照条件 - Python 3.7+ 环境 - 推荐GPU支持(可选) """) except Exception as e: st.error(f"❌ 应用程序启动失败: {str(e)}") st.info("💡 请检查所有依赖库是否正确安装") if __name__ == "__main__": # Set environment variables for performance # os.environ['OMP_NUM_THREADS'] = '1' # os.environ['MKL_NUM_THREADS'] = '1' # try: # import torch # torch.set_num_threads(1) # except ImportError: # pass main()