Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| from torch import nn | |
| from huggingface_hub import snapshot_download | |
| class LeNet(nn.Module): | |
| def __init__(self): | |
| super(LeNet, self).__init__() | |
| self.convs = nn.Sequential( | |
| nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5, 5)), | |
| nn.Tanh(), | |
| nn.AvgPool2d(2, 2), | |
| nn.Conv2d(in_channels=4, out_channels=12, kernel_size=(5, 5)), | |
| nn.Tanh(), | |
| nn.AvgPool2d(2, 2) | |
| ) | |
| self.linear = nn.Sequential( | |
| nn.Linear(4*4*12,10) | |
| ) | |
| def forward(self, x): | |
| x = self.convs(x) | |
| x = torch.flatten(x, 1) | |
| return self.linear(x) | |
| def predict(self, input): | |
| input = input.reshape(1, 1, 28, 28) | |
| out = self(input) | |
| return nn.functional.softmax(out[0], dim = 0) | |
| lenet = LeNet() | |
| lenet_pt = snapshot_download('stanimirovb/ibob-lenet-v1') + '/lenet-v1.pth' | |
| lenet.load_state_dict(torch.load(lenet_pt, map_location='cpu')) | |
| resize = torchvision.transforms.Resize((28, 28), antialias=True) | |
| def on_submit(img): | |
| with torch.no_grad(): | |
| img = img['composite'].astype(np.float32) | |
| img = torch.from_numpy(img) | |
| img = resize(img.unsqueeze(0)) | |
| result = lenet.predict(img) | |
| sorted = [[i, e] for i, e in enumerate(result.numpy())] | |
| sorted.sort(key = lambda a : -a[1]) | |
| return "\n".join(map(str, sorted)) | |
| iface = gr.Interface( | |
| title = "LeNet", | |
| fn = on_submit, | |
| inputs=gr.Sketchpad(image_mode='P'), | |
| outputs=gr.Text(), | |
| ) | |
| iface.launch() | |