try: # Try loading as state_dict first state_dict = torch.load(model_file, map_location=device) if isinstance(state_dict, dict): print("Loaded state_dict, initializing model...") from my_model import MyModel # import your model class model = MyModel(...) # init with same architecture model.load_state_dict(state_dict) else: print("Loaded full model object.") model = state_dict except Exception as e: print("state_dict load failed, retrying with weights_only=False:", e) model = torch.load(model_file, map_location=device, weights_only=False) model.to(device) model.eval()