OliverPerrin's picture
Refactor: Consolidate dependencies, improve testing, and add CI/CD
d18b34d
import os
import tempfile
import unittest
import torch
from src.utils.io import load_state, save_state
class TestIO(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.ckpt_path = os.path.join(self.temp_dir.name, "model.pt")
self.model = torch.nn.Linear(10, 2)
def tearDown(self):
self.temp_dir.cleanup()
def test_save_and_load_state(self):
# Save
save_state(self.model, self.ckpt_path)
self.assertTrue(os.path.exists(self.ckpt_path))
# Modify model
original_weight = self.model.weight.clone()
torch.nn.init.xavier_uniform_(self.model.weight)
self.assertFalse(torch.equal(self.model.weight, original_weight))
# Load
load_state(self.model, self.ckpt_path)
self.assertTrue(torch.equal(self.model.weight, original_weight))
def test_save_creates_directories(self):
nested_path = os.path.join(self.temp_dir.name, "subdir", "model.pt")
save_state(self.model, nested_path)
self.assertTrue(os.path.exists(nested_path))
if __name__ == "__main__":
unittest.main()