posedet/motion_app.py
2025-06-22 12:11:30 +08:00

323 lines
15 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import streamlit as st
import cv2
import time
import os
import numpy as np
import torch
from rtmlib import Body, draw_skeleton
from audio_player import AudioPlayer
from pose_analyzer import PoseSimilarityAnalyzer
from config import REALSENSE_AVAILABLE
if REALSENSE_AVAILABLE:
import pyrealsense2 as rs
class MotionComparisonApp:
"""Main application class for motion comparison."""
def __init__(self):
self.body_detector = None
self.is_running = False
self.standard_video_path = None
self.webcam_cap = None
self.standard_cap = None
self.similarity_analyzer = PoseSimilarityAnalyzer()
self.frame_counter = 0
self.audio_player = AudioPlayer()
self.display_settings = {'resolution_mode': 'high', 'target_width': 960, 'target_height': 720}
self.realsense_pipeline = None
self.is_realsense_active = False
self.last_error_time = 0
self.error_count = 0
if 'comparison_state' not in st.session_state:
st.session_state.comparison_state = {'is_running': False, 'should_stop': False, 'should_restart': False}
def get_display_resolution(self):
modes = {'high': (1024, 576), 'medium': (960, 720), 'low': (640, 480)}
mode = self.display_settings.get('resolution_mode', 'medium')
return modes.get(mode, (960, 720))
def initialize_detector(self):
if self.body_detector is None:
try:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.body_detector = Body(mode='lightweight', to_openpose=True, backend='onnxruntime', device=device)
st.success(f"Keypoint detector initialized on device: {device}")
return True
except Exception as e:
st.error(f"Detector initialization failed: {e}")
return False
return True
def initialize_camera(self):
if REALSENSE_AVAILABLE:
try:
self.realsense_pipeline = rs.pipeline()
config = rs.config()
width, height = self.get_display_resolution()
config.enable_stream(rs.stream.color, width, height, rs.format.bgr8, 30)
profile = self.realsense_pipeline.start(config)
device = profile.get_device().get_info(rs.camera_info.name)
st.success(f"✅ RealSense摄像头初始化成功: {device} ({width}x{height})")
self.is_realsense_active = True
return True
except Exception as e:
st.warning(f"RealSense初始化失败: {e}. 切换到USB摄像头.")
return self._initialize_webcam()
else:
return self._initialize_webcam()
def _initialize_webcam(self):
try:
self.webcam_cap = cv2.VideoCapture(0)
if self.webcam_cap.isOpened():
self.webcam_cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'))
width, height = self.get_display_resolution()
self.webcam_cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
self.webcam_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
self.webcam_cap.set(cv2.CAP_PROP_FPS, 30)
actual_w = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
actual_h = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
st.success(f"✅ USB摄像头初始化成功 ({actual_w}x{actual_h})")
return True
else:
st.error("❌ 无法打开USB摄像头")
return False
except Exception as e:
st.error(f"❌ USB摄像头初始化失败: {e}")
return False
def read_camera_frame(self):
if self.is_realsense_active and self.realsense_pipeline:
try:
frames = self.realsense_pipeline.wait_for_frames(timeout_ms=1000)
color_frame = frames.get_color_frame()
if not color_frame: return False, None
return True, np.asanyarray(color_frame.get_data())
except Exception:
return False, None
elif self.webcam_cap and self.webcam_cap.isOpened():
return self.webcam_cap.read()
return False, None
def get_camera_preview_frame(self):
ret, frame = self.read_camera_frame()
if not ret or frame is None: return None
frame = cv2.flip(frame, 1)
if self.body_detector:
try:
keypoints, scores = self.body_detector(frame)
frame = draw_skeleton(frame.copy(), keypoints, scores, openpose_skeleton=True, kpt_thr=0.43)
except Exception: pass
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
def cleanup(self):
"""Cleans up all resources."""
if self.standard_cap: self.standard_cap.release()
if self.webcam_cap: self.webcam_cap.release()
if self.is_realsense_active and self.realsense_pipeline: self.realsense_pipeline.stop()
self.audio_player.cleanup()
self.is_running = False
st.session_state.comparison_state['is_running'] = False
def show_final_statistics(self):
"""Displays final statistics after the comparison ends."""
st.markdown("---") # 添加分隔线
history = self.similarity_analyzer.similarity_history
if not history:
st.warning("没有足够的比较数据来生成分析结果。")
if st.button("返回主菜单", use_container_width=True):
st.rerun()
return
final_avg = sum(history) / len(history)
# 我们将评级文本和颜色分开定义,逻辑更清晰
if final_avg >= 80:
level = "非常棒! 👏"
md_color = "green" # 使用 'green' 颜色
elif final_avg >= 60:
level = "整体不错! 👍"
md_color = "blue" # 使用 'blue' 颜色
else:
level = "需要改进! 💪"
md_color = "orange" # 使用 'orange' 颜色替代 'warning'
st.success("🎉 比较完成!")
st.markdown(f"**整体表现**: :{md_color}[{level}]")
col1, col2, col3 = st.columns(3)
col1.metric("平均相似度", f"{final_avg:.1f}%")
col2.metric("最高相似度", f"{max(history):.1f}%")
col3.metric("最低相似度", f"{min(history):.1f}%")
if final_avg < 60:
with st.expander("💡 改善建议"):
st.markdown("- 确保您的全身在摄像头画面中清晰可见\n"
"- 尽量匹配标准视频的节奏和动作幅度\n"
"- 确保光线充足且稳定")
# 添加返回按钮
st.markdown("---")
if st.button("返回主菜单", use_container_width=True):
st.session_state.comparison_state['is_running'] = False # 确保状态重置
st.rerun()
def start_comparison(self, video_path):
"""The main loop for comparing motion, synchronized with real-time."""
self.is_running = True
st.session_state.comparison_state.update({'is_running': True, 'should_stop': False, 'should_restart': False})
self.standard_video_path = video_path
self.frame_counter = 0
self.similarity_analyzer.reset()
audio_loaded = self.audio_player.load_audio(video_path)
if audio_loaded: st.success("✅ 音频加载成功")
else: st.info(" 将不会播放音频")
self.standard_cap = cv2.VideoCapture(video_path)
if not self.standard_cap.isOpened():
st.error("无法打开标准视频文件")
return
if not self.is_realsense_active and (not self.webcam_cap or not self.webcam_cap.isOpened()):
if not self.initialize_camera(): return
# --- 获取视频信息和设置变量 ---
total_frames = int(self.standard_cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
if video_fps == 0: video_fps = 30 # 防止除零错误
video_duration = total_frames / video_fps
target_width, target_height = self.get_display_resolution()
# --- UI 占位符 ---
st.markdown("### 📺 视频比较")
video_col1, video_col2 = st.columns(2, gap="small")
standard_placeholder = video_col1.empty()
webcam_placeholder = video_col2.empty()
with video_col1: st.markdown("#### 🎯 标准动作视频")
with video_col2: st.markdown(f"#### 📹 {'RealSense' if self.is_realsense_active else 'USB'}实时影像")
st.markdown("---")
control_col1, control_col2 = st.columns(2)
stop_button = control_col1.button("⏹️ 停止", use_container_width=True, key="stop_comparison")
restart_button = control_col2.button("🔄 重新开始", use_container_width=True, key="restart_comparison")
st.markdown("---")
st.markdown("### 📊 动作相似度分析")
sim_col1, sim_col2, sim_col3 = st.columns([1, 1, 2])
similarity_score_placeholder = sim_col1.empty()
avg_score_placeholder = sim_col2.empty()
similarity_plot_placeholder = sim_col3.empty()
progress_bar = st.progress(0)
status_text = st.empty()
# --- 按钮逻辑 ---
if stop_button: st.session_state.comparison_state['should_stop'] = True
if restart_button: st.session_state.comparison_state['should_restart'] = True
# --- 主循环 (基于时间同步) ---
start_time = time.time()
current_similarity = 0
last_plot_update = 0
if audio_loaded: self.audio_player.play()
while True:
# --- 检查退出条件 ---
elapsed_time = time.time() - start_time
if elapsed_time >= video_duration or st.session_state.comparison_state['should_stop']:
break
# --- 检查重新开始 ---
if st.session_state.comparison_state['should_restart']:
self.similarity_analyzer.reset()
start_time = time.time() # 重置计时器
if audio_loaded: self.audio_player.restart()
st.session_state.comparison_state['should_restart'] = False
continue
# --- 基于真实时间获取标准视频帧 ---
target_frame_idx = int(elapsed_time * video_fps)
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_idx)
ret_standard, standard_frame = self.standard_cap.read()
if not ret_standard: continue # 如果读取失败,跳过这一轮
# --- 读取当前摄像头帧 ---
ret_webcam, webcam_frame = self.read_camera_frame()
if not ret_webcam or webcam_frame is None:
status_text.warning("摄像头画面读取失败,请检查连接...")
time.sleep(0.1) # 短暂等待
continue
self.frame_counter += 1
webcam_frame = cv2.flip(webcam_frame, 1)
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
# --- 关键点检测与相似度计算 (可以保持原来的逻辑) ---
try:
standard_keypoints, standard_scores = self.body_detector(standard_frame)
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
if standard_keypoints is not None and webcam_keypoints is not None:
standard_angles = self.similarity_analyzer.extract_joint_angles(standard_keypoints, standard_scores)
webcam_angles = self.similarity_analyzer.extract_joint_angles(webcam_keypoints, webcam_scores)
if standard_angles and webcam_angles:
current_similarity = self.similarity_analyzer.calculate_similarity(standard_angles, webcam_angles)
self.similarity_analyzer.add_similarity_score(current_similarity, elapsed_time)
except Exception:
standard_keypoints, webcam_keypoints = None, None
# --- 绘制与显示 (可以保持原来的逻辑只修改UI更新部分) ---
try:
# 绘制骨骼
standard_with_keypoints = draw_skeleton(standard_frame.copy(), standard_keypoints, standard_scores, openpose_skeleton=True, kpt_thr=0.43)
webcam_with_keypoints = draw_skeleton(webcam_frame.copy(), webcam_keypoints, webcam_scores, openpose_skeleton=True, kpt_thr=0.43)
# 更新视频画面
standard_placeholder.image(cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
webcam_placeholder.image(cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
# 更新状态信息
if self.similarity_analyzer.similarity_history:
avg_sim = np.mean(self.similarity_analyzer.similarity_history)
similarity_score_placeholder.metric("当前相似度", f"{current_similarity:.1f}%")
avg_score_placeholder.metric("平均相似度", f"{avg_sim:.1f}%")
# 更新图表 (节流以提高性能)
if self.frame_counter - last_plot_update > 5: # 每处理约5帧更新一次图表
plot = self.similarity_analyzer.get_similarity_plot()
if plot: similarity_plot_placeholder.plotly_chart(plot, use_container_width=True)
last_plot_update = self.frame_counter
# 更新进度条和状态文本
progress = min(elapsed_time / video_duration, 1.0)
progress_bar.progress(progress)
processing_fps = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
status_text.text(f"时间: {elapsed_time:.1f}s / {video_duration:.1f}s | "
f"处理帧率: {processing_fps:.1f} FPS | "
f"标准视频帧: {target_frame_idx}/{total_frames}")
except Exception as e:
# 如果UI更新失败避免程序崩溃
pass
# 这个循环没有 time.sleep(),它会尽力运行。
# 同步是通过每次循环都根据真实时间去标准视频中定位帧来实现的。
# --- 循环结束后的清理工作 ---
self.cleanup()
self.show_final_statistics()