Upload Music VAE checkpoint (epoch 99)
Browse files- README.md +113 -0
- best_model.pt +3 -0
- config.json +8 -0
README.md
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- music-generation
|
| 5 |
+
- variational-autoencoder
|
| 6 |
+
- piano-roll
|
| 7 |
+
- midi
|
| 8 |
+
- pytorch
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Music VAE β CNN pretrained on the Lakh MIDI Dataset
|
| 12 |
+
|
| 13 |
+
## Overview
|
| 14 |
+
|
| 15 |
+
This is a **Convolutional Variational Autoencoder (CNN VAE)** pretrained on the
|
| 16 |
+
[Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/) (lmd_full, ~175k MIDI files).
|
| 17 |
+
|
| 18 |
+
It was trained as part of a machine learning course assignment at Purdue University
|
| 19 |
+
to give students a meaningful starting point for music-generation tasks.
|
| 20 |
+
|
| 21 |
+
## Input / Output Format
|
| 22 |
+
|
| 23 |
+
| Property | Value |
|
| 24 |
+
|-----------------|--------------------------------|
|
| 25 |
+
| Input shape | `[batch, 1, 88, 32]` β float32 |
|
| 26 |
+
| Output shape | `[batch, 1, 88, 32]` β float32 |
|
| 27 |
+
| Pitch range | MIDI 21β108 (A0 β C8, 88 keys) |
|
| 28 |
+
| Time resolution | 16th notes at 120 BPM |
|
| 29 |
+
| Segment length | 2 bars = 32 timesteps |
|
| 30 |
+
| Value range | [0, 1] (Sigmoid output) |
|
| 31 |
+
|
| 32 |
+
A tensor value of `1` at position `[pitch_idx, time_step]` means that pitch
|
| 33 |
+
`21 + pitch_idx` is active at that 16th-note time step.
|
| 34 |
+
|
| 35 |
+
## Architecture Summary
|
| 36 |
+
|
| 37 |
+
```
|
| 38 |
+
ENCODER
|
| 39 |
+
Conv2d(1 β 32, k=4, s=2, p=1) + ReLU + BN β [B, 32, 44, 16]
|
| 40 |
+
Conv2d(32 β 64, k=4, s=2, p=1) + ReLU + BN β [B, 64, 22, 8]
|
| 41 |
+
Conv2d(64 β 128,k=4, s=2, p=1) + ReLU + BN β [B, 128, 11, 4]
|
| 42 |
+
Conv2d(128β 256,k=4, s=2, p=1) + ReLU + BN β [B, 256, 5, 2]
|
| 43 |
+
Flatten β 2560
|
| 44 |
+
Linear β mu [B, 256]
|
| 45 |
+
Linear β log_var [B, 256]
|
| 46 |
+
|
| 47 |
+
REPARAMETERISATION
|
| 48 |
+
z = mu + eps * exp(0.5 * log_var), eps ~ N(0, I)
|
| 49 |
+
|
| 50 |
+
DECODER
|
| 51 |
+
Linear(256 β 2560) β Reshape [B, 256, 5, 2]
|
| 52 |
+
ConvTranspose2d(256β128, k=4, s=2, p=1, output_padding=(1,0)) β [B, 128, 11, 4]
|
| 53 |
+
ConvTranspose2d(128β 64, k=4, s=2, p=1) β [B, 64, 22, 8]
|
| 54 |
+
ConvTranspose2d( 64β 32, k=4, s=2, p=1) β [B, 32, 44, 16]
|
| 55 |
+
ConvTranspose2d( 32β 1, k=4, s=2, p=1) + Sigmoid β [B, 1, 88, 32]
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
- **Latent dimension**: 256
|
| 59 |
+
- **Trainable parameters**: ~4.2M
|
| 60 |
+
|
| 61 |
+
## Loading the Model (Course Assignment)
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
import torch
|
| 65 |
+
from model import MusicVAE # copy src/model.py into your project
|
| 66 |
+
|
| 67 |
+
# Load checkpoint
|
| 68 |
+
ckpt = torch.load("best_model.pt", map_location="cpu")
|
| 69 |
+
config = ckpt["config"]
|
| 70 |
+
|
| 71 |
+
model = MusicVAE(latent_dim=config["latent_dim"])
|
| 72 |
+
model.load_state_dict(ckpt["model_state"])
|
| 73 |
+
model.eval()
|
| 74 |
+
|
| 75 |
+
# Generate new piano rolls
|
| 76 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 77 |
+
model = model.to(device)
|
| 78 |
+
samples = model.sample(n=4, device=device) # [4, 1, 88, 32]
|
| 79 |
+
|
| 80 |
+
# Encode a segment and reconstruct it
|
| 81 |
+
x = ... # your [1, 1, 88, 32] piano-roll tensor
|
| 82 |
+
x_recon, mu, log_var = model(x.to(device))
|
| 83 |
+
|
| 84 |
+
# Interpolate between two points in latent space
|
| 85 |
+
z1 = mu[0:1]
|
| 86 |
+
z2 = mu[1:2] # second example
|
| 87 |
+
interp = model.interpolate(z1, z2, steps=8)
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
Also see `src/utils.py` for `pianoroll_to_midi()` and `visualize_pianoroll()`.
|
| 91 |
+
|
| 92 |
+
## Training Details
|
| 93 |
+
|
| 94 |
+
- **Dataset**: Lakh MIDI Dataset (lmd_full)
|
| 95 |
+
- **Piano roll**: 88-pitch binary, 16th-note resolution, 120 BPM normalised
|
| 96 |
+
- **Segments**: 2 bars (32 frames), stride 1 bar (16 frames)
|
| 97 |
+
- **Loss**: BCE reconstruction + Ξ²-annealed KL (Ξ²: 0 β 1 over 20 epochs)
|
| 98 |
+
- **Optimizer**: Adam, lr=1e-3, ReduceLROnPlateau (patience=5, factor=0.5)
|
| 99 |
+
- **Batch size**: 256 | **Epochs**: 100 | **Gradient clip**: 1.0
|
| 100 |
+
|
| 101 |
+
## Citation
|
| 102 |
+
|
| 103 |
+
If you use this model in your work, please cite the Lakh MIDI Dataset:
|
| 104 |
+
|
| 105 |
+
```bibtex
|
| 106 |
+
@inproceedings{Raffel2016,
|
| 107 |
+
author = {Colin Raffel},
|
| 108 |
+
title = {Learning-Based Methods for Comparing Sequences, with Applications
|
| 109 |
+
to Audio-to-{MIDI} Alignment and Matching},
|
| 110 |
+
booktitle = {PhD Thesis, Columbia University},
|
| 111 |
+
year = {2016}
|
| 112 |
+
}
|
| 113 |
+
```
|
best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f62b170a51e317c2c930d612a878aed2331f6855cc1272a160df019564abcc1
|
| 3 |
+
size 13414709
|
config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"latent_dim": 256,
|
| 3 |
+
"n_pitches": 88,
|
| 4 |
+
"seg_frames": 32,
|
| 5 |
+
"pitch_min": 21,
|
| 6 |
+
"pitch_max": 108,
|
| 7 |
+
"architecture": "CNN-VAE"
|
| 8 |
+
}
|