Skip to content
  • P
    Projects
  • G
    Groups
  • S
    Snippets
  • Help

orsier / mles_distortion_rm_proj

  • This project
    • Loading...
  • Sign in
Go to a project
  • Project
  • Repository
  • Issues 0
  • Merge Requests 0
  • Pipelines
  • Wiki
  • Snippets
  • Members
  • Activity
  • Graph
  • Charts
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
  • Files
  • Commits
  • Branches
  • Tags
  • Contributors
  • Graph
  • Compare
  • Charts
Switch branch/tag
  • mles_distortion_rm_proj
  • metrics
  • metricsresult.py
Find file
BlameHistoryPermalink
  • orsier's avatar
    Second commit · a9773f04
    orsier committed 4 months ago
    a9773f04
metricsresult.py 1.22 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Charger le fichier CSV
df = pd.read_csv("audio_comparison_results_normalized.csv")

# Calculer les moyennes pour les différentes métriques
mse_vocals_target_avg = df["mse_vocals_target"].mean()
snr_vocals_target_avg = df["snr_vocals_target"].mean()
sdr_vocals_target_avg = df["sdr_vocals_target"].mean()

mse_mixture_target_avg = df["mse_mixture_target"].mean()
snr_mixture_target_avg = df["snr_mixture_target"].mean()
sdr_mixture_target_avg = df["sdr_mixture_target"].mean()

# Construire le tableau des valeurs
metrics_table = [
    ["DemucsV3", "SNR", "SDR", "MSE"],
    ["vocals vs target", f"{snr_vocals_target_avg:.2f}", f"{sdr_vocals_target_avg:.2f}", f"{mse_vocals_target_avg:.2f}"],
    ["mixture vs target", f"{snr_mixture_target_avg:.2f}", f"{sdr_mixture_target_avg:.2f}", f"{mse_mixture_target_avg:.2f}"]
]

# Créer la figure et l'axe pour le tableau
fig, ax = plt.subplots(figsize=(8, 3))
ax.axis('off')  # Masquer les axes

# Créer le tableau
ax.table(
    cellText=metrics_table,
    colLabels=None,
    cellLoc='center',
    loc='center',
    colWidths=[0.2] * 4
)

# Sauvegarder l'image du tableau
plt.savefig("metrics.jpg", bbox_inches='tight')
plt.show()