Spaces:
Running
Running
| """ | |
| Utility Functions | |
| Helper functions for file handling, validation, progress tracking, | |
| and system management for the knowledge distillation application. | |
| """ | |
| import os | |
| import logging | |
| import asyncio | |
| import hashlib | |
| import mimetypes | |
| import shutil | |
| import psutil | |
| import time | |
| from typing import Dict, Any, List, Optional, Union | |
| from pathlib import Path | |
| import json | |
| import tempfile | |
| from datetime import datetime, timedelta | |
| import torch | |
| import numpy as np | |
| from fastapi import UploadFile | |
| # Configure logging | |
| def setup_logging(level: str = "INFO", log_file: Optional[str] = None) -> None: | |
| """ | |
| Setup application logging | |
| Args: | |
| level: Logging level (DEBUG, INFO, WARNING, ERROR) | |
| log_file: Optional log file path | |
| """ | |
| log_level = getattr(logging, level.upper(), logging.INFO) | |
| # Configure logging format | |
| formatter = logging.Formatter( | |
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| # Setup handlers | |
| handlers = [] | |
| # Console handler (always available) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setFormatter(formatter) | |
| handlers.append(console_handler) | |
| # File handler (only if writable) | |
| try: | |
| # Create logs directory if it doesn't exist and is writable | |
| logs_dir = Path("logs") | |
| logs_dir.mkdir(exist_ok=True) | |
| if log_file is None: | |
| log_file = f"logs/app_{datetime.now().strftime('%Y%m%d')}.log" | |
| # Test if we can write to the log file | |
| test_file = Path(log_file) | |
| test_file.touch() | |
| file_handler = logging.FileHandler(log_file) | |
| file_handler.setFormatter(formatter) | |
| handlers.append(file_handler) | |
| except (PermissionError, OSError): | |
| # If we can't write to file, just use console logging | |
| print(f"Warning: Cannot write to log file, using console logging only") | |
| # Configure root logger | |
| logging.basicConfig( | |
| level=log_level, | |
| handlers=handlers, | |
| force=True | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"Logging initialized with level: {level}") | |
| def validate_file(file: UploadFile) -> Dict[str, Any]: | |
| """ | |
| Validate uploaded file for security and format compliance | |
| Args: | |
| file: FastAPI UploadFile object | |
| Returns: | |
| Validation result dictionary | |
| """ | |
| try: | |
| # File size limits (in bytes) | |
| MAX_FILE_SIZE = 5 * 1024 * 1024 * 1024 # 5GB | |
| MIN_FILE_SIZE = 1024 # 1KB | |
| # Allowed file extensions | |
| ALLOWED_EXTENSIONS = { | |
| '.pt', '.pth', '.bin', '.safetensors', | |
| '.onnx', '.h5', '.pkl', '.joblib' | |
| } | |
| # Allowed MIME types | |
| ALLOWED_MIME_TYPES = { | |
| 'application/octet-stream', | |
| 'application/x-pytorch', | |
| 'application/x-pickle', | |
| 'application/x-hdf5' | |
| } | |
| # Check file size | |
| if hasattr(file, 'size') and file.size: | |
| if file.size > MAX_FILE_SIZE: | |
| return { | |
| 'valid': False, | |
| 'error': f'File too large. Maximum size: {MAX_FILE_SIZE // (1024**3)}GB' | |
| } | |
| if file.size < MIN_FILE_SIZE: | |
| return { | |
| 'valid': False, | |
| 'error': f'File too small. Minimum size: {MIN_FILE_SIZE} bytes' | |
| } | |
| # Check file extension | |
| file_extension = Path(file.filename).suffix.lower() | |
| if file_extension not in ALLOWED_EXTENSIONS: | |
| return { | |
| 'valid': False, | |
| 'error': f'Invalid file extension. Allowed: {", ".join(ALLOWED_EXTENSIONS)}' | |
| } | |
| # Check MIME type | |
| mime_type, _ = mimetypes.guess_type(file.filename) | |
| if mime_type and mime_type not in ALLOWED_MIME_TYPES: | |
| # Allow octet-stream as fallback for binary files | |
| if mime_type != 'application/octet-stream': | |
| logging.warning(f"Unexpected MIME type: {mime_type} for {file.filename}") | |
| # Check filename for security | |
| if not _is_safe_filename(file.filename): | |
| return { | |
| 'valid': False, | |
| 'error': 'Invalid filename. Contains unsafe characters.' | |
| } | |
| return { | |
| 'valid': True, | |
| 'extension': file_extension, | |
| 'mime_type': mime_type, | |
| 'size': getattr(file, 'size', None) | |
| } | |
| except Exception as e: | |
| return { | |
| 'valid': False, | |
| 'error': f'Validation error: {str(e)}' | |
| } | |
| def _is_safe_filename(filename: str) -> bool: | |
| """Check if filename is safe (no path traversal, etc.)""" | |
| if not filename: | |
| return False | |
| # Check for path traversal attempts | |
| if '..' in filename or '/' in filename or '\\' in filename: | |
| return False | |
| # Check for null bytes | |
| if '\x00' in filename: | |
| return False | |
| # Check for control characters | |
| if any(ord(c) < 32 for c in filename): | |
| return False | |
| return True | |
| def get_system_info() -> Dict[str, Any]: | |
| """ | |
| Get system information for monitoring and debugging | |
| Returns: | |
| System information dictionary | |
| """ | |
| try: | |
| # CPU information | |
| cpu_info = { | |
| 'count': psutil.cpu_count(), | |
| 'usage_percent': psutil.cpu_percent(interval=1), | |
| 'frequency': psutil.cpu_freq()._asdict() if psutil.cpu_freq() else None | |
| } | |
| # Memory information | |
| memory = psutil.virtual_memory() | |
| memory_info = { | |
| 'total_gb': round(memory.total / (1024**3), 2), | |
| 'available_gb': round(memory.available / (1024**3), 2), | |
| 'used_gb': round(memory.used / (1024**3), 2), | |
| 'percent': memory.percent | |
| } | |
| # Disk information | |
| disk = psutil.disk_usage('/') | |
| disk_info = { | |
| 'total_gb': round(disk.total / (1024**3), 2), | |
| 'free_gb': round(disk.free / (1024**3), 2), | |
| 'used_gb': round(disk.used / (1024**3), 2), | |
| 'percent': round((disk.used / disk.total) * 100, 2) | |
| } | |
| # GPU information | |
| gpu_info = {} | |
| if torch.cuda.is_available(): | |
| gpu_info = { | |
| 'available': True, | |
| 'count': torch.cuda.device_count(), | |
| 'current_device': torch.cuda.current_device(), | |
| 'device_name': torch.cuda.get_device_name(), | |
| 'memory_allocated_gb': round(torch.cuda.memory_allocated() / (1024**3), 2), | |
| 'memory_reserved_gb': round(torch.cuda.memory_reserved() / (1024**3), 2) | |
| } | |
| else: | |
| gpu_info = {'available': False} | |
| return { | |
| 'cpu': cpu_info, | |
| 'memory': memory_info, | |
| 'disk': disk_info, | |
| 'gpu': gpu_info, | |
| 'python_version': f"{psutil.sys.version_info.major}.{psutil.sys.version_info.minor}.{psutil.sys.version_info.micro}", | |
| 'platform': psutil.os.name | |
| } | |
| except Exception as e: | |
| logging.error(f"Error getting system info: {e}") | |
| return {'error': str(e)} | |
| def cleanup_temp_files(max_age_hours: int = 24) -> Dict[str, Any]: | |
| """ | |
| Clean up temporary files older than specified age | |
| Args: | |
| max_age_hours: Maximum age of files to keep (in hours) | |
| Returns: | |
| Cleanup statistics | |
| """ | |
| try: | |
| cleanup_stats = { | |
| 'files_removed': 0, | |
| 'bytes_freed': 0, | |
| 'directories_cleaned': [] | |
| } | |
| cutoff_time = time.time() - (max_age_hours * 3600) | |
| # Directories to clean | |
| temp_dirs = ['temp', 'uploads'] | |
| for dir_name in temp_dirs: | |
| dir_path = Path(dir_name) | |
| if not dir_path.exists(): | |
| continue | |
| files_removed = 0 | |
| bytes_freed = 0 | |
| for file_path in dir_path.rglob('*'): | |
| if file_path.is_file(): | |
| try: | |
| # Check file age | |
| if file_path.stat().st_mtime < cutoff_time: | |
| file_size = file_path.stat().st_size | |
| file_path.unlink() | |
| files_removed += 1 | |
| bytes_freed += file_size | |
| except Exception as e: | |
| logging.warning(f"Error removing file {file_path}: {e}") | |
| if files_removed > 0: | |
| cleanup_stats['directories_cleaned'].append({ | |
| 'directory': str(dir_path), | |
| 'files_removed': files_removed, | |
| 'bytes_freed': bytes_freed | |
| }) | |
| cleanup_stats['files_removed'] += files_removed | |
| cleanup_stats['bytes_freed'] += bytes_freed | |
| logging.info(f"Cleanup completed: {cleanup_stats['files_removed']} files removed, " | |
| f"{cleanup_stats['bytes_freed'] / (1024**2):.2f} MB freed") | |
| return cleanup_stats | |
| except Exception as e: | |
| logging.error(f"Error during cleanup: {e}") | |
| return {'error': str(e)} | |
| def calculate_file_hash(file_path: Union[str, Path], algorithm: str = 'sha256') -> str: | |
| """ | |
| Calculate hash of a file | |
| Args: | |
| file_path: Path to the file | |
| algorithm: Hash algorithm (md5, sha1, sha256, etc.) | |
| Returns: | |
| Hexadecimal hash string | |
| """ | |
| try: | |
| hash_obj = hashlib.new(algorithm) | |
| with open(file_path, 'rb') as f: | |
| for chunk in iter(lambda: f.read(8192), b""): | |
| hash_obj.update(chunk) | |
| return hash_obj.hexdigest() | |
| except Exception as e: | |
| logging.error(f"Error calculating hash for {file_path}: {e}") | |
| raise | |
| def format_bytes(bytes_value: int) -> str: | |
| """ | |
| Format bytes into human-readable string | |
| Args: | |
| bytes_value: Number of bytes | |
| Returns: | |
| Formatted string (e.g., "1.5 GB") | |
| """ | |
| for unit in ['B', 'KB', 'MB', 'GB', 'TB']: | |
| if bytes_value < 1024.0: | |
| return f"{bytes_value:.1f} {unit}" | |
| bytes_value /= 1024.0 | |
| return f"{bytes_value:.1f} PB" | |
| def format_duration(seconds: float) -> str: | |
| """ | |
| Format duration in seconds to human-readable string | |
| Args: | |
| seconds: Duration in seconds | |
| Returns: | |
| Formatted string (e.g., "2h 30m 15s") | |
| """ | |
| if seconds < 60: | |
| return f"{seconds:.1f}s" | |
| elif seconds < 3600: | |
| minutes = int(seconds // 60) | |
| secs = int(seconds % 60) | |
| return f"{minutes}m {secs}s" | |
| else: | |
| hours = int(seconds // 3600) | |
| minutes = int((seconds % 3600) // 60) | |
| secs = int(seconds % 60) | |
| return f"{hours}h {minutes}m {secs}s" | |
| def create_progress_tracker(): | |
| """ | |
| Create a progress tracking utility | |
| Returns: | |
| Progress tracker instance | |
| """ | |
| class ProgressTracker: | |
| def __init__(self): | |
| self.start_time = time.time() | |
| self.last_update = self.start_time | |
| self.steps_completed = 0 | |
| self.total_steps = 0 | |
| def update(self, current_step: int, total_steps: int, message: str = ""): | |
| self.steps_completed = current_step | |
| self.total_steps = total_steps | |
| self.last_update = time.time() | |
| # Calculate progress metrics | |
| progress = current_step / total_steps if total_steps > 0 else 0 | |
| elapsed = self.last_update - self.start_time | |
| if progress > 0: | |
| eta = (elapsed / progress) * (1 - progress) | |
| eta_str = format_duration(eta) | |
| else: | |
| eta_str = "Unknown" | |
| return { | |
| 'progress': progress, | |
| 'current_step': current_step, | |
| 'total_steps': total_steps, | |
| 'elapsed': format_duration(elapsed), | |
| 'eta': eta_str, | |
| 'message': message | |
| } | |
| return ProgressTracker() | |
| def safe_json_load(file_path: Union[str, Path]) -> Optional[Dict[str, Any]]: | |
| """ | |
| Safely load JSON file with error handling | |
| Args: | |
| file_path: Path to JSON file | |
| Returns: | |
| Loaded JSON data or None if error | |
| """ | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| logging.warning(f"Error loading JSON from {file_path}: {e}") | |
| return None | |
| def safe_json_save(data: Dict[str, Any], file_path: Union[str, Path]) -> bool: | |
| """ | |
| Safely save data to JSON file | |
| Args: | |
| data: Data to save | |
| file_path: Path to save file | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| try: | |
| # Ensure directory exists | |
| Path(file_path).parent.mkdir(parents=True, exist_ok=True) | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, indent=2, ensure_ascii=False) | |
| return True | |
| except Exception as e: | |
| logging.error(f"Error saving JSON to {file_path}: {e}") | |
| return False | |
| def get_available_memory() -> float: | |
| """ | |
| Get available system memory in GB | |
| Returns: | |
| Available memory in GB | |
| """ | |
| try: | |
| memory = psutil.virtual_memory() | |
| return memory.available / (1024**3) | |
| except Exception: | |
| return 0.0 | |
| def check_disk_space(path: str = ".", min_gb: float = 1.0) -> bool: | |
| """ | |
| Check if there's enough disk space | |
| Args: | |
| path: Path to check | |
| min_gb: Minimum required space in GB | |
| Returns: | |
| True if enough space available | |
| """ | |
| try: | |
| disk = psutil.disk_usage(path) | |
| free_gb = disk.free / (1024**3) | |
| return free_gb >= min_gb | |
| except Exception: | |
| return False | |