LexiMind / scripts /visualize_training.py
OliverPerrin
Update LexiMind: improved training, model architecture, and evaluation
1ec7405
"""
Visualize training metrics from MLflow runs.
Generates plots showing:
- Loss curves (training/validation)
- Task-specific metrics over time
- Learning rate schedule
- Training speed analysis
Author: Oliver Perrin
Date: December 2025
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import mlflow
import mlflow.tracking
import seaborn as sns
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.utils.logging import configure_logging, get_logger
configure_logging()
logger = get_logger(__name__)
# Configure plotting style
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (12, 8)
plt.rcParams["figure.dpi"] = 100
OUTPUTS_DIR = PROJECT_ROOT / "outputs"
MLRUNS_DIR = PROJECT_ROOT / "mlruns"
def load_training_history() -> dict[str, object] | None:
"""Load training history from JSON if available."""
history_path = OUTPUTS_DIR / "training_history.json"
if history_path.exists():
with open(history_path) as f:
data: dict[str, object] = json.load(f)
return data
return None
def get_latest_run():
"""Get the most recent MLflow run."""
mlflow.set_tracking_uri(f"file://{MLRUNS_DIR}")
client = mlflow.tracking.MlflowClient()
# Get the experiment (LexiMind)
experiment = client.get_experiment_by_name("LexiMind")
if not experiment:
logger.error("No 'LexiMind' experiment found")
return None
# Get all runs, sorted by start time
runs = client.search_runs(
experiment_ids=[experiment.experiment_id],
order_by=["start_time DESC"],
max_results=1,
)
if not runs:
logger.error("No runs found in experiment")
return None
return runs[0]
def plot_loss_curves(run):
"""Plot training and validation loss over time."""
client = mlflow.tracking.MlflowClient()
# Get metrics
train_loss = client.get_metric_history(run.info.run_id, "train_total_loss")
val_loss = client.get_metric_history(run.info.run_id, "val_total_loss")
fig, ax = plt.subplots(figsize=(12, 6))
if not train_loss:
# Create placeholder plot
ax.text(
0.5,
0.5,
"No training data yet\n\nWaiting for first epoch to complete...",
ha="center",
va="center",
fontsize=14,
color="gray",
)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
else:
# Extract steps and values
train_steps = [m.step for m in train_loss]
train_values = [m.value for m in train_loss]
ax.plot(train_steps, train_values, label="Training Loss", linewidth=2, alpha=0.8)
if val_loss:
val_steps = [m.step for m in val_loss]
val_values = [m.value for m in val_loss]
ax.plot(val_steps, val_values, label="Validation Loss", linewidth=2, alpha=0.8)
ax.legend(fontsize=11)
ax.set_xlabel("Epoch", fontsize=12)
ax.set_ylabel("Loss", fontsize=12)
ax.set_title("Training Progress: Total Loss", fontsize=14, fontweight="bold")
ax.grid(True, alpha=0.3)
plt.tight_layout()
output_path = OUTPUTS_DIR / "training_loss_curve.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
logger.info(f"✓ Saved loss curve to {output_path}")
plt.close()
def plot_task_metrics(run):
"""Plot metrics for each task."""
client = mlflow.tracking.MlflowClient()
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle("Task-Specific Training Metrics", fontsize=16, fontweight="bold")
# Summarization
ax = axes[0, 0]
train_sum = client.get_metric_history(run.info.run_id, "train_summarization_loss")
val_sum = client.get_metric_history(run.info.run_id, "val_summarization_loss")
if train_sum:
ax.plot(
[m.step for m in train_sum], [m.value for m in train_sum], label="Train", linewidth=2
)
if val_sum:
ax.plot([m.step for m in val_sum], [m.value for m in val_sum], label="Val", linewidth=2)
ax.set_title("Summarization Loss", fontweight="bold")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.legend()
ax.grid(True, alpha=0.3)
# Emotion
ax = axes[0, 1]
train_emo = client.get_metric_history(run.info.run_id, "train_emotion_loss")
val_emo = client.get_metric_history(run.info.run_id, "val_emotion_loss")
train_f1 = client.get_metric_history(run.info.run_id, "train_emotion_f1")
val_f1 = client.get_metric_history(run.info.run_id, "val_emotion_f1")
if train_emo:
ax.plot(
[m.step for m in train_emo],
[m.value for m in train_emo],
label="Train Loss",
linewidth=2,
)
if val_emo:
ax.plot(
[m.step for m in val_emo], [m.value for m in val_emo], label="Val Loss", linewidth=2
)
ax2 = ax.twinx()
if train_f1:
ax2.plot(
[m.step for m in train_f1],
[m.value for m in train_f1],
label="Train F1",
linewidth=2,
linestyle="--",
alpha=0.7,
)
if val_f1:
ax2.plot(
[m.step for m in val_f1],
[m.value for m in val_f1],
label="Val F1",
linewidth=2,
linestyle="--",
alpha=0.7,
)
ax.set_title("Emotion Detection", fontweight="bold")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax2.set_ylabel("F1 Score")
ax.legend(loc="upper left")
ax2.legend(loc="upper right")
ax.grid(True, alpha=0.3)
# Topic
ax = axes[1, 0]
train_topic = client.get_metric_history(run.info.run_id, "train_topic_loss")
val_topic = client.get_metric_history(run.info.run_id, "val_topic_loss")
train_acc = client.get_metric_history(run.info.run_id, "train_topic_accuracy")
val_acc = client.get_metric_history(run.info.run_id, "val_topic_accuracy")
if train_topic:
ax.plot(
[m.step for m in train_topic],
[m.value for m in train_topic],
label="Train Loss",
linewidth=2,
)
if val_topic:
ax.plot(
[m.step for m in val_topic], [m.value for m in val_topic], label="Val Loss", linewidth=2
)
ax2 = ax.twinx()
if train_acc:
ax2.plot(
[m.step for m in train_acc],
[m.value for m in train_acc],
label="Train Acc",
linewidth=2,
linestyle="--",
alpha=0.7,
)
if val_acc:
ax2.plot(
[m.step for m in val_acc],
[m.value for m in val_acc],
label="Val Acc",
linewidth=2,
linestyle="--",
alpha=0.7,
)
ax.set_title("Topic Classification", fontweight="bold")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax2.set_ylabel("Accuracy")
ax.legend(loc="upper left")
ax2.legend(loc="upper right")
ax.grid(True, alpha=0.3)
# Summary statistics
ax = axes[1, 1]
ax.axis("off")
# Get final metrics
summary_text = "Final Metrics (Last Epoch)\n" + "=" * 35 + "\n\n"
if val_topic and val_acc:
summary_text += f"Topic Accuracy: {val_acc[-1].value:.1%}\n"
if val_emo and val_f1:
summary_text += f"Emotion F1: {val_f1[-1].value:.1%}\n"
if val_sum:
summary_text += f"Summarization Loss: {val_sum[-1].value:.3f}\n"
ax.text(0.1, 0.5, summary_text, fontsize=12, family="monospace", verticalalignment="center")
plt.tight_layout()
output_path = OUTPUTS_DIR / "task_metrics.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
logger.info(f"✓ Saved task metrics to {output_path}")
plt.close()
def plot_learning_rate(run):
"""Plot learning rate schedule if available."""
client = mlflow.tracking.MlflowClient()
lr_metrics = client.get_metric_history(run.info.run_id, "learning_rate")
fig, ax = plt.subplots(figsize=(12, 5))
if not lr_metrics:
# Create placeholder
ax.text(
0.5,
0.5,
"No learning rate data yet\n\n(Will be logged in future training runs)",
ha="center",
va="center",
fontsize=14,
color="gray",
)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
else:
steps = [m.step for m in lr_metrics]
values = [m.value for m in lr_metrics]
ax.plot(steps, values, linewidth=2, color="darkblue")
# Mark warmup region
warmup_steps = 1000 # From config
if warmup_steps < max(steps):
ax.axvline(warmup_steps, color="red", linestyle="--", alpha=0.5, label="Warmup End")
ax.legend()
ax.set_xlabel("Step", fontsize=12)
ax.set_ylabel("Learning Rate", fontsize=12)
ax.set_title("Learning Rate Schedule (Cosine with Warmup)", fontsize=14, fontweight="bold")
ax.grid(True, alpha=0.3)
plt.tight_layout()
output_path = OUTPUTS_DIR / "learning_rate_schedule.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
logger.info(f"✓ Saved LR schedule to {output_path}")
plt.close()
def main():
"""Generate all training visualizations."""
logger.info("Loading MLflow data...")
run = get_latest_run()
if not run:
logger.error("No training run found. Make sure training has started.")
return
logger.info(f"Analyzing run: {run.info.run_id}")
OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
logger.info("Generating visualizations...")
plot_loss_curves(run)
plot_task_metrics(run)
plot_learning_rate(run)
logger.info("\n" + "=" * 60)
logger.info("✓ All visualizations saved to outputs/")
logger.info("=" * 60)
logger.info(" - training_loss_curve.png")
logger.info(" - task_metrics.png")
logger.info(" - learning_rate_schedule.png")
logger.info("=" * 60)
if __name__ == "__main__":
main()