import os, json, zipfile, tempfile, time, traceback import gradio as gr import pandas as pd import numpy as np import onnxruntime as ort from collections import defaultdict from typing import Union, Dict, Any, Tuple, List from PIL import Image from huggingface_hub import hf_hub_download from huggingface_hub.errors import EntryNotFoundError from datetime import datetime from modules.media_handler import handle_single_media_upload, handle_multiple_media_uploads # Global variables for model components (for memory management) CURRENT_MODEL = None CURRENT_MODEL_NAME = None CURRENT_TAGS_DF = None CURRENT_D_IPS = None CURRENT_PREPROCESS_FUNC = None CURRENT_THRESHOLDS = None CURRENT_CATEGORY_NAMES = None css = """ #custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;} #custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease, box-shadow 0.2s ease;} #custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);} #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;} #custom-gallery .thumbnail-item img.portrait {max-width: 100%;} #custom-gallery .thumbnail-item img.landscape {max-height: 100%;} .gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;} .thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;} #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;} """ def preprocess_on_gpu(img, device='cuda'): """Preprocess image on GPU using PyTorch""" import torch import torchvision.transforms as transforms # Convert PIL to tensor and move to GPU transform = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) # Move to GPU if available tensor_img = transform(img).unsqueeze(0) if torch.cuda.is_available(): tensor_img = tensor_img.to(device) return tensor_img.cpu().numpy() class Timer: # Report the execution time & process def __init__(self): self.start_time = time.perf_counter() self.checkpoints = [('Start', self.start_time)] def checkpoint(self, label='Checkpoint'): now = time.perf_counter() self.checkpoints.append((label, now)) def report(self, is_clear_checkpoints=True): max_label_length = max(len(label) for (label, _) in self.checkpoints) if self.checkpoints else 0 prev_time = self.checkpoints[0][1] if self.checkpoints else self.start_time for (label, curr_time) in self.checkpoints[1:]: elapsed = curr_time - prev_time print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds") prev_time = curr_time if is_clear_checkpoints: self.checkpoints.clear() self.checkpoint() def report_all(self): print('\n> Execution Time Report:') max_label_length = max(len(label) for (label, _) in self.checkpoints) if len(self.checkpoints) > 0 else 0 prev_time = self.start_time for (label, curr_time) in self.checkpoints[1:]: elapsed = curr_time - prev_time print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds") prev_time = curr_time total_time = self.checkpoints[-1][1] - self.start_time if self.checkpoints else 0 print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n") # Performance tests self.checkpoints.clear() def restart(self): self.start_time = time.perf_counter() self.checkpoints = [('Start', self.start_time)] def _get_repo_id(model_name: str) -> str: """Get the repository ID for the specified model name.""" if '/' in model_name: return model_name else: return f'deepghs/pixai-tagger-{model_name}-onnx' def _download_model_files(model_name: str): """Download all required model files.""" repo_id = _get_repo_id(model_name) # Download the necessary files using hf_hub_download instead of local cache... model_path = hf_hub_download( repo_id=repo_id, filename='model.onnx', library_name="pixai-tagger" ) tags_path = hf_hub_download( repo_id=repo_id, filename='selected_tags.csv', library_name="pixai-tagger" ) preprocess_path = hf_hub_download( repo_id=repo_id, filename='preprocess.json', library_name="pixai-tagger" ) try: thresholds_path = hf_hub_download( repo_id=repo_id, filename='thresholds.csv', library_name="pixai-tagger" ) except EntryNotFoundError: thresholds_path = None return model_path, tags_path, preprocess_path, thresholds_path def create_optimized_ort_session(model_path): """Create an optimized ONNX Runtime session with GPU support""" # Test: Session options for better performance sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.intra_op_num_threads = 0 # Use all available cores sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL sess_options.enable_mem_pattern = True sess_options.enable_cpu_mem_arena = True # Check available providers available_providers = ort.get_available_providers() print(f"Available ONNX Runtime providers: {available_providers}") # Use appropriate execution providers (in order of preference) providers = [] # Use CUDA if available if 'CUDAExecutionProvider' in available_providers: cuda_provider = ('CUDAExecutionProvider', { 'device_id': 0, 'arena_extend_strategy': 'kNextPowerOfTwo', 'gpu_mem_limit': 4 * 1024 * 1024 * 1024, # 4GB VRAM 'cudnn_conv_algo_search': 'EXHAUSTIVE', 'do_copy_in_default_stream': True, }) providers.append(cuda_provider) print("Using CUDA provider for ONNX inference") else: print("CUDA provider not available, falling back to CPU") # Always include CPU as fallback (FOR HF) providers.append('CPUExecutionProvider') try: session = ort.InferenceSession(model_path, sess_options, providers=providers) print(f"Model loaded with providers: {session.get_providers()}") return session except Exception as e: print(f"Failed to create ONNX session: {e}") raise def _load_model_components_optimized(model_name: str): global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES # Only reload if model changed if CURRENT_MODEL_NAME != model_name: # Download files model_path, tags_path, preprocess_path, thresholds_path = _download_model_files(model_name) # Load optimized ONNX model CURRENT_MODEL = create_optimized_ort_session(model_path) # Load tags CURRENT_TAGS_DF = pd.read_csv(tags_path) CURRENT_D_IPS = {} if 'ips' in CURRENT_TAGS_DF.columns: CURRENT_TAGS_DF['ips'] = CURRENT_TAGS_DF['ips'].fillna('{}').map(json.loads) for name, ips in zip(CURRENT_TAGS_DF['name'], CURRENT_TAGS_DF['ips']): if ips: CURRENT_D_IPS[name] = ips # Load preprocessing with open(preprocess_path, 'r') as f: data_ = json.load(f) # Simple preprocessing function def transform(img): # Ensure image is in RGB mode if img.mode != 'RGB': img = img.convert('RGB') # Resize to 448x448 <- Very important. img = img.resize((448, 448), Image.Resampling.LANCZOS) # Convert to numpy array and normalize img_array = np.array(img).astype(np.float32) # Normalize pixel values to [0, 1] img_array = img_array / 255.0 # Normalize with ImageNet mean and std mean = np.array([0.48145466, 0.4578275, 0.40821073]).astype(np.float32) std = np.array([0.26862954, 0.26130258, 0.27577711]).astype(np.float32) img_array = (img_array - mean) / std # Transpose to (C, H, W) img_array = np.transpose(img_array, (2, 0, 1)) return img_array CURRENT_PREPROCESS_FUNC = transform # Load thresholds CURRENT_THRESHOLDS = {} CURRENT_CATEGORY_NAMES = {} if thresholds_path and os.path.exists(thresholds_path): df_category_thresholds = pd.read_csv(thresholds_path, keep_default_na=False) for item in df_category_thresholds.to_dict('records'): if item['category'] not in CURRENT_THRESHOLDS: CURRENT_THRESHOLDS[item['category']] = item['threshold'] CURRENT_CATEGORY_NAMES[item['category']] = item['name'] else: # Default thresholds if file doesn't exist CURRENT_THRESHOLDS = {0: 0.3, 4: 0.85, 9: 0.85} CURRENT_CATEGORY_NAMES = {0: 'general', 4: 'character', 9: 'rating'} CURRENT_MODEL_NAME = model_name return (CURRENT_MODEL, CURRENT_TAGS_DF, CURRENT_D_IPS, CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES) def _raw_predict(image: Image.Image, model_name: str): """Make a raw prediction with the PixAI tagger model.""" try: # Ensure we have a PIL Image if not isinstance(image, Image.Image): raise ValueError("Input must be a PIL Image") # <- # Load model components model, _, _, preprocess_func, _, _ = _load_model_components_optimized(model_name) # Preprocess image input_tensor = preprocess_func(image) # Add batch dimension if len(input_tensor.shape) == 3: input_tensor = np.expand_dims(input_tensor, axis=0) # Run inference output_names = [output.name for output in model.get_outputs()] output_values = model.run(output_names, {'input': input_tensor.astype(np.float32)}) return {name: value[0] for name, value in zip(output_names, output_values)} except Exception as e: raise RuntimeError(f"Error processing image: {str(e)}") def get_pixai_tags( image: Union[str, Image.Image], model_name: str = 'deepghs/pixai-tagger-v0.9-onnx', thresholds: Union[float, Dict[Any, float]] = None, fmt='all' ): try: # Load image if it's a path if isinstance(image, str): pil_image = Image.open(image) elif isinstance(image, Image.Image): pil_image = image else: raise ValueError("Image must be a file path or PIL Image") # Load model components _, df_tags, d_ips, _, default_thresholds, category_names = _load_model_components_optimized(model_name) values = _raw_predict(pil_image, model_name) prediction = values.get('prediction', np.array([])) if prediction.size == 0: raise RuntimeError("Model did not return valid predictions") tags = {} # Process tags by category for category in sorted(set(df_tags['category'].tolist())): mask = df_tags['category'] == category tag_names = df_tags.loc[mask, 'name'] category_pred = prediction[mask] # Determine threshold for this category if isinstance(thresholds, float): category_threshold = thresholds elif isinstance(thresholds, dict) and \ (category in thresholds or category_names.get(category, '') in thresholds): if category in thresholds: category_threshold = thresholds[category] elif category_names.get(category, '') in thresholds: category_threshold = thresholds[category_names[category]] else: category_threshold = 0.85 else: category_threshold = default_thresholds.get(category, 0.85) # Apply threshold pred_mask = category_pred >= category_threshold filtered_tag_names = tag_names[pred_mask].tolist() filtered_predictions = category_pred[pred_mask].tolist() # Sort by confidence cate_tags = dict(sorted( zip(filtered_tag_names, filtered_predictions), key=lambda x: (-x[1], x[0]) )) category_name = category_names.get(category, f"category_{category}") values[category_name] = cate_tags tags.update(cate_tags) values['tag'] = tags # Handle IPs if available if 'ips' in df_tags.columns: ips_mapping, ips_counts = {}, defaultdict(int) for tag, _ in tags.items(): if tag in d_ips: ips_mapping[tag] = d_ips[tag] for ip_name in d_ips[tag]: ips_counts[ip_name] += 1 values['ips_mapping'] = ips_mapping values['ips_count'] = dict(ips_counts) values['ips'] = [x for x, _ in sorted(ips_counts.items(), key=lambda x: (-x[1], x[0]))] # Return based on format if fmt == 'all': # Return all available categories available_categories = [category_names.get(cat, f"category_{cat}") for cat in sorted(set(df_tags['category'].tolist()))] return tuple(values.get(cat, {}) for cat in available_categories) elif fmt in values: return values[fmt] else: return values except Exception as e: raise RuntimeError(f"Error processing image: {str(e)}") def format_ips_output(ips_result, ips_mapping): """Format IP detection output as a single string with proper escaping.""" if not ips_result and not ips_mapping: return "" # Format detected IPs ips_list = [] if ips_result: ips_list = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ") for ip in ips_result] # Format character-to-IP mapping mapping_list = [] if ips_mapping: for char, ips in ips_mapping.items(): formatted_char = char.replace("(", "\\(").replace(")", "\\)").replace("_", " ") formatted_ips = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ") for ip in ips] mapping_list.append(f"{formatted_char}: {', '.join(formatted_ips)}") # Combine all into a single string result_parts = [] if ips_list: result_parts.append(", ".join(ips_list)) if mapping_list: result_parts.extend(mapping_list) return ", ".join(result_parts) def process_single_image( image_path, model_name="deepghs/pixai-tagger-v0.9-onnx", ### general_threshold=0.3, character_threshold=0.85, progress=None, idx=0, total_images=1 ): """Process a single image and return all formatted outputs.""" try: if image_path is None: return "", "", "", "", {}, {} if progress: progress((idx)/total_images, desc=f"Processing image {idx+1}/{total_images}") # Load image from path pil_image = Image.open(image_path) # Set thresholds thresholds = { 'general': general_threshold, 'character': character_threshold } # Get all tag categories all_categories = get_pixai_tags( pil_image, model_name, thresholds, fmt='all' ) # Ensure we have at least 3 categories (general, character, rating) while len(all_categories) < 3: all_categories += ({},) general_tags = all_categories[0] if len(all_categories) > 0 else {} character_tags = all_categories[1] if len(all_categories) > 1 else {} rating_tags = all_categories[2] if len(all_categories) > 2 else {} # Get IP detection data ips_result = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips') or [] ips_mapping = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips_mapping') or {} # Format character tags (names only) character_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") # Replacement shouldn't be necessary here, but I'll do anyway for name in character_tags.keys()] character_output = ", ".join(character_names) # Format general tags (names only) general_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") for name in general_tags.keys()] general_output = ", ".join(general_names) # Format IP detection output ips_output = format_ips_output(ips_result, ips_mapping) # Format combined tags (Character tags first, then General tags, then IP tags) combined_parts = [] if character_names: combined_parts.append(", ".join(character_names)) if general_names: combined_parts.append(", ".join(general_names)) if ips_output: combined_parts.append(ips_output) combined_output = ", ".join(combined_parts) # Get detailed JSON data json_data = { "character_tags": character_tags, "general_tags": general_tags, "rating_tags": rating_tags, "ips_result": ips_result, "ips_mapping": ips_mapping } # Format rating as label-compatible dict rating_output = {k.replace("(", "\\(").replace(")", "\\)").replace("_", " "): v for k, v in rating_tags.items()} return ( character_output, # Character tags general_output, # General tags ips_output, # IP Detection combined_output, # Combined tags json_data, # Detailed JSON rating_output # Rating <- Not working atm ) except Exception as e: error_msg = f"Error: {str(e)}" # Return error message for all 6 outputs return error_msg, error_msg, error_msg, error_msg, {}, {} # 6 """GPU""" def unload_model(): """Explicitly unload the current model from memory""" global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES # Delete the model session if CURRENT_MODEL is not None: del CURRENT_MODEL CURRENT_MODEL = None # Clear other large objects CURRENT_TAGS_DF = None CURRENT_D_IPS = None CURRENT_PREPROCESS_FUNC = None CURRENT_THRESHOLDS = None CURRENT_CATEGORY_NAMES = None CURRENT_MODEL_NAME = None # Force garbage collection import gc gc.collect() # Clear CUDA cache if using GPU try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except ImportError: pass # print("Model unloaded and memory cleared") def cleanup_after_processing(): unload_model() def process_gallery_images( gallery, model_name, general_threshold, character_threshold, progress=gr.Progress() ): """Process all images in the gallery and return results with download file.""" if not gallery: return [], "", "", "", {}, {}, {}, None tag_results = {} txt_infos = [] output_dir = tempfile.mkdtemp() if not os.path.exists(output_dir): os.makedirs(output_dir) total_images = len(gallery) timer = Timer() try: for idx, image_data in enumerate(gallery): try: image_path = image_data[0] if isinstance(image_data, (list, tuple)) else image_data # Process image results = process_single_image( image_path, model_name, general_threshold, character_threshold, progress, idx, total_images ) # Store results tag_results[image_path] = { 'character_tags': results[0], 'general_tags': results[1], 'ips_detection': results[2], 'combined_tags': results[3], 'json_data': results[4], 'rating': results[5] } # Create output files with descriptive names image_name = os.path.splitext(os.path.basename(image_path))[0] # Save all output files with descriptive prefixes files_to_create = [ (f"character_tags-{image_name}.txt", results[0]), (f"general_tags-{image_name}.txt", results[1]), (f"ips_detection-{image_name}.txt", results[2]), (f"combined_tags-{image_name}.txt", results[3]), (f"detailed_json-{image_name}.json", json.dumps(results[4], indent=4, ensure_ascii=False)) ] for file_name, content in files_to_create: file_path = os.path.join(output_dir, file_name) with open(file_path, 'w', encoding='utf-8') as f: f.write(content if isinstance(content, str) else content) txt_infos.append({'path': file_path, 'name': file_name}) # Copy original image original_image = Image.open(image_path) image_copy_path = os.path.join(output_dir, f"{image_name}{os.path.splitext(image_path)[1]}") original_image.save(image_copy_path) txt_infos.append({'path': image_copy_path, 'name': f"{image_name}{os.path.splitext(image_path)[1]}"}) timer.checkpoint(f"image{idx:02d}, processed") except Exception as e: print(f"Error processing image {image_path}: {str(e)}") print(traceback.format_exc()) continue # Create zip file download_zip_path = os.path.join(output_dir, f"Multi-Tagger-{datetime.now().strftime('%Y%m%d-%H%M%S')}.zip") with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: for info in txt_infos: zipf.write(info['path'], arcname=info['name']) # If using GPU, model will auto unload after zip file creation cleanup_after_processing() # Comment here to turn off this behavior progress(1.0, desc="Processing complete") timer.report_all() print('Processing is complete.') # Return first image results as default if available even if we are tagging 1000+ images. first_image_results = ("", "", "", {}, {}, "") # 6 if gallery and len(gallery) > 0: first_image_path = gallery[0][0] if isinstance(gallery[0], (list, tuple)) else gallery[0] if first_image_path in tag_results: result = tag_results[first_image_path] first_image_results = ( result['character_tags'], result['general_tags'], result['combined_tags'], result['json_data'], result['rating'], result['ips_detection'] ) return tag_results, first_image_results[0], first_image_results[1], first_image_results[2], first_image_results[3], first_image_results[4], first_image_results[5], download_zip_path except Exception as e: print(f"Error in process_gallery_images: {str(e)}") print(traceback.format_exc()) progress(1.0, desc="Processing failed") return {}, "", "", "", {}, {}, "", None def get_selection_from_gallery(gallery, tag_results, selected_state: gr.SelectData): """Handle gallery image selection and update UI with stored results.""" if not selected_state or not tag_results: return "", "", "", {}, {}, "" # Get selected image path selected_value = selected_state.value if isinstance(selected_value, dict) and 'image' in selected_value: image_path = selected_value['image']['path'] elif isinstance(selected_value, (list, tuple)) and len(selected_value) > 0: image_path = selected_value[0] else: image_path = str(selected_value) # Retrieve stored results if image_path in tag_results: result = tag_results[image_path] return ( result['character_tags'], result['general_tags'], result['combined_tags'], result['json_data'], result['rating'], result['ips_detection'] ) # Return empty if not found return "", "", "", {}, {}, "" def append_gallery(gallery, image): """Add a single media file (image or video) to the gallery.""" return handle_single_media_upload(image, gallery) def extend_gallery(gallery, images): """Add multiple media files (images or videos) to the gallery.""" return handle_multiple_media_uploads(images, gallery) def create_pixai_interface(): """Create the PixAI Gradio interface""" with gr.Blocks(css=css, fill_width=True) as demo: # gr.Markdown("Upload anime-style images to extract tags using PixAI") # State to store results tag_results = gr.State({}) selected_image = gr.Textbox(label='Selected Image', visible=False) with gr.Row(): with gr.Column(): # Image upload section with gr.Column(variant='panel'): image_input = gr.Image( label='Upload an Image (or paste from clipboard)', type='filepath', sources=['upload', 'clipboard'], height=150 ) with gr.Row(): upload_button = gr.UploadButton( 'Upload multiple images or videos', file_types=['image', 'video'], file_count='multiple', size='md' ) gallery = gr.Gallery( columns=2, show_share_button=False, interactive=True, height='auto', label='Grid of images', preview=False, elem_id='custom-gallery' ) run_button = gr.Button("Analyze Images", variant="primary", size='lg') clear = gr.ClearButton(components=[gallery], value='Clear Gallery', variant='secondary', size='sm') model_dropdown = gr.Dropdown( choices=["deepghs/pixai-tagger-v0.9-onnx"], value="deepghs/pixai-tagger-v0.9-onnx", label="Model" ) # Threshold controls with gr.Row(): general_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.30, step=0.05, label="General Tags Threshold", scale=3 ) character_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.85, step=0.05, label="Character Tags Threshold", scale=3 ) with gr.Row(): clear = gr.ClearButton( components=[gallery, model_dropdown, general_threshold, character_threshold], value="Clear Everything", variant='secondary', size='lg' ) clear.add([tag_results]) detailed_json_output = gr.JSON(label="Detailed JSON") with gr.Column(variant='panel'): download_file = gr.File(label="Download") # Output blocks character_tags_output = gr.Textbox( label="Character tags", show_copy_button=True, lines=3 ) general_tags_output = gr.Textbox( label="General tags", show_copy_button=True, lines=3 ) ips_detection_output = gr.Textbox( label="IPs Detection", show_copy_button=True, lines=5 ) combined_tags_output = gr.Textbox( label="Combined tags", show_copy_button=True, lines=6 ) rating_output = gr.Label(label="Rating") # Clear button targets clear.add([ download_file, character_tags_output, general_tags_output, ips_detection_output, combined_tags_output, rating_output, detailed_json_output ]) # Event handlers image_input.change( append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input] ) upload_button.upload( extend_gallery, inputs=[gallery, upload_button], outputs=gallery ) gallery.select( get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[ character_tags_output, general_tags_output, combined_tags_output, detailed_json_output, rating_output, ips_detection_output ] ) run_button.click( process_gallery_images, inputs=[gallery, model_dropdown, general_threshold, character_threshold], outputs=[ tag_results, character_tags_output, general_tags_output, combined_tags_output, detailed_json_output, rating_output, ips_detection_output, download_file ] ) gr.Markdown('[Based on Source code for imgutils.tagging.pixai](https://dghs-imgutils.deepghs.org/main/_modules/imgutils/tagging/pixai.html) & [pixai-labs/pixai-tagger-demo](https://huggingface.co/spaces/pixai-labs/pixai-tagger-demo)') return demo # Export public API __all__ = [ 'get_pixai_tags', 'process_single_image', 'process_gallery_images', 'create_pixai_interface', 'unload_model', 'cleanup_after_processing' ]