import gradio as gr import torch import numpy as np import pickle from PIL import Image import os from convnext_original import ConvNeXt as ConvNeXtOriginal from convnext_finetune import ConvNeXt # Global variables for models content_model = None quality_model = None scaler = None regression_model = None device = None def get_activation(name, activations): """Hook function to capture activations.""" def hook(model, input, output): activations[name] = output.detach() return hook def register_hooks(model): """Register hooks for each layer in the model.""" activations = {} for name, module in model.named_modules(): module.register_forward_hook(get_activation(name, activations)) return activations def preprocess_image(image): """Preprocess image for model input.""" # ImageNet normalization parameters mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) img_array = np.array(image, dtype=np.float32) / 255.0 img_array = (img_array - mean) / std return torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() def load_models(): """Load all required models.""" global content_model, quality_model, scaler, regression_model, device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Check if model files exist required_files = [ 'feature_models/convnext_tiny_22k_224.pth', 'feature_models/triqa_quality_aware.pth', 'Regression_Models/KonIQ_scaler.save', 'Regression_Models/KonIQ_TRIQA.save' ] missing_files = [f for f in required_files if not os.path.exists(f)] if missing_files: print(f"Missing model files: {missing_files}") print("Please download model files from the Box link and place them in the correct directories.") return None, None try: # Load content-aware model (using original ConvNeXt) content_model = ConvNeXtOriginal(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]) content_state_dict = torch.load('feature_models/convnext_tiny_22k_224.pth', map_location=device)['model'] content_state_dict = {k: v for k, v in content_state_dict.items() if not k.startswith('head.')} content_model.load_state_dict(content_state_dict, strict=False) content_model.to(device).eval() # Load quality-aware model quality_model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]) quality_state_dict = torch.load('feature_models/triqa_quality_aware.pth', map_location=device) quality_model.load_state_dict(quality_state_dict, strict=True) quality_model.to(device).eval() # Register hooks for feature extraction content_activations = register_hooks(content_model) quality_activations = register_hooks(quality_model) # Load scaler and regression model with open('Regression_Models/KonIQ_scaler.save', 'rb') as f: scaler = pickle.load(f) with open('Regression_Models/KonIQ_TRIQA.save', 'rb') as f: regression_model = pickle.load(f) return content_activations, quality_activations except Exception as e: print(f"Error loading models: {e}") return None, None def predict_quality(image): """Predict image quality score on 1-5 scale.""" global content_model, quality_model, scaler, regression_model, device if content_model is None or quality_model is None: return "Models not loaded. Please wait..." # Load and preprocess image image_half = image.resize((image.size[0]//2, image.size[1]//2), Image.LANCZOS) img_full = preprocess_image(image).to(device) img_half = preprocess_image(image_half).to(device) with torch.no_grad(): # Extract content features using hooks _ = content_model(img_full) content_full = content_model.activations['norm'].cpu().numpy().flatten() _ = content_model(img_half) content_half = content_model.activations['norm'].cpu().numpy().flatten() content_features = np.concatenate([content_full, content_half]) # Extract quality features using hooks _ = quality_model(img_full) quality_full = quality_model.activations['norm'].cpu().numpy().flatten() _ = quality_model(img_half) quality_half = quality_model.activations['norm'].cpu().numpy().flatten() quality_features = np.concatenate([quality_full, quality_half]) # Combine features and predict combined_features = np.concatenate([content_features, quality_features]) normalized_features = scaler.transform(combined_features.reshape(1, -1)) quality_score = regression_model.predict(normalized_features)[0] return f"Quality Score: {quality_score:.2f}/5.0" def create_demo(): """Create the Gradio demo interface.""" # Load models try: content_activations, quality_activations = load_models() content_model.activations = content_activations quality_model.activations = quality_activations print("Models loaded successfully!") except Exception as e: print(f"Error loading models: {e}") return None # Create Gradio interface with gr.Blocks(title="TRIQA: Image Quality Assessment", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # TRIQA: Image Quality Assessment **TRIQA** combines content-aware and quality-aware features from ConvNeXt models to predict image quality scores on a 1-5 scale. ### How to use: 1. Upload an image using the file uploader below 2. Click "Assess Quality" to get the quality score 3. The score ranges from 1-5, where 5 represents the highest quality ### Paper Links: - **arXiv**: [https://arxiv.org/pdf/2507.12687](https://arxiv.org/pdf/2507.12687) - **IEEE Xplore**: [https://ieeexplore.ieee.org/abstract/document/11084443](https://ieeexplore.ieee.org/abstract/document/11084443) """) with gr.Row(): with gr.Column(): input_image = gr.Image( label="Upload Image", type="pil", height=400 ) submit_btn = gr.Button("Assess Quality", variant="primary") with gr.Column(): output_text = gr.Textbox( label="Quality Assessment Result", value="Upload an image and click 'Assess Quality' to get the quality score.", interactive=False ) gr.Examples( examples=[ ["sample_image/233045618.jpg"], ["sample_image/25239707.jpg"], ["sample_image/44009500.jpg"], ["sample_image/5129172.jpg"], ["sample_image/85119046.jpg"] ], inputs=input_image, label="Sample Images" ) submit_btn.click( fn=predict_quality, inputs=input_image, outputs=output_text ) gr.Markdown(""" ### Citation: If you use this code in your research, please cite our paper: ```bibtex @INPROCEEDINGS{11084443, author={Sureddi, Rajesh and Zadtootaghaj, Saman and Barman, Nabajeet and Bovik, Alan C.}, booktitle={2025 IEEE International Conference on Image Processing (ICIP)}, title={Triqa: Image Quality Assessment by Contrastive Pretraining on Ordered Distortion Triplets}, year={2025}, volume={}, number={}, pages={1744-1749}, keywords={Image quality;Training;Deep learning;Contrastive learning;Predictive models;Feature extraction;Distortion;Data models;Synthetic data;Image Quality Assessment;Contrastive Learning}, doi={10.1109/ICIP55913.2025.11084443}} ``` """) return demo if __name__ == "__main__": demo = create_demo() if demo: demo.launch(server_name="0.0.0.0", server_port=7860, share=True) else: print("Failed to create demo. Please check model files.")