Quantization-Aware Training (QAT) β Violence Detection
This notebook trains a violence detection model using Quantization-Aware Training (QAT) via TensorFlow Model Optimization Toolkit (TFMOT), then exports it as an INT8 TFLite model for edge deployment. It builds on a MobileNetV2 + BiLSTM architecture trained on the Real Life Violence Situations Dataset.
Overview
| Property | Details |
|---|---|
| Framework | TensorFlow + tensorflow-model-optimization |
| Architecture | MobileNetV2 (feature extractor) + Bidirectional LSTM (classifier) |
| Input Shape | (batch, 16, 224, 224, 3) β 16-frame sequences at 224Γ224 |
| Classes | NonViolence, Violence |
| Output | qat_video_model.tflite (INT8) |
| Checkpoint storage | Hugging Face Hub (private repo) |
| Platform | Kaggle (GPU-enabled) |
Key Design Decisions
QAT vs. Post-Training Quantization
Unlike post-training quantization (PTQ), QAT inserts fake-quantization nodes into the model graph during training. This allows the model to learn to compensate for quantization errors, resulting in better INT8 accuracy. Mixed precision (float16) is explicitly disabled here β it conflicts with fake-quantization nodes.
Split Quantization Strategy
The CNN and LSTM parts are quantized separately:
- MobileNetV2 β fully quantized with
tfmot.quantization.keras.quantize_model(). CNN layers are well-supported by TFMOT. - Bidirectional LSTM head β left in float32. Bidirectional/LSTM layers have limited TFMOT support and are kept unquantized to prevent graph errors.
LSTM TFLite Compatibility Flags
Both training-time and export-time conversions set these flags to handle dynamic LSTM loops:
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
converter._experimental_lower_tensor_list_ops = False
Requirements
tensorflow
tensorflow-model-optimization
numpy<2.0.0
opencv-python
scikit-learn
matplotlib
huggingface_hub
Install via:
pip install -q huggingface_hub tensorflow-model-optimization "numpy<2.0.0"
Note:
TF_USE_LEGACY_KERAS=1must be set before importing TensorFlow to ensure TFMOT compatibility on Kaggle (which uses Keras 3 by default).
Configuration
Key hyperparameters and paths are defined in Cell 4:
| Variable | Default | Description |
|---|---|---|
IMAGE_HEIGHT/WIDTH |
224 |
Frame resolution |
SEQUENCE_LENGTH |
16 |
Frames per video clip |
BATCH_SIZE |
8 |
Training batch size |
EPOCHS |
10 |
Total training epochs |
RESUME |
True |
Resume from HF checkpoint if available |
DATASET_PATH |
Kaggle input path | Root of the violence dataset |
HF_REPO_ID |
Auto-generated | Private HF repo for checkpoints |
Notebook Structure
| Cells | Purpose |
|---|---|
| 1β3 | Install dependencies, imports, TFMOT setup |
| 4β5 | Config: hyperparameters, HF repo, checkpoint paths |
| 6 | GPU detection and QAT mode confirmation |
| 7β10 | HuggingFace upload/download helpers, training state I/O |
| 11 | HuggingFaceCheckpoint callback β saves model to HF after each epoch |
| 12 | frames_extraction() β reads video, resizes frames, applies motion-based sampling |
| 13 | collect_video_paths() β scans dataset directory for .mp4/.avi/.mov files |
| 14 | VideoDataGenerator β Keras Sequence for memory-efficient batch loading |
| 15 | Custom flatten_time / unflatten_time Keras-serializable helpers |
| 16 | create_model() β builds the QAT MobileNetV2 + BiLSTM model |
| 17 | plot_training() β saves accuracy/loss curves to training_curves.png |
| 18 | Main training loop β loads state, splits data, trains with MirroredStrategy |
| 19 | TFLite conversion β qat_video_model.tflite |
| 20β26 | Empty (reserved / scratch) |
Training & Checkpointing
Training uses tf.distribute.MirroredStrategy for multi-GPU support. After every epoch, the HuggingFaceCheckpoint callback:
- Saves the model locally as
model_qat.keras - Uploads it to your private Hugging Face repo
- Saves training state (epoch number, metrics, learning rate) as
training_state_qat.json
Set RESUME = True (default) to automatically resume from the latest HF checkpoint.
A Hugging Face token stored as a Kaggle secret named HF_TOKEN is required.
Output
| File | Description |
|---|---|
qat_video_model.tflite |
INT8 TFLite model ready for edge deployment |
training_curves.png |
Accuracy and loss plots across all epochs |
model_qat.keras |
Full QAT Keras model (also pushed to HF Hub) |
training_state_qat.json |
Epoch/metrics state for resumable training |
Dataset
Real Life Violence Situations Dataset β expected directory structure:
Real Life Violence Dataset/
βββ Violence/
β βββ video1.mp4
β βββ ...
βββ NonViolence/
βββ video1.mp4
βββ ...
- Downloads last month
- 50