aiexplorations commited on
Commit
1841103
·
verified ·
1 Parent(s): f891147

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +503 -0
app.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Vidai HuggingFace Spaces Demo
4
+
5
+ Self-contained demo for parsing mathematical expressions to prefix notation.
6
+ Includes all necessary model code to run standalone on HuggingFace Spaces.
7
+ """
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+
12
+ import gradio as gr
13
+ import sympy
14
+ import torch
15
+ import torch.nn as nn
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ # =============================================================================
19
+ # Model Configuration
20
+ # =============================================================================
21
+
22
+ @dataclass
23
+ class TreeComputeConfig:
24
+ """Configuration for the Tree Compute Transformer."""
25
+ d_model: int = 256
26
+ n_context_layers: int = 4
27
+ n_heads: int = 8
28
+ d_ff: int = 1024
29
+ dropout: float = 0.1
30
+ expert_hidden_dim: int = 128
31
+ expert_layers: int = 2
32
+ max_seq_len: int = 512
33
+ max_depth: int = 32
34
+ max_nodes: int = 64
35
+ vocab_size: int = 35
36
+ add_token_id: int = 16
37
+ sub_token_id: int = 15
38
+ mul_token_id: int = 18
39
+ div_token_id: int = 19
40
+ pow_token_id: int = 25
41
+ mod_token_id: int = 26
42
+ sqrt_token_id: int = 27
43
+ abs_token_id: int = 28
44
+ floor_token_id: int = 29
45
+ ceil_token_id: int = 30
46
+ value_clamp_min: float = -1e6
47
+ value_clamp_max: float = 1e6
48
+ n_decoder_layers: int = 4
49
+ parser_vocab_size: int = 128
50
+ max_output_len: int = 256
51
+ parser_pad_id: int = 0
52
+ parser_bos_id: int = 1
53
+ parser_eos_id: int = 2
54
+
55
+ def __post_init__(self) -> None:
56
+ assert self.d_model % self.n_heads == 0
57
+
58
+
59
+ # =============================================================================
60
+ # Model Components
61
+ # =============================================================================
62
+
63
+ class ContextEncoder(nn.Module):
64
+ """Transformer encoder for input text."""
65
+
66
+ def __init__(self, config: TreeComputeConfig):
67
+ super().__init__()
68
+ self.config = config
69
+ self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
70
+ self.position_embedding = nn.Embedding(config.max_seq_len, config.d_model)
71
+ self.depth_embedding = nn.Embedding(config.max_depth, config.d_model)
72
+
73
+ encoder_layer = nn.TransformerEncoderLayer(
74
+ d_model=config.d_model,
75
+ nhead=config.n_heads,
76
+ dim_feedforward=config.d_ff,
77
+ dropout=config.dropout,
78
+ activation='gelu',
79
+ batch_first=True,
80
+ )
81
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.n_context_layers)
82
+ self.layer_norm = nn.LayerNorm(config.d_model)
83
+ self.dropout = nn.Dropout(config.dropout)
84
+
85
+ def forward(self, input_ids, tree_depths, attention_mask=None):
86
+ batch_size, seq_len = input_ids.shape
87
+ device = input_ids.device
88
+ x = self.token_embedding(input_ids)
89
+ positions = torch.arange(seq_len, device=device).expand(batch_size, -1)
90
+ x = x + self.position_embedding(positions)
91
+ depths_clamped = tree_depths.clamp(0, self.config.max_depth - 1)
92
+ x = x + self.depth_embedding(depths_clamped)
93
+ x = self.dropout(x)
94
+ src_key_padding_mask = ~attention_mask if attention_mask is not None else None
95
+ x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
96
+ return self.layer_norm(x)
97
+
98
+
99
+ class SymbolicParserDecoder(nn.Module):
100
+ """Transformer decoder for generating prefix notation."""
101
+
102
+ def __init__(self, config: TreeComputeConfig):
103
+ super().__init__()
104
+ self.config = config
105
+ self.token_embedding = nn.Embedding(config.parser_vocab_size, config.d_model)
106
+ self.position_embedding = nn.Embedding(config.max_output_len, config.d_model)
107
+
108
+ decoder_layer = nn.TransformerDecoderLayer(
109
+ d_model=config.d_model,
110
+ nhead=config.n_heads,
111
+ dim_feedforward=config.d_ff,
112
+ dropout=config.dropout,
113
+ activation='gelu',
114
+ batch_first=True,
115
+ )
116
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.n_decoder_layers)
117
+ self.output_projection = nn.Linear(config.d_model, config.parser_vocab_size)
118
+ self.layer_norm = nn.LayerNorm(config.d_model)
119
+ self.dropout = nn.Dropout(config.dropout)
120
+
121
+ def forward(self, target_ids, encoder_memory, target_mask=None, memory_mask=None):
122
+ batch_size, tgt_len = target_ids.shape
123
+ device = target_ids.device
124
+ x = self.token_embedding(target_ids)
125
+ positions = torch.arange(tgt_len, device=device).unsqueeze(0).expand(batch_size, -1)
126
+ x = x + self.position_embedding(positions)
127
+ x = self.dropout(x)
128
+ causal_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len, device=device)
129
+ tgt_key_padding_mask = ~target_mask if target_mask is not None else None
130
+ memory_key_padding_mask = ~memory_mask if memory_mask is not None else None
131
+ x = self.transformer_decoder(
132
+ tgt=x, memory=encoder_memory, tgt_mask=causal_mask,
133
+ tgt_key_padding_mask=tgt_key_padding_mask,
134
+ memory_key_padding_mask=memory_key_padding_mask,
135
+ )
136
+ x = self.layer_norm(x)
137
+ return self.output_projection(x)
138
+
139
+ @torch.no_grad()
140
+ def generate(self, encoder_memory, memory_mask=None, max_len=None, temperature=1.0):
141
+ if max_len is None:
142
+ max_len = self.config.max_output_len
143
+ batch_size = encoder_memory.shape[0]
144
+ device = encoder_memory.device
145
+ output_ids = torch.full((batch_size, 1), self.config.parser_bos_id, dtype=torch.long, device=device)
146
+ memory_key_padding_mask = ~memory_mask if memory_mask is not None else None
147
+
148
+ for _ in range(max_len - 1):
149
+ tgt_len = output_ids.shape[1]
150
+ x = self.token_embedding(output_ids)
151
+ positions = torch.arange(tgt_len, device=device).unsqueeze(0).expand(batch_size, -1)
152
+ x = x + self.position_embedding(positions)
153
+ causal_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len, device=device)
154
+ x = self.transformer_decoder(tgt=x, memory=encoder_memory, tgt_mask=causal_mask,
155
+ memory_key_padding_mask=memory_key_padding_mask)
156
+ x = self.layer_norm(x)
157
+ logits = self.output_projection(x[:, -1, :])
158
+ next_token = logits.argmax(dim=-1, keepdim=True)
159
+ output_ids = torch.cat([output_ids, next_token], dim=1)
160
+ if (next_token == self.config.parser_eos_id).all():
161
+ break
162
+ return output_ids
163
+
164
+
165
+ class TreeComputeTransformer(nn.Module):
166
+ """Main Vidai model combining encoder, decoder, and compute modules."""
167
+
168
+ def __init__(self, config: TreeComputeConfig):
169
+ super().__init__()
170
+ self.config = config
171
+ self.context_encoder = ContextEncoder(config)
172
+ self.parser_decoder = SymbolicParserDecoder(config)
173
+
174
+ @torch.no_grad()
175
+ def parse(self, input_ids, input_mask=None, max_len=256, temperature=1.0, beam_size=1):
176
+ if input_mask is None:
177
+ input_mask = input_ids != 0
178
+ tree_depths = torch.zeros_like(input_ids)
179
+ encoder_output = self.context_encoder(input_ids, tree_depths, input_mask)
180
+ return self.parser_decoder.generate(encoder_memory=encoder_output, memory_mask=input_mask,
181
+ max_len=max_len, temperature=temperature)
182
+
183
+
184
+ # =============================================================================
185
+ # Tokenizer
186
+ # =============================================================================
187
+
188
+ class ParserTokenizer:
189
+ """Tokenizer for parsing mathematical expressions."""
190
+
191
+ PAD_TOKEN = "<pad>"
192
+ BOS_TOKEN = "<bos>"
193
+ EOS_TOKEN = "<eos>"
194
+ UNK_TOKEN = "<unk>"
195
+
196
+ def __init__(self):
197
+ self.input_vocab_size = 256
198
+ self.output_vocab = self._build_output_vocab()
199
+ self.output_token_to_id = {t: i for i, t in enumerate(self.output_vocab)}
200
+ self.output_id_to_token = {i: t for i, t in enumerate(self.output_vocab)}
201
+
202
+ def _build_output_vocab(self):
203
+ vocab = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
204
+ vocab.extend(["+", "-", "*", "/", "**", "%"])
205
+ vocab.extend(["sqrt", "abs", "floor", "ceil", "sin", "cos", "tan", "log", "exp"])
206
+ vocab.extend(list("xyzabcdnmtrvgk"))
207
+ vocab.extend(["alpha", "beta", "gamma", "delta", "theta", "phi", "psi", "omega", "lambda", "mu", "sigma", "tau"])
208
+ vocab.extend(["pi", "e", "i"])
209
+ vocab.extend(list("0123456789."))
210
+ vocab.extend(["(", ")", " ", ",", "/"])
211
+ return vocab
212
+
213
+ @property
214
+ def output_vocab_size(self):
215
+ return len(self.output_vocab)
216
+
217
+ @property
218
+ def pad_id(self):
219
+ return self.output_token_to_id[self.PAD_TOKEN]
220
+
221
+ @property
222
+ def bos_id(self):
223
+ return self.output_token_to_id[self.BOS_TOKEN]
224
+
225
+ @property
226
+ def eos_id(self):
227
+ return self.output_token_to_id[self.EOS_TOKEN]
228
+
229
+ def encode_input(self, text: str, max_len: int = 256) -> list:
230
+ ids = [ord(c) if ord(c) < 256 else ord('?') for c in text]
231
+ ids = ids[:max_len]
232
+ ids = ids + [0] * (max_len - len(ids))
233
+ return ids
234
+
235
+ def decode_output(self, ids: list, skip_special: bool = True) -> str:
236
+ tokens = []
237
+ special_ids = {self.pad_id, self.bos_id, self.eos_id}
238
+ for tid in ids:
239
+ if tid == self.eos_id:
240
+ break
241
+ if skip_special and tid in special_ids:
242
+ continue
243
+ if tid < len(self.output_id_to_token):
244
+ tokens.append(self.output_id_to_token[tid])
245
+ return "".join(tokens)
246
+
247
+
248
+ # =============================================================================
249
+ # Prefix Notation to SymPy
250
+ # =============================================================================
251
+
252
+ PREFIX_OPS = {
253
+ '+': lambda a, b: a + b,
254
+ '-': lambda a, b: a - b,
255
+ '*': lambda a, b: a * b,
256
+ '/': lambda a, b: a / b,
257
+ '**': lambda a, b: a ** b,
258
+ '^': lambda a, b: a ** b,
259
+ }
260
+
261
+ PREFIX_UNARY = {
262
+ 'sqrt': sympy.sqrt,
263
+ 'abs': sympy.Abs,
264
+ 'floor': sympy.floor,
265
+ 'ceil': sympy.ceiling,
266
+ 'sin': sympy.sin,
267
+ 'cos': sympy.cos,
268
+ 'tan': sympy.tan,
269
+ 'exp': sympy.exp,
270
+ 'log': sympy.log,
271
+ }
272
+
273
+ PREFIX_CONSTANTS = {
274
+ 'pi': sympy.pi,
275
+ 'e': sympy.E,
276
+ }
277
+
278
+
279
+ def prefix_to_sympy(prefix_str: str):
280
+ """Convert prefix notation to SymPy expression."""
281
+ tokens = prefix_str.strip().split()
282
+ if not tokens:
283
+ raise ValueError("Empty prefix notation")
284
+ result, remaining = _parse_prefix_tokens(tokens)
285
+ if remaining:
286
+ raise ValueError(f"Unexpected tokens: {remaining}")
287
+ return result
288
+
289
+
290
+ def _parse_prefix_tokens(tokens):
291
+ if not tokens:
292
+ raise ValueError("Unexpected end of tokens")
293
+ token = tokens[0]
294
+ rest = tokens[1:]
295
+
296
+ if token in PREFIX_OPS:
297
+ left, rest = _parse_prefix_tokens(rest)
298
+ right, rest = _parse_prefix_tokens(rest)
299
+ return PREFIX_OPS[token](left, right), rest
300
+
301
+ if token in PREFIX_UNARY:
302
+ operand, rest = _parse_prefix_tokens(rest)
303
+ return PREFIX_UNARY[token](operand), rest
304
+
305
+ if token in PREFIX_CONSTANTS:
306
+ return PREFIX_CONSTANTS[token], rest
307
+
308
+ try:
309
+ if '.' not in token:
310
+ return sympy.Integer(token), rest
311
+ return sympy.Float(token), rest
312
+ except (ValueError, TypeError):
313
+ pass
314
+
315
+ return sympy.Symbol(token), rest
316
+
317
+
318
+ # =============================================================================
319
+ # Global State
320
+ # =============================================================================
321
+
322
+ MODEL = None
323
+ TOKENIZER = None
324
+ DEVICE = None
325
+
326
+
327
+ def load_model():
328
+ """Load model from HuggingFace Hub."""
329
+ global MODEL, TOKENIZER, DEVICE
330
+
331
+ if MODEL is not None:
332
+ return MODEL, TOKENIZER
333
+
334
+ if torch.cuda.is_available():
335
+ DEVICE = "cuda"
336
+ else:
337
+ DEVICE = "cpu"
338
+
339
+ checkpoint_path = hf_hub_download(
340
+ repo_id="aiexplorations/vidai",
341
+ filename="finetune_v1_step3500.pt",
342
+ )
343
+
344
+ ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
345
+ config = TreeComputeConfig(**ckpt['config']['model_config'])
346
+
347
+ MODEL = TreeComputeTransformer(config)
348
+ MODEL.load_state_dict(ckpt['model_state_dict'])
349
+ MODEL.eval()
350
+ MODEL.to(DEVICE)
351
+
352
+ TOKENIZER = ParserTokenizer()
353
+ return MODEL, TOKENIZER
354
+
355
+
356
+ def parse_expression(expression: str, evaluate: bool = False, substitutions: str = ""):
357
+ """Parse a mathematical expression to prefix notation."""
358
+ if not expression.strip():
359
+ return "", "", "Please enter an expression"
360
+
361
+ try:
362
+ model, tokenizer = load_model()
363
+ except Exception as e:
364
+ return "", "", f"Model loading error: {str(e)}"
365
+
366
+ try:
367
+ encoded = tokenizer.encode_input(expression, max_len=128)
368
+ input_ids = torch.tensor([encoded], device=DEVICE)
369
+ input_mask = (input_ids != 0).bool()
370
+
371
+ with torch.no_grad():
372
+ output_ids = model.parse(input_ids, input_mask, max_len=64)
373
+
374
+ prefix = tokenizer.decode_output(output_ids[0].tolist())
375
+
376
+ eval_result = ""
377
+ if evaluate and prefix:
378
+ try:
379
+ sympy_expr = prefix_to_sympy(prefix)
380
+ subs = {}
381
+ if substitutions.strip():
382
+ # Handle various formats: "x=1, y=2" or "x=1 y=2" or "x = 1, y = 2"
383
+ import re
384
+ pairs = re.findall(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*([+-]?[\d.]+)', substitutions)
385
+ for var, val in pairs:
386
+ subs[sympy.Symbol(var)] = float(val)
387
+
388
+ if subs:
389
+ result = sympy_expr.subs(subs)
390
+ eval_result = str(float(result))
391
+ elif not sympy_expr.free_symbols:
392
+ eval_result = str(float(sympy_expr))
393
+ else:
394
+ eval_result = f"Symbolic: {sympy_expr}"
395
+
396
+ except Exception as e:
397
+ eval_result = f"Evaluation error: {str(e)}"
398
+
399
+ return prefix, eval_result, "Success"
400
+
401
+ except Exception as e:
402
+ return "", "", f"Error: {str(e)}"
403
+
404
+
405
+ # =============================================================================
406
+ # Gradio Interface
407
+ # =============================================================================
408
+
409
+ EXAMPLES = [
410
+ # These work reliably
411
+ ["3 + 5 * 2", True, ""],
412
+ ["(x^2) + (3*y)", False, ""],
413
+ ["(x^2) + y", True, "x=3, y=4"],
414
+ ["sin(pi/2)", True, ""],
415
+ ["sqrt(16)", True, ""],
416
+ ["(a + b) * (a - b)", True, "a=5, b=3"],
417
+ ["(2*x) + (3*y) - z", True, "x=1, y=2, z=3"],
418
+ ]
419
+
420
+ with gr.Blocks(title="Vidai - Neural Math Parser", theme=gr.themes.Soft()) as demo:
421
+ gr.Markdown("""
422
+ # Vidai: Neural Mathematical Parsing
423
+
424
+ > **Work in Progress**: Simple expressions work well; complex expressions need more training data.
425
+ > [Read the full story](https://rajeshrs.in/blog/ai-explorations/posts/2026-01-04-vidai-teaching-machines-arithmetic/)
426
+
427
+ Vidai (Tamil for "answer") uses transformers for what they're good at: recognizing the tree structure
428
+ in mathematical expressions. Instead of learning arithmetic from text, it learns to parse notation
429
+ into trees, then SymPy computes exact results.
430
+
431
+ - **Input**: Mathematical expression (e.g., `(x^2) + (3*y)`)
432
+ - **Output**: Prefix notation tree (e.g., `+ ** x 2 * 3 y` where `+` is the root)
433
+ - **Tip**: Use parentheses for reliable results
434
+ """)
435
+
436
+ with gr.Row():
437
+ with gr.Column(scale=2):
438
+ expression_input = gr.Textbox(
439
+ label="Mathematical Expression",
440
+ placeholder="Enter an expression like: x^2 + 3*y",
441
+ lines=1,
442
+ )
443
+
444
+ with gr.Row():
445
+ evaluate_checkbox = gr.Checkbox(label="Evaluate", value=False)
446
+ substitutions_input = gr.Textbox(
447
+ label="Variable Substitutions (optional)",
448
+ placeholder="x=3, y=4",
449
+ lines=1,
450
+ )
451
+
452
+ parse_button = gr.Button("Parse", variant="primary")
453
+
454
+ with gr.Column(scale=2):
455
+ prefix_output = gr.Textbox(label="Prefix Notation", interactive=False)
456
+ eval_output = gr.Textbox(label="Evaluation Result", interactive=False)
457
+ status_output = gr.Textbox(label="Status", interactive=False)
458
+
459
+ gr.Markdown("### Examples")
460
+ gr.Examples(
461
+ examples=EXAMPLES,
462
+ inputs=[expression_input, evaluate_checkbox, substitutions_input],
463
+ outputs=[prefix_output, eval_output, status_output],
464
+ fn=parse_expression,
465
+ cache_examples=False,
466
+ )
467
+
468
+ gr.Markdown("""
469
+ ---
470
+ ### How It Works
471
+
472
+ 1. **Character-level encoding**: Input is encoded as ASCII characters
473
+ 2. **Transformer parsing**: Encoder-decoder model (44.6M params) converts to prefix notation
474
+ 3. **SymPy evaluation**: Deterministic symbolic computation (0 learned parameters)
475
+
476
+ **Supported operations**: +, -, *, /, ^ (power), sqrt, sin, cos, tan, log, exp, abs
477
+
478
+ **Variables**: x, y, z, a, b, c, d, n, m, t, r, pi, e, and Greek letters
479
+
480
+ ---
481
+ ### Known Limitations
482
+
483
+ | Expression Type | Accuracy | Recommendation |
484
+ |-----------------|----------|----------------|
485
+ | Parenthesized expressions | **100%** | Always works |
486
+ | Simple expressions (2 terms) | ~95% | Usually works |
487
+ | Complex without parens (3+ terms) | ~86% | Add parentheses |
488
+ | Functions + operators | ~86% | Wrap functions: `(sqrt(x)) + y` |
489
+
490
+ **For reliable results**: `(sqrt(16)) + (2^3)` instead of `sqrt(16) + 2^3`
491
+
492
+ [GitHub](https://github.com/aiexplorations/vidai) | [Model Card](https://huggingface.co/aiexplorations/vidai)
493
+ """)
494
+
495
+ parse_button.click(
496
+ fn=parse_expression,
497
+ inputs=[expression_input, evaluate_checkbox, substitutions_input],
498
+ outputs=[prefix_output, eval_output, status_output],
499
+ )
500
+
501
+
502
+ if __name__ == "__main__":
503
+ demo.launch()