Spaces:
Running
Running
| """ | |
| Advanced Chunk Loader for large models with memory constraints | |
| Optimized for CPU-only training on 16GB RAM systems | |
| """ | |
| import os | |
| import gc | |
| import mmap | |
| import logging | |
| import asyncio | |
| from typing import Dict, Any, List, Optional, Iterator, Union | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel, AutoConfig, AutoTokenizer | |
| from safetensors import safe_open | |
| import numpy as np | |
| from .memory_manager import AdvancedMemoryManager | |
| logger = logging.getLogger(__name__) | |
| class ModelChunk: | |
| """Represents a chunk of a large model""" | |
| def __init__(self, chunk_id: str, parameters: Dict[str, torch.Tensor], | |
| metadata: Dict[str, Any]): | |
| self.chunk_id = chunk_id | |
| self.parameters = parameters | |
| self.metadata = metadata | |
| self.is_loaded = True | |
| self.memory_size_mb = sum(p.numel() * p.element_size() for p in parameters.values()) / 1024**2 | |
| def unload(self): | |
| """Unload chunk from memory""" | |
| if self.is_loaded: | |
| del self.parameters | |
| self.parameters = {} | |
| self.is_loaded = False | |
| gc.collect() | |
| logger.debug(f"Unloaded chunk {self.chunk_id}") | |
| def __del__(self): | |
| if hasattr(self, 'is_loaded') and self.is_loaded: | |
| self.unload() | |
| class AdvancedChunkLoader: | |
| """ | |
| Advanced chunk loader for handling large models with memory constraints | |
| """ | |
| def __init__(self, memory_manager: AdvancedMemoryManager, | |
| chunk_size_mb: float = 500.0): | |
| """ | |
| Initialize chunk loader | |
| Args: | |
| memory_manager: Memory manager instance | |
| chunk_size_mb: Target size for each chunk in MB | |
| """ | |
| self.memory_manager = memory_manager | |
| self.chunk_size_mb = chunk_size_mb | |
| self.chunk_size_bytes = chunk_size_mb * 1024**2 | |
| self.loaded_chunks = {} | |
| self.chunk_cache = {} | |
| self.max_cached_chunks = 3 | |
| # Register cleanup callback | |
| self.memory_manager.register_cleanup_callback(self._cleanup_chunks) | |
| logger.info(f"Chunk loader initialized with {chunk_size_mb}MB chunks") | |
| async def load_model_in_chunks(self, model_path: str, **kwargs) -> Dict[str, Any]: | |
| """ | |
| Load a large model in chunks | |
| Args: | |
| model_path: Path to model (local or HF repo) | |
| **kwargs: Additional loading parameters | |
| Returns: | |
| Model metadata and chunk information | |
| """ | |
| with self.memory_manager.memory_context("load_model_in_chunks"): | |
| logger.info(f"Loading model in chunks: {model_path}") | |
| # First, get model config and size estimation | |
| config = await self._load_model_config(model_path, **kwargs) | |
| estimated_size_mb = self._estimate_model_size(config) | |
| logger.info(f"Estimated model size: {estimated_size_mb:.1f}MB") | |
| if estimated_size_mb <= self.chunk_size_mb * 2: | |
| # Small model, load normally | |
| return await self._load_small_model(model_path, config, **kwargs) | |
| else: | |
| # Large model, use chunking | |
| return await self._load_large_model_chunked(model_path, config, **kwargs) | |
| async def _load_model_config(self, model_path: str, **kwargs) -> AutoConfig: | |
| """Load model configuration""" | |
| try: | |
| hf_token = kwargs.get('token') or os.getenv('HF_TOKEN') | |
| trust_remote_code = kwargs.get('trust_remote_code', False) | |
| config = AutoConfig.from_pretrained( | |
| model_path, | |
| trust_remote_code=trust_remote_code, | |
| token=hf_token, | |
| timeout=30 | |
| ) | |
| return config | |
| except Exception as e: | |
| logger.error(f"Failed to load config for {model_path}: {e}") | |
| raise | |
| def _estimate_model_size(self, config: AutoConfig) -> float: | |
| """Estimate model size in MB""" | |
| try: | |
| # Get basic parameters | |
| hidden_size = getattr(config, 'hidden_size', 768) | |
| num_layers = getattr(config, 'num_hidden_layers', | |
| getattr(config, 'num_layers', 12)) | |
| vocab_size = getattr(config, 'vocab_size', 50000) | |
| # Rough estimation for transformer models | |
| embedding_params = vocab_size * hidden_size | |
| layer_params = num_layers * (hidden_size * hidden_size * 4) # Simplified | |
| total_params = embedding_params + layer_params | |
| # Convert to MB (4 bytes per parameter for float32) | |
| size_mb = (total_params * 4) / (1024 ** 2) | |
| return max(size_mb, 100) # Minimum 100MB | |
| except Exception: | |
| return 2000 # Default 2GB if estimation fails | |
| async def _load_small_model(self, model_path: str, config: AutoConfig, | |
| **kwargs) -> Dict[str, Any]: | |
| """Load small model normally""" | |
| logger.info(f"Loading small model normally: {model_path}") | |
| hf_token = kwargs.get('token') or os.getenv('HF_TOKEN') | |
| trust_remote_code = kwargs.get('trust_remote_code', False) | |
| try: | |
| # Load model with CPU optimization | |
| model = AutoModel.from_pretrained( | |
| model_path, | |
| config=config, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=trust_remote_code, | |
| token=hf_token, | |
| low_cpu_mem_usage=True, | |
| device_map='cpu' | |
| ) | |
| # Load tokenizer/processor | |
| tokenizer = None | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| token=hf_token, | |
| trust_remote_code=trust_remote_code | |
| ) | |
| except: | |
| logger.warning(f"Could not load tokenizer for {model_path}") | |
| return { | |
| 'model': model, | |
| 'tokenizer': tokenizer, | |
| 'config': config, | |
| 'is_chunked': False, | |
| 'source': model_path, | |
| 'estimated_size_mb': self._estimate_model_size(config) | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to load small model {model_path}: {e}") | |
| raise | |
| async def _load_large_model_chunked(self, model_path: str, config: AutoConfig, | |
| **kwargs) -> Dict[str, Any]: | |
| """Load large model using chunking strategy""" | |
| logger.info(f"Loading large model with chunking: {model_path}") | |
| # Create chunks metadata | |
| chunks_info = await self._create_chunks_metadata(model_path, config, **kwargs) | |
| # Load first chunk to get model structure | |
| first_chunk = await self._load_chunk(model_path, chunks_info[0], **kwargs) | |
| return { | |
| 'model': None, # No single model object for chunked models | |
| 'chunks_info': chunks_info, | |
| 'first_chunk': first_chunk, | |
| 'config': config, | |
| 'is_chunked': True, | |
| 'source': model_path, | |
| 'total_chunks': len(chunks_info), | |
| 'estimated_size_mb': self._estimate_model_size(config) | |
| } | |
| async def _create_chunks_metadata(self, model_path: str, config: AutoConfig, | |
| **kwargs) -> List[Dict[str, Any]]: | |
| """Create metadata for model chunks""" | |
| # This is a simplified chunking strategy | |
| # In practice, you'd analyze the model structure more carefully | |
| estimated_size_mb = self._estimate_model_size(config) | |
| num_chunks = max(1, int(estimated_size_mb / self.chunk_size_mb)) | |
| chunks_info = [] | |
| for i in range(num_chunks): | |
| chunk_info = { | |
| 'chunk_id': f"chunk_{i}", | |
| 'start_layer': i * (config.num_hidden_layers // num_chunks), | |
| 'end_layer': min((i + 1) * (config.num_hidden_layers // num_chunks), | |
| config.num_hidden_layers), | |
| 'estimated_size_mb': estimated_size_mb / num_chunks, | |
| 'parameters': [] # Will be populated during loading | |
| } | |
| chunks_info.append(chunk_info) | |
| return chunks_info | |
| async def _load_chunk(self, model_path: str, chunk_info: Dict[str, Any], | |
| **kwargs) -> ModelChunk: | |
| """Load a specific chunk of the model""" | |
| chunk_id = chunk_info['chunk_id'] | |
| with self.memory_manager.memory_context(f"load_chunk_{chunk_id}"): | |
| logger.debug(f"Loading chunk {chunk_id}") | |
| # For now, this is a placeholder implementation | |
| # In practice, you'd implement layer-wise loading | |
| parameters = {} | |
| # Create dummy parameters for demonstration | |
| # Replace with actual chunk loading logic | |
| hidden_size = getattr(kwargs.get('config', {}), 'hidden_size', 768) | |
| chunk_params = torch.randn(hidden_size, hidden_size) * 0.02 | |
| parameters[f'{chunk_id}_weight'] = chunk_params | |
| metadata = { | |
| 'chunk_id': chunk_id, | |
| 'layer_range': (chunk_info['start_layer'], chunk_info['end_layer']), | |
| 'parameter_count': sum(p.numel() for p in parameters.values()) | |
| } | |
| chunk = ModelChunk(chunk_id, parameters, metadata) | |
| self.loaded_chunks[chunk_id] = chunk | |
| # Manage cache | |
| await self._manage_chunk_cache() | |
| return chunk | |
| async def _manage_chunk_cache(self): | |
| """Manage chunk cache to prevent memory overflow""" | |
| if len(self.loaded_chunks) > self.max_cached_chunks: | |
| # Remove oldest chunks | |
| chunks_to_remove = list(self.loaded_chunks.keys())[:-self.max_cached_chunks] | |
| for chunk_id in chunks_to_remove: | |
| chunk = self.loaded_chunks.pop(chunk_id) | |
| chunk.unload() | |
| logger.debug(f"Removed chunk {chunk_id} from cache") | |
| def _cleanup_chunks(self): | |
| """Cleanup callback for memory manager""" | |
| logger.info("Cleaning up loaded chunks") | |
| for chunk in self.loaded_chunks.values(): | |
| chunk.unload() | |
| self.loaded_chunks.clear() | |
| gc.collect() | |
| async def get_chunk_iterator(self, model_info: Dict[str, Any]) -> Iterator[ModelChunk]: | |
| """Get iterator for model chunks""" | |
| if not model_info.get('is_chunked', False): | |
| # Not a chunked model | |
| yield model_info['model'] | |
| return | |
| chunks_info = model_info['chunks_info'] | |
| model_path = model_info['source'] | |
| for chunk_info in chunks_info: | |
| chunk = await self._load_chunk(model_path, chunk_info) | |
| yield chunk | |
| # Optionally unload chunk after yielding | |
| # chunk.unload() | |
| def get_memory_usage(self) -> Dict[str, float]: | |
| """Get current memory usage of loaded chunks""" | |
| total_memory_mb = sum(chunk.memory_size_mb for chunk in self.loaded_chunks.values()) | |
| return { | |
| 'total_chunks_memory_mb': total_memory_mb, | |
| 'loaded_chunks_count': len(self.loaded_chunks), | |
| 'average_chunk_size_mb': total_memory_mb / len(self.loaded_chunks) if self.loaded_chunks else 0 | |
| } | |