Spaces:
Running
Running
File size: 1,157 Bytes
d18b34d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import os
import tempfile
import unittest
import torch
from src.utils.io import load_state, save_state
class TestIO(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.ckpt_path = os.path.join(self.temp_dir.name, "model.pt")
self.model = torch.nn.Linear(10, 2)
def tearDown(self):
self.temp_dir.cleanup()
def test_save_and_load_state(self):
# Save
save_state(self.model, self.ckpt_path)
self.assertTrue(os.path.exists(self.ckpt_path))
# Modify model
original_weight = self.model.weight.clone()
torch.nn.init.xavier_uniform_(self.model.weight)
self.assertFalse(torch.equal(self.model.weight, original_weight))
# Load
load_state(self.model, self.ckpt_path)
self.assertTrue(torch.equal(self.model.weight, original_weight))
def test_save_creates_directories(self):
nested_path = os.path.join(self.temp_dir.name, "subdir", "model.pt")
save_state(self.model, nested_path)
self.assertTrue(os.path.exists(nested_path))
if __name__ == "__main__":
unittest.main()
|