Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import os | |
| import sys | |
| import json | |
| from PIL import Image | |
| # Add project root to sys.path | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| project_root = os.path.dirname(current_dir) | |
| if project_root not in sys.path: | |
| sys.path.append(project_root) | |
| # Internal modules | |
| from src.model.architecture import build_model | |
| from src.model.gradcam import GradCAMPlusPlus as GradCAM | |
| from src.utils.config import BEST_MODEL_PATH | |
| # Load disease information | |
| with open("disease_info.json", "r") as f: | |
| DISEASE_INFO = json.load(f) | |
| # Load label mapping | |
| with open("models/labels.json", "r") as f: | |
| idx_to_class = json.load(f) | |
| # Load model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| num_classes = len(idx_to_class) | |
| model = build_model(num_classes=num_classes, freeze_backbone=False) | |
| model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device)) | |
| model = model.to(device) | |
| model.eval() | |
| target_layer = model.features[-1] | |
| # Sample images | |
| SAMPLE_DIR = "sample_images" | |
| sample_choices = sorted(os.listdir(SAMPLE_DIR)) | |
| # Utility Functions | |
| def beautify_name(raw_classname): | |
| parts = raw_classname.split("___") | |
| if len(parts) >= 3: | |
| plant = parts[1].title() | |
| disease = parts[2].replace("_", " ").replace("(", "").replace(")", "").replace("__", " ").title() | |
| return plant, disease | |
| else: | |
| return "Unknown", "Unknown" | |
| def generate_gradcam(model, input_tensor, target_layer): | |
| gradcam = GradCAM(model, target_layer) | |
| cam = gradcam.generate(input_tensor) | |
| return cam | |
| def preprocess_image(image): | |
| from torchvision import transforms | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return transform(image) | |
| def predict(image): | |
| if image is None: | |
| return None, None, None, None, None | |
| model.eval() | |
| image_tensor = preprocess_image(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| probs = torch.softmax(output, dim=1).cpu().numpy()[0] | |
| top3_indices = probs.argsort()[-3:][::-1] | |
| top3_classes = [(idx_to_class[str(idx)], float(probs[idx])) for idx in top3_indices] | |
| pred_class, pred_prob = top3_classes[0] | |
| plant, disease = beautify_name(pred_class) | |
| # Format Top-3 nicely | |
| top3_text = "" | |
| for i, (c, p) in enumerate(top3_classes, 1): | |
| c_plant, c_disease = beautify_name(c) | |
| top3_text += f"{i}. {c_plant} - {c_disease} ({p*100:.2f}%)\n" | |
| # GradCAM | |
| cam = generate_gradcam(model, image_tensor, target_layer) | |
| img_np = np.array(image) / 255.0 | |
| # π₯ Resize uploaded image to 224x224 for matching heatmap | |
| img_np_resized = cv2.resize(img_np, (224, 224)) | |
| img_gray = cv2.cvtColor(np.uint8(img_np_resized * 255), cv2.COLOR_RGB2GRAY) / 255.0 | |
| img_gray_3ch = np.stack([img_gray]*3, axis=-1) | |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_PLASMA) | |
| heatmap = np.float32(heatmap) / 255 | |
| overlay = heatmap + img_gray_3ch | |
| overlay = overlay / np.max(overlay) | |
| fig, ax = plt.subplots(figsize=(5,5)) | |
| ax.imshow(overlay) | |
| ax.axis("off") | |
| cbar = plt.colorbar(plt.cm.ScalarMappable(cmap='plasma'), orientation='horizontal', pad=0.05, ax=ax) | |
| cbar.set_ticks([0, 0.5, 1]) | |
| cbar.set_ticklabels(['Low Focus', 'Medium Focus', 'High Focus']) | |
| plt.tight_layout() | |
| cam_path = "cam_output.png" | |
| fig.savefig(cam_path) | |
| plt.close(fig) | |
| # Health Status + Disease Info | |
| if "healthy" in pred_class.lower(): | |
| health_status = f"{plant} - Healthy" | |
| identified_disease = "None" | |
| disease_info_text = "β No disease detected." | |
| else: | |
| health_status = f"{plant} - Diseased" | |
| identified_disease = disease | |
| disease_data = DISEASE_INFO.get(pred_class, {}) | |
| disease_info_text = f""" | |
| **Symptoms:** {disease_data.get('symptoms', 'No information available.')} | |
| **Causes:** {disease_data.get('causes', 'No information available.')} | |
| **Disease Cycle:** {disease_data.get('disease_cycle', 'No information available.')} | |
| **Care & Treatment:** {disease_data.get('care_treatment', 'No information available.')} | |
| [Learn more on Wikipedia]({disease_data.get('wiki_url', '#')}) | |
| """ | |
| alert = None | |
| if pred_prob < 0.6: | |
| alert = "β οΈ Low confidence in prediction! Please verify manually." | |
| return health_status, identified_disease, top3_text, alert, cam_path, disease_info_text | |
| def load_sample_image(sample_name): | |
| img_path = os.path.join(SAMPLE_DIR, sample_name) | |
| img = Image.open(img_path).convert("RGB") | |
| return img | |
| # Interface | |
| title = "CropGuard: Leaf Disease Detector" | |
| copyright_text = "Β© 2025 Made by [Arka Mitra](https://github.com/mitraarka27)" | |
| instruction_text = """ | |
| Upload a clear image of a **potato**, **tomato**, or **grape** leaf. | |
| CropGuard will predict: | |
| - Whether the leaf is **healthy** or **diseased**. | |
| - The likely disease (if any). | |
| - Where the model focused its attention. | |
| β‘ **Note**: Currently supports only **Potato, Tomato, Grape** leaves. | |
| """ | |
| with gr.Blocks(theme="default") as app: | |
| with gr.Row(): | |
| gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>") | |
| gr.Markdown("<p style='text-align: center;'>Β© 2025 Made by <a href='https://github.com/mitraarka27' target='_blank'>Arka Mitra</a></p>") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| upload = gr.Image( | |
| type="pil", | |
| sources=["upload", "webcam", "clipboard"], | |
| label="Upload, Capture, or Paste Leaf Image" | |
| ) | |
| gr.Markdown("**OR** choose from sample images below:") | |
| sample_dropdown = gr.Dropdown(choices=sample_choices, label="Select a Sample Image") | |
| load_btn = gr.Button("Load Sample Image") | |
| predict_btn = gr.Button("Predict", variant="primary") | |
| gr.Markdown(instruction_text) | |
| alert_box = gr.Textbox(label="Prediction Alert", lines=2, interactive=False) | |
| top3_preds = gr.Textbox(label="Top-3 Predictions", lines=5, interactive=False) | |
| with gr.Column(scale=3): | |
| health_status = gr.Label(label="Plant Health Status") | |
| disease_name = gr.Label(label="Identified Disease (includes details)") | |
| disease_info = gr.Markdown() | |
| heatmap = gr.Image(label="Model Focus Heatmap") | |
| load_btn.click( | |
| fn=load_sample_image, | |
| inputs=[sample_dropdown], | |
| outputs=[upload] | |
| ) | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[upload], | |
| outputs=[health_status, disease_name, top3_preds, alert_box, heatmap, disease_info] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=7860) |