Spaces:
Running
Running
| """ | |
| Advanced Memory Manager for CPU-only training with 16GB RAM constraint | |
| Optimized for Hugging Face Spaces free tier | |
| """ | |
| import os | |
| import gc | |
| import psutil | |
| import logging | |
| import threading | |
| import time | |
| from typing import Dict, Any, Optional, List, Callable | |
| from pathlib import Path | |
| import torch | |
| import numpy as np | |
| from contextlib import contextmanager | |
| logger = logging.getLogger(__name__) | |
| class AdvancedMemoryManager: | |
| """ | |
| Advanced memory management for CPU-only training with strict memory constraints | |
| """ | |
| def __init__(self, max_memory_gb: float = 14.0): | |
| """ | |
| Initialize memory manager | |
| Args: | |
| max_memory_gb: Maximum memory usage in GB (default 14GB for 16GB systems) | |
| """ | |
| self.max_memory_bytes = max_memory_gb * 1024**3 | |
| self.current_memory_usage = 0 | |
| self.memory_threshold_warning = 0.8 # 80% warning | |
| self.memory_threshold_critical = 0.9 # 90% critical | |
| self.memory_threshold_emergency = 0.95 # 95% emergency cleanup | |
| # Memory tracking | |
| self.allocated_objects = {} | |
| self.memory_history = [] | |
| self.cleanup_callbacks = [] | |
| # Threading for monitoring | |
| self.monitoring_active = False | |
| self.monitor_thread = None | |
| # CPU optimization | |
| self.cpu_count = os.cpu_count() | |
| torch.set_num_threads(min(self.cpu_count, 8)) # Limit threads for stability | |
| logger.info(f"Memory Manager initialized with {max_memory_gb}GB limit") | |
| logger.info(f"CPU threads set to: {torch.get_num_threads()}") | |
| def get_memory_info(self) -> Dict[str, Any]: | |
| """Get current memory information""" | |
| process = psutil.Process() | |
| memory_info = process.memory_info() | |
| system_memory = psutil.virtual_memory() | |
| return { | |
| 'process_memory_mb': memory_info.rss / 1024**2, | |
| 'process_memory_percent': (memory_info.rss / system_memory.total) * 100, | |
| 'system_memory_total_gb': system_memory.total / 1024**3, | |
| 'system_memory_available_gb': system_memory.available / 1024**3, | |
| 'system_memory_percent': system_memory.percent, | |
| 'max_allowed_gb': self.max_memory_bytes / 1024**3, | |
| 'torch_allocated_mb': torch.cuda.memory_allocated() / 1024**2 if torch.cuda.is_available() else 0, | |
| 'torch_cached_mb': torch.cuda.memory_reserved() / 1024**2 if torch.cuda.is_available() else 0 | |
| } | |
| def check_memory_status(self) -> str: | |
| """Check current memory status""" | |
| memory_info = self.get_memory_info() | |
| usage_ratio = memory_info['process_memory_mb'] * 1024**2 / self.max_memory_bytes | |
| if usage_ratio >= self.memory_threshold_emergency: | |
| return 'emergency' | |
| elif usage_ratio >= self.memory_threshold_critical: | |
| return 'critical' | |
| elif usage_ratio >= self.memory_threshold_warning: | |
| return 'warning' | |
| else: | |
| return 'normal' | |
| def force_cleanup(self): | |
| """Force aggressive memory cleanup""" | |
| logger.warning("Performing emergency memory cleanup") | |
| # Clear Python garbage | |
| collected = gc.collect() | |
| logger.info(f"Garbage collection freed {collected} objects") | |
| # Clear PyTorch cache | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Run cleanup callbacks | |
| for callback in self.cleanup_callbacks: | |
| try: | |
| callback() | |
| except Exception as e: | |
| logger.error(f"Cleanup callback failed: {e}") | |
| # Force another garbage collection | |
| gc.collect() | |
| memory_info = self.get_memory_info() | |
| logger.info(f"Memory after cleanup: {memory_info['process_memory_mb']:.1f}MB") | |
| def memory_context(self, operation_name: str, expected_memory_mb: float = 0): | |
| """Context manager for memory-aware operations""" | |
| start_memory = self.get_memory_info() | |
| logger.debug(f"Starting {operation_name}, memory: {start_memory['process_memory_mb']:.1f}MB") | |
| # Check if we have enough memory | |
| if expected_memory_mb > 0: | |
| available_mb = (self.max_memory_bytes / 1024**2) - start_memory['process_memory_mb'] | |
| if expected_memory_mb > available_mb * 0.8: # 80% safety margin | |
| logger.warning(f"Operation {operation_name} may exceed memory limit") | |
| self.force_cleanup() | |
| try: | |
| yield self | |
| finally: | |
| end_memory = self.get_memory_info() | |
| memory_diff = end_memory['process_memory_mb'] - start_memory['process_memory_mb'] | |
| logger.debug(f"Completed {operation_name}, memory change: {memory_diff:+.1f}MB") | |
| # Check if cleanup is needed | |
| status = self.check_memory_status() | |
| if status in ['critical', 'emergency']: | |
| self.force_cleanup() | |
| def register_cleanup_callback(self, callback: Callable): | |
| """Register a cleanup callback function""" | |
| self.cleanup_callbacks.append(callback) | |
| def start_monitoring(self, interval_seconds: float = 30.0): | |
| """Start memory monitoring thread""" | |
| if self.monitoring_active: | |
| return | |
| self.monitoring_active = True | |
| self.monitor_thread = threading.Thread( | |
| target=self._monitor_memory, | |
| args=(interval_seconds,), | |
| daemon=True | |
| ) | |
| self.monitor_thread.start() | |
| logger.info("Memory monitoring started") | |
| def stop_monitoring(self): | |
| """Stop memory monitoring""" | |
| self.monitoring_active = False | |
| if self.monitor_thread: | |
| self.monitor_thread.join(timeout=5.0) | |
| logger.info("Memory monitoring stopped") | |
| def _monitor_memory(self, interval_seconds: float): | |
| """Internal memory monitoring loop""" | |
| while self.monitoring_active: | |
| try: | |
| memory_info = self.get_memory_info() | |
| status = self.check_memory_status() | |
| # Log memory status | |
| if status != 'normal': | |
| logger.warning(f"Memory status: {status}, usage: {memory_info['process_memory_mb']:.1f}MB") | |
| # Auto cleanup if needed | |
| if status == 'emergency': | |
| self.force_cleanup() | |
| elif status == 'critical': | |
| gc.collect() | |
| # Store history | |
| self.memory_history.append({ | |
| 'timestamp': time.time(), | |
| 'memory_mb': memory_info['process_memory_mb'], | |
| 'status': status | |
| }) | |
| # Keep only last 100 entries | |
| if len(self.memory_history) > 100: | |
| self.memory_history = self.memory_history[-100:] | |
| time.sleep(interval_seconds) | |
| except Exception as e: | |
| logger.error(f"Memory monitoring error: {e}") | |
| time.sleep(interval_seconds) | |
| def get_memory_recommendations(self) -> List[str]: | |
| """Get memory optimization recommendations""" | |
| memory_info = self.get_memory_info() | |
| recommendations = [] | |
| if memory_info['process_memory_mb'] > 8000: # > 8GB | |
| recommendations.append("Consider using smaller batch sizes") | |
| recommendations.append("Enable gradient checkpointing") | |
| recommendations.append("Use model sharding for large models") | |
| if memory_info['system_memory_percent'] > 80: | |
| recommendations.append("Close unnecessary applications") | |
| recommendations.append("Consider using swap memory") | |
| if len(self.memory_history) > 10: | |
| recent_growth = self.memory_history[-1]['memory_mb'] - self.memory_history[-10]['memory_mb'] | |
| if recent_growth > 1000: # > 1GB growth | |
| recommendations.append("Memory usage is growing rapidly - check for memory leaks") | |
| return recommendations | |
| def optimize_torch_settings(self): | |
| """Optimize PyTorch settings for CPU and memory efficiency""" | |
| # Set optimal thread count | |
| torch.set_num_threads(min(self.cpu_count, 8)) | |
| # Enable memory efficient attention if available | |
| try: | |
| torch.backends.cuda.enable_flash_sdp(False) # Disable for CPU | |
| torch.backends.cuda.enable_math_sdp(True) | |
| torch.backends.cuda.enable_mem_efficient_sdp(True) | |
| except: | |
| pass | |
| # Set memory allocation strategy | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' | |
| logger.info("PyTorch settings optimized for CPU and memory efficiency") | |
| def __enter__(self): | |
| self.start_monitoring() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.stop_monitoring() | |
| self.force_cleanup() | |