triqa-iqa / app.py
S-Rajesh's picture
Upload app.py with huggingface_hub
6177940 verified
import gradio as gr
import torch
import numpy as np
import pickle
from PIL import Image
import os
from convnext_original import ConvNeXt as ConvNeXtOriginal
from convnext_finetune import ConvNeXt
# Global variables for models
content_model = None
quality_model = None
scaler = None
regression_model = None
device = None
def get_activation(name, activations):
"""Hook function to capture activations."""
def hook(model, input, output):
activations[name] = output.detach()
return hook
def register_hooks(model):
"""Register hooks for each layer in the model."""
activations = {}
for name, module in model.named_modules():
module.register_forward_hook(get_activation(name, activations))
return activations
def preprocess_image(image):
"""Preprocess image for model input."""
# ImageNet normalization parameters
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_array = np.array(image, dtype=np.float32) / 255.0
img_array = (img_array - mean) / std
return torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float()
def load_models():
"""Load all required models."""
global content_model, quality_model, scaler, regression_model, device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Check if model files exist
required_files = [
'feature_models/convnext_tiny_22k_224.pth',
'feature_models/triqa_quality_aware.pth',
'Regression_Models/KonIQ_scaler.save',
'Regression_Models/KonIQ_TRIQA.save'
]
missing_files = [f for f in required_files if not os.path.exists(f)]
if missing_files:
print(f"Missing model files: {missing_files}")
print("Please download model files from the Box link and place them in the correct directories.")
return None, None
try:
# Load content-aware model (using original ConvNeXt)
content_model = ConvNeXtOriginal(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
content_state_dict = torch.load('feature_models/convnext_tiny_22k_224.pth', map_location=device)['model']
content_state_dict = {k: v for k, v in content_state_dict.items() if not k.startswith('head.')}
content_model.load_state_dict(content_state_dict, strict=False)
content_model.to(device).eval()
# Load quality-aware model
quality_model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
quality_state_dict = torch.load('feature_models/triqa_quality_aware.pth', map_location=device)
quality_model.load_state_dict(quality_state_dict, strict=True)
quality_model.to(device).eval()
# Register hooks for feature extraction
content_activations = register_hooks(content_model)
quality_activations = register_hooks(quality_model)
# Load scaler and regression model
with open('Regression_Models/KonIQ_scaler.save', 'rb') as f:
scaler = pickle.load(f)
with open('Regression_Models/KonIQ_TRIQA.save', 'rb') as f:
regression_model = pickle.load(f)
return content_activations, quality_activations
except Exception as e:
print(f"Error loading models: {e}")
return None, None
def predict_quality(image):
"""Predict image quality score on 1-5 scale."""
global content_model, quality_model, scaler, regression_model, device
if content_model is None or quality_model is None:
return "Models not loaded. Please wait..."
# Load and preprocess image
image_half = image.resize((image.size[0]//2, image.size[1]//2), Image.LANCZOS)
img_full = preprocess_image(image).to(device)
img_half = preprocess_image(image_half).to(device)
with torch.no_grad():
# Extract content features using hooks
_ = content_model(img_full)
content_full = content_model.activations['norm'].cpu().numpy().flatten()
_ = content_model(img_half)
content_half = content_model.activations['norm'].cpu().numpy().flatten()
content_features = np.concatenate([content_full, content_half])
# Extract quality features using hooks
_ = quality_model(img_full)
quality_full = quality_model.activations['norm'].cpu().numpy().flatten()
_ = quality_model(img_half)
quality_half = quality_model.activations['norm'].cpu().numpy().flatten()
quality_features = np.concatenate([quality_full, quality_half])
# Combine features and predict
combined_features = np.concatenate([content_features, quality_features])
normalized_features = scaler.transform(combined_features.reshape(1, -1))
quality_score = regression_model.predict(normalized_features)[0]
return f"Quality Score: {quality_score:.2f}/5.0"
def create_demo():
"""Create the Gradio demo interface."""
# Load models
try:
content_activations, quality_activations = load_models()
content_model.activations = content_activations
quality_model.activations = quality_activations
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
return None
# Create Gradio interface
with gr.Blocks(title="TRIQA: Image Quality Assessment", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# TRIQA: Image Quality Assessment
**TRIQA** combines content-aware and quality-aware features from ConvNeXt models to predict image quality scores on a 1-5 scale.
### How to use:
1. Upload an image using the file uploader below
2. Click "Assess Quality" to get the quality score
3. The score ranges from 1-5, where 5 represents the highest quality
### Paper Links:
- **arXiv**: [https://arxiv.org/pdf/2507.12687](https://arxiv.org/pdf/2507.12687)
- **IEEE Xplore**: [https://ieeexplore.ieee.org/abstract/document/11084443](https://ieeexplore.ieee.org/abstract/document/11084443)
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Image",
type="pil",
height=400
)
submit_btn = gr.Button("Assess Quality", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Quality Assessment Result",
value="Upload an image and click 'Assess Quality' to get the quality score.",
interactive=False
)
gr.Examples(
examples=[
["sample_image/233045618.jpg"],
["sample_image/25239707.jpg"],
["sample_image/44009500.jpg"],
["sample_image/5129172.jpg"],
["sample_image/85119046.jpg"]
],
inputs=input_image,
label="Sample Images"
)
submit_btn.click(
fn=predict_quality,
inputs=input_image,
outputs=output_text
)
gr.Markdown("""
### Citation:
If you use this code in your research, please cite our paper:
```bibtex
@INPROCEEDINGS{11084443,
author={Sureddi, Rajesh and Zadtootaghaj, Saman and Barman, Nabajeet and Bovik, Alan C.},
booktitle={2025 IEEE International Conference on Image Processing (ICIP)},
title={Triqa: Image Quality Assessment by Contrastive Pretraining on Ordered Distortion Triplets},
year={2025},
volume={},
number={},
pages={1744-1749},
keywords={Image quality;Training;Deep learning;Contrastive learning;Predictive models;Feature extraction;Distortion;Data models;Synthetic data;Image Quality Assessment;Contrastive Learning},
doi={10.1109/ICIP55913.2025.11084443}}
```
""")
return demo
if __name__ == "__main__":
demo = create_demo()
if demo:
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
else:
print("Failed to create demo. Please check model files.")