460 lines
20 KiB
Python
460 lines
20 KiB
Python
import streamlit as st
|
||
import cv2
|
||
import time
|
||
import os
|
||
import numpy as np
|
||
import torch
|
||
from rtmlib import Body, draw_skeleton
|
||
|
||
from audio_player import AudioPlayer
|
||
from pose_analyzer import PoseSimilarityAnalyzer
|
||
from config import REALSENSE_AVAILABLE
|
||
|
||
if REALSENSE_AVAILABLE:
|
||
import pyrealsense2 as rs
|
||
|
||
class MotionComparisonApp:
|
||
"""Main application class for motion comparison."""
|
||
|
||
def __init__(self):
|
||
self.body_detector = None
|
||
self.is_running = False
|
||
self.standard_video_path = None
|
||
self.webcam_cap = None
|
||
self.standard_cap = None
|
||
self.similarity_analyzer = PoseSimilarityAnalyzer()
|
||
self.frame_counter = 0
|
||
self.audio_player = AudioPlayer()
|
||
|
||
self.display_settings = {'resolution_mode': 'high', 'target_width': 960, 'target_height': 720}
|
||
self.realsense_pipeline = None
|
||
self.is_realsense_active = False
|
||
self.last_error_time = 0
|
||
self.error_count = 0
|
||
|
||
if 'comparison_state' not in st.session_state:
|
||
st.session_state.comparison_state = {'is_running': False, 'should_stop': False, 'should_restart': False}
|
||
|
||
def get_display_resolution(self):
|
||
modes = {'high': (1280, 800), 'medium': (960, 720), 'low': (640, 480)}
|
||
mode = self.display_settings.get('resolution_mode', 'medium')
|
||
return modes.get(mode, (960, 720))
|
||
|
||
def initialize_detector(self):
|
||
if self.body_detector is None:
|
||
try:
|
||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
self.body_detector = Body(mode='lightweight', to_openpose=True, backend='onnxruntime', device=device)
|
||
st.success(f"Keypoint detector initialized on device: {device}")
|
||
return True
|
||
except Exception as e:
|
||
st.error(f"Detector initialization failed: {e}")
|
||
return False
|
||
return True
|
||
|
||
def initialize_camera(self):
|
||
if REALSENSE_AVAILABLE:
|
||
try:
|
||
self.realsense_pipeline = rs.pipeline()
|
||
config = rs.config()
|
||
width, height = self.get_display_resolution()
|
||
config.enable_stream(rs.stream.color, width, height, rs.format.bgr8, 30)
|
||
profile = self.realsense_pipeline.start(config)
|
||
device = profile.get_device().get_info(rs.camera_info.name)
|
||
st.success(f"✅ RealSense摄像头初始化成功: {device} ({width}x{height})")
|
||
self.is_realsense_active = True
|
||
return True
|
||
except Exception as e:
|
||
st.warning(f"RealSense初始化失败: {e}. 切换到USB摄像头.")
|
||
return self._initialize_webcam()
|
||
else:
|
||
return self._initialize_webcam()
|
||
|
||
def _initialize_webcam(self):
|
||
try:
|
||
self.webcam_cap = cv2.VideoCapture(0)
|
||
if self.webcam_cap.isOpened():
|
||
width, height = self.get_display_resolution()
|
||
self.webcam_cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
|
||
self.webcam_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
||
self.webcam_cap.set(cv2.CAP_PROP_FPS, 30)
|
||
actual_w = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
actual_h = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
st.success(f"✅ USB摄像头初始化成功 ({actual_w}x{actual_h})")
|
||
return True
|
||
else:
|
||
st.error("❌ 无法打开USB摄像头")
|
||
return False
|
||
except Exception as e:
|
||
st.error(f"❌ USB摄像头初始化失败: {e}")
|
||
return False
|
||
|
||
def read_camera_frame(self):
|
||
if self.is_realsense_active and self.realsense_pipeline:
|
||
try:
|
||
frames = self.realsense_pipeline.wait_for_frames(timeout_ms=1000)
|
||
color_frame = frames.get_color_frame()
|
||
if not color_frame: return False, None
|
||
return True, np.asanyarray(color_frame.get_data())
|
||
except Exception:
|
||
return False, None
|
||
elif self.webcam_cap and self.webcam_cap.isOpened():
|
||
return self.webcam_cap.read()
|
||
return False, None
|
||
|
||
def get_camera_preview_frame(self):
|
||
ret, frame = self.read_camera_frame()
|
||
if not ret or frame is None: return None
|
||
|
||
frame = cv2.flip(frame, 1)
|
||
if self.body_detector:
|
||
try:
|
||
keypoints, scores = self.body_detector(frame)
|
||
frame = draw_skeleton(frame.copy(), keypoints, scores, openpose_skeleton=True, kpt_thr=0.43)
|
||
except Exception: pass
|
||
|
||
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
|
||
def cleanup(self):
|
||
"""Cleans up all resources."""
|
||
if self.standard_cap: self.standard_cap.release()
|
||
if self.webcam_cap: self.webcam_cap.release()
|
||
if self.is_realsense_active and self.realsense_pipeline: self.realsense_pipeline.stop()
|
||
self.audio_player.cleanup()
|
||
self.is_running = False
|
||
st.session_state.comparison_state['is_running'] = False
|
||
|
||
def show_final_statistics(self):
|
||
"""Displays final statistics after the comparison ends."""
|
||
history = self.similarity_analyzer.similarity_history
|
||
if not history: return
|
||
|
||
final_avg = sum(history) / len(history)
|
||
level, color = ("非常棒! 👏", "success") if final_avg >= 80 else \
|
||
("整体不错! 👍", "info") if final_avg >= 60 else \
|
||
("需要改进! 💪", "warning")
|
||
|
||
st.success("🎉 比较完成!")
|
||
st.markdown(f"**整体表现**: :{color}[{level}]")
|
||
|
||
col1, col2, col3 = st.columns(3)
|
||
col1.metric("平均相似度", f"{final_avg:.1f}%")
|
||
col2.metric("最高相似度", f"{max(history):.1f}%")
|
||
col3.metric("最低相似度", f"{min(history):.1f}%")
|
||
|
||
if final_avg < 60:
|
||
with st.expander("💡 改善建议"):
|
||
st.markdown("- 确保您的全身在摄像头画面中清晰可见\n"
|
||
"- 尽量匹配标准视频的节奏和动作幅度\n"
|
||
"- 确保光线充足且稳定")
|
||
|
||
def start_comparison(self, video_path):
|
||
"""The main loop for comparing motion."""
|
||
# Setup and initialization... (abbreviated for clarity, logic is the same as original)
|
||
self.is_running = True
|
||
st.session_state.comparison_state.update({'is_running': True, 'should_stop': False, 'should_restart': False})
|
||
|
||
self.standard_video_path = video_path
|
||
self.frame_counter = 0
|
||
self.similarity_analyzer.reset()
|
||
|
||
audio_loaded = self.audio_player.load_audio(video_path)
|
||
if audio_loaded: st.success("✅ 音频加载成功")
|
||
else: st.info("ℹ️ 将不会播放音频")
|
||
|
||
self.standard_cap = cv2.VideoCapture(video_path)
|
||
if not self.standard_cap.isOpened():
|
||
st.error("无法打开标准视频文件")
|
||
return
|
||
|
||
if not self.is_realsense_active and (not self.webcam_cap or not self.webcam_cap.isOpened()):
|
||
if not self.initialize_camera(): return
|
||
|
||
# Get video info and setup variables
|
||
total_frames = int(self.standard_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
video_fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
|
||
frame_delay = 1.0 / video_fps if video_fps > 0 else 1.0 / 30.0
|
||
target_width, target_height = self.get_display_resolution()
|
||
|
||
# UI Placeholders
|
||
st.markdown("### 📺 视频比较")
|
||
video_col1, video_col2 = st.columns(2, gap="small")
|
||
standard_placeholder = video_col1.empty()
|
||
webcam_placeholder = video_col2.empty()
|
||
|
||
with video_col1:
|
||
st.markdown("#### 🎯 标准动作视频")
|
||
|
||
with video_col2:
|
||
camera_type = "RealSense摄像头" if self.is_realsense_active else "USB摄像头"
|
||
st.markdown(f"#### 📹 {camera_type}实时影像")
|
||
|
||
# 创建控制按钮区域(紧凑布局)
|
||
st.markdown("---")
|
||
control_container = st.container()
|
||
with control_container:
|
||
control_col1, control_col2, control_col3, control_col4 = st.columns([1, 1, 1, 1])
|
||
|
||
with control_col1:
|
||
stop_button = st.button("⏹️ 停止", use_container_width=True, key="stop_comparison")
|
||
with control_col2:
|
||
restart_button = st.button("🔄 重新开始", use_container_width=True, key="restart_comparison")
|
||
with control_col3:
|
||
if audio_loaded:
|
||
if st.button("🔊 音频状态", use_container_width=True, key="audio_status"):
|
||
st.info(f"音频: {'播放中' if self.audio_player.is_playing else '已停止'}")
|
||
with control_col4:
|
||
# 分辨率切换按钮
|
||
if st.button("📐 切换分辨率", use_container_width=True, key="resolution_toggle"):
|
||
modes = ['high', 'medium', 'low']
|
||
current_idx = modes.index(self.display_settings.get('resolution_mode', 'high'))
|
||
next_idx = (current_idx + 1) % len(modes)
|
||
self.display_settings['resolution_mode'] = modes[next_idx]
|
||
st.info(f"分辨率模式: {modes[next_idx]}")
|
||
|
||
# Similarity UI
|
||
st.markdown("---")
|
||
st.markdown("### 📊 动作相似度分析")
|
||
sim_col1, sim_col2, sim_col3 = st.columns([1, 1, 2])
|
||
similarity_score_placeholder = sim_col1.empty()
|
||
avg_score_placeholder = sim_col2.empty()
|
||
similarity_plot_placeholder = sim_col3.empty()
|
||
|
||
# 处理按钮点击
|
||
if stop_button:
|
||
st.session_state.comparison_state['should_stop'] = True
|
||
if restart_button:
|
||
st.session_state.comparison_state['should_restart'] = True
|
||
|
||
# 状态显示区域
|
||
status_container = st.container()
|
||
with status_container:
|
||
progress_bar = st.progress(0)
|
||
status_text = st.empty()
|
||
|
||
|
||
# 主循环
|
||
video_frame_idx = 0
|
||
start_time = time.time()
|
||
current_similarity = 0
|
||
last_plot_update = 0
|
||
|
||
# Start Audio
|
||
if audio_loaded: self.audio_player.play()
|
||
|
||
|
||
while (st.session_state.comparison_state['is_running'] and
|
||
not st.session_state.comparison_state['should_stop']):
|
||
|
||
loop_start = time.time()
|
||
self.frame_counter += 1
|
||
|
||
# 检查重新开始
|
||
if st.session_state.comparison_state['should_restart']:
|
||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||
self.similarity_analyzer.reset()
|
||
start_time = time.time()
|
||
video_frame_idx = 0
|
||
if audio_loaded:
|
||
self.audio_player.restart()
|
||
st.session_state.comparison_state['should_restart'] = False
|
||
continue
|
||
|
||
# 读取标准视频的当前帧
|
||
ret_standard, standard_frame = self.standard_cap.read()
|
||
if not ret_standard:
|
||
# 视频结束,重新开始
|
||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||
video_frame_idx = 0
|
||
self.similarity_analyzer.reset()
|
||
start_time = time.time()
|
||
if audio_loaded:
|
||
self.audio_player.restart()
|
||
continue
|
||
|
||
# 读取摄像头当前帧
|
||
ret_webcam, webcam_frame = self.read_camera_frame()
|
||
if not ret_webcam or webcam_frame is None:
|
||
self.error_count += 1
|
||
if self.error_count > 100: # 连续错误过多时停止
|
||
st.error("摄像头连接出现问题,停止比较")
|
||
break
|
||
continue
|
||
|
||
self.error_count = 0 # 重置错误计数
|
||
|
||
# 翻转摄像头画面(镜像效果)
|
||
webcam_frame = cv2.flip(webcam_frame, 1)
|
||
|
||
# 调整尺寸使两个视频大小一致(使用更高分辨率)
|
||
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
|
||
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
|
||
|
||
# 处理关键点检测
|
||
try:
|
||
standard_keypoints, standard_scores = self.body_detector(standard_frame)
|
||
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
|
||
except Exception as e:
|
||
# 关键点检测失败时继续显示原始图像
|
||
standard_keypoints, standard_scores = None, None
|
||
webcam_keypoints, webcam_scores = None, None
|
||
|
||
# 计算相似度(每5帧计算一次以提高性能)
|
||
if (self.frame_counter % 5 == 0 and
|
||
standard_keypoints is not None and
|
||
webcam_keypoints is not None):
|
||
try:
|
||
# 提取当前标准视频帧和摄像头帧的关节角度
|
||
standard_angles = self.similarity_analyzer.extract_joint_angles(
|
||
standard_keypoints, standard_scores
|
||
)
|
||
webcam_angles = self.similarity_analyzer.extract_joint_angles(
|
||
webcam_keypoints, webcam_scores
|
||
)
|
||
|
||
# 计算相似度
|
||
if standard_angles and webcam_angles:
|
||
current_similarity = self.similarity_analyzer.calculate_similarity(
|
||
standard_angles, webcam_angles
|
||
)
|
||
|
||
# 添加到历史记录
|
||
timestamp = time.time() - start_time
|
||
self.similarity_analyzer.add_similarity_score(current_similarity, timestamp)
|
||
except Exception as e:
|
||
pass # 忽略相似度计算错误
|
||
|
||
# 绘制关键点
|
||
try:
|
||
if standard_keypoints is not None and standard_scores is not None:
|
||
standard_with_keypoints = draw_skeleton(
|
||
standard_frame.copy(),
|
||
standard_keypoints,
|
||
standard_scores,
|
||
openpose_skeleton=True,
|
||
kpt_thr=0.43
|
||
)
|
||
else:
|
||
standard_with_keypoints = standard_frame.copy()
|
||
|
||
if webcam_keypoints is not None and webcam_scores is not None:
|
||
webcam_with_keypoints = draw_skeleton(
|
||
webcam_frame.copy(),
|
||
webcam_keypoints,
|
||
webcam_scores,
|
||
openpose_skeleton=True,
|
||
kpt_thr=0.43
|
||
)
|
||
else:
|
||
webcam_with_keypoints = webcam_frame.copy()
|
||
|
||
except Exception as e:
|
||
# 如果绘制失败,使用原始帧
|
||
standard_with_keypoints = standard_frame.copy()
|
||
webcam_with_keypoints = webcam_frame.copy()
|
||
|
||
# 转换颜色空间 (BGR to RGB)
|
||
standard_rgb = cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB)
|
||
webcam_rgb = cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB)
|
||
|
||
# 添加帧信息到图像上
|
||
current_time = time.time() - start_time
|
||
frame_info = f"时间: {current_time:.1f}s | 帧: {video_frame_idx}/{total_frames}"
|
||
audio_info = f" | 音频: {'🔊' if self.audio_player.is_playing else '🔇'}" if audio_loaded else ""
|
||
resolution_info = f" | {target_width}x{target_height}"
|
||
|
||
# 显示大尺寸画面
|
||
with video_col1:
|
||
standard_placeholder.image(
|
||
standard_rgb,
|
||
caption=f"标准动作 - {frame_info}{audio_info}{resolution_info}",
|
||
use_container_width=True
|
||
)
|
||
|
||
with video_col2:
|
||
webcam_placeholder.image(
|
||
webcam_rgb,
|
||
caption=f"您的动作 - 实时画面{resolution_info}",
|
||
use_container_width=True
|
||
)
|
||
|
||
# 显示相似度信息(紧凑显示)
|
||
if len(self.similarity_analyzer.similarity_history) > 0:
|
||
try:
|
||
avg_similarity = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
|
||
|
||
# 使用不同颜色显示相似度
|
||
if current_similarity >= 80:
|
||
similarity_color = "🟢"
|
||
level = "优秀"
|
||
elif current_similarity >= 60:
|
||
similarity_color = "🟡"
|
||
level = "良好"
|
||
else:
|
||
similarity_color = "🔴"
|
||
level = "需要改进"
|
||
|
||
with sim_col1:
|
||
similarity_score_placeholder.metric(
|
||
"当前相似度",
|
||
f"{similarity_color} {current_similarity:.1f}%",
|
||
delta=f"{level}"
|
||
)
|
||
|
||
with sim_col2:
|
||
avg_score_placeholder.metric(
|
||
"平均相似度",
|
||
f"{avg_similarity:.1f}%"
|
||
)
|
||
|
||
# 更新相似度图表(每20帧更新一次以提高性能)
|
||
if (len(self.similarity_analyzer.similarity_history) >= 2 and
|
||
self.frame_counter - last_plot_update >= 20):
|
||
try:
|
||
similarity_plot = self.similarity_analyzer.get_similarity_plot()
|
||
if similarity_plot:
|
||
with sim_col3:
|
||
similarity_plot_placeholder.plotly_chart(
|
||
similarity_plot,
|
||
use_container_width=True,
|
||
key=f"similarity_plot_{int(time.time() * 1000)}" # 使用时间戳避免重复ID
|
||
)
|
||
last_plot_update = self.frame_counter
|
||
except Exception as e:
|
||
pass # 忽略图表更新错误
|
||
except Exception as e:
|
||
pass # 忽略显示错误
|
||
|
||
# 更新进度和状态(紧凑显示)
|
||
try:
|
||
progress = min(video_frame_idx / total_frames, 1.0) if total_frames > 0 else 0
|
||
progress_bar.progress(progress)
|
||
|
||
elapsed_time = time.time() - start_time
|
||
fps_actual = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
|
||
|
||
status_text.text(
|
||
f"进度: {video_frame_idx}/{total_frames} | "
|
||
f"实际FPS: {fps_actual:.1f} | "
|
||
f"分辨率: {target_width}x{target_height} | "
|
||
f"模式: {self.display_settings['resolution_mode']}"
|
||
)
|
||
except Exception as e:
|
||
pass # 忽略状态更新错误
|
||
|
||
video_frame_idx += 1
|
||
|
||
# 精确的帧率控制
|
||
loop_elapsed = time.time() - loop_start
|
||
sleep_time = max(0, frame_delay - loop_elapsed)
|
||
if sleep_time > 0:
|
||
time.sleep(sleep_time)
|
||
|
||
# 强制更新UI(每30帧一次)
|
||
if self.frame_counter % 30 == 0:
|
||
st.empty() # 触发UI更新
|
||
|
||
|
||
self.cleanup()
|
||
self.show_final_statistics()
|