|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
import pickle |
|
|
from PIL import Image |
|
|
import os |
|
|
from convnext_original import ConvNeXt as ConvNeXtOriginal |
|
|
from convnext_finetune import ConvNeXt |
|
|
|
|
|
|
|
|
content_model = None |
|
|
quality_model = None |
|
|
scaler = None |
|
|
regression_model = None |
|
|
device = None |
|
|
|
|
|
def get_activation(name, activations): |
|
|
"""Hook function to capture activations.""" |
|
|
def hook(model, input, output): |
|
|
activations[name] = output.detach() |
|
|
return hook |
|
|
|
|
|
def register_hooks(model): |
|
|
"""Register hooks for each layer in the model.""" |
|
|
activations = {} |
|
|
for name, module in model.named_modules(): |
|
|
module.register_forward_hook(get_activation(name, activations)) |
|
|
return activations |
|
|
|
|
|
def preprocess_image(image): |
|
|
"""Preprocess image for model input.""" |
|
|
|
|
|
mean = np.array([0.485, 0.456, 0.406]) |
|
|
std = np.array([0.229, 0.224, 0.225]) |
|
|
|
|
|
img_array = np.array(image, dtype=np.float32) / 255.0 |
|
|
img_array = (img_array - mean) / std |
|
|
return torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() |
|
|
|
|
|
def load_models(): |
|
|
"""Load all required models.""" |
|
|
global content_model, quality_model, scaler, regression_model, device |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
required_files = [ |
|
|
'feature_models/convnext_tiny_22k_224.pth', |
|
|
'feature_models/triqa_quality_aware.pth', |
|
|
'Regression_Models/KonIQ_scaler.save', |
|
|
'Regression_Models/KonIQ_TRIQA.save' |
|
|
] |
|
|
|
|
|
missing_files = [f for f in required_files if not os.path.exists(f)] |
|
|
if missing_files: |
|
|
print(f"Missing model files: {missing_files}") |
|
|
print("Please download model files from the Box link and place them in the correct directories.") |
|
|
return None, None |
|
|
|
|
|
try: |
|
|
|
|
|
content_model = ConvNeXtOriginal(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]) |
|
|
content_state_dict = torch.load('feature_models/convnext_tiny_22k_224.pth', map_location=device)['model'] |
|
|
content_state_dict = {k: v for k, v in content_state_dict.items() if not k.startswith('head.')} |
|
|
content_model.load_state_dict(content_state_dict, strict=False) |
|
|
content_model.to(device).eval() |
|
|
|
|
|
|
|
|
quality_model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]) |
|
|
quality_state_dict = torch.load('feature_models/triqa_quality_aware.pth', map_location=device) |
|
|
quality_model.load_state_dict(quality_state_dict, strict=True) |
|
|
quality_model.to(device).eval() |
|
|
|
|
|
|
|
|
content_activations = register_hooks(content_model) |
|
|
quality_activations = register_hooks(quality_model) |
|
|
|
|
|
|
|
|
with open('Regression_Models/KonIQ_scaler.save', 'rb') as f: |
|
|
scaler = pickle.load(f) |
|
|
with open('Regression_Models/KonIQ_TRIQA.save', 'rb') as f: |
|
|
regression_model = pickle.load(f) |
|
|
|
|
|
return content_activations, quality_activations |
|
|
except Exception as e: |
|
|
print(f"Error loading models: {e}") |
|
|
return None, None |
|
|
|
|
|
def predict_quality(image): |
|
|
"""Predict image quality score on 1-5 scale.""" |
|
|
global content_model, quality_model, scaler, regression_model, device |
|
|
|
|
|
if content_model is None or quality_model is None: |
|
|
return "Models not loaded. Please wait..." |
|
|
|
|
|
|
|
|
image_half = image.resize((image.size[0]//2, image.size[1]//2), Image.LANCZOS) |
|
|
|
|
|
img_full = preprocess_image(image).to(device) |
|
|
img_half = preprocess_image(image_half).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
_ = content_model(img_full) |
|
|
content_full = content_model.activations['norm'].cpu().numpy().flatten() |
|
|
|
|
|
_ = content_model(img_half) |
|
|
content_half = content_model.activations['norm'].cpu().numpy().flatten() |
|
|
|
|
|
content_features = np.concatenate([content_full, content_half]) |
|
|
|
|
|
|
|
|
_ = quality_model(img_full) |
|
|
quality_full = quality_model.activations['norm'].cpu().numpy().flatten() |
|
|
|
|
|
_ = quality_model(img_half) |
|
|
quality_half = quality_model.activations['norm'].cpu().numpy().flatten() |
|
|
|
|
|
quality_features = np.concatenate([quality_full, quality_half]) |
|
|
|
|
|
|
|
|
combined_features = np.concatenate([content_features, quality_features]) |
|
|
normalized_features = scaler.transform(combined_features.reshape(1, -1)) |
|
|
quality_score = regression_model.predict(normalized_features)[0] |
|
|
|
|
|
return f"Quality Score: {quality_score:.2f}/5.0" |
|
|
|
|
|
def create_demo(): |
|
|
"""Create the Gradio demo interface.""" |
|
|
|
|
|
|
|
|
try: |
|
|
content_activations, quality_activations = load_models() |
|
|
content_model.activations = content_activations |
|
|
quality_model.activations = quality_activations |
|
|
print("Models loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error loading models: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
with gr.Blocks(title="TRIQA: Image Quality Assessment", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# TRIQA: Image Quality Assessment |
|
|
|
|
|
**TRIQA** combines content-aware and quality-aware features from ConvNeXt models to predict image quality scores on a 1-5 scale. |
|
|
|
|
|
### How to use: |
|
|
1. Upload an image using the file uploader below |
|
|
2. Click "Assess Quality" to get the quality score |
|
|
3. The score ranges from 1-5, where 5 represents the highest quality |
|
|
|
|
|
### Paper Links: |
|
|
- **arXiv**: [https://arxiv.org/pdf/2507.12687](https://arxiv.org/pdf/2507.12687) |
|
|
- **IEEE Xplore**: [https://ieeexplore.ieee.org/abstract/document/11084443](https://ieeexplore.ieee.org/abstract/document/11084443) |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image( |
|
|
label="Upload Image", |
|
|
type="pil", |
|
|
height=400 |
|
|
) |
|
|
submit_btn = gr.Button("Assess Quality", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox( |
|
|
label="Quality Assessment Result", |
|
|
value="Upload an image and click 'Assess Quality' to get the quality score.", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["sample_image/233045618.jpg"], |
|
|
["sample_image/25239707.jpg"], |
|
|
["sample_image/44009500.jpg"], |
|
|
["sample_image/5129172.jpg"], |
|
|
["sample_image/85119046.jpg"] |
|
|
], |
|
|
inputs=input_image, |
|
|
label="Sample Images" |
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=predict_quality, |
|
|
inputs=input_image, |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### Citation: |
|
|
If you use this code in your research, please cite our paper: |
|
|
|
|
|
```bibtex |
|
|
@INPROCEEDINGS{11084443, |
|
|
author={Sureddi, Rajesh and Zadtootaghaj, Saman and Barman, Nabajeet and Bovik, Alan C.}, |
|
|
booktitle={2025 IEEE International Conference on Image Processing (ICIP)}, |
|
|
title={Triqa: Image Quality Assessment by Contrastive Pretraining on Ordered Distortion Triplets}, |
|
|
year={2025}, |
|
|
volume={}, |
|
|
number={}, |
|
|
pages={1744-1749}, |
|
|
keywords={Image quality;Training;Deep learning;Contrastive learning;Predictive models;Feature extraction;Distortion;Data models;Synthetic data;Image Quality Assessment;Contrastive Learning}, |
|
|
doi={10.1109/ICIP55913.2025.11084443}} |
|
|
``` |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_demo() |
|
|
if demo: |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |
|
|
else: |
|
|
print("Failed to create demo. Please check model files.") |
|
|
|