import spaces @spaces.GPU def run(): import torch from torchao.quantization import Float8DynamicActivationFloat8WeightConfig from torchao.quantization import quantize_ class ToyLinearModel(torch.nn.Module): def __init__(self, m: int, n: int): super().__init__() self.linear = torch.nn.Linear(m, n) def forward(self, x): x = self.linear(x) return x module = ToyLinearModel(1024, 1024).to(device='cuda', dtype=torch.bfloat16) args = (torch.randn(1, 1024, dtype=torch.bfloat16, device='cuda'),) exported = torch.export.export(module, args=args) compiled = torch._inductor.aoti_load_package(torch._inductor.aoti_compile_and_package(exported)) assert set(exported.state_dict.keys()) == set(module.state_dict().keys()) print(set(exported.state_dict.keys())) # {'linear.bias', 'linear.weight'} print(set(compiled.get_constant_fqns())) # {'linear.bias', 'linear.weight'} module = ToyLinearModel(1024, 1024).to(device='cuda', dtype=torch.bfloat16) quantize_(module, Float8DynamicActivationFloat8WeightConfig()) exported = torch.export.export(module, args=args) compiled = torch._inductor.aoti_load_package(torch._inductor.aoti_compile_and_package(exported)) assert set(exported.state_dict.keys()) == set(module.state_dict().keys()) print(set(exported.state_dict.keys())) # {'linear.bias', 'linear.weight'} print(set(compiled.get_constant_fqns())) # {'linear.parametrizations.weight.original1', 'linear.bias', 'linear.parametrizations.weight.original0'} run()