posedet/motion_app.py

483 lines
21 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import streamlit as st
import cv2
import time
import os
import numpy as np
import torch
import threading
import queue
from rtmlib import Body, draw_skeleton
from audio_player import AudioPlayer
from pose_analyzer import PoseSimilarityAnalyzer
from config import REALSENSE_AVAILABLE
def draw_skeleton_with_similarity(img, keypoints, scores, joint_similarities=None, openpose_skeleton=True, kpt_thr=0.43, line_width=2):
"""
自定义骨骼绘制函数,根据关节相似度设置颜色
相似度 > 90: 绿色
相似度 80-90: 黄色
相似度 < 80: 红色
支持多人检测,将绘制所有检测到的人体骨骼
"""
if keypoints is None or len(keypoints) == 0:
return img
# OpenPose连接关系
skeleton = [
[1, 0], [1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7],
[1, 8], [8, 9], [9, 10], [1, 11], [11, 12], [12, 13],
[0, 14], [0, 15], [14, 16], [15, 17]
]
# 关节与关节角度的映射
joint_to_angle_mapping = {
(2, 3): 'left_shoulder', (3, 4): 'left_elbow',
(5, 6): 'right_shoulder', (6, 7): 'right_elbow',
(8, 9): 'left_hip', (9, 10): 'left_knee',
(11, 12): 'right_hip', (12, 13): 'right_knee'
}
# 骨骼的默认颜色
default_color = (0, 255, 255) # 黄色
# 确保keypoints和scores的形状正确处理多人情况
if len(keypoints.shape) == 2:
keypoints = keypoints[None, :, :]
scores = scores[None, :, :]
# 遍历所有人
num_instances = keypoints.shape[0]
for person_idx in range(num_instances):
person_kpts = keypoints[person_idx]
person_scores = scores[person_idx]
# 绘制骨骼
for limb_id, limb in enumerate(skeleton):
joint_a, joint_b = limb
if joint_a >= len(person_scores) or joint_b >= len(person_scores):
continue
if person_scores[joint_a] < kpt_thr or person_scores[joint_b] < kpt_thr:
continue
x_a, y_a = person_kpts[joint_a]
x_b, y_b = person_kpts[joint_b]
# 确定线条颜色
color = default_color
if joint_similarities is not None:
# 检查这个连接是否有对应的关节角度
if (joint_a, joint_b) in joint_to_angle_mapping:
angle_name = joint_to_angle_mapping[(joint_a, joint_b)]
if angle_name in joint_similarities:
similarity = joint_similarities[angle_name]
if similarity > 90:
color = (0, 255, 0) # 绿色
elif similarity > 80:
color = (0, 255, 255) # 黄色
else:
color = (0, 0, 255) # 红色
cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), color, thickness=line_width)
# 绘制关键点
for kpt_id, (x, y) in enumerate(person_kpts):
if person_scores[kpt_id] < kpt_thr:
continue
cv2.circle(img, (int(x), int(y)), 3, (255, 0, 255), -1)
return img
if REALSENSE_AVAILABLE:
import pyrealsense2 as rs
class MotionComparisonApp:
"""Main application class for motion comparison."""
def __init__(self):
self.body_detector = None
self.is_running = False
self.standard_video_path = None
self.webcam_cap = None
self.standard_cap = None
self.similarity_analyzer = PoseSimilarityAnalyzer()
self.frame_counter = 0
self.audio_player = AudioPlayer()
self.display_settings = {'resolution_mode': 'high', 'target_width': 960, 'target_height': 720}
self.realsense_pipeline = None
self.is_realsense_active = False
self.last_error_time = 0
self.error_count = 0
# Async processing components
self.pose_data_queue = queue.Queue(maxsize=50)
self.similarity_thread = None
self.similarity_stop_flag = threading.Event()
if 'comparison_state' not in st.session_state:
st.session_state.comparison_state = {'is_running': False, 'should_stop': False, 'should_restart': False}
def get_display_resolution(self):
modes = {'high': (1024, 576), 'medium': (960, 720), 'low': (640, 480)}
mode = self.display_settings.get('resolution_mode', 'medium')
return modes.get(mode, (960, 720))
def initialize_detector(self):
if self.body_detector is None:
try:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.body_detector = Body(mode='lightweight', to_openpose=True, backend='onnxruntime', device=device)
st.success(f"Keypoint detector initialized on device: {device}")
return True
except Exception as e:
st.error(f"Detector initialization failed: {e}")
return False
return True
def initialize_camera(self):
if REALSENSE_AVAILABLE:
try:
self.realsense_pipeline = rs.pipeline()
config = rs.config()
width, height = self.get_display_resolution()
config.enable_stream(rs.stream.color, width, height, rs.format.bgr8, 30)
profile = self.realsense_pipeline.start(config)
device = profile.get_device().get_info(rs.camera_info.name)
st.success(f"✅ RealSense摄像头初始化成功: {device} ({width}x{height})")
self.is_realsense_active = True
return True
except Exception as e:
st.warning(f"RealSense初始化失败: {e}. 切换到USB摄像头.")
return self._initialize_webcam()
else:
return self._initialize_webcam()
def _initialize_webcam(self):
try:
self.webcam_cap = cv2.VideoCapture(0)
if self.webcam_cap.isOpened():
self.webcam_cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'))
width, height = self.get_display_resolution()
self.webcam_cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
self.webcam_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
self.webcam_cap.set(cv2.CAP_PROP_FPS, 30)
actual_w = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
actual_h = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
st.success(f"✅ USB摄像头初始化成功 ({actual_w}x{actual_h})")
return True
else:
st.error("❌ 无法打开USB摄像头")
return False
except Exception as e:
st.error(f"❌ USB摄像头初始化失败: {e}")
return False
def read_camera_frame(self):
if self.is_realsense_active and self.realsense_pipeline:
try:
frames = self.realsense_pipeline.wait_for_frames(timeout_ms=1000)
color_frame = frames.get_color_frame()
if not color_frame: return False, None
return True, np.asanyarray(color_frame.get_data())
except Exception:
return False, None
elif self.webcam_cap and self.webcam_cap.isOpened():
return self.webcam_cap.read()
return False, None
def get_camera_preview_frame(self):
ret, frame = self.read_camera_frame()
if not ret or frame is None: return None
frame = cv2.flip(frame, 1)
if self.body_detector:
try:
keypoints, scores = self.body_detector(frame)
frame = draw_skeleton_with_similarity(frame.copy(), keypoints, scores, joint_similarities=None, openpose_skeleton=True, kpt_thr=0.43, line_width=1)
except Exception: pass
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
def cleanup(self):
"""Cleans up all resources."""
# Stop similarity calculation thread
if self.similarity_thread and self.similarity_thread.is_alive():
self.similarity_stop_flag.set()
self.similarity_thread.join(timeout=2)
if self.standard_cap: self.standard_cap.release()
if self.webcam_cap: self.webcam_cap.release()
if self.is_realsense_active and self.realsense_pipeline: self.realsense_pipeline.stop()
self.audio_player.cleanup()
self.is_running = False
st.session_state.comparison_state['is_running'] = False
def show_final_statistics(self):
"""Displays final statistics after the comparison ends."""
st.markdown("---") # 添加分隔线
history = self.similarity_analyzer.similarity_history
if not history:
st.warning("没有足够的比较数据来生成分析结果。")
if st.button("返回主菜单", use_container_width=True):
st.rerun()
return
final_avg = sum(history) / len(history)
# 我们将评级文本和颜色分开定义,逻辑更清晰
if final_avg >= 80:
level = "非常棒! 👏"
md_color = "green" # 使用 'green' 颜色
elif final_avg >= 60:
level = "整体不错! 👍"
md_color = "blue" # 使用 'blue' 颜色
else:
level = "需要改进! 💪"
md_color = "orange" # 使用 'orange' 颜色替代 'warning'
st.success("🎉 比较完成!")
st.markdown(f"**整体表现**: :{md_color}[{level}]")
col1, col2, col3 = st.columns(3)
col1.metric("平均相似度", f"{final_avg:.1f}%")
col2.metric("最高相似度", f"{max(history):.1f}%")
col3.metric("最低相似度", f"{min(history):.1f}%")
if final_avg < 60:
with st.expander("💡 改善建议"):
st.markdown("- 确保您的全身在摄像头画面中清晰可见\n"
"- 尽量匹配标准视频的节奏和动作幅度\n"
"- 确保光线充足且稳定")
# 添加返回按钮
st.markdown("---")
if st.button("返回主菜单", use_container_width=True):
st.session_state.comparison_state['is_running'] = False # 确保状态重置
st.rerun()
def _similarity_calculation_worker(self, start_time, video_fps):
"""Background thread for similarity calculation and plot updates."""
while not self.similarity_stop_flag.is_set():
try:
# Get pose data from queue with timeout
pose_data = self.pose_data_queue.get(timeout=1.0)
if pose_data is None: # Poison pill to stop
break
elapsed_time, standard_angles, webcam_angles = pose_data
if standard_angles and webcam_angles:
current_similarity = self.similarity_analyzer.calculate_similarity(standard_angles, webcam_angles)
self.similarity_analyzer.add_similarity_score(current_similarity, elapsed_time)
self.pose_data_queue.task_done()
except queue.Empty:
continue
except Exception as e:
# Log error but continue processing
continue
def start_comparison(self, video_path):
"""The main loop for comparing motion, synchronized with real-time."""
self.is_running = True
st.session_state.comparison_state.update({'is_running': True, 'should_stop': False, 'should_restart': False})
self.standard_video_path = video_path
self.frame_counter = 0
self.similarity_analyzer.reset()
audio_loaded = self.audio_player.load_audio(video_path)
if audio_loaded: st.success("✅ 音频加载成功")
else: st.info(" 将不会播放音频")
self.standard_cap = cv2.VideoCapture(video_path)
if not self.standard_cap.isOpened():
st.error("无法打开标准视频文件")
return
if not self.is_realsense_active and (not self.webcam_cap or not self.webcam_cap.isOpened()):
if not self.initialize_camera(): return
# --- 获取视频信息和设置变量 ---
total_frames = int(self.standard_cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
if video_fps == 0: video_fps = 30 # 防止除零错误
video_duration = total_frames / video_fps
target_width, target_height = self.get_display_resolution()
# --- UI 占位符 ---
st.markdown("### 📺 视频比较")
video_col1, video_col2 = st.columns(2, gap="small")
standard_placeholder = video_col1.empty()
webcam_placeholder = video_col2.empty()
with video_col1: st.markdown("#### 🎯 标准动作视频")
with video_col2: st.markdown(f"#### 📹 {'RealSense' if self.is_realsense_active else 'USB'}实时影像")
st.markdown("---")
control_col1, control_col2 = st.columns(2)
stop_button = control_col1.button("⏹️ 停止", use_container_width=True, key="stop_comparison")
restart_button = control_col2.button("🔄 重新开始", use_container_width=True, key="restart_comparison")
st.markdown("---")
st.markdown("### 📊 动作相似度分析")
sim_col1, sim_col2, sim_col3 = st.columns([1, 1, 2])
similarity_score_placeholder = sim_col1.empty()
avg_score_placeholder = sim_col2.empty()
similarity_plot_placeholder = sim_col3.empty()
progress_bar = st.progress(0)
status_text = st.empty()
# --- 按钮逻辑 ---
if stop_button: st.session_state.comparison_state['should_stop'] = True
if restart_button: st.session_state.comparison_state['should_restart'] = True
# --- 主循环 (基于时间同步,专注于实时显示) ---
start_time = time.time()
# --- 启动异步相似度计算线程 ---
self.similarity_stop_flag.clear()
self.similarity_thread = threading.Thread(
target=self._similarity_calculation_worker,
args=(start_time, video_fps),
daemon=True
)
self.similarity_thread.start()
current_similarity = 0
last_plot_update = 0
if audio_loaded: self.audio_player.play()
while True:
# --- 检查退出条件 ---
elapsed_time = time.time() - start_time
if elapsed_time >= video_duration or st.session_state.comparison_state['should_stop']:
# 停止音频播放
if audio_loaded:
self.audio_player.stop()
break
# --- 检查重新开始 ---
if st.session_state.comparison_state['should_restart']:
self.similarity_analyzer.reset()
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # 重置视频到开头
start_time = time.time() # 重置计时器
if audio_loaded:
self.audio_player.restart()
st.session_state.comparison_state['should_restart'] = False
continue
# --- 基于真实时间获取标准视频帧 ---
target_frame_idx = int(elapsed_time * video_fps)
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_idx)
ret_standard, standard_frame = self.standard_cap.read()
if not ret_standard: continue # 如果读取失败,跳过这一轮
# --- 读取当前摄像头帧 ---
ret_webcam, webcam_frame = self.read_camera_frame()
if not ret_webcam or webcam_frame is None:
status_text.warning("摄像头画面读取失败,请检查连接...")
time.sleep(0.1) # 短暂等待
continue
self.frame_counter += 1
webcam_frame = cv2.flip(webcam_frame, 1)
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
# --- 关键点检测 (实时处理) ---
try:
standard_keypoints, standard_scores = self.body_detector(standard_frame)
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
# --- 异步相似度计算 (发送给后台线程) ---
if standard_keypoints is not None and webcam_keypoints is not None:
standard_angles = self.similarity_analyzer.extract_joint_angles(standard_keypoints, standard_scores)
webcam_angles = self.similarity_analyzer.extract_joint_angles(webcam_keypoints, webcam_scores)
# 非阻塞地将数据放入队列
try:
self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles))
except queue.Full:
# 如果队列满了,丢弃最旧的数据
try:
self.pose_data_queue.get_nowait()
self.pose_data_queue.put_nowait((elapsed_time, standard_angles, webcam_angles))
except queue.Empty:
pass
except Exception:
standard_keypoints, webcam_keypoints = None, None
# --- 绘制与显示 (可以保持原来的逻辑只修改UI更新部分) ---
try:
# 计算关节相似度
joint_similarities = None
if standard_keypoints is not None and webcam_keypoints is not None:
standard_angles = self.similarity_analyzer.extract_joint_angles(standard_keypoints, standard_scores)
webcam_angles = self.similarity_analyzer.extract_joint_angles(webcam_keypoints, webcam_scores)
if standard_angles and webcam_angles:
joint_similarities = self.similarity_analyzer.calculate_joint_similarities(standard_angles, webcam_angles)
# 绘制骨骼 (线条更窄,根据相似度设置颜色)
standard_with_keypoints = draw_skeleton_with_similarity(standard_frame.copy(), standard_keypoints, standard_scores, joint_similarities=None, openpose_skeleton=True, kpt_thr=0.43, line_width=1)
webcam_with_keypoints = draw_skeleton_with_similarity(webcam_frame.copy(), webcam_keypoints, webcam_scores, joint_similarities=joint_similarities, openpose_skeleton=True, kpt_thr=0.43, line_width=1)
# 更新视频画面 (主线程专注于实时显示)
standard_placeholder.image(cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
webcam_placeholder.image(cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB), use_container_width=True)
# 更新UI (在主线程中,降低更新频率)
if self.frame_counter % 30 == 0: # 每10帧更新一次UI
# 更新状态信息
if self.similarity_analyzer.similarity_history:
latest_similarity = self.similarity_analyzer.similarity_history[-1]
avg_sim = np.mean(self.similarity_analyzer.similarity_history)
similarity_score_placeholder.metric("当前相似度", f"{latest_similarity:.1f}%")
avg_score_placeholder.metric("平均相似度", f"{avg_sim:.1f}%")
# 更新图表 (每50帧更新一次)
if self.frame_counter % 30 == 0:
plot = self.similarity_analyzer.get_similarity_plot()
if plot:
similarity_plot_placeholder.plotly_chart(plot, use_container_width=True)
# 更新进度条和状态文本
progress = min(elapsed_time / video_duration, 1.0)
progress_bar.progress(progress)
processing_fps = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
status_text.text(f"时间: {elapsed_time:.1f}s / {video_duration:.1f}s | "
f"处理帧率: {processing_fps:.1f} FPS | "
f"标准视频帧: {target_frame_idx}/{total_frames}")
except Exception as e:
# 如果UI更新失败避免程序崩溃
pass
# 停止相似度计算线程
self.similarity_stop_flag.set()
if self.similarity_thread and self.similarity_thread.is_alive():
# 发送毒丸信号停止线程
try:
self.pose_data_queue.put_nowait(None)
except queue.Full:
pass
self.similarity_thread.join(timeout=2)
# 确保音频已停止
if audio_loaded:
self.audio_player.stop()
self.cleanup()
self.show_final_statistics()