File size: 8,495 Bytes
d9a1fb2 6177940 d9a1fb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import gradio as gr
import torch
import numpy as np
import pickle
from PIL import Image
import os
from convnext_original import ConvNeXt as ConvNeXtOriginal
from convnext_finetune import ConvNeXt
# Global variables for models
content_model = None
quality_model = None
scaler = None
regression_model = None
device = None
def get_activation(name, activations):
"""Hook function to capture activations."""
def hook(model, input, output):
activations[name] = output.detach()
return hook
def register_hooks(model):
"""Register hooks for each layer in the model."""
activations = {}
for name, module in model.named_modules():
module.register_forward_hook(get_activation(name, activations))
return activations
def preprocess_image(image):
"""Preprocess image for model input."""
# ImageNet normalization parameters
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_array = np.array(image, dtype=np.float32) / 255.0
img_array = (img_array - mean) / std
return torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float()
def load_models():
"""Load all required models."""
global content_model, quality_model, scaler, regression_model, device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Check if model files exist
required_files = [
'feature_models/convnext_tiny_22k_224.pth',
'feature_models/triqa_quality_aware.pth',
'Regression_Models/KonIQ_scaler.save',
'Regression_Models/KonIQ_TRIQA.save'
]
missing_files = [f for f in required_files if not os.path.exists(f)]
if missing_files:
print(f"Missing model files: {missing_files}")
print("Please download model files from the Box link and place them in the correct directories.")
return None, None
try:
# Load content-aware model (using original ConvNeXt)
content_model = ConvNeXtOriginal(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
content_state_dict = torch.load('feature_models/convnext_tiny_22k_224.pth', map_location=device)['model']
content_state_dict = {k: v for k, v in content_state_dict.items() if not k.startswith('head.')}
content_model.load_state_dict(content_state_dict, strict=False)
content_model.to(device).eval()
# Load quality-aware model
quality_model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
quality_state_dict = torch.load('feature_models/triqa_quality_aware.pth', map_location=device)
quality_model.load_state_dict(quality_state_dict, strict=True)
quality_model.to(device).eval()
# Register hooks for feature extraction
content_activations = register_hooks(content_model)
quality_activations = register_hooks(quality_model)
# Load scaler and regression model
with open('Regression_Models/KonIQ_scaler.save', 'rb') as f:
scaler = pickle.load(f)
with open('Regression_Models/KonIQ_TRIQA.save', 'rb') as f:
regression_model = pickle.load(f)
return content_activations, quality_activations
except Exception as e:
print(f"Error loading models: {e}")
return None, None
def predict_quality(image):
"""Predict image quality score on 1-5 scale."""
global content_model, quality_model, scaler, regression_model, device
if content_model is None or quality_model is None:
return "Models not loaded. Please wait..."
# Load and preprocess image
image_half = image.resize((image.size[0]//2, image.size[1]//2), Image.LANCZOS)
img_full = preprocess_image(image).to(device)
img_half = preprocess_image(image_half).to(device)
with torch.no_grad():
# Extract content features using hooks
_ = content_model(img_full)
content_full = content_model.activations['norm'].cpu().numpy().flatten()
_ = content_model(img_half)
content_half = content_model.activations['norm'].cpu().numpy().flatten()
content_features = np.concatenate([content_full, content_half])
# Extract quality features using hooks
_ = quality_model(img_full)
quality_full = quality_model.activations['norm'].cpu().numpy().flatten()
_ = quality_model(img_half)
quality_half = quality_model.activations['norm'].cpu().numpy().flatten()
quality_features = np.concatenate([quality_full, quality_half])
# Combine features and predict
combined_features = np.concatenate([content_features, quality_features])
normalized_features = scaler.transform(combined_features.reshape(1, -1))
quality_score = regression_model.predict(normalized_features)[0]
return f"Quality Score: {quality_score:.2f}/5.0"
def create_demo():
"""Create the Gradio demo interface."""
# Load models
try:
content_activations, quality_activations = load_models()
content_model.activations = content_activations
quality_model.activations = quality_activations
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
return None
# Create Gradio interface
with gr.Blocks(title="TRIQA: Image Quality Assessment", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# TRIQA: Image Quality Assessment
**TRIQA** combines content-aware and quality-aware features from ConvNeXt models to predict image quality scores on a 1-5 scale.
### How to use:
1. Upload an image using the file uploader below
2. Click "Assess Quality" to get the quality score
3. The score ranges from 1-5, where 5 represents the highest quality
### Paper Links:
- **arXiv**: [https://arxiv.org/pdf/2507.12687](https://arxiv.org/pdf/2507.12687)
- **IEEE Xplore**: [https://ieeexplore.ieee.org/abstract/document/11084443](https://ieeexplore.ieee.org/abstract/document/11084443)
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Image",
type="pil",
height=400
)
submit_btn = gr.Button("Assess Quality", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Quality Assessment Result",
value="Upload an image and click 'Assess Quality' to get the quality score.",
interactive=False
)
gr.Examples(
examples=[
["sample_image/233045618.jpg"],
["sample_image/25239707.jpg"],
["sample_image/44009500.jpg"],
["sample_image/5129172.jpg"],
["sample_image/85119046.jpg"]
],
inputs=input_image,
label="Sample Images"
)
submit_btn.click(
fn=predict_quality,
inputs=input_image,
outputs=output_text
)
gr.Markdown("""
### Citation:
If you use this code in your research, please cite our paper:
```bibtex
@INPROCEEDINGS{11084443,
author={Sureddi, Rajesh and Zadtootaghaj, Saman and Barman, Nabajeet and Bovik, Alan C.},
booktitle={2025 IEEE International Conference on Image Processing (ICIP)},
title={Triqa: Image Quality Assessment by Contrastive Pretraining on Ordered Distortion Triplets},
year={2025},
volume={},
number={},
pages={1744-1749},
keywords={Image quality;Training;Deep learning;Contrastive learning;Predictive models;Feature extraction;Distortion;Data models;Synthetic data;Image Quality Assessment;Contrastive Learning},
doi={10.1109/ICIP55913.2025.11084443}}
```
""")
return demo
if __name__ == "__main__":
demo = create_demo()
if demo:
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
else:
print("Failed to create demo. Please check model files.")
|