refactor(main): improve app structure and UI
This commit is contained in:
parent
b33ad5e876
commit
427eca08d0
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
.aider*
|
@ -8,7 +8,7 @@ if PYGAME_AVAILABLE:
|
||||
import pygame
|
||||
|
||||
if MOVIEPY_AVAILABLE:
|
||||
from moviepy.editor import VideoFileClip
|
||||
from moviepy import VideoFileClip
|
||||
|
||||
class AudioPlayer:
|
||||
"""A class to handle audio extraction and playback for the video."""
|
||||
@ -37,7 +37,7 @@ class AudioPlayer:
|
||||
temp_audio = tempfile.mktemp(suffix='.wav')
|
||||
video_clip = VideoFileClip(video_path)
|
||||
if video_clip.audio is not None:
|
||||
video_clip.audio.write_audiofile(temp_audio, verbose=False, logger=None)
|
||||
video_clip.audio.write_audiofile(temp_audio, logger=None)
|
||||
video_clip.close()
|
||||
return temp_audio
|
||||
else:
|
||||
|
@ -10,7 +10,7 @@ except ImportError:
|
||||
|
||||
# Check for MoviePy availability for audio extraction
|
||||
try:
|
||||
from moviepy.editor import VideoFileClip
|
||||
from moviepy import VideoFileClip
|
||||
MOVIEPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
MOVIEPY_AVAILABLE = False
|
||||
|
301
main_app.py
301
main_app.py
@ -8,92 +8,233 @@ import torch
|
||||
from motion_app import MotionComparisonApp
|
||||
from config import REALSENSE_AVAILABLE, PYGAME_AVAILABLE
|
||||
|
||||
torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)]
|
||||
|
||||
# Set page config at module level
|
||||
st.set_page_config(page_title="动作比较", page_icon="🏃", layout="wide")
|
||||
|
||||
def main():
|
||||
"""Main function to run the Streamlit app."""
|
||||
st.set_page_config(page_title="Motion Comparison", page_icon="🏃", layout="wide")
|
||||
st.title("🏃 Motion Comparison & Pose Analysis System")
|
||||
st.markdown("---")
|
||||
|
||||
# Initialize the app object in session state
|
||||
if 'app' not in st.session_state:
|
||||
st.session_state.app = MotionComparisonApp()
|
||||
app = st.session_state.app
|
||||
|
||||
# --- Sidebar UI ---
|
||||
with st.sidebar:
|
||||
st.header("🎛️ Control Panel")
|
||||
|
||||
# Display settings
|
||||
resolution_options = {"High (1280x800)": "high", "Medium (960x720)": "medium", "Standard (640x480)": "low"}
|
||||
selected_res = st.selectbox("Display Resolution", list(resolution_options.keys()), index=1)
|
||||
app.display_settings['resolution_mode'] = resolution_options[selected_res]
|
||||
|
||||
try:
|
||||
st.title("🏃 动作比较与姿态分析系统")
|
||||
st.markdown("---")
|
||||
|
||||
# Video Source Selection
|
||||
video_source = st.radio("Video Source", ["Preset Video", "Upload Video"])
|
||||
video_path = None
|
||||
|
||||
if video_source == "Preset Video":
|
||||
preset_path = "preset_videos/liuzi.mp4"
|
||||
if os.path.exists(preset_path):
|
||||
st.success("✅ '六字诀' video found.")
|
||||
video_path = preset_path
|
||||
|
||||
# Initialize the app object in session state
|
||||
if 'app' not in st.session_state:
|
||||
st.session_state.app = MotionComparisonApp()
|
||||
app = st.session_state.app
|
||||
|
||||
# --- Sidebar UI ---
|
||||
with st.sidebar:
|
||||
st.header("🎛️ 控制面板")
|
||||
|
||||
# Display settings
|
||||
resolution_options = {"高清 (1280x800)": "high", "中等 (960x720)": "medium", "标准 (640x480)": "low"}
|
||||
selected_res = st.selectbox("显示分辨率", list(resolution_options.keys()), index=1)
|
||||
app.display_settings['resolution_mode'] = resolution_options[selected_res]
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# Video Source Selection
|
||||
video_source = st.radio("视频来源", ["预设视频", "上传视频"])
|
||||
video_path = None
|
||||
|
||||
if video_source == "预设视频":
|
||||
# Define preset video mapping (display name -> filename)
|
||||
preset_videos = {
|
||||
"六字诀": "liuzi.mp4",
|
||||
# Add more preset videos here:
|
||||
# "太极拳": "taiji.mp4",
|
||||
# "八段锦": "baduanjin.mp4",
|
||||
}
|
||||
|
||||
preset_folder = "preset_videos"
|
||||
if os.path.exists(preset_folder):
|
||||
# Check which preset videos actually exist in the folder
|
||||
available_presets = {}
|
||||
for display_name, filename in preset_videos.items():
|
||||
full_path = os.path.join(preset_folder, filename)
|
||||
if os.path.exists(full_path):
|
||||
available_presets[display_name] = filename
|
||||
|
||||
if available_presets:
|
||||
selected_display_name = st.selectbox("选择预设视频", list(available_presets.keys()))
|
||||
selected_filename = available_presets[selected_display_name]
|
||||
video_path = os.path.join(preset_folder, selected_filename)
|
||||
st.success(f"✅ 已选择视频: {selected_display_name} ({selected_filename})")
|
||||
else:
|
||||
st.error("❌ preset_videos 文件夹中未找到预设视频文件")
|
||||
st.info(f"请确保以下文件存在: {', '.join(preset_videos.values())}")
|
||||
else:
|
||||
st.error("❌ preset_videos 文件夹不存在")
|
||||
else:
|
||||
st.error("❌ Preset video not found. Please place 'liuzi.mp4' in 'preset_videos' folder.")
|
||||
uploaded_file = st.file_uploader("上传视频文件", type=['mp4', 'avi', 'mov', 'mkv'])
|
||||
if uploaded_file:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
|
||||
tmp_file.write(uploaded_file.read())
|
||||
video_path = tmp_file.name
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# System Initialization
|
||||
st.subheader("⚙️ 系统初始化")
|
||||
if st.button("🚀 初始化系统", use_container_width=True):
|
||||
with st.spinner("正在初始化检测器和摄像头..."):
|
||||
app.initialize_detector()
|
||||
app.initialize_camera()
|
||||
|
||||
# System Status
|
||||
st.subheader("ℹ️ 系统状态")
|
||||
st.info(f"计算设备: {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}")
|
||||
st.info(f"摄像头: {'RealSense' if REALSENSE_AVAILABLE else 'USB 摄像头'}")
|
||||
st.info(f"音频: {'启用' if PYGAME_AVAILABLE else '禁用'}")
|
||||
|
||||
# --- Main Page UI ---
|
||||
if video_path is not None:
|
||||
try:
|
||||
# 显示视频基本信息(紧凑显示)
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if cap.isOpened():
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
duration = frame_count / fps if fps > 0 else 0
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
info_col1, info_col2, info_col3, info_col4 = st.columns(4)
|
||||
with info_col1:
|
||||
st.metric("⏱️ 时长", f"{duration:.1f}s")
|
||||
with info_col2:
|
||||
st.metric("🔄 帧率", f"{fps:.1f}")
|
||||
with info_col3:
|
||||
st.metric("📐 原始分辨率", f"{width}×{height}")
|
||||
with info_col4:
|
||||
current_res = app.get_display_resolution()
|
||||
st.metric("🖥️ 显示分辨率", f"{current_res[0]}×{current_res[1]}")
|
||||
|
||||
cap.release()
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# 初始化session_state
|
||||
if 'show_preview' not in st.session_state:
|
||||
st.session_state.show_preview = False
|
||||
|
||||
# 预览区域
|
||||
if st.session_state.show_preview:
|
||||
st.markdown("### 📷 摄像头预览")
|
||||
|
||||
# 初始化摄像头(如果还没有初始化)
|
||||
if not app.is_realsense_active and (app.webcam_cap is None or not app.webcam_cap.isOpened()):
|
||||
app.initialize_realsense()
|
||||
|
||||
# 获取预览帧
|
||||
preview_frame = app.get_camera_preview_frame()
|
||||
if preview_frame is not None:
|
||||
# 使用大尺寸预览
|
||||
preview_col1, preview_col2, preview_col3 = st.columns([1, 3, 1])
|
||||
with preview_col2:
|
||||
st.image(preview_frame, caption="摄像头预览 (已镜像)", use_container_width=True)
|
||||
|
||||
# 显示提示信息
|
||||
camera_type = "RealSense摄像头" if app.is_realsense_active else "USB摄像头"
|
||||
st.info(f"🎯 正在使用: {camera_type} - 请调整您的位置,确保全身在画面中清晰可见")
|
||||
|
||||
# 预览控制按钮
|
||||
preview_btn_col1, preview_btn_col2, preview_btn_col3 = st.columns(3)
|
||||
with preview_btn_col1:
|
||||
if st.button("🔄 刷新画面", key="refresh_preview", use_container_width=True):
|
||||
st.rerun()
|
||||
with preview_btn_col2:
|
||||
if st.button("⏹️ 停止预览", key="stop_preview", use_container_width=True):
|
||||
st.session_state.show_preview = False
|
||||
st.rerun()
|
||||
with preview_btn_col3:
|
||||
if st.button("🚀 开始比较", key="start_from_preview", use_container_width=True):
|
||||
if app.body_detector is None:
|
||||
st.error("⚠️ 请先初始化关键点检测器")
|
||||
else:
|
||||
st.session_state.show_preview = False
|
||||
app.start_comparison(video_path)
|
||||
else:
|
||||
st.error("❌ 无法获取摄像头画面,请检查摄像头连接")
|
||||
st.session_state.show_preview = False
|
||||
|
||||
else:
|
||||
# 主控制按钮(大按钮)
|
||||
main_btn_col1, main_btn_col2, main_btn_col3 = st.columns(3)
|
||||
|
||||
with main_btn_col1:
|
||||
if st.button("📷 预览摄像头", use_container_width=True, help="先预览摄像头调整位置"):
|
||||
if app.body_detector is None:
|
||||
st.error("⚠️ 请先在侧边栏初始化系统")
|
||||
else:
|
||||
st.session_state.show_preview = True
|
||||
st.rerun()
|
||||
|
||||
with main_btn_col2:
|
||||
if st.button("🚀 直接开始比较", use_container_width=True, help="直接开始动作比较"):
|
||||
if app.body_detector is None:
|
||||
st.error("⚠️ 请先在侧边栏初始化系统")
|
||||
else:
|
||||
app.start_comparison(video_path)
|
||||
|
||||
with main_btn_col3:
|
||||
st.markdown("") # 占位
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"❌ 处理视频时出现错误: {str(e)}")
|
||||
|
||||
else:
|
||||
uploaded_file = st.file_uploader("Upload a video", type=['mp4', 'avi', 'mov', 'mkv'])
|
||||
if uploaded_file:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
|
||||
tmp_file.write(uploaded_file.read())
|
||||
video_path = tmp_file.name
|
||||
# 无视频时显示说明
|
||||
st.info("👈 请在左侧选择预设视频或上传标准动作视频")
|
||||
|
||||
# 显示使用说明(优化布局)
|
||||
with st.expander("📖 使用说明", expanded=True):
|
||||
usage_col1, usage_col2 = st.columns(2)
|
||||
|
||||
with usage_col1:
|
||||
st.markdown("""
|
||||
#### 🚀 快速开始:
|
||||
1. **选择视频**: 侧边栏选择预设视频或上传视频
|
||||
2. **调整分辨率**: 根据设备性能选择合适分辨率
|
||||
3. **初始化系统**: 点击"初始化系统"按钮
|
||||
4. **预览摄像头**: 调整位置确保全身可见
|
||||
5. **开始比较**: 跟随标准视频做动作
|
||||
|
||||
#### ⚙️ 分辨率建议:
|
||||
- **高性能设备**: 高清模式 (1280x800)
|
||||
- **一般设备**: 中等模式 (960x720)
|
||||
- **低配设备**: 标准模式 (640x480)
|
||||
""")
|
||||
|
||||
with usage_col2:
|
||||
st.markdown("""
|
||||
#### ✨ 主要特性:
|
||||
- 🎯 实时关键点检测和动作分析
|
||||
- 📹 支持RealSense和USB摄像头
|
||||
- 🔊 视频音频同步播放
|
||||
- 📊 实时相似度图表分析
|
||||
- 🖥️ 大屏幕优化显示
|
||||
- ⚡ GPU加速支持
|
||||
|
||||
#### 📋 系统要求:
|
||||
- 摄像头设备及权限
|
||||
- 充足的光照条件
|
||||
- Python 3.7+ 环境
|
||||
- 推荐GPU支持(可选)
|
||||
""")
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# System Initialization
|
||||
st.subheader("⚙️ System Initialization")
|
||||
if st.button("🚀 Initialize System", use_container_width=True):
|
||||
with st.spinner("Initializing detectors and cameras..."):
|
||||
app.initialize_detector()
|
||||
app.initialize_camera()
|
||||
|
||||
# System Status
|
||||
st.subheader("ℹ️ System Status")
|
||||
st.info(f"Computation: {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}")
|
||||
st.info(f"Camera: {'RealSense' if REALSENSE_AVAILABLE else 'USB Webcam'}")
|
||||
st.info(f"Audio: {'Enabled' if PYGAME_AVAILABLE else 'Disabled'}")
|
||||
|
||||
# --- Main Page UI ---
|
||||
if video_path:
|
||||
# Display video info and control buttons
|
||||
# This part is identical to your original `main` function's logic
|
||||
# It sets up the "Preview Camera" and "Start Comparison" buttons
|
||||
# And calls app.start_comparison(video_path) when clicked.
|
||||
|
||||
# Example of how you would structure the main page:
|
||||
if st.button("🚀 Start Comparison", use_container_width=True):
|
||||
if not app.body_detector:
|
||||
st.error("⚠️ Please initialize the system from the sidebar first!")
|
||||
else:
|
||||
# The start_comparison method now contains the main display loop
|
||||
app.start_comparison(video_path)
|
||||
else:
|
||||
st.info("👈 Please select or upload a standard video from the sidebar to begin.")
|
||||
with st.expander("📖 Usage Guide", expanded=True):
|
||||
st.markdown("""
|
||||
1. **Select Video**: Choose a preset or upload your own video in the sidebar.
|
||||
2. **Initialize**: Click 'Initialize System' to prepare the camera and AI model.
|
||||
3. **Start**: Click 'Start Comparison' to begin the analysis.
|
||||
""")
|
||||
except Exception as e:
|
||||
st.error(f"❌ 应用程序启动失败: {str(e)}")
|
||||
st.info("💡 请检查所有依赖库是否正确安装")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set environment variables for performance
|
||||
os.environ['OMP_NUM_THREADS'] = '1'
|
||||
os.environ['MKL_NUM_THREADS'] = '1'
|
||||
try:
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# os.environ['OMP_NUM_THREADS'] = '1'
|
||||
# os.environ['MKL_NUM_THREADS'] = '1'
|
||||
# try:
|
||||
# import torch
|
||||
# torch.set_num_threads(1)
|
||||
# except ImportError:
|
||||
# pass
|
||||
main()
|
||||
|
329
motion_app.py
329
motion_app.py
@ -61,11 +61,11 @@ class MotionComparisonApp:
|
||||
config.enable_stream(rs.stream.color, width, height, rs.format.bgr8, 30)
|
||||
profile = self.realsense_pipeline.start(config)
|
||||
device = profile.get_device().get_info(rs.camera_info.name)
|
||||
st.success(f"✅ RealSense camera initialized: {device} ({width}x{height})")
|
||||
st.success(f"✅ RealSense摄像头初始化成功: {device} ({width}x{height})")
|
||||
self.is_realsense_active = True
|
||||
return True
|
||||
except Exception as e:
|
||||
st.warning(f"RealSense init failed: {e}. Falling back to USB camera.")
|
||||
st.warning(f"RealSense初始化失败: {e}. 切换到USB摄像头.")
|
||||
return self._initialize_webcam()
|
||||
else:
|
||||
return self._initialize_webcam()
|
||||
@ -80,13 +80,13 @@ class MotionComparisonApp:
|
||||
self.webcam_cap.set(cv2.CAP_PROP_FPS, 30)
|
||||
actual_w = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
actual_h = int(self.webcam_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
st.success(f"✅ USB camera initialized ({actual_w}x{actual_h})")
|
||||
st.success(f"✅ USB摄像头初始化成功 ({actual_w}x{actual_h})")
|
||||
return True
|
||||
else:
|
||||
st.error("❌ Could not open USB camera.")
|
||||
st.error("❌ 无法打开USB摄像头")
|
||||
return False
|
||||
except Exception as e:
|
||||
st.error(f"❌ USB camera init failed: {e}")
|
||||
st.error(f"❌ USB摄像头初始化失败: {e}")
|
||||
return False
|
||||
|
||||
def read_camera_frame(self):
|
||||
@ -130,23 +130,23 @@ class MotionComparisonApp:
|
||||
if not history: return
|
||||
|
||||
final_avg = sum(history) / len(history)
|
||||
level, color = ("Excellent! 👏", "success") if final_avg >= 80 else \
|
||||
("Good! 👍", "info") if final_avg >= 60 else \
|
||||
("Needs Improvement! 💪", "warning")
|
||||
level, color = ("非常棒! 👏", "success") if final_avg >= 80 else \
|
||||
("整体不错! 👍", "info") if final_avg >= 60 else \
|
||||
("需要改进! 💪", "warning")
|
||||
|
||||
st.success("🎉 Comparison Finished!")
|
||||
st.markdown(f"**Overall Performance**: :{color}[{level}]")
|
||||
st.success("🎉 比较完成!")
|
||||
st.markdown(f"**整体表现**: :{color}[{level}]")
|
||||
|
||||
col1, col2, col3 = st.columns(3)
|
||||
col1.metric("Average Similarity", f"{final_avg:.1f}%")
|
||||
col2.metric("Max Similarity", f"{max(history):.1f}%")
|
||||
col3.metric("Min Similarity", f"{min(history):.1f}%")
|
||||
col1.metric("平均相似度", f"{final_avg:.1f}%")
|
||||
col2.metric("最高相似度", f"{max(history):.1f}%")
|
||||
col3.metric("最低相似度", f"{min(history):.1f}%")
|
||||
|
||||
if final_avg < 60:
|
||||
with st.expander("💡 Improvement Tips"):
|
||||
st.markdown("- Ensure your full body is visible to the camera.\n"
|
||||
"- Try to match the timing and range of motion of the standard video.\n"
|
||||
"- Ensure good, consistent lighting.")
|
||||
with st.expander("💡 改善建议"):
|
||||
st.markdown("- 确保您的全身在摄像头画面中清晰可见\n"
|
||||
"- 尽量匹配标准视频的节奏和动作幅度\n"
|
||||
"- 确保光线充足且稳定")
|
||||
|
||||
def start_comparison(self, video_path):
|
||||
"""The main loop for comparing motion."""
|
||||
@ -159,52 +159,301 @@ class MotionComparisonApp:
|
||||
self.similarity_analyzer.reset()
|
||||
|
||||
audio_loaded = self.audio_player.load_audio(video_path)
|
||||
if audio_loaded: st.success("✅ Audio loaded successfully")
|
||||
else: st.info("ℹ️ No audio will be played.")
|
||||
if audio_loaded: st.success("✅ 音频加载成功")
|
||||
else: st.info("ℹ️ 将不会播放音频")
|
||||
|
||||
self.standard_cap = cv2.VideoCapture(video_path)
|
||||
if not self.standard_cap.isOpened():
|
||||
st.error("Cannot open standard video.")
|
||||
st.error("无法打开标准视频文件")
|
||||
return
|
||||
|
||||
if not self.is_realsense_active and (not self.webcam_cap or not self.webcam_cap.isOpened()):
|
||||
if not self.initialize_camera(): return
|
||||
|
||||
# Get video info and setup variables
|
||||
total_frames = int(self.standard_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
video_fps = self.standard_cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_delay = 1.0 / video_fps if video_fps > 0 else 1.0 / 30.0
|
||||
target_width, target_height = self.get_display_resolution()
|
||||
|
||||
# UI Placeholders
|
||||
st.markdown("### 📺 Video Comparison")
|
||||
vid_col1, vid_col2 = st.columns(2, gap="small")
|
||||
standard_placeholder = vid_col1.empty()
|
||||
webcam_placeholder = vid_col2.empty()
|
||||
st.markdown("### 📺 视频比较")
|
||||
video_col1, video_col2 = st.columns(2, gap="small")
|
||||
standard_placeholder = video_col1.empty()
|
||||
webcam_placeholder = video_col2.empty()
|
||||
|
||||
# ... Control buttons setup as in original file ...
|
||||
with video_col1:
|
||||
st.markdown("#### 🎯 标准动作视频")
|
||||
|
||||
with video_col2:
|
||||
camera_type = "RealSense摄像头" if self.is_realsense_active else "USB摄像头"
|
||||
st.markdown(f"#### 📹 {camera_type}实时影像")
|
||||
|
||||
# 创建控制按钮区域(紧凑布局)
|
||||
st.markdown("---")
|
||||
control_container = st.container()
|
||||
with control_container:
|
||||
control_col1, control_col2, control_col3, control_col4 = st.columns([1, 1, 1, 1])
|
||||
|
||||
with control_col1:
|
||||
stop_button = st.button("⏹️ 停止", use_container_width=True, key="stop_comparison")
|
||||
with control_col2:
|
||||
restart_button = st.button("🔄 重新开始", use_container_width=True, key="restart_comparison")
|
||||
with control_col3:
|
||||
if audio_loaded:
|
||||
if st.button("🔊 音频状态", use_container_width=True, key="audio_status"):
|
||||
st.info(f"音频: {'播放中' if self.audio_player.is_playing else '已停止'}")
|
||||
with control_col4:
|
||||
# 分辨率切换按钮
|
||||
if st.button("📐 切换分辨率", use_container_width=True, key="resolution_toggle"):
|
||||
modes = ['high', 'medium', 'low']
|
||||
current_idx = modes.index(self.display_settings.get('resolution_mode', 'high'))
|
||||
next_idx = (current_idx + 1) % len(modes)
|
||||
self.display_settings['resolution_mode'] = modes[next_idx]
|
||||
st.info(f"分辨率模式: {modes[next_idx]}")
|
||||
|
||||
# Similarity UI
|
||||
st.markdown("---")
|
||||
st.markdown("### 📊 Similarity Analysis")
|
||||
st.markdown("### 📊 动作相似度分析")
|
||||
sim_col1, sim_col2, sim_col3 = st.columns([1, 1, 2])
|
||||
similarity_score_placeholder = sim_col1.empty()
|
||||
avg_score_placeholder = sim_col2.empty()
|
||||
similarity_plot_placeholder = sim_col3.empty()
|
||||
|
||||
# ... Progress bar setup ...
|
||||
# 处理按钮点击
|
||||
if stop_button:
|
||||
st.session_state.comparison_state['should_stop'] = True
|
||||
if restart_button:
|
||||
st.session_state.comparison_state['should_restart'] = True
|
||||
|
||||
# 状态显示区域
|
||||
status_container = st.container()
|
||||
with status_container:
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
|
||||
# 主循环
|
||||
video_frame_idx = 0
|
||||
start_time = time.time()
|
||||
current_similarity = 0
|
||||
last_plot_update = 0
|
||||
|
||||
# Start Audio
|
||||
if audio_loaded: self.audio_player.play()
|
||||
|
||||
# MAIN LOOP (Simplified logic, same as original)
|
||||
# while st.session_state.comparison_state['is_running'] and not st.session_state.comparison_state['should_stop']:
|
||||
# ... Read frames ...
|
||||
# ... Detect keypoints ...
|
||||
# ... Calculate similarity ...
|
||||
# ... Draw skeletons ...
|
||||
# ... Update UI placeholders ...
|
||||
# ... Handle restart/stop flags ...
|
||||
# ... Frame rate control ...
|
||||
|
||||
while (st.session_state.comparison_state['is_running'] and
|
||||
not st.session_state.comparison_state['should_stop']):
|
||||
|
||||
loop_start = time.time()
|
||||
self.frame_counter += 1
|
||||
|
||||
# 检查重新开始
|
||||
if st.session_state.comparison_state['should_restart']:
|
||||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
self.similarity_analyzer.reset()
|
||||
start_time = time.time()
|
||||
video_frame_idx = 0
|
||||
if audio_loaded:
|
||||
self.audio_player.restart()
|
||||
st.session_state.comparison_state['should_restart'] = False
|
||||
continue
|
||||
|
||||
# 读取标准视频的当前帧
|
||||
ret_standard, standard_frame = self.standard_cap.read()
|
||||
if not ret_standard:
|
||||
# 视频结束,重新开始
|
||||
self.standard_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
video_frame_idx = 0
|
||||
self.similarity_analyzer.reset()
|
||||
start_time = time.time()
|
||||
if audio_loaded:
|
||||
self.audio_player.restart()
|
||||
continue
|
||||
|
||||
# 读取摄像头当前帧
|
||||
ret_webcam, webcam_frame = self.read_camera_frame()
|
||||
if not ret_webcam or webcam_frame is None:
|
||||
self.error_count += 1
|
||||
if self.error_count > 100: # 连续错误过多时停止
|
||||
st.error("摄像头连接出现问题,停止比较")
|
||||
break
|
||||
continue
|
||||
|
||||
self.error_count = 0 # 重置错误计数
|
||||
|
||||
# 翻转摄像头画面(镜像效果)
|
||||
webcam_frame = cv2.flip(webcam_frame, 1)
|
||||
|
||||
# 调整尺寸使两个视频大小一致(使用更高分辨率)
|
||||
standard_frame = cv2.resize(standard_frame, (target_width, target_height))
|
||||
webcam_frame = cv2.resize(webcam_frame, (target_width, target_height))
|
||||
|
||||
# 处理关键点检测
|
||||
try:
|
||||
standard_keypoints, standard_scores = self.body_detector(standard_frame)
|
||||
webcam_keypoints, webcam_scores = self.body_detector(webcam_frame)
|
||||
except Exception as e:
|
||||
# 关键点检测失败时继续显示原始图像
|
||||
standard_keypoints, standard_scores = None, None
|
||||
webcam_keypoints, webcam_scores = None, None
|
||||
|
||||
# 计算相似度(每5帧计算一次以提高性能)
|
||||
if (self.frame_counter % 5 == 0 and
|
||||
standard_keypoints is not None and
|
||||
webcam_keypoints is not None):
|
||||
try:
|
||||
# 提取当前标准视频帧和摄像头帧的关节角度
|
||||
standard_angles = self.similarity_analyzer.extract_joint_angles(
|
||||
standard_keypoints, standard_scores
|
||||
)
|
||||
webcam_angles = self.similarity_analyzer.extract_joint_angles(
|
||||
webcam_keypoints, webcam_scores
|
||||
)
|
||||
|
||||
# 计算相似度
|
||||
if standard_angles and webcam_angles:
|
||||
current_similarity = self.similarity_analyzer.calculate_similarity(
|
||||
standard_angles, webcam_angles
|
||||
)
|
||||
|
||||
# 添加到历史记录
|
||||
timestamp = time.time() - start_time
|
||||
self.similarity_analyzer.add_similarity_score(current_similarity, timestamp)
|
||||
except Exception as e:
|
||||
pass # 忽略相似度计算错误
|
||||
|
||||
# 绘制关键点
|
||||
try:
|
||||
if standard_keypoints is not None and standard_scores is not None:
|
||||
standard_with_keypoints = draw_skeleton(
|
||||
standard_frame.copy(),
|
||||
standard_keypoints,
|
||||
standard_scores,
|
||||
openpose_skeleton=True,
|
||||
kpt_thr=0.43
|
||||
)
|
||||
else:
|
||||
standard_with_keypoints = standard_frame.copy()
|
||||
|
||||
if webcam_keypoints is not None and webcam_scores is not None:
|
||||
webcam_with_keypoints = draw_skeleton(
|
||||
webcam_frame.copy(),
|
||||
webcam_keypoints,
|
||||
webcam_scores,
|
||||
openpose_skeleton=True,
|
||||
kpt_thr=0.43
|
||||
)
|
||||
else:
|
||||
webcam_with_keypoints = webcam_frame.copy()
|
||||
|
||||
except Exception as e:
|
||||
# 如果绘制失败,使用原始帧
|
||||
standard_with_keypoints = standard_frame.copy()
|
||||
webcam_with_keypoints = webcam_frame.copy()
|
||||
|
||||
# 转换颜色空间 (BGR to RGB)
|
||||
standard_rgb = cv2.cvtColor(standard_with_keypoints, cv2.COLOR_BGR2RGB)
|
||||
webcam_rgb = cv2.cvtColor(webcam_with_keypoints, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 添加帧信息到图像上
|
||||
current_time = time.time() - start_time
|
||||
frame_info = f"时间: {current_time:.1f}s | 帧: {video_frame_idx}/{total_frames}"
|
||||
audio_info = f" | 音频: {'🔊' if self.audio_player.is_playing else '🔇'}" if audio_loaded else ""
|
||||
resolution_info = f" | {target_width}x{target_height}"
|
||||
|
||||
# 显示大尺寸画面
|
||||
with video_col1:
|
||||
standard_placeholder.image(
|
||||
standard_rgb,
|
||||
caption=f"标准动作 - {frame_info}{audio_info}{resolution_info}",
|
||||
use_container_width=True
|
||||
)
|
||||
|
||||
with video_col2:
|
||||
webcam_placeholder.image(
|
||||
webcam_rgb,
|
||||
caption=f"您的动作 - 实时画面{resolution_info}",
|
||||
use_container_width=True
|
||||
)
|
||||
|
||||
# 显示相似度信息(紧凑显示)
|
||||
if len(self.similarity_analyzer.similarity_history) > 0:
|
||||
try:
|
||||
avg_similarity = sum(self.similarity_analyzer.similarity_history) / len(self.similarity_analyzer.similarity_history)
|
||||
|
||||
# 使用不同颜色显示相似度
|
||||
if current_similarity >= 80:
|
||||
similarity_color = "🟢"
|
||||
level = "优秀"
|
||||
elif current_similarity >= 60:
|
||||
similarity_color = "🟡"
|
||||
level = "良好"
|
||||
else:
|
||||
similarity_color = "🔴"
|
||||
level = "需要改进"
|
||||
|
||||
with sim_col1:
|
||||
similarity_score_placeholder.metric(
|
||||
"当前相似度",
|
||||
f"{similarity_color} {current_similarity:.1f}%",
|
||||
delta=f"{level}"
|
||||
)
|
||||
|
||||
with sim_col2:
|
||||
avg_score_placeholder.metric(
|
||||
"平均相似度",
|
||||
f"{avg_similarity:.1f}%"
|
||||
)
|
||||
|
||||
# 更新相似度图表(每20帧更新一次以提高性能)
|
||||
if (len(self.similarity_analyzer.similarity_history) >= 2 and
|
||||
self.frame_counter - last_plot_update >= 20):
|
||||
try:
|
||||
similarity_plot = self.similarity_analyzer.get_similarity_plot()
|
||||
if similarity_plot:
|
||||
with sim_col3:
|
||||
similarity_plot_placeholder.plotly_chart(
|
||||
similarity_plot,
|
||||
use_container_width=True,
|
||||
key=f"similarity_plot_{int(time.time() * 1000)}" # 使用时间戳避免重复ID
|
||||
)
|
||||
last_plot_update = self.frame_counter
|
||||
except Exception as e:
|
||||
pass # 忽略图表更新错误
|
||||
except Exception as e:
|
||||
pass # 忽略显示错误
|
||||
|
||||
# 更新进度和状态(紧凑显示)
|
||||
try:
|
||||
progress = min(video_frame_idx / total_frames, 1.0) if total_frames > 0 else 0
|
||||
progress_bar.progress(progress)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
fps_actual = self.frame_counter / elapsed_time if elapsed_time > 0 else 0
|
||||
|
||||
status_text.text(
|
||||
f"进度: {video_frame_idx}/{total_frames} | "
|
||||
f"实际FPS: {fps_actual:.1f} | "
|
||||
f"分辨率: {target_width}x{target_height} | "
|
||||
f"模式: {self.display_settings['resolution_mode']}"
|
||||
)
|
||||
except Exception as e:
|
||||
pass # 忽略状态更新错误
|
||||
|
||||
video_frame_idx += 1
|
||||
|
||||
# 精确的帧率控制
|
||||
loop_elapsed = time.time() - loop_start
|
||||
sleep_time = max(0, frame_delay - loop_elapsed)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# 强制更新UI(每30帧一次)
|
||||
if self.frame_counter % 30 == 0:
|
||||
st.empty() # 触发UI更新
|
||||
|
||||
# The full loop from your original file goes here.
|
||||
# It's omitted for brevity but the logic remains identical.
|
||||
# Just ensure you call the correct methods:
|
||||
# e.g., self.read_camera_frame(), self.similarity_analyzer.calculate_similarity(), etc.
|
||||
|
||||
self.cleanup()
|
||||
self.show_final_statistics()
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user