Spaces:
Build error
Build error
| import sys | |
| import logging | |
| import yaml | |
| from pathlib import Path | |
| import argparse | |
| # Add src to path for imports | |
| sys.path.append(str(Path(__file__).parent / "src")) | |
| from structure_prep import prepare_structure, load_existing_structure_files | |
| from md_simulation import run_md_simulation, load_existing_simulation_results | |
| from compute_descriptors import compute_descriptors, load_existing_descriptors | |
| from model_inference import run_model_inference, load_existing_predictions | |
| from cleanup_temp_files import cleanup_temp_directory | |
| from timing import get_timing_report, reset_timing_report, time_step | |
| def main(): | |
| # Parse command line arguments | |
| parser = argparse.ArgumentParser(description='AbMelt Inference Pipeline') | |
| # Input options - either sequences or PDB file | |
| input_group = parser.add_mutually_exclusive_group(required=True) | |
| input_group.add_argument('--h', '--heavy', type=str, | |
| help='Heavy chain amino acid sequence (use with --l)') | |
| input_group.add_argument('--pdb', type=str, | |
| help='Input PDB file path') | |
| parser.add_argument('--l', '--light', type=str, | |
| help='Light chain amino acid sequence (use with --h)') | |
| parser.add_argument('--name', type=str, default='antibody', | |
| help='Antibody name/identifier') | |
| parser.add_argument('--config', type=str, | |
| help='Configuration file path') | |
| parser.add_argument('--output', type=str, default='results', | |
| help='Output directory') | |
| # Skip step flags | |
| parser.add_argument('--skip-structure', action='store_true', | |
| help='Skip structure preparation step (load existing files)') | |
| parser.add_argument('--skip-md', action='store_true', | |
| help='Skip MD simulation step (load existing trajectory files)') | |
| parser.add_argument('--skip-descriptors', action='store_true', | |
| help='Skip descriptor computation step (load existing descriptors)') | |
| parser.add_argument('--skip-inference', action='store_true', | |
| help='Skip model inference step (load existing predictions)') | |
| # Timing options | |
| parser.add_argument('--timing-report', type=str, metavar='PATH', | |
| help='Save timing report to JSON file (also prints summary to console)') | |
| args = parser.parse_args() | |
| # Validate arguments | |
| if args.h and not args.l: | |
| parser.error("--l/--light is required when using --h/--heavy") | |
| if args.l and not args.h: | |
| parser.error("--h/--heavy is required when using --l/--light") | |
| # 1. Load configuration | |
| config = load_config(args.config) | |
| # 2. Setup logging and directories | |
| setup_logging(config) | |
| create_directories(config) | |
| # Initialize timing report | |
| reset_timing_report() | |
| timing_report = get_timing_report() | |
| timing_report.start() | |
| # 3. Create antibody input based on input type | |
| if args.pdb: | |
| # Input from PDB file | |
| antibody = { | |
| "name": args.name, | |
| "pdb_file": args.pdb, | |
| "type": "pdb" | |
| } | |
| else: | |
| # Input from sequences | |
| antibody = { | |
| "name": args.name, | |
| "heavy_chain": args.h, | |
| "light_chain": args.l, | |
| "type": "sequences" | |
| } | |
| # 4. Run inference pipeline | |
| try: | |
| result = run_inference_pipeline( | |
| antibody, | |
| config, | |
| skip_structure=args.skip_structure, | |
| skip_md=args.skip_md, | |
| skip_descriptors=args.skip_descriptors, | |
| skip_inference=args.skip_inference | |
| ) | |
| print(f"Inference pipeline for {args.name}:") | |
| print(f" Status: {result['status']}") | |
| print(f" Message: {result['message']}") | |
| print(f" PDB file: {result['structure_files']['pdb_file']}") | |
| print(f" Work directory: {result['structure_files']['work_dir']}") | |
| if 'chains' in result['structure_files']: | |
| print(f" Chains found: {list(result['structure_files']['chains'].keys())}") | |
| if 'simulation_result' in result: | |
| print(f" MD simulations completed at temperatures: {list(result['simulation_result']['trajectory_files'].keys())}") | |
| for temp, files in result['simulation_result']['trajectory_files'].items(): | |
| print(f" {temp}K: {files['final_xtc']}") | |
| if 'descriptor_result' in result: | |
| print(f" Descriptors computed: {result['descriptor_result']['descriptors_df'].shape[1]} features") | |
| print(f" XVG files generated: {len(result['descriptor_result']['xvg_files'])}") | |
| if 'inference_result' in result: | |
| print(f"\n=== PREDICTIONS ===") | |
| predictions = result['inference_result']['predictions'] | |
| for model_name, pred in predictions.items(): | |
| if pred is not None: | |
| print(f" {model_name.upper()}: {pred[0]:.3f}") | |
| else: | |
| print(f" {model_name.upper()}: FAILED") | |
| # Add timing data to result | |
| result['timing'] = timing_report.to_dict() | |
| finally: | |
| # Stop timing - always runs even on exception | |
| timing_report.stop() | |
| # Print timing report - always runs even on exception | |
| print(timing_report.format_summary()) | |
| # Save timing report if requested - always runs even on exception | |
| if args.timing_report: | |
| timing_report.save_json(args.timing_report) | |
| print(f"\nTiming report saved to: {args.timing_report}") | |
| return result | |
| def load_config(config_path: str) -> dict: | |
| """Load configuration from YAML file.""" | |
| try: | |
| with open(config_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| except Exception as e: | |
| raise Exception(f"Failed to load config: {e}") | |
| def setup_logging(config: dict): | |
| """Setup logging configuration.""" | |
| log_level = getattr(logging, config["logging"]["level"].upper()) | |
| log_file = config["logging"]["file"] | |
| # Create log directory if it doesn't exist | |
| Path(log_file).parent.mkdir(parents=True, exist_ok=True) | |
| logging.basicConfig( | |
| level=log_level, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler(log_file), | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| def create_directories(config: dict): | |
| """Create necessary directories.""" | |
| script_directory = Path(__file__).parent.resolve() | |
| config['paths']['output_dir'] = script_directory / config["paths"]["run_dir"] / config['paths']['output_dir'] | |
| config['paths']['temp_dir'] = script_directory / config["paths"]["run_dir"] / config['paths']['temp_dir'] | |
| config['paths']['log_dir'] = script_directory / config["paths"]["run_dir"] / config['paths']['log_dir'] | |
| directories = [ | |
| config["paths"]["output_dir"], | |
| config["paths"]["temp_dir"], | |
| config["paths"]["log_dir"] | |
| ] | |
| for directory in directories: | |
| Path(directory).mkdir(parents=True, exist_ok=True) | |
| def run_inference_pipeline(antibody, config, skip_structure=False, skip_md=False, skip_descriptors=False, skip_inference=False): | |
| """ | |
| Run the complete inference pipeline. | |
| Args: | |
| antibody: Dictionary containing antibody information | |
| config: Configuration dictionary | |
| skip_structure: If True, load existing structure files instead of preparing | |
| skip_md: If True, load existing MD simulation results instead of running | |
| skip_descriptors: If True, load existing descriptors instead of computing | |
| skip_inference: If True, load existing predictions instead of computing | |
| Returns: | |
| Dictionary containing pipeline results | |
| """ | |
| logging.info(f"Starting inference pipeline for antibody: {antibody['name']}") | |
| if skip_structure: | |
| logging.info("Skipping structure preparation (using --skip-structure flag)") | |
| if skip_md: | |
| logging.info("Skipping MD simulation (using --skip-md flag)") | |
| if skip_descriptors: | |
| logging.info("Skipping descriptor computation (using --skip-descriptors flag)") | |
| if skip_inference: | |
| logging.info("Skipping model inference (using --skip-inference flag)") | |
| try: | |
| # Step 1: Structure preparation | |
| with time_step("Structure Preparation"): | |
| if skip_structure: | |
| logging.info("Step 1: Loading existing structure files...") | |
| structure_files = load_existing_structure_files(antibody, config) | |
| logging.info("Structure files loaded successfully") | |
| else: | |
| logging.info("Step 1: Preparing structure...") | |
| structure_files = prepare_structure(antibody, config) | |
| logging.info("Structure preparation completed") | |
| # Log structure files | |
| logging.info(f"Structure files:") | |
| for key, path in structure_files.items(): | |
| if key != "chains": | |
| logging.info(f" {key}: {path}") | |
| if "chains" in structure_files: | |
| logging.info(f" chains: {list(structure_files['chains'].keys())}") | |
| # Step 2: MD simulation | |
| with time_step("MD Simulation"): | |
| if skip_md: | |
| logging.info("Step 2: Loading existing MD simulation results...") | |
| simulation_result = load_existing_simulation_results(structure_files, config) | |
| logging.info("MD simulation results loaded successfully") | |
| else: | |
| logging.info("Step 2: Running MD simulations...") | |
| simulation_result = run_md_simulation(structure_files, config) | |
| logging.info("MD simulations completed") | |
| # Log trajectory files | |
| logging.info(f"Trajectory files:") | |
| for temp, files in simulation_result["trajectory_files"].items(): | |
| logging.info(f" {temp}K: {files['final_xtc']}") | |
| # Step 3: Descriptor computation | |
| with time_step("Descriptor Computation"): | |
| if skip_descriptors: | |
| logging.info("Step 3: Loading existing descriptor computation results...") | |
| descriptor_result = load_existing_descriptors(simulation_result, config) | |
| logging.info("Descriptor computation results loaded successfully") | |
| else: | |
| logging.info("Step 3: Computing descriptors...") | |
| descriptor_result = compute_descriptors(simulation_result, config) | |
| logging.info("Descriptor computation completed") | |
| # Log descriptor computation results | |
| logging.info(f"Descriptors:") | |
| logging.info(f" DataFrame shape: {descriptor_result['descriptors_df'].shape}") | |
| logging.info(f" Number of features: {len(descriptor_result['descriptors_df'].columns)}") | |
| logging.info(f" XVG files: {len(descriptor_result['xvg_files'])}") | |
| # Step 4: Model inference | |
| with time_step("Model Inference"): | |
| if skip_inference: | |
| logging.info("Step 4: Loading existing model predictions...") | |
| work_dir = Path(descriptor_result['work_dir']) | |
| inference_result = load_existing_predictions(work_dir, antibody['name']) | |
| logging.info("Model predictions loaded successfully") | |
| else: | |
| logging.info("Step 4: Running model inference...") | |
| inference_result = run_model_inference(descriptor_result, config) | |
| logging.info("Model inference completed") | |
| # Log prediction results | |
| logging.info(f"Predictions:") | |
| for model_name, pred in inference_result['predictions'].items(): | |
| if pred is not None: | |
| logging.info(f" {model_name}: {pred[0]:.3f}") | |
| else: | |
| logging.info(f" {model_name}: FAILED") | |
| # Cleanup intermediate files if configured | |
| cleanup_config = config.get("performance", {}) | |
| if cleanup_config.get("cleanup_temp", False): | |
| cleanup_after = cleanup_config.get("cleanup_after", "inference") | |
| if cleanup_after == "inference": | |
| logging.info("Cleaning up intermediate files...") | |
| try: | |
| temperatures = [str(t) for t in config["simulation"]["temperatures"]] | |
| cleanup_stats = cleanup_temp_directory( | |
| work_dir=Path(descriptor_result['work_dir']), | |
| antibody_name=antibody['name'], | |
| temperatures=temperatures, | |
| dry_run=False, | |
| keep_order_params=not cleanup_config.get("delete_order_params", False) | |
| ) | |
| logging.info(f"Cleanup completed: deleted {cleanup_stats.get('deleted', 0)} files") | |
| except Exception as e: | |
| logging.warning(f"Cleanup failed (non-fatal): {e}") | |
| result = { | |
| "status": "success", | |
| "structure_files": structure_files, | |
| "simulation_result": simulation_result, | |
| "descriptor_result": descriptor_result, | |
| "inference_result": inference_result, | |
| "message": "Complete inference pipeline finished successfully." | |
| } | |
| logging.info("Inference pipeline completed successfully") | |
| return result | |
| except Exception as e: | |
| logging.error(f"Inference pipeline failed: {e}") | |
| raise | |
| if __name__ == "__main__": | |
| main() |