posedet/motion_app.py
2025-06-20 15:01:39 +08:00

460 lines
20 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import streamlit as st
import cv2
import time
import os
import numpy as np
import torch
from rtmlib import Body, draw_skeleton
from audio_player import AudioPlayer
from pose_analyzer import PoseSimilarityAnalyzer
from config import REALSENSE_AVAILABLE
if REALSENSE_AVAILABLE:
import pyrealsense2 as rs
class MotionComparisonApp:
"""Main application class for motion comparison."""
def __init__(self):
self.body_detector = None
self.is_running = False
self.standard_video_path = None
self.webcam_cap = None
self.standard_cap = None
self.similarity_analyzer = PoseSimilarityAnalyzer()
self.frame_counter = 0
self.audio_player = AudioPlayer()
self.display_settings = {'resolution_mode': 'high', 'target_width': 960, 'target_height': 720}
self.realsense_pipeline = None
self.is_realsense_active = False
self.last_error_time = 0
self.error_count = 0
if 'comparison_state' not in st.session_state:
st.session_state.comparison_state = {'is_running': False, 'should_stop': False, 'should_restart': False}
def get_display_resolution(self):
modes = {'high': (1280, 800), 'medium': (960, 720), 'low': (640, 480)}
mode = self.display_settings.get('resolution_mode', 'medium')
return modes.get(mode, (960, 720))
def initialize_detector(self):
if self.body_detector is None:
try:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.body_detector = Body(mode='lightweight', to_openpose=True, backend='onnxruntime', device=device)
st.success(f"Keypoint detector initialized on device: {device}")
return True
except Exception as e:
st.error(f"Detector initialization failed: {e}")
return False
return True
def initialize_camera(self):
if REALSENSE_AVAILABLE:
try:
self.realsense_pipeline = rs.pipeline()
config = rs.config()
width, height = self.get_display_resolution()
config.enable_stream(rs.stream.color, width, height, rs.format.bgr8, 30)
profile = self.realsense_pipeline.start(config)
device = profile.get_device().get_info(rs.camera_info.name)
st.success(f"✅ RealSense摄像头初始化成功: {device} ({width}x{height})")
self.is_realsense_active = True
return True
except Exception as e:
st.warning(f"RealSense初始化失败: {e}. 切换到USB摄像头.")
return self._initialize_webcam()
else:
return self._initialize_webcam()
def _initialize_webcam(self):
try:
self.webcam_cap = cv2.VideoCapture(0)
if self.webcam_cap.isOpened():
width, height = self.get_display_resolution()
self.webcam_cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
self.webcam_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
self.webcam_cap.set(cv2.CAP_PROP_FPS, 30)
actual_w = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
actual_h = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
st.success(f"✅ USB摄像头初始化成功 ({actual_w}x{actual_h})")
return True
else:
st.error("❌ 无法打开USB摄像头")
return False
except Exception as e:
st.error(f"❌ USB摄像头初始化失败: {e}")
return False
def read_camera_frame(self):
if self.is_realsense_active and self.realsense_pipeline:
try:
frames = self.realsense_pipeline.wait_for_frames(timeout_ms=1000)
color_frame = frames.get_color_frame()
if not color_frame: return False, None
return True, np.asanyarray(color_frame.get_data())
except Exception:
return False, None
elif self.webcam_cap and self.webcam_cap.isOpened():
return self.webcam_cap.read()
return False, None
def get_camera_preview_frame(self):
ret, frame = self.read_camera_frame()
if not ret or frame is None: return None
frame = cv2.flip(frame, 1)
if self.body_detector:
try:
keypoints, scores = self.body_detector(frame)
frame = draw_skeleton(frame.copy(), keypoints, scores, openpose_skeleton=True, kpt_thr=0.43)
except Exception: pass
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
def cleanup(self):
"""Cleans up all resources."""
if self.standard_cap: self.standard_cap.release()
if self.webcam_cap: self.webcam_cap.release()
if self.is_realsense_active and self.realsense_pipeline: self.realsense_pipeline.stop()
self.audio_player.cleanup()
self.is_running = False
st.session_state.comparison_state['is_running'] = False
def show_final_statistics(self):
"""Displays final statistics after the comparison ends."""
history = self.similarity_analyzer.similarity_history
if not history: return
final_avg = sum(history) / len(history)
level, color = ("非常棒! 👏", "success") if final_avg >= 80 else \
("整体不错! 👍", "info") if final_avg >= 60 else \
("需要改进! 💪", "warning")
st.success("🎉 比较完成!")
st.markdown(f"**整体表现**: :{color}[{level}]")
col1, col2, col3 = st.columns(3)
col1.metric("平均相似度", f"{final_avg:.1f}%")
col2.metric("最高相似度", f"{max(history):.1f}%")
col3.metric("最低相似度", f"{min(history):.1f}%")
if final_avg < 60:
with st.expander("💡 改善建议"):
st.markdown("- 确保您的全身在摄像头画面中清晰可见\n"
"- 尽量匹配标准视频的节奏和动作幅度\n"
"- 确保光线充足且稳定")
def start_comparison(self, video_path):
"""The main loop for comparing motion."""
# Setup and initialization... (abbreviated for clarity, logic is the same as original)
self.is_running = True
st.session_state.comparison_state.update({'is_running': True, 'should_stop': False, 'should_restart': False})
self.standard_video_path = video_path
self.frame_counter = 0
self.similarity_analyzer.reset()
audio_loaded = self.audio_player.load_audio(video_path)
if audio_loaded: st.success("✅ 音频加载成功")
else: st.info(" 将不会播放音频")
self.standard_cap = cv2.VideoCapture(video_path)
if not self.standard_cap.isOpened():
st.error("无法打开标准视频文件")
return
if not self.is_realsense_active and (not self.webcam_cap or not self.webcam_cap.isOpened()):
if not self.initialize_camera(): return
# Get video info and setup variables
total_frames = int(self.standard_cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
frame_delay = 1.0 / video_fps if video_fps > 0 else 1.0 / 30.0
target_width, target_height = self.get_display_resolution()
# UI Placeholders
st.markdown("### 📺 视频比较")
video_col1, video_col2 = st.columns(2, gap="small")
standard_placeholder = video_col1.empty()
webcam_placeholder = video_col2.empty()
with video_col1:
st.markdown("#### 🎯 标准动作视频")
with video_col2:
camera_type = "RealSense摄像头" if self.is_realsense_active else "USB摄像头"
st.markdown(f"#### 📹 {camera_type}实时影像")
# 创建控制按钮区域(紧凑布局)
st.markdown("---")
control_container = st.container()
with control_container:
control_col1, control_col2, control_col3, control_col4 = st.columns([1, 1, 1, 1])
with control_col1:
stop_button = st.button("⏹️ 停止", use_container_width=True, key="stop_comparison")
with control_col2:
restart_button = st.button("🔄 重新开始", use_container_width=True, key="restart_comparison")
with control_col3:
if audio_loaded:
if st.button("🔊 音频状态", use_container_width=True, key="audio_status"):
st.info(f"音频: {'播放中' if self.audio_player.is_playing else '已停止'}")
with control_col4:
# 分辨率切换按钮
if st.button("📐 切换分辨率", use_container_width=True, key="resolution_toggle"):
modes = ['high', 'medium', 'low']
current_idx = modes.index(self.display_settings.get('resolution_mode', 'high'))
next_idx = (current_idx + 1) % len(modes)
self.display_settings['resolution_mode'] = modes[next_idx]
st.info(f"分辨率模式: {modes[next_idx]}")
# Similarity UI
st.markdown("---")
st.markdown("### 📊 动作相似度分析")
sim_col1, sim_col2, sim_col3 = st.columns([1, 1, 2])
similarity_score_placeholder = sim_col1.empty()
avg_score_placeholder = sim_col2.empty()
similarity_plot_placeholder = sim_col3.empty()
# 处理按钮点击
if stop_button:
st.session_state.comparison_state['should_stop'] = True
if restart_button:
st.session_state.comparison_state['should_restart'] = True
# 状态显示区域
status_container = st.container()
with status_container:
progress_bar = st.progress(0)
status_text = st.empty()
# 主循环
video_frame_idx = 0
start_time = time.time()
current_similarity = 0
last_plot_update = 0
# Start Audio
if audio_loaded: self.audio_player.play()
while (st.session_state.comparison_state['is_running'] and
not st.session_state.comparison_state['should_stop']):
loop_start = time.time()
self.frame_counter += 1
# 检查重新开始
if st.session_state.comparison_state['should_restart']:
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
self.similarity_analyzer.reset()
start_time = time.time()
video_frame_idx = 0
if audio_loaded:
self.audio_player.restart()
st.session_state.comparison_state['should_restart'] = False
continue
# 读取标准视频的当前帧
ret_standard, standard_frame = self.standard_cap.read()
if not ret_standard:
# 视频结束,重新开始
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
video_frame_idx = 0
self.similarity_analyzer.reset()
start_time = time.time()
if audio_loaded:
self.audio_player.restart()
continue
# 读取摄像头当前帧
ret_webcam, webcam_frame = self.read_camera_frame()
if not ret_webcam or webcam_frame is None:
self.error_count += 1
if self.error_count > 100: # 连续错误过多时停止
st.error("摄像头连接出现问题,停止比较")
break
continue
self.error_count = 0 # 重置错误计数
# 翻转摄像头画面(镜像效果)
webcam_frame = cv2.flip(webcam_frame, 1)
# 调整尺寸使两个视频大小一致(使用更高分辨率)
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
# 处理关键点检测
try:
standard_keypoints, standard_scores = self.body_detector(standard_frame)
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
except Exception as e:
# 关键点检测失败时继续显示原始图像
standard_keypoints, standard_scores = None, None
webcam_keypoints, webcam_scores = None, None
# 计算相似度每5帧计算一次以提高性能
if (self.frame_counter % 5 == 0 and
standard_keypoints is not None and
webcam_keypoints is not None):
try:
# 提取当前标准视频帧和摄像头帧的关节角度
standard_angles = self.similarity_analyzer.extract_joint_angles(
standard_keypoints, standard_scores
)
webcam_angles = self.similarity_analyzer.extract_joint_angles(
webcam_keypoints, webcam_scores
)
# 计算相似度
if standard_angles and webcam_angles:
current_similarity = self.similarity_analyzer.calculate_similarity(
standard_angles, webcam_angles
)
# 添加到历史记录
timestamp = time.time() - start_time
self.similarity_analyzer.add_similarity_score(current_similarity, timestamp)
except Exception as e:
pass # 忽略相似度计算错误
# 绘制关键点
try:
if standard_keypoints is not None and standard_scores is not None:
standard_with_keypoints = draw_skeleton(
standard_frame.copy(),
standard_keypoints,
standard_scores,
openpose_skeleton=True,
kpt_thr=0.43
)
else:
standard_with_keypoints = standard_frame.copy()
if webcam_keypoints is not None and webcam_scores is not None:
webcam_with_keypoints = draw_skeleton(
webcam_frame.copy(),
webcam_keypoints,
webcam_scores,
openpose_skeleton=True,
kpt_thr=0.43
)
else:
webcam_with_keypoints = webcam_frame.copy()
except Exception as e:
# 如果绘制失败,使用原始帧
standard_with_keypoints = standard_frame.copy()
webcam_with_keypoints = webcam_frame.copy()
# 转换颜色空间 (BGR to RGB)
standard_rgb = cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB)
webcam_rgb = cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB)
# 添加帧信息到图像上
current_time = time.time() - start_time
frame_info = f"时间: {current_time:.1f}s | 帧: {video_frame_idx}/{total_frames}"
audio_info = f" | 音频: {'🔊' if self.audio_player.is_playing else '🔇'}" if audio_loaded else ""
resolution_info = f" | {target_width}x{target_height}"
# 显示大尺寸画面
with video_col1:
standard_placeholder.image(
standard_rgb,
caption=f"标准动作 - {frame_info}{audio_info}{resolution_info}",
use_container_width=True
)
with video_col2:
webcam_placeholder.image(
webcam_rgb,
caption=f"您的动作 - 实时画面{resolution_info}",
use_container_width=True
)
# 显示相似度信息(紧凑显示)
if len(self.similarity_analyzer.similarity_history) > 0:
try:
avg_similarity = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
# 使用不同颜色显示相似度
if current_similarity >= 80:
similarity_color = "🟢"
level = "优秀"
elif current_similarity >= 60:
similarity_color = "🟡"
level = "良好"
else:
similarity_color = "🔴"
level = "需要改进"
with sim_col1:
similarity_score_placeholder.metric(
"当前相似度",
f"{similarity_color} {current_similarity:.1f}%",
delta=f"{level}"
)
with sim_col2:
avg_score_placeholder.metric(
"平均相似度",
f"{avg_similarity:.1f}%"
)
# 更新相似度图表每20帧更新一次以提高性能
if (len(self.similarity_analyzer.similarity_history) >= 2 and
self.frame_counter - last_plot_update >= 20):
try:
similarity_plot = self.similarity_analyzer.get_similarity_plot()
if similarity_plot:
with sim_col3:
similarity_plot_placeholder.plotly_chart(
similarity_plot,
use_container_width=True,
key=f"similarity_plot_{int(time.time() * 1000)}" # 使用时间戳避免重复ID
)
last_plot_update = self.frame_counter
except Exception as e:
pass # 忽略图表更新错误
except Exception as e:
pass # 忽略显示错误
# 更新进度和状态(紧凑显示)
try:
progress = min(video_frame_idx / total_frames, 1.0) if total_frames > 0 else 0
progress_bar.progress(progress)
elapsed_time = time.time() - start_time
fps_actual = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
status_text.text(
f"进度: {video_frame_idx}/{total_frames} | "
f"实际FPS: {fps_actual:.1f} | "
f"分辨率: {target_width}x{target_height} | "
f"模式: {self.display_settings['resolution_mode']}"
)
except Exception as e:
pass # 忽略状态更新错误
video_frame_idx += 1
# 精确的帧率控制
loop_elapsed = time.time() - loop_start
sleep_time = max(0, frame_delay - loop_elapsed)
if sleep_time > 0:
time.sleep(sleep_time)
# 强制更新UI每30帧一次
if self.frame_counter % 30 == 0:
st.empty() # 触发UI更新
self.cleanup()
self.show_final_statistics()