483 lines
21 KiB
Python
483 lines
21 KiB
Python
import streamlit as st
|
||
import cv2
|
||
import time
|
||
import os
|
||
import numpy as np
|
||
import torch
|
||
import threading
|
||
import queue
|
||
from rtmlib import Body, draw_skeleton
|
||
|
||
from audio_player import AudioPlayer
|
||
from pose_analyzer import PoseSimilarityAnalyzer
|
||
from config import REALSENSE_AVAILABLE
|
||
|
||
def draw_skeleton_with_similarity(img, keypoints, scores, joint_similarities=None, openpose_skeleton=True, kpt_thr=0.43, line_width=2):
|
||
"""
|
||
自定义骨骼绘制函数,根据关节相似度设置颜色
|
||
相似度 > 90: 绿色
|
||
相似度 80-90: 黄色
|
||
相似度 < 80: 红色
|
||
|
||
支持多人检测,将绘制所有检测到的人体骨骼
|
||
"""
|
||
if keypoints is None or len(keypoints) == 0:
|
||
return img
|
||
|
||
# OpenPose连接关系
|
||
skeleton = [
|
||
[1, 0], [1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7],
|
||
[1, 8], [8, 9], [9, 10], [1, 11], [11, 12], [12, 13],
|
||
[0, 14], [0, 15], [14, 16], [15, 17]
|
||
]
|
||
|
||
# 关节与关节角度的映射
|
||
joint_to_angle_mapping = {
|
||
(2, 3): 'left_shoulder', (3, 4): 'left_elbow',
|
||
(5, 6): 'right_shoulder', (6, 7): 'right_elbow',
|
||
(8, 9): 'left_hip', (9, 10): 'left_knee',
|
||
(11, 12): 'right_hip', (12, 13): 'right_knee'
|
||
}
|
||
|
||
# 骨骼的默认颜色
|
||
default_color = (0, 255, 255) # 黄色
|
||
|
||
# 确保keypoints和scores的形状正确,处理多人情况
|
||
if len(keypoints.shape) == 2:
|
||
keypoints = keypoints[None, :, :]
|
||
scores = scores[None, :, :]
|
||
|
||
# 遍历所有人
|
||
num_instances = keypoints.shape[0]
|
||
for person_idx in range(num_instances):
|
||
person_kpts = keypoints[person_idx]
|
||
person_scores = scores[person_idx]
|
||
|
||
# 绘制骨骼
|
||
for limb_id, limb in enumerate(skeleton):
|
||
joint_a, joint_b = limb
|
||
|
||
if joint_a >= len(person_scores) or joint_b >= len(person_scores):
|
||
continue
|
||
|
||
if person_scores[joint_a] < kpt_thr or person_scores[joint_b] < kpt_thr:
|
||
continue
|
||
|
||
x_a, y_a = person_kpts[joint_a]
|
||
x_b, y_b = person_kpts[joint_b]
|
||
|
||
# 确定线条颜色
|
||
color = default_color
|
||
if joint_similarities is not None:
|
||
# 检查这个连接是否有对应的关节角度
|
||
if (joint_a, joint_b) in joint_to_angle_mapping:
|
||
angle_name = joint_to_angle_mapping[(joint_a, joint_b)]
|
||
if angle_name in joint_similarities:
|
||
similarity = joint_similarities[angle_name]
|
||
if similarity > 90:
|
||
color = (0, 255, 0) # 绿色
|
||
elif similarity > 80:
|
||
color = (0, 255, 255) # 黄色
|
||
else:
|
||
color = (0, 0, 255) # 红色
|
||
|
||
cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), color, thickness=line_width)
|
||
|
||
# 绘制关键点
|
||
for kpt_id, (x, y) in enumerate(person_kpts):
|
||
if person_scores[kpt_id] < kpt_thr:
|
||
continue
|
||
cv2.circle(img, (int(x), int(y)), 3, (255, 0, 255), -1)
|
||
|
||
return img
|
||
|
||
if REALSENSE_AVAILABLE:
|
||
import pyrealsense2 as rs
|
||
|
||
class MotionComparisonApp:
|
||
"""Main application class for motion comparison."""
|
||
|
||
def __init__(self):
|
||
self.body_detector = None
|
||
self.is_running = False
|
||
self.standard_video_path = None
|
||
self.webcam_cap = None
|
||
self.standard_cap = None
|
||
self.similarity_analyzer = PoseSimilarityAnalyzer()
|
||
self.frame_counter = 0
|
||
self.audio_player = AudioPlayer()
|
||
|
||
self.display_settings = {'resolution_mode': 'high', 'target_width': 960, 'target_height': 720}
|
||
self.realsense_pipeline = None
|
||
self.is_realsense_active = False
|
||
self.last_error_time = 0
|
||
self.error_count = 0
|
||
|
||
# Async processing components
|
||
self.pose_data_queue = queue.Queue(maxsize=50)
|
||
self.similarity_thread = None
|
||
self.similarity_stop_flag = threading.Event()
|
||
|
||
|
||
if 'comparison_state' not in st.session_state:
|
||
st.session_state.comparison_state = {'is_running': False, 'should_stop': False, 'should_restart': False}
|
||
|
||
def get_display_resolution(self):
|
||
modes = {'high': (1024, 576), 'medium': (960, 720), 'low': (640, 480)}
|
||
mode = self.display_settings.get('resolution_mode', 'medium')
|
||
return modes.get(mode, (960, 720))
|
||
|
||
def initialize_detector(self):
|
||
if self.body_detector is None:
|
||
try:
|
||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
self.body_detector = Body(mode='lightweight', to_openpose=True, backend='onnxruntime', device=device)
|
||
st.success(f"Keypoint detector initialized on device: {device}")
|
||
return True
|
||
except Exception as e:
|
||
st.error(f"Detector initialization failed: {e}")
|
||
return False
|
||
return True
|
||
|
||
def initialize_camera(self):
|
||
if REALSENSE_AVAILABLE:
|
||
try:
|
||
self.realsense_pipeline = rs.pipeline()
|
||
config = rs.config()
|
||
width, height = self.get_display_resolution()
|
||
config.enable_stream(rs.stream.color, width, height, rs.format.bgr8, 30)
|
||
profile = self.realsense_pipeline.start(config)
|
||
device = profile.get_device().get_info(rs.camera_info.name)
|
||
st.success(f"✅ RealSense摄像头初始化成功: {device} ({width}x{height})")
|
||
self.is_realsense_active = True
|
||
return True
|
||
except Exception as e:
|
||
st.warning(f"RealSense初始化失败: {e}. 切换到USB摄像头.")
|
||
return self._initialize_webcam()
|
||
else:
|
||
return self._initialize_webcam()
|
||
|
||
def _initialize_webcam(self):
|
||
try:
|
||
self.webcam_cap = cv2.VideoCapture(0)
|
||
if self.webcam_cap.isOpened():
|
||
self.webcam_cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'))
|
||
width, height = self.get_display_resolution()
|
||
self.webcam_cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
|
||
self.webcam_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
||
self.webcam_cap.set(cv2.CAP_PROP_FPS, 30)
|
||
actual_w = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
actual_h = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
st.success(f"✅ USB摄像头初始化成功 ({actual_w}x{actual_h})")
|
||
return True
|
||
else:
|
||
st.error("❌ 无法打开USB摄像头")
|
||
return False
|
||
except Exception as e:
|
||
st.error(f"❌ USB摄像头初始化失败: {e}")
|
||
return False
|
||
|
||
def read_camera_frame(self):
|
||
if self.is_realsense_active and self.realsense_pipeline:
|
||
try:
|
||
frames = self.realsense_pipeline.wait_for_frames(timeout_ms=1000)
|
||
color_frame = frames.get_color_frame()
|
||
if not color_frame: return False, None
|
||
return True, np.asanyarray(color_frame.get_data())
|
||
except Exception:
|
||
return False, None
|
||
elif self.webcam_cap and self.webcam_cap.isOpened():
|
||
return self.webcam_cap.read()
|
||
return False, None
|
||
|
||
def get_camera_preview_frame(self):
|
||
ret, frame = self.read_camera_frame()
|
||
if not ret or frame is None: return None
|
||
|
||
frame = cv2.flip(frame, 1)
|
||
if self.body_detector:
|
||
try:
|
||
keypoints, scores = self.body_detector(frame)
|
||
frame = draw_skeleton_with_similarity(frame.copy(), keypoints, scores, joint_similarities=None, openpose_skeleton=True, kpt_thr=0.43, line_width=1)
|
||
except Exception: pass
|
||
|
||
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
|
||
def cleanup(self):
|
||
"""Cleans up all resources."""
|
||
# Stop similarity calculation thread
|
||
if self.similarity_thread and self.similarity_thread.is_alive():
|
||
self.similarity_stop_flag.set()
|
||
self.similarity_thread.join(timeout=2)
|
||
|
||
|
||
if self.standard_cap: self.standard_cap.release()
|
||
if self.webcam_cap: self.webcam_cap.release()
|
||
if self.is_realsense_active and self.realsense_pipeline: self.realsense_pipeline.stop()
|
||
self.audio_player.cleanup()
|
||
self.is_running = False
|
||
st.session_state.comparison_state['is_running'] = False
|
||
|
||
def show_final_statistics(self):
|
||
"""Displays final statistics after the comparison ends."""
|
||
st.markdown("---") # 添加分隔线
|
||
history = self.similarity_analyzer.similarity_history
|
||
if not history:
|
||
st.warning("没有足够的比较数据来生成分析结果。")
|
||
if st.button("返回主菜单", use_container_width=True):
|
||
st.rerun()
|
||
return
|
||
|
||
final_avg = sum(history) / len(history)
|
||
# 我们将评级文本和颜色分开定义,逻辑更清晰
|
||
if final_avg >= 80:
|
||
level = "非常棒! 👏"
|
||
md_color = "green" # 使用 'green' 颜色
|
||
elif final_avg >= 60:
|
||
level = "整体不错! 👍"
|
||
md_color = "blue" # 使用 'blue' 颜色
|
||
else:
|
||
level = "需要改进! 💪"
|
||
md_color = "orange" # 使用 'orange' 颜色替代 'warning'
|
||
|
||
st.success("🎉 比较完成!")
|
||
st.markdown(f"**整体表现**: :{md_color}[{level}]")
|
||
|
||
col1, col2, col3 = st.columns(3)
|
||
col1.metric("平均相似度", f"{final_avg:.1f}%")
|
||
col2.metric("最高相似度", f"{max(history):.1f}%")
|
||
col3.metric("最低相似度", f"{min(history):.1f}%")
|
||
|
||
if final_avg < 60:
|
||
with st.expander("💡 改善建议"):
|
||
st.markdown("- 确保您的全身在摄像头画面中清晰可见\n"
|
||
"- 尽量匹配标准视频的节奏和动作幅度\n"
|
||
"- 确保光线充足且稳定")
|
||
# 添加返回按钮
|
||
st.markdown("---")
|
||
if st.button("返回主菜单", use_container_width=True):
|
||
st.session_state.comparison_state['is_running'] = False # 确保状态重置
|
||
st.rerun()
|
||
|
||
def _similarity_calculation_worker(self, start_time, video_fps):
|
||
"""Background thread for similarity calculation and plot updates."""
|
||
while not self.similarity_stop_flag.is_set():
|
||
try:
|
||
# Get pose data from queue with timeout
|
||
pose_data = self.pose_data_queue.get(timeout=1.0)
|
||
if pose_data is None: # Poison pill to stop
|
||
break
|
||
|
||
elapsed_time, standard_angles, webcam_angles = pose_data
|
||
|
||
if standard_angles and webcam_angles:
|
||
current_similarity = self.similarity_analyzer.calculate_similarity(standard_angles, webcam_angles)
|
||
self.similarity_analyzer.add_similarity_score(current_similarity, elapsed_time)
|
||
|
||
self.pose_data_queue.task_done()
|
||
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
# Log error but continue processing
|
||
continue
|
||
|
||
|
||
def start_comparison(self, video_path):
|
||
"""The main loop for comparing motion, synchronized with real-time."""
|
||
|
||
self.is_running = True
|
||
st.session_state.comparison_state.update({'is_running': True, 'should_stop': False, 'should_restart': False})
|
||
|
||
self.standard_video_path = video_path
|
||
|
||
self.frame_counter = 0
|
||
self.similarity_analyzer.reset()
|
||
|
||
|
||
audio_loaded = self.audio_player.load_audio(video_path)
|
||
if audio_loaded: st.success("✅ 音频加载成功")
|
||
else: st.info("ℹ️ 将不会播放音频")
|
||
|
||
self.standard_cap = cv2.VideoCapture(video_path)
|
||
if not self.standard_cap.isOpened():
|
||
st.error("无法打开标准视频文件")
|
||
return
|
||
|
||
if not self.is_realsense_active and (not self.webcam_cap or not self.webcam_cap.isOpened()):
|
||
if not self.initialize_camera(): return
|
||
# --- 获取视频信息和设置变量 ---
|
||
total_frames = int(self.standard_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
video_fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
|
||
if video_fps == 0: video_fps = 30 # 防止除零错误
|
||
video_duration = total_frames / video_fps
|
||
target_width, target_height = self.get_display_resolution()
|
||
# --- UI 占位符 ---
|
||
st.markdown("### 📺 视频比较")
|
||
|
||
video_col1, video_col2 = st.columns(2, gap="small")
|
||
standard_placeholder = video_col1.empty()
|
||
webcam_placeholder = video_col2.empty()
|
||
|
||
with video_col1: st.markdown("#### 🎯 标准动作视频")
|
||
|
||
with video_col2: st.markdown(f"#### 📹 {'RealSense' if self.is_realsense_active else 'USB'}实时影像")
|
||
|
||
st.markdown("---")
|
||
control_col1, control_col2 = st.columns(2)
|
||
stop_button = control_col1.button("⏹️ 停止", use_container_width=True, key="stop_comparison")
|
||
restart_button = control_col2.button("🔄 重新开始", use_container_width=True, key="restart_comparison")
|
||
|
||
st.markdown("---")
|
||
st.markdown("### 📊 动作相似度分析")
|
||
sim_col1, sim_col2, sim_col3 = st.columns([1, 1, 2])
|
||
similarity_score_placeholder = sim_col1.empty()
|
||
avg_score_placeholder = sim_col2.empty()
|
||
similarity_plot_placeholder = sim_col3.empty()
|
||
|
||
progress_bar = st.progress(0)
|
||
status_text = st.empty()
|
||
|
||
# --- 按钮逻辑 ---
|
||
if stop_button: st.session_state.comparison_state['should_stop'] = True
|
||
if restart_button: st.session_state.comparison_state['should_restart'] = True
|
||
|
||
# --- 主循环 (基于时间同步,专注于实时显示) ---
|
||
start_time = time.time()
|
||
|
||
# --- 启动异步相似度计算线程 ---
|
||
self.similarity_stop_flag.clear()
|
||
self.similarity_thread = threading.Thread(
|
||
target=self._similarity_calculation_worker,
|
||
args=(start_time, video_fps),
|
||
daemon=True
|
||
)
|
||
self.similarity_thread.start()
|
||
|
||
current_similarity = 0
|
||
last_plot_update = 0
|
||
if audio_loaded: self.audio_player.play()
|
||
while True:
|
||
# --- 检查退出条件 ---
|
||
elapsed_time = time.time() - start_time
|
||
|
||
if elapsed_time >= video_duration or st.session_state.comparison_state['should_stop']:
|
||
# 停止音频播放
|
||
if audio_loaded:
|
||
self.audio_player.stop()
|
||
break
|
||
|
||
# --- 检查重新开始 ---
|
||
if st.session_state.comparison_state['should_restart']:
|
||
self.similarity_analyzer.reset()
|
||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置视频到开头
|
||
start_time = time.time() # 重置计时器
|
||
if audio_loaded:
|
||
self.audio_player.restart()
|
||
st.session_state.comparison_state['should_restart'] = False
|
||
continue
|
||
# --- 基于真实时间获取标准视频帧 ---
|
||
target_frame_idx = int(elapsed_time * video_fps)
|
||
|
||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_idx)
|
||
ret_standard, standard_frame = self.standard_cap.read()
|
||
if not ret_standard: continue # 如果读取失败,跳过这一轮
|
||
|
||
# --- 读取当前摄像头帧 ---
|
||
ret_webcam, webcam_frame = self.read_camera_frame()
|
||
if not ret_webcam or webcam_frame is None:
|
||
status_text.warning("摄像头画面读取失败,请检查连接...")
|
||
|
||
time.sleep(0.1) # 短暂等待
|
||
continue
|
||
|
||
self.frame_counter += 1
|
||
webcam_frame = cv2.flip(webcam_frame, 1)
|
||
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
|
||
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
|
||
# --- 关键点检测 (实时处理) ---
|
||
try:
|
||
standard_keypoints, standard_scores = self.body_detector(standard_frame)
|
||
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
|
||
|
||
# --- 异步相似度计算 (发送给后台线程) ---
|
||
if standard_keypoints is not None and webcam_keypoints is not None:
|
||
standard_angles = self.similarity_analyzer.extract_joint_angles(standard_keypoints, standard_scores)
|
||
webcam_angles = self.similarity_analyzer.extract_joint_angles(webcam_keypoints, webcam_scores)
|
||
|
||
# 非阻塞地将数据放入队列
|
||
try:
|
||
self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles))
|
||
except queue.Full:
|
||
# 如果队列满了,丢弃最旧的数据
|
||
try:
|
||
self.pose_data_queue.get_nowait()
|
||
self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles))
|
||
except queue.Empty:
|
||
pass
|
||
except Exception:
|
||
standard_keypoints, webcam_keypoints = None, None
|
||
# --- 绘制与显示 (可以保持原来的逻辑,只修改UI更新部分) ---
|
||
try:
|
||
|
||
# 计算关节相似度
|
||
joint_similarities = None
|
||
if standard_keypoints is not None and webcam_keypoints is not None:
|
||
standard_angles = self.similarity_analyzer.extract_joint_angles(standard_keypoints, standard_scores)
|
||
webcam_angles = self.similarity_analyzer.extract_joint_angles(webcam_keypoints, webcam_scores)
|
||
if standard_angles and webcam_angles:
|
||
joint_similarities = self.similarity_analyzer.calculate_joint_similarities(standard_angles, webcam_angles)
|
||
|
||
# 绘制骨骼 (线条更窄,根据相似度设置颜色)
|
||
standard_with_keypoints = draw_skeleton_with_similarity(standard_frame.copy(), standard_keypoints, standard_scores, joint_similarities=None, openpose_skeleton=True, kpt_thr=0.43, line_width=1)
|
||
webcam_with_keypoints = draw_skeleton_with_similarity(webcam_frame.copy(), webcam_keypoints, webcam_scores, joint_similarities=joint_similarities, openpose_skeleton=True, kpt_thr=0.43, line_width=1)
|
||
|
||
# 更新视频画面 (主线程专注于实时显示)
|
||
standard_placeholder.image(cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
|
||
webcam_placeholder.image(cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
|
||
|
||
# 更新UI (在主线程中,降低更新频率)
|
||
if self.frame_counter % 30 == 0: # 每10帧更新一次UI
|
||
# 更新状态信息
|
||
if self.similarity_analyzer.similarity_history:
|
||
latest_similarity = self.similarity_analyzer.similarity_history[-1]
|
||
avg_sim = np.mean(self.similarity_analyzer.similarity_history)
|
||
similarity_score_placeholder.metric("当前相似度", f"{latest_similarity:.1f}%")
|
||
avg_score_placeholder.metric("平均相似度", f"{avg_sim:.1f}%")
|
||
|
||
# 更新图表 (每50帧更新一次)
|
||
if self.frame_counter % 30 == 0:
|
||
plot = self.similarity_analyzer.get_similarity_plot()
|
||
if plot:
|
||
similarity_plot_placeholder.plotly_chart(plot, use_container_width=True)
|
||
|
||
# 更新进度条和状态文本
|
||
progress = min(elapsed_time / video_duration, 1.0)
|
||
progress_bar.progress(progress)
|
||
processing_fps = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
|
||
status_text.text(f"时间: {elapsed_time:.1f}s / {video_duration:.1f}s | "
|
||
f"处理帧率: {processing_fps:.1f} FPS | "
|
||
f"标准视频帧: {target_frame_idx}/{total_frames}")
|
||
|
||
except Exception as e:
|
||
# 如果UI更新失败,避免程序崩溃
|
||
pass
|
||
|
||
# 停止相似度计算线程
|
||
self.similarity_stop_flag.set()
|
||
if self.similarity_thread and self.similarity_thread.is_alive():
|
||
# 发送毒丸信号停止线程
|
||
try:
|
||
self.pose_data_queue.put_nowait(None)
|
||
except queue.Full:
|
||
pass
|
||
self.similarity_thread.join(timeout=2)
|
||
|
||
|
||
# 确保音频已停止
|
||
if audio_loaded:
|
||
self.audio_player.stop()
|
||
|
||
self.cleanup()
|
||
self.show_final_statistics()
|