Spaces:
Running
Running
| import os | |
| import tempfile | |
| import unittest | |
| import torch | |
| from src.utils.io import load_state, save_state | |
| class TestIO(unittest.TestCase): | |
| def setUp(self): | |
| self.temp_dir = tempfile.TemporaryDirectory() | |
| self.ckpt_path = os.path.join(self.temp_dir.name, "model.pt") | |
| self.model = torch.nn.Linear(10, 2) | |
| def tearDown(self): | |
| self.temp_dir.cleanup() | |
| def test_save_and_load_state(self): | |
| # Save | |
| save_state(self.model, self.ckpt_path) | |
| self.assertTrue(os.path.exists(self.ckpt_path)) | |
| # Modify model | |
| original_weight = self.model.weight.clone() | |
| torch.nn.init.xavier_uniform_(self.model.weight) | |
| self.assertFalse(torch.equal(self.model.weight, original_weight)) | |
| # Load | |
| load_state(self.model, self.ckpt_path) | |
| self.assertTrue(torch.equal(self.model.weight, original_weight)) | |
| def test_save_creates_directories(self): | |
| nested_path = os.path.join(self.temp_dir.name, "subdir", "model.pt") | |
| save_state(self.model, nested_path) | |
| self.assertTrue(os.path.exists(nested_path)) | |
| if __name__ == "__main__": | |
| unittest.main() | |