1248 lines
50 KiB
Python
1248 lines
50 KiB
Python
import streamlit as st
|
||
import cv2
|
||
import time
|
||
import tempfile
|
||
import os
|
||
import sys
|
||
import numpy as np
|
||
from collections import deque
|
||
import math
|
||
import pandas as pd
|
||
import threading
|
||
from pathlib import Path
|
||
|
||
# 设置环境变量和torch配置
|
||
os.environ['OMP_NUM_THREADS'] = '1'
|
||
os.environ['MKL_NUM_THREADS'] = '1'
|
||
|
||
try:
|
||
import torch
|
||
torch.set_num_threads(1)
|
||
except ImportError as e:
|
||
st.error(f"PyTorch导入失败: {e}")
|
||
st.stop()
|
||
|
||
try:
|
||
import matplotlib.pyplot as plt
|
||
import plotly.graph_objects as go
|
||
from plotly.subplots import make_subplots
|
||
except ImportError as e:
|
||
st.error(f"绘图库导入失败: {e}")
|
||
st.stop()
|
||
|
||
try:
|
||
import pygame
|
||
PYGAME_AVAILABLE = True
|
||
except ImportError:
|
||
PYGAME_AVAILABLE = False
|
||
st.warning("Pygame未安装,视频将无声音播放。安装方法: pip install pygame")
|
||
|
||
try:
|
||
from moviepy.editor import VideoFileClip
|
||
MOVIEPY_AVAILABLE = True
|
||
except ImportError:
|
||
MOVIEPY_AVAILABLE = False
|
||
if PYGAME_AVAILABLE:
|
||
st.warning("MoviePy未安装,将尝试其他方式处理音频。安装方法: pip install moviepy")
|
||
|
||
try:
|
||
import pyrealsense2 as rs
|
||
REALSENSE_AVAILABLE = True
|
||
except ImportError:
|
||
REALSENSE_AVAILABLE = False
|
||
st.warning("RealSense库未安装,将使用普通摄像头")
|
||
|
||
# 导入rtmlib
|
||
try:
|
||
from rtmlib import Body, draw_skeleton
|
||
except ImportError:
|
||
st.error("请安装rtmlib库: pip install rtmlib")
|
||
st.stop()
|
||
|
||
class AudioPlayer:
|
||
"""音频播放器类"""
|
||
|
||
def __init__(self):
|
||
self.is_playing = False
|
||
self.audio_file = None
|
||
self.start_time = None
|
||
self.pygame_initialized = False
|
||
|
||
if PYGAME_AVAILABLE:
|
||
try:
|
||
pygame.mixer.pre_init(frequency=22050, size=-16, channels=2, buffer=512)
|
||
pygame.mixer.init()
|
||
self.pygame_initialized = True
|
||
except Exception as e:
|
||
st.warning(f"音频初始化失败: {e}")
|
||
|
||
def extract_audio_from_video(self, video_path):
|
||
"""从视频中提取音频"""
|
||
if not MOVIEPY_AVAILABLE or not self.pygame_initialized:
|
||
return None
|
||
|
||
try:
|
||
# 生成临时音频文件路径
|
||
temp_audio = tempfile.mktemp(suffix='.wav')
|
||
|
||
# 使用moviepy提取音频
|
||
video_clip = VideoFileClip(video_path)
|
||
if video_clip.audio is not None:
|
||
video_clip.audio.write_audiofile(temp_audio, verbose=False, logger=None)
|
||
video_clip.close()
|
||
return temp_audio
|
||
else:
|
||
video_clip.close()
|
||
return None
|
||
except Exception as e:
|
||
return None
|
||
|
||
def load_audio(self, video_path):
|
||
"""加载音频文件"""
|
||
if not self.pygame_initialized:
|
||
return False
|
||
|
||
try:
|
||
# 先尝试提取音频
|
||
audio_file = self.extract_audio_from_video(video_path)
|
||
if audio_file and os.path.exists(audio_file):
|
||
self.audio_file = audio_file
|
||
return True
|
||
return False
|
||
except Exception as e:
|
||
return False
|
||
|
||
def play(self):
|
||
"""开始播放音频"""
|
||
if not self.pygame_initialized or not self.audio_file:
|
||
return False
|
||
|
||
try:
|
||
if not self.is_playing:
|
||
pygame.mixer.music.load(self.audio_file)
|
||
pygame.mixer.music.play()
|
||
self.is_playing = True
|
||
self.start_time = time.time()
|
||
return True
|
||
except Exception as e:
|
||
return False
|
||
return False
|
||
|
||
def stop(self):
|
||
"""停止播放音频"""
|
||
if self.pygame_initialized and self.is_playing:
|
||
try:
|
||
pygame.mixer.music.stop()
|
||
self.is_playing = False
|
||
return True
|
||
except Exception as e:
|
||
return False
|
||
return False
|
||
|
||
def restart(self):
|
||
"""重新开始播放音频"""
|
||
if self.pygame_initialized and self.audio_file:
|
||
try:
|
||
pygame.mixer.music.load(self.audio_file)
|
||
pygame.mixer.music.play()
|
||
self.is_playing = True
|
||
self.start_time = time.time()
|
||
return True
|
||
except Exception as e:
|
||
return False
|
||
return False
|
||
|
||
def cleanup(self):
|
||
"""清理资源"""
|
||
try:
|
||
if self.is_playing:
|
||
self.stop()
|
||
if self.audio_file and os.path.exists(self.audio_file):
|
||
os.unlink(self.audio_file)
|
||
self.audio_file = None
|
||
except Exception as e:
|
||
pass
|
||
|
||
class PoseSimilarityAnalyzer:
|
||
"""姿态相似度分析器"""
|
||
|
||
def __init__(self):
|
||
self.similarity_history = deque(maxlen=500) # 保存最近500个相似度值
|
||
self.frame_timestamps = deque(maxlen=500)
|
||
self.start_time = None
|
||
|
||
# OpenPose关键点索引映射(Body 17个关键点)
|
||
self.keypoint_map = {
|
||
'nose': 0, 'neck': 1,
|
||
'left_shoulder': 2, 'left_elbow': 3, 'left_wrist': 4,
|
||
'right_shoulder': 5, 'right_elbow': 6, 'right_wrist': 7,
|
||
'left_hip': 8, 'left_knee': 9, 'left_ankle': 10,
|
||
'right_hip': 11, 'right_knee': 12, 'right_ankle': 13,
|
||
'left_eye': 14, 'right_eye': 15,
|
||
'left_ear': 16, 'right_ear': 17
|
||
}
|
||
|
||
# 关节角度定义(关节名:[点1, 关节点, 点2])
|
||
self.joint_angles = {
|
||
'left_elbow': ['left_shoulder', 'left_elbow', 'left_wrist'],
|
||
'right_elbow': ['right_shoulder', 'right_elbow', 'right_wrist'],
|
||
'left_shoulder': ['left_elbow', 'left_shoulder', 'neck'],
|
||
'right_shoulder': ['right_elbow', 'right_shoulder', 'neck'],
|
||
'left_knee': ['left_hip', 'left_knee', 'left_ankle'],
|
||
'right_knee': ['right_hip', 'right_knee', 'right_ankle'],
|
||
'left_hip': ['left_knee', 'left_hip', 'neck'],
|
||
'right_hip': ['right_knee', 'right_hip', 'neck'],
|
||
}
|
||
|
||
# 关节权重(重要关节权重更高)
|
||
self.joint_weights = {
|
||
'left_elbow': 1.2,
|
||
'right_elbow': 1.2,
|
||
'left_shoulder': 1.0,
|
||
'right_shoulder': 1.0,
|
||
'left_knee': 1.3,
|
||
'right_knee': 1.3,
|
||
'left_hip': 1.1,
|
||
'right_hip': 1.1,
|
||
}
|
||
|
||
def calculate_angle(self, p1, p2, p3):
|
||
"""计算三个点组成的角度"""
|
||
try:
|
||
# 向量v1 = p1 - p2, v2 = p3 - p2
|
||
v1 = np.array([p1[0] - p2[0], p1[1] - p2[1]], dtype=np.float64)
|
||
v2 = np.array([p3[0] - p2[0], p3[1] - p2[1]], dtype=np.float64)
|
||
|
||
# 计算向量长度
|
||
v1_norm = np.linalg.norm(v1)
|
||
v2_norm = np.linalg.norm(v2)
|
||
|
||
if v1_norm == 0 or v2_norm == 0:
|
||
return None
|
||
|
||
# 计算夹角(弧度)
|
||
cos_angle = np.dot(v1, v2) / (v1_norm * v2_norm)
|
||
cos_angle = np.clip(cos_angle, -1.0, 1.0) # 防止数值误差
|
||
angle = np.arccos(cos_angle)
|
||
|
||
# 转换为角度
|
||
return np.degrees(angle)
|
||
except Exception as e:
|
||
return None
|
||
|
||
def extract_joint_angles(self, keypoints, scores, confidence_threshold=0.3):
|
||
"""从关键点提取关节角度"""
|
||
if keypoints is None or len(keypoints) == 0:
|
||
return None
|
||
|
||
try:
|
||
# 取第一个人的关键点
|
||
person_keypoints = keypoints[0] if len(keypoints.shape) > 2 else keypoints
|
||
person_scores = scores[0] if len(scores.shape) > 1 else scores
|
||
|
||
joint_angles_result = {}
|
||
|
||
for joint_name, (p1_name, p2_name, p3_name) in self.joint_angles.items():
|
||
try:
|
||
# 获取关键点索引
|
||
if (p1_name not in self.keypoint_map or
|
||
p2_name not in self.keypoint_map or
|
||
p3_name not in self.keypoint_map):
|
||
continue
|
||
|
||
p1_idx = self.keypoint_map[p1_name]
|
||
p2_idx = self.keypoint_map[p2_name]
|
||
p3_idx = self.keypoint_map[p3_name]
|
||
|
||
# 检查索引范围
|
||
if (p1_idx >= len(person_scores) or
|
||
p2_idx >= len(person_scores) or
|
||
p3_idx >= len(person_scores)):
|
||
continue
|
||
|
||
# 检查置信度
|
||
if (person_scores[p1_idx] < confidence_threshold or
|
||
person_scores[p2_idx] < confidence_threshold or
|
||
person_scores[p3_idx] < confidence_threshold):
|
||
continue
|
||
|
||
# 获取坐标
|
||
p1 = person_keypoints[p1_idx]
|
||
p2 = person_keypoints[p2_idx]
|
||
p3 = person_keypoints[p3_idx]
|
||
|
||
# 计算角度
|
||
angle = self.calculate_angle(p1, p2, p3)
|
||
if angle is not None:
|
||
joint_angles_result[joint_name] = angle
|
||
|
||
except (KeyError, IndexError, TypeError) as e:
|
||
continue
|
||
|
||
return joint_angles_result
|
||
except Exception as e:
|
||
return None
|
||
|
||
def calculate_similarity(self, angles1, angles2):
|
||
"""计算两组关节角度的相似度"""
|
||
if not angles1 or not angles2:
|
||
return 0.0
|
||
|
||
try:
|
||
# 找到共同的关节
|
||
common_joints = set(angles1.keys()) & set(angles2.keys())
|
||
if not common_joints:
|
||
return 0.0
|
||
|
||
total_weight = 0
|
||
weighted_similarity = 0
|
||
|
||
for joint in common_joints:
|
||
angle_diff = abs(angles1[joint] - angles2[joint])
|
||
|
||
# 角度差异转换为相似度(0-1)
|
||
# 使用高斯函数,角度差异越小,相似度越高
|
||
similarity = math.exp(-(angle_diff ** 2) / (2 * (30 ** 2))) # 30度标准差
|
||
|
||
# 应用权重
|
||
weight = self.joint_weights.get(joint, 1.0)
|
||
weighted_similarity += similarity * weight
|
||
total_weight += weight
|
||
|
||
# 归一化相似度
|
||
final_similarity = weighted_similarity / total_weight if total_weight > 0 else 0
|
||
return min(max(final_similarity * 100, 0), 100) # 转换为0-100分
|
||
except Exception as e:
|
||
return 0.0
|
||
|
||
def add_similarity_score(self, score, timestamp=None):
|
||
"""添加相似度分数到历史记录"""
|
||
try:
|
||
if self.start_time is None:
|
||
self.start_time = time.time()
|
||
|
||
if timestamp is None:
|
||
timestamp = time.time() - self.start_time
|
||
|
||
self.similarity_history.append(float(score))
|
||
self.frame_timestamps.append(float(timestamp))
|
||
except Exception as e:
|
||
pass # 忽略记录错误
|
||
|
||
def get_similarity_plot(self):
|
||
"""生成相似度变化折线图"""
|
||
try:
|
||
if len(self.similarity_history) < 2:
|
||
return None
|
||
|
||
fig = go.Figure()
|
||
|
||
fig.add_trace(go.Scatter(
|
||
x=list(self.frame_timestamps),
|
||
y=list(self.similarity_history),
|
||
mode='lines+markers',
|
||
name='动作相似度',
|
||
line=dict(color='#2E86AB', width=2),
|
||
marker=dict(size=4)
|
||
))
|
||
|
||
fig.update_layout(
|
||
title='动作相似度变化趋势',
|
||
xaxis_title='时间 (秒)',
|
||
yaxis_title='相似度分数 (%)',
|
||
yaxis=dict(range=[0, 100]),
|
||
height=250, # 减小图表高度为相似度区域腾出空间
|
||
margin=dict(l=50, r=50, t=50, b=50),
|
||
showlegend=False
|
||
)
|
||
|
||
# 添加平均分数线
|
||
if len(self.similarity_history) > 0:
|
||
avg_score = sum(self.similarity_history) / len(self.similarity_history)
|
||
fig.add_hline(
|
||
y=avg_score,
|
||
line_dash="dash",
|
||
line_color="red",
|
||
annotation_text=f"平均分: {avg_score:.1f}%"
|
||
)
|
||
|
||
return fig
|
||
except Exception as e:
|
||
return None
|
||
|
||
def reset(self):
|
||
"""重置分析器"""
|
||
try:
|
||
self.similarity_history.clear()
|
||
self.frame_timestamps.clear()
|
||
self.start_time = None
|
||
except Exception as e:
|
||
pass
|
||
|
||
|
||
class MotionComparisonApp:
|
||
"""动作比较应用程序主类"""
|
||
|
||
def __init__(self):
|
||
self.body_detector = None
|
||
self.is_running = False
|
||
self.standard_video_path = None
|
||
self.webcam_cap = None
|
||
self.standard_cap = None
|
||
self.similarity_analyzer = PoseSimilarityAnalyzer()
|
||
self.frame_counter = 0
|
||
self.audio_player = AudioPlayer()
|
||
|
||
# 显示设置
|
||
self.display_settings = {
|
||
'resolution_mode': 'high', # high, medium, low
|
||
'target_width': 960, # 增大默认分辨率
|
||
'target_height': 720
|
||
}
|
||
|
||
# RealSense相关
|
||
self.realsense_pipeline = None
|
||
self.realsense_config = None
|
||
self.is_realsense_active = False
|
||
|
||
# 状态管理
|
||
self.last_error_time = 0
|
||
self.error_count = 0
|
||
|
||
# 控制状态
|
||
if 'comparison_state' not in st.session_state:
|
||
st.session_state.comparison_state = {
|
||
'is_running': False,
|
||
'should_stop': False,
|
||
'should_restart': False
|
||
}
|
||
|
||
def get_display_resolution(self):
|
||
"""根据设置获取显示分辨率"""
|
||
resolution_modes = {
|
||
'high': (1280, 800), # 高清模式
|
||
'medium': (960, 720), # 中等模式
|
||
'low': (640, 480) # 低分辨率模式
|
||
}
|
||
|
||
mode = self.display_settings.get('resolution_mode', 'high')
|
||
return resolution_modes.get(mode, (960, 720))
|
||
|
||
def initialize_realsense(self):
|
||
"""初始化RealSense摄像头"""
|
||
if not REALSENSE_AVAILABLE:
|
||
return self.initialize_webcam()
|
||
|
||
try:
|
||
# 创建pipeline和config
|
||
self.realsense_pipeline = rs.pipeline()
|
||
self.realsense_config = rs.config()
|
||
|
||
# 配置RGB流 - 更高分辨率
|
||
width, height = self.get_display_resolution()
|
||
self.realsense_config.enable_stream(rs.stream.color, width, height, rs.format.bgr8, 30)
|
||
|
||
# 启动pipeline
|
||
profile = self.realsense_pipeline.start(self.realsense_config)
|
||
|
||
# 获取设备信息
|
||
device = profile.get_device()
|
||
device_info = f"RealSense {device.get_info(rs.camera_info.name)}"
|
||
|
||
self.is_realsense_active = True
|
||
st.success(f"✅ RealSense摄像头初始化成功: {device_info} ({width}x{height})")
|
||
return True
|
||
|
||
except Exception as e:
|
||
st.warning(f"RealSense摄像头初始化失败: {str(e)}")
|
||
return self.initialize_webcam()
|
||
|
||
def initialize_webcam(self):
|
||
"""初始化普通USB摄像头"""
|
||
try:
|
||
self.webcam_cap = cv2.VideoCapture(0)
|
||
if self.webcam_cap.isOpened():
|
||
# 设置摄像头参数为更高分辨率
|
||
width, height = self.get_display_resolution()
|
||
self.webcam_cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
|
||
self.webcam_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
||
self.webcam_cap.set(cv2.CAP_PROP_FPS, 30)
|
||
|
||
# 验证实际分辨率
|
||
actual_width = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
actual_height = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
|
||
st.success(f"✅ USB摄像头初始化成功 ({actual_width}x{actual_height})")
|
||
return True
|
||
else:
|
||
st.error("❌ 无法打开USB摄像头")
|
||
return False
|
||
except Exception as e:
|
||
st.error(f"❌ USB摄像头初始化失败: {str(e)}")
|
||
return False
|
||
|
||
def read_realsense_frame(self):
|
||
"""从RealSense读取一帧图像"""
|
||
if not self.is_realsense_active or self.realsense_pipeline is None:
|
||
return False, None
|
||
|
||
try:
|
||
# 等待新的帧
|
||
frames = self.realsense_pipeline.wait_for_frames(timeout_ms=1000)
|
||
|
||
# 获取RGB帧
|
||
color_frame = frames.get_color_frame()
|
||
if not color_frame:
|
||
return False, None
|
||
|
||
# 转换为numpy数组
|
||
color_image = np.asanyarray(color_frame.get_data())
|
||
|
||
return True, color_image
|
||
|
||
except Exception as e:
|
||
current_time = time.time()
|
||
if current_time - self.last_error_time > 5: # 每5秒只输出一次错误
|
||
st.warning(f"读取RealSense帧失败: {str(e)}")
|
||
self.last_error_time = current_time
|
||
return False, None
|
||
|
||
def read_webcam_frame(self):
|
||
"""统一的摄像头读取接口"""
|
||
if self.is_realsense_active:
|
||
return self.read_realsense_frame()
|
||
elif self.webcam_cap is not None and self.webcam_cap.isOpened():
|
||
ret, frame = self.webcam_cap.read()
|
||
return ret, frame
|
||
else:
|
||
return False, None
|
||
|
||
def get_camera_preview_frame(self):
|
||
"""获取摄像头预览帧"""
|
||
try:
|
||
ret, frame = self.read_webcam_frame()
|
||
if not ret or frame is None:
|
||
return None
|
||
|
||
# 翻转摄像头画面(镜像效果)
|
||
frame = cv2.flip(frame, 1)
|
||
|
||
# 如果检测器已初始化,显示关键点
|
||
if self.body_detector is not None:
|
||
try:
|
||
keypoints, scores = self.body_detector(frame)
|
||
frame = draw_skeleton(
|
||
frame.copy(),
|
||
keypoints,
|
||
scores,
|
||
openpose_skeleton=True,
|
||
kpt_thr=0.43
|
||
)
|
||
except Exception as e:
|
||
# 预览阶段,忽略检测错误
|
||
pass
|
||
|
||
# 转换颜色空间
|
||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
return frame_rgb
|
||
|
||
except Exception as e:
|
||
return None
|
||
|
||
def stop_realsense(self):
|
||
"""停止RealSense摄像头"""
|
||
if self.is_realsense_active and self.realsense_pipeline is not None:
|
||
try:
|
||
self.realsense_pipeline.stop()
|
||
self.is_realsense_active = False
|
||
except Exception as e:
|
||
pass
|
||
|
||
def cleanup(self):
|
||
"""清理所有资源"""
|
||
try:
|
||
if self.standard_cap is not None and self.standard_cap.isOpened():
|
||
self.standard_cap.release()
|
||
self.standard_cap = None
|
||
|
||
if self.webcam_cap is not None and self.webcam_cap.isOpened():
|
||
self.webcam_cap.release()
|
||
self.webcam_cap = None
|
||
|
||
self.stop_realsense()
|
||
self.audio_player.cleanup()
|
||
self.is_running = False
|
||
st.session_state.comparison_state['is_running'] = False
|
||
except Exception as e:
|
||
pass
|
||
|
||
def initialize_detector(self):
|
||
"""初始化身体关键点检测器"""
|
||
try:
|
||
if self.body_detector is None:
|
||
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'
|
||
self.body_detector = Body(
|
||
mode='lightweight', # 使用轻量级模式提高实时性能
|
||
to_openpose=True,
|
||
backend='onnxruntime',
|
||
device=device
|
||
)
|
||
st.success(f"关键点检测器初始化完成 (设备: {device})")
|
||
return True
|
||
except Exception as e:
|
||
st.error(f"检测器初始化失败: {str(e)}")
|
||
return False
|
||
|
||
def start_comparison(self, video_path):
|
||
"""开始动作比较"""
|
||
try:
|
||
# 设置运行状态
|
||
st.session_state.comparison_state = {
|
||
'is_running': True,
|
||
'should_stop': False,
|
||
'should_restart': False
|
||
}
|
||
|
||
self.is_running = True
|
||
self.standard_video_path = video_path
|
||
self.frame_counter = 0
|
||
self.error_count = 0
|
||
|
||
# 重置相似度分析器
|
||
self.similarity_analyzer.reset()
|
||
|
||
# 加载音频
|
||
audio_loaded = self.audio_player.load_audio(video_path)
|
||
if audio_loaded:
|
||
st.success("✅ 音频加载成功")
|
||
else:
|
||
st.info("ℹ️ 音频加载失败或视频无音频,将静音播放")
|
||
|
||
# 打开标准视频
|
||
self.standard_cap = cv2.VideoCapture(video_path)
|
||
if not self.standard_cap.isOpened():
|
||
st.error("无法打开标准视频文件")
|
||
return
|
||
|
||
# 确保摄像头正常工作
|
||
if not self.is_realsense_active and (self.webcam_cap is None or not self.webcam_cap.isOpened()):
|
||
if not self.initialize_realsense():
|
||
st.error("无法初始化摄像头")
|
||
return
|
||
|
||
# 获取视频信息
|
||
fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
|
||
total_frames = int(self.standard_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
frame_delay = 1.0 / fps if fps > 0 else 1.0 / 30
|
||
|
||
# 获取目标分辨率
|
||
target_width, target_height = self.get_display_resolution()
|
||
|
||
# 创建大尺寸视频显示区域
|
||
st.markdown("### 📺 视频对比显示")
|
||
|
||
# 使用更大的列布局比例
|
||
video_col1, video_col2 = st.columns([1, 1], gap="small")
|
||
|
||
with video_col1:
|
||
st.markdown("#### 🎯 标准动作视频")
|
||
standard_placeholder = st.empty()
|
||
|
||
with video_col2:
|
||
camera_type = "RealSense摄像头" if self.is_realsense_active else "USB摄像头"
|
||
st.markdown(f"#### 📹 {camera_type}实时影像")
|
||
webcam_placeholder = st.empty()
|
||
|
||
# 创建控制按钮区域(紧凑布局)
|
||
st.markdown("---")
|
||
control_container = st.container()
|
||
with control_container:
|
||
control_col1, control_col2, control_col3, control_col4 = st.columns([1, 1, 1, 1])
|
||
|
||
with control_col1:
|
||
stop_button = st.button("⏹️ 停止", use_container_width=True, key="stop_comparison")
|
||
with control_col2:
|
||
restart_button = st.button("🔄 重新开始", use_container_width=True, key="restart_comparison")
|
||
with control_col3:
|
||
if audio_loaded:
|
||
if st.button("🔊 音频状态", use_container_width=True, key="audio_status"):
|
||
st.info(f"音频: {'播放中' if self.audio_player.is_playing else '已停止'}")
|
||
with control_col4:
|
||
# 分辨率切换按钮
|
||
if st.button("📐 切换分辨率", use_container_width=True, key="resolution_toggle"):
|
||
modes = ['high', 'medium', 'low']
|
||
current_idx = modes.index(self.display_settings.get('resolution_mode', 'high'))
|
||
next_idx = (current_idx + 1) % len(modes)
|
||
self.display_settings['resolution_mode'] = modes[next_idx]
|
||
st.info(f"分辨率模式: {modes[next_idx]}")
|
||
|
||
# 相似度显示区域(紧凑显示)
|
||
st.markdown("---")
|
||
st.markdown("### 📊 动作相似度分析")
|
||
|
||
similarity_container = st.container()
|
||
with similarity_container:
|
||
# 使用3列布局来压缩相似度显示区域
|
||
sim_col1, sim_col2, sim_col3 = st.columns([1, 1, 2])
|
||
|
||
with sim_col1:
|
||
similarity_score_placeholder = st.empty()
|
||
|
||
with sim_col2:
|
||
avg_score_placeholder = st.empty()
|
||
|
||
with sim_col3:
|
||
similarity_plot_placeholder = st.empty()
|
||
|
||
# 处理按钮点击
|
||
if stop_button:
|
||
st.session_state.comparison_state['should_stop'] = True
|
||
if restart_button:
|
||
st.session_state.comparison_state['should_restart'] = True
|
||
|
||
# 状态显示区域
|
||
status_container = st.container()
|
||
with status_container:
|
||
progress_bar = st.progress(0)
|
||
status_text = st.empty()
|
||
|
||
# 主循环
|
||
video_frame_idx = 0
|
||
start_time = time.time()
|
||
current_similarity = 0
|
||
last_plot_update = 0
|
||
|
||
# 开始播放音频
|
||
if audio_loaded:
|
||
self.audio_player.play()
|
||
|
||
while (st.session_state.comparison_state['is_running'] and
|
||
not st.session_state.comparison_state['should_stop']):
|
||
|
||
loop_start = time.time()
|
||
self.frame_counter += 1
|
||
|
||
# 检查重新开始
|
||
if st.session_state.comparison_state['should_restart']:
|
||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||
self.similarity_analyzer.reset()
|
||
start_time = time.time()
|
||
video_frame_idx = 0
|
||
if audio_loaded:
|
||
self.audio_player.restart()
|
||
st.session_state.comparison_state['should_restart'] = False
|
||
continue
|
||
|
||
# 读取标准视频的当前帧
|
||
ret_standard, standard_frame = self.standard_cap.read()
|
||
if not ret_standard:
|
||
# 视频结束,重新开始
|
||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||
video_frame_idx = 0
|
||
self.similarity_analyzer.reset()
|
||
start_time = time.time()
|
||
if audio_loaded:
|
||
self.audio_player.restart()
|
||
continue
|
||
|
||
# 读取摄像头当前帧
|
||
ret_webcam, webcam_frame = self.read_webcam_frame()
|
||
if not ret_webcam or webcam_frame is None:
|
||
self.error_count += 1
|
||
if self.error_count > 100: # 连续错误过多时停止
|
||
st.error("摄像头连接出现问题,停止比较")
|
||
break
|
||
continue
|
||
|
||
self.error_count = 0 # 重置错误计数
|
||
|
||
# 翻转摄像头画面(镜像效果)
|
||
webcam_frame = cv2.flip(webcam_frame, 1)
|
||
|
||
# 调整尺寸使两个视频大小一致(使用更高分辨率)
|
||
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
|
||
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
|
||
|
||
# 处理关键点检测
|
||
try:
|
||
standard_keypoints, standard_scores = self.body_detector(standard_frame)
|
||
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
|
||
except Exception as e:
|
||
# 关键点检测失败时继续显示原始图像
|
||
standard_keypoints, standard_scores = None, None
|
||
webcam_keypoints, webcam_scores = None, None
|
||
|
||
# 计算相似度(每5帧计算一次以提高性能)
|
||
if (self.frame_counter % 5 == 0 and
|
||
standard_keypoints is not None and
|
||
webcam_keypoints is not None):
|
||
try:
|
||
# 提取当前标准视频帧和摄像头帧的关节角度
|
||
standard_angles = self.similarity_analyzer.extract_joint_angles(
|
||
standard_keypoints, standard_scores
|
||
)
|
||
webcam_angles = self.similarity_analyzer.extract_joint_angles(
|
||
webcam_keypoints, webcam_scores
|
||
)
|
||
|
||
# 计算相似度
|
||
if standard_angles and webcam_angles:
|
||
current_similarity = self.similarity_analyzer.calculate_similarity(
|
||
standard_angles, webcam_angles
|
||
)
|
||
|
||
# 添加到历史记录
|
||
timestamp = time.time() - start_time
|
||
self.similarity_analyzer.add_similarity_score(current_similarity, timestamp)
|
||
except Exception as e:
|
||
pass # 忽略相似度计算错误
|
||
|
||
# 绘制关键点
|
||
try:
|
||
if standard_keypoints is not None and standard_scores is not None:
|
||
standard_with_keypoints = draw_skeleton(
|
||
standard_frame.copy(),
|
||
standard_keypoints,
|
||
standard_scores,
|
||
openpose_skeleton=True,
|
||
kpt_thr=0.43
|
||
)
|
||
else:
|
||
standard_with_keypoints = standard_frame.copy()
|
||
|
||
if webcam_keypoints is not None and webcam_scores is not None:
|
||
webcam_with_keypoints = draw_skeleton(
|
||
webcam_frame.copy(),
|
||
webcam_keypoints,
|
||
webcam_scores,
|
||
openpose_skeleton=True,
|
||
kpt_thr=0.43
|
||
)
|
||
else:
|
||
webcam_with_keypoints = webcam_frame.copy()
|
||
|
||
except Exception as e:
|
||
# 如果绘制失败,使用原始帧
|
||
standard_with_keypoints = standard_frame.copy()
|
||
webcam_with_keypoints = webcam_frame.copy()
|
||
|
||
# 转换颜色空间 (BGR to RGB)
|
||
standard_rgb = cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB)
|
||
webcam_rgb = cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB)
|
||
|
||
# 添加帧信息到图像上
|
||
current_time = time.time() - start_time
|
||
frame_info = f"时间: {current_time:.1f}s | 帧: {video_frame_idx}/{total_frames}"
|
||
audio_info = f" | 音频: {'🔊' if self.audio_player.is_playing else '🔇'}" if audio_loaded else ""
|
||
resolution_info = f" | {target_width}x{target_height}"
|
||
|
||
# 显示大尺寸画面
|
||
with video_col1:
|
||
standard_placeholder.image(
|
||
standard_rgb,
|
||
caption=f"标准动作 - {frame_info}{audio_info}{resolution_info}",
|
||
use_container_width=True
|
||
)
|
||
|
||
with video_col2:
|
||
webcam_placeholder.image(
|
||
webcam_rgb,
|
||
caption=f"您的动作 - 实时画面{resolution_info}",
|
||
use_container_width=True
|
||
)
|
||
|
||
# 显示相似度信息(紧凑显示)
|
||
if len(self.similarity_analyzer.similarity_history) > 0:
|
||
try:
|
||
avg_similarity = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
|
||
|
||
# 使用不同颜色显示相似度
|
||
if current_similarity >= 80:
|
||
similarity_color = "🟢"
|
||
level = "优秀"
|
||
elif current_similarity >= 60:
|
||
similarity_color = "🟡"
|
||
level = "良好"
|
||
else:
|
||
similarity_color = "🔴"
|
||
level = "需要改进"
|
||
|
||
with sim_col1:
|
||
similarity_score_placeholder.metric(
|
||
"当前相似度",
|
||
f"{similarity_color} {current_similarity:.1f}%",
|
||
delta=f"{level}"
|
||
)
|
||
|
||
with sim_col2:
|
||
avg_score_placeholder.metric(
|
||
"平均相似度",
|
||
f"{avg_similarity:.1f}%"
|
||
)
|
||
|
||
# 更新相似度图表(每20帧更新一次以提高性能)
|
||
if (len(self.similarity_analyzer.similarity_history) >= 2 and
|
||
self.frame_counter - last_plot_update >= 20):
|
||
try:
|
||
similarity_plot = self.similarity_analyzer.get_similarity_plot()
|
||
if similarity_plot:
|
||
with sim_col3:
|
||
similarity_plot_placeholder.plotly_chart(
|
||
similarity_plot,
|
||
use_container_width=True,
|
||
key=f"similarity_plot_{int(time.time() * 1000)}" # 使用时间戳避免重复ID
|
||
)
|
||
last_plot_update = self.frame_counter
|
||
except Exception as e:
|
||
pass # 忽略图表更新错误
|
||
except Exception as e:
|
||
pass # 忽略显示错误
|
||
|
||
# 更新进度和状态(紧凑显示)
|
||
try:
|
||
progress = min(video_frame_idx / total_frames, 1.0) if total_frames > 0 else 0
|
||
progress_bar.progress(progress)
|
||
|
||
elapsed_time = time.time() - start_time
|
||
fps_actual = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
|
||
|
||
status_text.text(
|
||
f"进度: {video_frame_idx}/{total_frames} | "
|
||
f"实际FPS: {fps_actual:.1f} | "
|
||
f"分辨率: {target_width}x{target_height} | "
|
||
f"模式: {self.display_settings['resolution_mode']}"
|
||
)
|
||
except Exception as e:
|
||
pass # 忽略状态更新错误
|
||
|
||
video_frame_idx += 1
|
||
|
||
# 精确的帧率控制
|
||
loop_elapsed = time.time() - loop_start
|
||
sleep_time = max(0, frame_delay - loop_elapsed)
|
||
if sleep_time > 0:
|
||
time.sleep(sleep_time)
|
||
|
||
# 强制更新UI(每30帧一次)
|
||
if self.frame_counter % 30 == 0:
|
||
st.empty() # 触发UI更新
|
||
|
||
# 停止音频
|
||
self.audio_player.stop()
|
||
|
||
# 显示最终统计
|
||
self.show_final_statistics()
|
||
|
||
except Exception as e:
|
||
st.error(f"比较过程中出现错误: {str(e)}")
|
||
finally:
|
||
self.is_running = False
|
||
st.session_state.comparison_state['is_running'] = False
|
||
self.audio_player.stop()
|
||
|
||
def show_final_statistics(self):
|
||
"""显示最终统计信息"""
|
||
try:
|
||
if len(self.similarity_analyzer.similarity_history) > 0:
|
||
final_avg = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
|
||
max_similarity = max(self.similarity_analyzer.similarity_history)
|
||
min_similarity = min(self.similarity_analyzer.similarity_history)
|
||
|
||
st.success(f"🎉 动作比较完成!")
|
||
|
||
# 评价等级
|
||
if final_avg >= 80:
|
||
level = "优秀!👏"
|
||
color = "success"
|
||
elif final_avg >= 60:
|
||
level = "良好!👍"
|
||
color = "info"
|
||
else:
|
||
level = "继续加油!💪"
|
||
color = "warning"
|
||
|
||
st.markdown(f"**整体评价**: :{color}[{level}]")
|
||
|
||
col1, col2, col3 = st.columns(3)
|
||
with col1:
|
||
st.metric("平均相似度", f"{final_avg:.1f}%")
|
||
with col2:
|
||
st.metric("最高相似度", f"{max_similarity:.1f}%")
|
||
with col3:
|
||
st.metric("最低相似度", f"{min_similarity:.1f}%")
|
||
|
||
# 显示改进建议
|
||
if final_avg < 60:
|
||
with st.expander("💡 改进建议"):
|
||
st.markdown("""
|
||
- 确保全身在摄像头视野内
|
||
- 动作要清晰、幅度到位
|
||
- 保持与标准视频相同的节奏
|
||
- 检查光线是否充足
|
||
- 尽量保持身体正对摄像头
|
||
""")
|
||
except Exception as e:
|
||
st.warning("无法显示统计信息")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
try:
|
||
# 确保目录存在
|
||
os.makedirs("preset_videos", exist_ok=True)
|
||
|
||
# 页面配置 - 设置为宽模式以最大化显示空间
|
||
st.set_page_config(
|
||
page_title="动作比较与关键点检测",
|
||
page_icon="🏃",
|
||
layout="wide",
|
||
initial_sidebar_state="expanded"
|
||
)
|
||
|
||
st.title("🏃 动作比较与关键点检测系统")
|
||
st.markdown("---")
|
||
|
||
# 创建应用实例
|
||
if 'app' not in st.session_state:
|
||
st.session_state.app = MotionComparisonApp()
|
||
|
||
app = st.session_state.app
|
||
|
||
# 侧边栏控制面板(紧凑设计)
|
||
with st.sidebar:
|
||
st.header("🎛️ 控制面板")
|
||
|
||
# 显示设置
|
||
st.subheader("🖥️ 显示设置")
|
||
resolution_options = {
|
||
"高清 (1280x800)": "high",
|
||
"中等 (960x720)": "medium",
|
||
"标准 (640x480)": "low"
|
||
}
|
||
|
||
selected_resolution = st.selectbox(
|
||
"选择显示分辨率",
|
||
list(resolution_options.keys()),
|
||
index=1 # 默认选择中等分辨率
|
||
)
|
||
app.display_settings['resolution_mode'] = resolution_options[selected_resolution]
|
||
|
||
st.markdown("---")
|
||
|
||
# 预设视频选择
|
||
preset_videos = {
|
||
"六字诀": "preset_videos/liuzi.mp4",
|
||
}
|
||
|
||
# 预设视频状态检查
|
||
st.subheader("📁 预设视频")
|
||
video_status_ok = False
|
||
for name, path in preset_videos.items():
|
||
if os.path.exists(path):
|
||
st.success(f"✅ {name}")
|
||
video_status_ok = True
|
||
else:
|
||
st.error(f"❌ {name}")
|
||
|
||
# 视频来源选择
|
||
video_source = st.radio(
|
||
"视频来源",
|
||
["预设视频", "上传视频"],
|
||
index=0 if video_status_ok else 1
|
||
)
|
||
|
||
video_path = None
|
||
|
||
if video_source == "预设视频":
|
||
if video_status_ok:
|
||
selected_preset = st.selectbox(
|
||
"选择视频",
|
||
list(preset_videos.keys())
|
||
)
|
||
video_path = preset_videos[selected_preset]
|
||
else:
|
||
uploaded_video = st.file_uploader(
|
||
"上传视频文件",
|
||
type=['mp4', 'avi', 'mov', 'mkv']
|
||
)
|
||
|
||
if uploaded_video is not None:
|
||
try:
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
|
||
tmp_file.write(uploaded_video.read())
|
||
video_path = tmp_file.name
|
||
st.success("✅ 上传成功")
|
||
except Exception as e:
|
||
st.error(f"上传失败: {str(e)}")
|
||
|
||
st.markdown("---")
|
||
|
||
# 初始化按钮
|
||
st.subheader("⚙️ 系统初始化")
|
||
if st.button("🚀 初始化系统", use_container_width=True):
|
||
with st.spinner("初始化中..."):
|
||
if app.initialize_detector():
|
||
app.initialize_realsense()
|
||
|
||
# 系统信息
|
||
st.subheader("ℹ️ 系统状态")
|
||
if torch.cuda.device_count() > 0:
|
||
st.success("✅ CUDA")
|
||
else:
|
||
st.info("💻 CPU")
|
||
|
||
if REALSENSE_AVAILABLE:
|
||
st.success("✅ RealSense")
|
||
else:
|
||
st.info("📹 USB摄像头")
|
||
|
||
if PYGAME_AVAILABLE:
|
||
st.success("✅ 音频支持")
|
||
else:
|
||
st.warning("🔇 无音频")
|
||
|
||
# 主界面
|
||
if video_path is not None:
|
||
try:
|
||
# 显示视频基本信息(紧凑显示)
|
||
cap = cv2.VideoCapture(video_path)
|
||
if cap.isOpened():
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
duration = frame_count / fps if fps > 0 else 0
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
|
||
info_col1, info_col2, info_col3, info_col4 = st.columns(4)
|
||
with info_col1:
|
||
st.metric("⏱️ 时长", f"{duration:.1f}s")
|
||
with info_col2:
|
||
st.metric("🔄 帧率", f"{fps:.1f}")
|
||
with info_col3:
|
||
st.metric("📐 原始分辨率", f"{width}×{height}")
|
||
with info_col4:
|
||
current_res = app.get_display_resolution()
|
||
st.metric("🖥️ 显示分辨率", f"{current_res[0]}×{current_res[1]}")
|
||
|
||
cap.release()
|
||
|
||
st.markdown("---")
|
||
|
||
# 初始化session_state
|
||
if 'show_preview' not in st.session_state:
|
||
st.session_state.show_preview = False
|
||
|
||
# 预览区域
|
||
if st.session_state.show_preview:
|
||
st.markdown("### 📷 摄像头预览")
|
||
|
||
# 初始化摄像头(如果还没有初始化)
|
||
if not app.is_realsense_active and (app.webcam_cap is None or not app.webcam_cap.isOpened()):
|
||
app.initialize_realsense()
|
||
|
||
# 获取预览帧
|
||
preview_frame = app.get_camera_preview_frame()
|
||
if preview_frame is not None:
|
||
# 使用大尺寸预览
|
||
preview_col1, preview_col2, preview_col3 = st.columns([1, 3, 1])
|
||
with preview_col2:
|
||
st.image(preview_frame, caption="摄像头预览 (已镜像)", use_container_width=True)
|
||
|
||
# 显示提示信息
|
||
camera_type = "RealSense摄像头" if app.is_realsense_active else "USB摄像头"
|
||
st.info(f"🎯 正在使用: {camera_type} - 请调整您的位置,确保全身在画面中清晰可见")
|
||
|
||
# 预览控制按钮
|
||
preview_btn_col1, preview_btn_col2, preview_btn_col3 = st.columns(3)
|
||
with preview_btn_col1:
|
||
if st.button("🔄 刷新画面", key="refresh_preview", use_container_width=True):
|
||
st.rerun()
|
||
with preview_btn_col2:
|
||
if st.button("⏹️ 停止预览", key="stop_preview", use_container_width=True):
|
||
st.session_state.show_preview = False
|
||
st.rerun()
|
||
with preview_btn_col3:
|
||
if st.button("🚀 开始比较", key="start_from_preview", use_container_width=True):
|
||
if app.body_detector is None:
|
||
st.error("⚠️ 请先初始化关键点检测器")
|
||
else:
|
||
st.session_state.show_preview = False
|
||
app.start_comparison(video_path)
|
||
else:
|
||
st.error("❌ 无法获取摄像头画面,请检查摄像头连接")
|
||
st.session_state.show_preview = False
|
||
|
||
else:
|
||
# 主控制按钮(大按钮)
|
||
main_btn_col1, main_btn_col2, main_btn_col3 = st.columns(3)
|
||
|
||
with main_btn_col1:
|
||
if st.button("📷 预览摄像头", use_container_width=True, help="先预览摄像头调整位置"):
|
||
if app.body_detector is None:
|
||
st.error("⚠️ 请先在侧边栏初始化系统")
|
||
else:
|
||
st.session_state.show_preview = True
|
||
st.rerun()
|
||
|
||
with main_btn_col2:
|
||
if st.button("🚀 直接开始比较", use_container_width=True, help="直接开始动作比较"):
|
||
if app.body_detector is None:
|
||
st.error("⚠️ 请先在侧边栏初始化系统")
|
||
else:
|
||
app.start_comparison(video_path)
|
||
|
||
with main_btn_col3:
|
||
st.markdown("") # 占位
|
||
|
||
except Exception as e:
|
||
st.error(f"❌ 处理视频时出现错误: {str(e)}")
|
||
|
||
else:
|
||
# 无视频时显示说明
|
||
st.info("👈 请在左侧选择或上传标准动作视频")
|
||
|
||
# 显示使用说明(优化布局)
|
||
with st.expander("📖 使用说明", expanded=True):
|
||
usage_col1, usage_col2 = st.columns(2)
|
||
|
||
with usage_col1:
|
||
st.markdown("""
|
||
#### 🚀 快速开始:
|
||
1. **选择视频**: 侧边栏选择预设或上传视频
|
||
2. **调整分辨率**: 根据设备性能选择合适分辨率
|
||
3. **初始化系统**: 点击"初始化系统"按钮
|
||
4. **预览摄像头**: 调整位置确保全身可见
|
||
5. **开始比较**: 跟随标准视频做动作
|
||
|
||
#### ⚙️ 分辨率建议:
|
||
- **高性能设备**: 高清模式 (1280x800)
|
||
- **一般设备**: 中等模式 (960x720)
|
||
- **低配设备**: 标准模式 (640x480)
|
||
""")
|
||
|
||
with usage_col2:
|
||
st.markdown("""
|
||
#### ✨ 主要特性:
|
||
- 🎯 实时关键点检测和动作分析
|
||
- 📹 支持RealSense和USB摄像头
|
||
- 🔊 视频音频同步播放
|
||
- 📊 实时相似度图表分析
|
||
- 🖥️ 大屏幕优化显示
|
||
- ⚡ GPU加速支持
|
||
|
||
#### 📋 系统要求:
|
||
- 摄像头设备及权限
|
||
- 充足的光照条件
|
||
- Python 3.7+ 环境
|
||
- 推荐GPU支持(可选)
|
||
""")
|
||
|
||
except Exception as e:
|
||
st.error(f"❌ 应用程序启动失败: {str(e)}")
|
||
st.info("💡 请检查所有依赖库是否正确安装")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|