posedet/similarity_display_widget.py
2025-06-22 16:51:48 +08:00

260 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import threading
import queue
from PyQt5.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QLabel,
QPushButton, QFrame, QGridLayout)
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer
from PyQt5.QtGui import QFont, QPalette
# Placeholder for matplotlib widget since we'll need it for plotting
try:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
class SimilarityThread(QThread):
similarity_update = pyqtSignal(float, float) # current_similarity, average_similarity
plot_update = pyqtSignal()
def __init__(self, motion_app):
super().__init__()
self.motion_app = motion_app
self.is_running = False
self.should_stop = False
def start_analysis(self):
self.is_running = True
self.should_stop = False
self.start()
def stop_analysis(self):
self.should_stop = True
self.is_running = False
def run(self):
while self.is_running and not self.should_stop:
try:
# Get pose data from queue with timeout
pose_data = self.motion_app.pose_data_queue.get(timeout=1.0)
if pose_data is None: # Poison pill
break
elapsed_time, standard_angles, webcam_angles = pose_data
if standard_angles and webcam_angles:
current_similarity = self.motion_app.similarity_analyzer.calculate_similarity(
standard_angles, webcam_angles)
self.motion_app.similarity_analyzer.add_similarity_score(
current_similarity, elapsed_time)
# Calculate average
history = self.motion_app.similarity_analyzer.similarity_history
if history:
avg_similarity = sum(history) / len(history)
self.similarity_update.emit(current_similarity, avg_similarity)
# Emit plot update signal (less frequently)
if len(history) % 10 == 0:
self.plot_update.emit()
self.motion_app.pose_data_queue.task_done()
except queue.Empty:
continue
except Exception as e:
continue
class SimilarityPlotWidget(QWidget):
def __init__(self):
super().__init__()
self.setup_ui()
def setup_ui(self):
layout = QVBoxLayout(self)
if MATPLOTLIB_AVAILABLE:
# Create matplotlib figure
self.figure = Figure(figsize=(8, 4))
self.canvas = FigureCanvas(self.figure)
layout.addWidget(self.canvas)
self.ax = self.figure.add_subplot(111)
self.ax.set_title('Similarity Trend')
self.ax.set_xlabel('Time (s)')
self.ax.set_ylabel('Score (%)')
self.ax.set_ylim(0, 100)
self.ax.grid(True, alpha=0.3)
# Initialize empty line
self.line, = self.ax.plot([], [], 'b-', linewidth=2, label='Similarity')
self.avg_line = self.ax.axhline(y=0, color='r', linestyle='--', alpha=0.7, label='Average')
self.ax.legend()
self.figure.tight_layout()
self.canvas.draw()
else:
# Fallback to simple text display
self.plot_label = QLabel("matplotlib不可用无法显示图表")
self.plot_label.setAlignment(Qt.AlignCenter)
layout.addWidget(self.plot_label)
def update_plot(self, timestamps, similarities):
if not MATPLOTLIB_AVAILABLE or not timestamps or not similarities:
return
# Update line data
self.line.set_data(timestamps, similarities)
# Update average line
if similarities:
avg = sum(similarities) / len(similarities)
self.avg_line.set_ydata([avg, avg])
# Update axis limits
if timestamps:
self.ax.set_xlim(0, max(timestamps) + 1)
self.canvas.draw()
class SimilarityDisplayWidget(QWidget):
def __init__(self, motion_app):
super().__init__()
self.motion_app = motion_app
self.similarity_thread = SimilarityThread(motion_app)
self.setup_ui()
self.connect_signals()
def setup_ui(self):
layout = QVBoxLayout(self)
# Title
title_label = QLabel("📊 动作相似度分析")
title_label.setAlignment(Qt.AlignCenter)
title_label.setStyleSheet("font-size: 16pt; font-weight: bold; margin: 10px;")
layout.addWidget(title_label)
# Metrics display
metrics_layout = QGridLayout()
# Current similarity
self.current_sim_label = QLabel("当前相似度")
self.current_sim_value = QLabel("0.0%")
self.current_sim_value.setAlignment(Qt.AlignCenter)
self.current_sim_value.setStyleSheet("font-size: 24pt; font-weight: bold; color: #0086d3;")
# Average similarity
self.avg_sim_label = QLabel("平均相似度")
self.avg_sim_value = QLabel("0.0%")
self.avg_sim_value.setAlignment(Qt.AlignCenter)
self.avg_sim_value.setStyleSheet("font-size: 24pt; font-weight: bold; color: #2e7d32;")
metrics_layout.addWidget(self.current_sim_label, 0, 0)
metrics_layout.addWidget(self.current_sim_value, 1, 0)
metrics_layout.addWidget(self.avg_sim_label, 0, 1)
metrics_layout.addWidget(self.avg_sim_value, 1, 1)
layout.addLayout(metrics_layout)
# Plot widget
self.plot_widget = SimilarityPlotWidget()
layout.addWidget(self.plot_widget)
# Final statistics (hidden initially)
self.stats_widget = self.create_statistics_widget()
self.stats_widget.hide()
layout.addWidget(self.stats_widget)
def create_statistics_widget(self):
stats_widget = QFrame()
stats_widget.setFrameStyle(QFrame.StyledPanel)
stats_widget.setStyleSheet("background-color: #f0f8ff; border: 1px solid #ccc; border-radius: 5px;")
layout = QVBoxLayout(stats_widget)
self.final_title = QLabel("🎉 比较完成!")
self.final_title.setAlignment(Qt.AlignCenter)
self.final_title.setStyleSheet("font-size: 18pt; font-weight: bold; color: #2e7d32; margin: 10px;")
self.performance_label = QLabel()
self.performance_label.setAlignment(Qt.AlignCenter)
self.performance_label.setStyleSheet("font-size: 14pt; font-weight: bold; margin: 5px;")
# Final metrics
final_metrics_layout = QGridLayout()
self.final_avg_label = QLabel("平均相似度")
self.final_avg_value = QLabel("0.0%")
self.final_max_label = QLabel("最高相似度")
self.final_max_value = QLabel("0.0%")
self.final_min_label = QLabel("最低相似度")
self.final_min_value = QLabel("0.0%")
final_metrics_layout.addWidget(self.final_avg_label, 0, 0)
final_metrics_layout.addWidget(self.final_avg_value, 1, 0)
final_metrics_layout.addWidget(self.final_max_label, 0, 1)
final_metrics_layout.addWidget(self.final_max_value, 1, 1)
final_metrics_layout.addWidget(self.final_min_label, 0, 2)
final_metrics_layout.addWidget(self.final_min_value, 1, 2)
layout.addWidget(self.final_title)
layout.addWidget(self.performance_label)
layout.addLayout(final_metrics_layout)
return stats_widget
def connect_signals(self):
self.similarity_thread.similarity_update.connect(self.update_similarity_display)
self.similarity_thread.plot_update.connect(self.update_plot)
def start_analysis(self):
self.stats_widget.hide()
self.motion_app.similarity_analyzer.reset()
self.similarity_thread.start_analysis()
def stop_analysis(self):
self.similarity_thread.stop_analysis()
self.show_final_statistics()
def update_similarity_display(self, current_similarity, average_similarity):
self.current_sim_value.setText(f"{current_similarity:.1f}%")
self.avg_sim_value.setText(f"{average_similarity:.1f}%")
def update_plot(self):
analyzer = self.motion_app.similarity_analyzer
if hasattr(analyzer, 'similarity_history') and hasattr(analyzer, 'frame_timestamps'):
timestamps = list(analyzer.frame_timestamps)
similarities = list(analyzer.similarity_history)
self.plot_widget.update_plot(timestamps, similarities)
def show_final_statistics(self):
history = self.motion_app.similarity_analyzer.similarity_history
if not history:
return
final_avg = sum(history) / len(history)
final_max = max(history)
final_min = min(history)
# Set performance level and color
if final_avg >= 80:
level = "非常棒! 👏"
color = "green"
elif final_avg >= 60:
level = "整体不错! 👍"
color = "blue"
else:
level = "需要改进! 💪"
color = "orange"
self.performance_label.setText(f"整体表现: {level}")
self.performance_label.setStyleSheet(f"font-size: 14pt; font-weight: bold; margin: 5px; color: {color};")
self.final_avg_value.setText(f"{final_avg:.1f}%")
self.final_max_value.setText(f"{final_max:.1f}%")
self.final_min_value.setText(f"{final_min:.1f}%")
self.stats_widget.show()