XR-Lin commited on
Commit
3d19d4f
Β·
verified Β·
1 Parent(s): 433c066

Upload Music VAE checkpoint (epoch 99)

Browse files
Files changed (3) hide show
  1. README.md +113 -0
  2. best_model.pt +3 -0
  3. config.json +8 -0
README.md ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - music-generation
5
+ - variational-autoencoder
6
+ - piano-roll
7
+ - midi
8
+ - pytorch
9
+ ---
10
+
11
+ # Music VAE β€” CNN pretrained on the Lakh MIDI Dataset
12
+
13
+ ## Overview
14
+
15
+ This is a **Convolutional Variational Autoencoder (CNN VAE)** pretrained on the
16
+ [Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/) (lmd_full, ~175k MIDI files).
17
+
18
+ It was trained as part of a machine learning course assignment at Purdue University
19
+ to give students a meaningful starting point for music-generation tasks.
20
+
21
+ ## Input / Output Format
22
+
23
+ | Property | Value |
24
+ |-----------------|--------------------------------|
25
+ | Input shape | `[batch, 1, 88, 32]` β€” float32 |
26
+ | Output shape | `[batch, 1, 88, 32]` β€” float32 |
27
+ | Pitch range | MIDI 21–108 (A0 – C8, 88 keys) |
28
+ | Time resolution | 16th notes at 120 BPM |
29
+ | Segment length | 2 bars = 32 timesteps |
30
+ | Value range | [0, 1] (Sigmoid output) |
31
+
32
+ A tensor value of `1` at position `[pitch_idx, time_step]` means that pitch
33
+ `21 + pitch_idx` is active at that 16th-note time step.
34
+
35
+ ## Architecture Summary
36
+
37
+ ```
38
+ ENCODER
39
+ Conv2d(1 β†’ 32, k=4, s=2, p=1) + ReLU + BN β†’ [B, 32, 44, 16]
40
+ Conv2d(32 β†’ 64, k=4, s=2, p=1) + ReLU + BN β†’ [B, 64, 22, 8]
41
+ Conv2d(64 β†’ 128,k=4, s=2, p=1) + ReLU + BN β†’ [B, 128, 11, 4]
42
+ Conv2d(128β†’ 256,k=4, s=2, p=1) + ReLU + BN β†’ [B, 256, 5, 2]
43
+ Flatten β†’ 2560
44
+ Linear β†’ mu [B, 256]
45
+ Linear β†’ log_var [B, 256]
46
+
47
+ REPARAMETERISATION
48
+ z = mu + eps * exp(0.5 * log_var), eps ~ N(0, I)
49
+
50
+ DECODER
51
+ Linear(256 β†’ 2560) β†’ Reshape [B, 256, 5, 2]
52
+ ConvTranspose2d(256β†’128, k=4, s=2, p=1, output_padding=(1,0)) β†’ [B, 128, 11, 4]
53
+ ConvTranspose2d(128β†’ 64, k=4, s=2, p=1) β†’ [B, 64, 22, 8]
54
+ ConvTranspose2d( 64β†’ 32, k=4, s=2, p=1) β†’ [B, 32, 44, 16]
55
+ ConvTranspose2d( 32β†’ 1, k=4, s=2, p=1) + Sigmoid β†’ [B, 1, 88, 32]
56
+ ```
57
+
58
+ - **Latent dimension**: 256
59
+ - **Trainable parameters**: ~4.2M
60
+
61
+ ## Loading the Model (Course Assignment)
62
+
63
+ ```python
64
+ import torch
65
+ from model import MusicVAE # copy src/model.py into your project
66
+
67
+ # Load checkpoint
68
+ ckpt = torch.load("best_model.pt", map_location="cpu")
69
+ config = ckpt["config"]
70
+
71
+ model = MusicVAE(latent_dim=config["latent_dim"])
72
+ model.load_state_dict(ckpt["model_state"])
73
+ model.eval()
74
+
75
+ # Generate new piano rolls
76
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
+ model = model.to(device)
78
+ samples = model.sample(n=4, device=device) # [4, 1, 88, 32]
79
+
80
+ # Encode a segment and reconstruct it
81
+ x = ... # your [1, 1, 88, 32] piano-roll tensor
82
+ x_recon, mu, log_var = model(x.to(device))
83
+
84
+ # Interpolate between two points in latent space
85
+ z1 = mu[0:1]
86
+ z2 = mu[1:2] # second example
87
+ interp = model.interpolate(z1, z2, steps=8)
88
+ ```
89
+
90
+ Also see `src/utils.py` for `pianoroll_to_midi()` and `visualize_pianoroll()`.
91
+
92
+ ## Training Details
93
+
94
+ - **Dataset**: Lakh MIDI Dataset (lmd_full)
95
+ - **Piano roll**: 88-pitch binary, 16th-note resolution, 120 BPM normalised
96
+ - **Segments**: 2 bars (32 frames), stride 1 bar (16 frames)
97
+ - **Loss**: BCE reconstruction + Ξ²-annealed KL (Ξ²: 0 β†’ 1 over 20 epochs)
98
+ - **Optimizer**: Adam, lr=1e-3, ReduceLROnPlateau (patience=5, factor=0.5)
99
+ - **Batch size**: 256 | **Epochs**: 100 | **Gradient clip**: 1.0
100
+
101
+ ## Citation
102
+
103
+ If you use this model in your work, please cite the Lakh MIDI Dataset:
104
+
105
+ ```bibtex
106
+ @inproceedings{Raffel2016,
107
+ author = {Colin Raffel},
108
+ title = {Learning-Based Methods for Comparing Sequences, with Applications
109
+ to Audio-to-{MIDI} Alignment and Matching},
110
+ booktitle = {PhD Thesis, Columbia University},
111
+ year = {2016}
112
+ }
113
+ ```
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f62b170a51e317c2c930d612a878aed2331f6855cc1272a160df019564abcc1
3
+ size 13414709
config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "latent_dim": 256,
3
+ "n_pitches": 88,
4
+ "seg_frames": 32,
5
+ "pitch_min": 21,
6
+ "pitch_max": 108,
7
+ "architecture": "CNN-VAE"
8
+ }