251 lines
12 KiB
Python
251 lines
12 KiB
Python
import streamlit as st
|
||
import os
|
||
import cv2
|
||
import tempfile
|
||
import torch
|
||
|
||
# Import the main app class and config flags
|
||
from motion_app import MotionComparisonApp
|
||
from config import REALSENSE_AVAILABLE, PYGAME_AVAILABLE
|
||
|
||
torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)]
|
||
|
||
# Set page config at module level
|
||
st.set_page_config(page_title="动作比较", page_icon="🏃", layout="wide")
|
||
|
||
def main():
|
||
try:
|
||
st.title("🏃 动作比较与姿态分析系统")
|
||
st.markdown("---")
|
||
|
||
# Initialize the app object in session state
|
||
if 'app' not in st.session_state:
|
||
st.session_state.app = MotionComparisonApp()
|
||
app = st.session_state.app
|
||
|
||
# --- Sidebar UI ---
|
||
with st.sidebar:
|
||
st.header("🎛️ 控制面板")
|
||
|
||
# Display settings
|
||
resolution_options = {"高清 (1280x800)": "high", "中等 (960x720)": "medium", "标准 (640x480)": "low"}
|
||
selected_res = st.selectbox("显示分辨率", list(resolution_options.keys()), index=1)
|
||
app.display_settings['resolution_mode'] = resolution_options[selected_res]
|
||
|
||
st.markdown("---")
|
||
|
||
# Video Source Selection
|
||
video_source = st.radio("视频来源", ["预设视频", "上传视频"])
|
||
video_path = None
|
||
|
||
if video_source == "预设视频":
|
||
# Define preset video mapping (display name -> filename)
|
||
preset_videos = {
|
||
"六字诀": "liuzi.mp4",
|
||
"六字诀精简": "liuzi-short.mp4",
|
||
# Add more preset videos here:
|
||
# "太极拳": "taiji.mp4",
|
||
# "八段锦": "baduanjin.mp4",
|
||
}
|
||
|
||
preset_folder = "preset_videos"
|
||
if os.path.exists(preset_folder):
|
||
# Check which preset videos actually exist in the folder
|
||
available_presets = {}
|
||
for display_name, filename in preset_videos.items():
|
||
full_path = os.path.join(preset_folder, filename)
|
||
if os.path.exists(full_path):
|
||
available_presets[display_name] = filename
|
||
|
||
if available_presets:
|
||
selected_display_name = st.selectbox("选择预设视频", list(available_presets.keys()))
|
||
selected_filename = available_presets[selected_display_name]
|
||
video_path = os.path.join(preset_folder, selected_filename)
|
||
st.success(f"✅ 已选择视频: {selected_display_name} ({selected_filename})")
|
||
else:
|
||
st.error("❌ preset_videos 文件夹中未找到预设视频文件")
|
||
st.info(f"请确保以下文件存在: {', '.join(preset_videos.values())}")
|
||
else:
|
||
st.error("❌ preset_videos 文件夹不存在")
|
||
else:
|
||
uploaded_file = st.file_uploader("上传视频文件", type=['mp4', 'avi', 'mov', 'mkv'])
|
||
if uploaded_file:
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
|
||
tmp_file.write(uploaded_file.read())
|
||
video_path = tmp_file.name
|
||
|
||
st.markdown("---")
|
||
|
||
# System Initialization
|
||
st.subheader("⚙️ 系统初始化")
|
||
if st.button("🚀 初始化系统", use_container_width=True):
|
||
with st.spinner("正在初始化检测器和摄像头..."):
|
||
app.initialize_detector()
|
||
app.initialize_camera()
|
||
|
||
# System Status
|
||
st.subheader("ℹ️ 系统状态")
|
||
st.info(f"计算设备: {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}")
|
||
st.info(f"摄像头: {'RealSense' if REALSENSE_AVAILABLE else 'USB 摄像头'}")
|
||
st.info(f"音频: {'启用' if PYGAME_AVAILABLE else '禁用'}")
|
||
|
||
# --- Main Page UI ---
|
||
if video_path is not None:
|
||
try:
|
||
# 显示视频基本信息(紧凑显示)
|
||
cap = cv2.VideoCapture(video_path)
|
||
if cap.isOpened():
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
duration = frame_count / fps if fps > 0 else 0
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
|
||
info_col1, info_col2, info_col3, info_col4 = st.columns(4)
|
||
with info_col1:
|
||
st.metric("⏱️ 时长", f"{duration:.1f}s")
|
||
with info_col2:
|
||
st.metric("🔄 帧率", f"{fps:.1f}")
|
||
with info_col3:
|
||
st.metric("📐 原始分辨率", f"{width}×{height}")
|
||
with info_col4:
|
||
current_res = app.get_display_resolution()
|
||
st.metric("🖥️ 显示分辨率", f"{current_res[0]}×{current_res[1]}")
|
||
|
||
cap.release()
|
||
|
||
st.markdown("---")
|
||
|
||
# 初始化session_state
|
||
if 'show_preview' not in st.session_state:
|
||
st.session_state.show_preview = False
|
||
|
||
# 预览区域
|
||
if st.session_state.show_preview:
|
||
st.markdown("### 📷 摄像头预览")
|
||
|
||
# 初始化摄像头(如果还没有初始化)
|
||
if not app.is_realsense_active and (app.webcam_cap is None or not app.webcam_cap.isOpened()):
|
||
app.initialize_realsense()
|
||
|
||
# 获取预览帧
|
||
preview_frame = app.get_camera_preview_frame()
|
||
if preview_frame is not None:
|
||
# 使用大尺寸预览
|
||
preview_col1, preview_col2, preview_col3 = st.columns([1, 3, 1])
|
||
with preview_col2:
|
||
st.image(preview_frame, caption="摄像头预览 (已镜像)", use_container_width=True)
|
||
|
||
# 显示提示信息
|
||
camera_type = "RealSense摄像头" if app.is_realsense_active else "USB摄像头"
|
||
st.info(f"🎯 正在使用: {camera_type} - 请调整您的位置,确保全身在画面中清晰可见")
|
||
|
||
# 预览控制按钮
|
||
preview_btn_col1, preview_btn_col2, preview_btn_col3 = st.columns(3)
|
||
with preview_btn_col1:
|
||
if st.button("🔄 刷新画面", key="refresh_preview", use_container_width=True):
|
||
st.rerun()
|
||
with preview_btn_col2:
|
||
if st.button("⏹️ 停止预览", key="stop_preview", use_container_width=True):
|
||
st.session_state.show_preview = False
|
||
st.rerun()
|
||
with preview_btn_col3:
|
||
if st.button("🚀 开始比较", key="start_from_preview", use_container_width=True):
|
||
if app.body_detector is None:
|
||
st.error("⚠️ 请先初始化关键点检测器")
|
||
else:
|
||
st.session_state.show_preview = False
|
||
app.start_comparison(video_path)
|
||
else:
|
||
st.error("❌ 无法获取摄像头画面,请检查摄像头连接")
|
||
st.session_state.show_preview = False
|
||
|
||
else:
|
||
if 'start_main_comparison' not in st.session_state:
|
||
st.session_state.start_main_comparison = False
|
||
# 如果不处于比较状态,则显示主控制按钮
|
||
if not st.session_state.start_main_comparison:
|
||
main_btn_col1, main_btn_col2, main_btn_col3 = st.columns(3)
|
||
|
||
with main_btn_col1:
|
||
if st.button("📷 预览摄像头", use_container_width=True, help="先预览摄像头调整位置"):
|
||
if app.body_detector is None:
|
||
st.error("⚠️ 请先在侧边栏初始化系统")
|
||
else:
|
||
st.session_state.show_preview = True
|
||
st.rerun()
|
||
|
||
with main_btn_col2:
|
||
if st.button("🚀 直接开始比较", use_container_width=True, help="直接开始动作比较"):
|
||
if app.body_detector is None:
|
||
st.error("⚠️ 请先在侧边栏初始化系统")
|
||
else:
|
||
st.session_state.start_main_comparison = True
|
||
st.rerun()
|
||
|
||
with main_btn_col3:
|
||
st.markdown("") # 占位
|
||
|
||
if st.session_state.start_main_comparison:
|
||
app.start_comparison(video_path)
|
||
# 比较结束后,重置状态,以便下次可以重新显示主按钮
|
||
st.session_state.start_main_comparison = False
|
||
|
||
except Exception as e:
|
||
st.error(f"❌ 处理视频时出现错误: {str(e)}")
|
||
|
||
else:
|
||
# 无视频时显示说明
|
||
st.info("👈 请在左侧选择预设视频或上传标准动作视频")
|
||
|
||
# 显示使用说明(优化布局)
|
||
with st.expander("📖 使用说明", expanded=True):
|
||
usage_col1, usage_col2 = st.columns(2)
|
||
|
||
with usage_col1:
|
||
st.markdown("""
|
||
#### 🚀 快速开始:
|
||
1. **选择视频**: 侧边栏选择预设视频或上传视频
|
||
2. **调整分辨率**: 根据设备性能选择合适分辨率
|
||
3. **初始化系统**: 点击"初始化系统"按钮
|
||
4. **预览摄像头**: 调整位置确保全身可见
|
||
5. **开始比较**: 跟随标准视频做动作
|
||
|
||
#### ⚙️ 分辨率建议:
|
||
- **高性能设备**: 高清模式 (1280x800)
|
||
- **一般设备**: 中等模式 (960x720)
|
||
- **低配设备**: 标准模式 (640x480)
|
||
""")
|
||
|
||
with usage_col2:
|
||
st.markdown("""
|
||
#### ✨ 主要特性:
|
||
- 🎯 实时关键点检测和动作分析
|
||
- 📹 支持RealSense和USB摄像头
|
||
- 🔊 视频音频同步播放
|
||
- 📊 实时相似度图表分析
|
||
- 🖥️ 大屏幕优化显示
|
||
- ⚡ GPU加速支持
|
||
|
||
#### 📋 系统要求:
|
||
- 摄像头设备及权限
|
||
- 充足的光照条件
|
||
- Python 3.7+ 环境
|
||
- 推荐GPU支持(可选)
|
||
""")
|
||
|
||
except Exception as e:
|
||
st.error(f"❌ 应用程序启动失败: {str(e)}")
|
||
st.info("💡 请检查所有依赖库是否正确安装")
|
||
|
||
if __name__ == "__main__":
|
||
# Set environment variables for performance
|
||
# os.environ['OMP_NUM_THREADS'] = '1'
|
||
# os.environ['MKL_NUM_THREADS'] = '1'
|
||
# try:
|
||
# import torch
|
||
# torch.set_num_threads(1)
|
||
# except ImportError:
|
||
# pass
|
||
main()
|