import os import time import inspect from typing import Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd import gradio as gr import torch import plotly.graph_objects as go from chronos import Chronos2Pipeline MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2") DATA_DIR = "data" OUT_DIR = "/tmp" # ------------------------- # Data # ------------------------- def available_test_csv() -> List[str]: if not os.path.isdir(DATA_DIR): return [] return sorted([f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv")]) def pick_device(ui_choice: str) -> str: return "cuda" if (ui_choice or "").startswith("cuda") and torch.cuda.is_available() else "cpu" def make_sample_series(n: int, seed: int, trend: float, season_period: int, season_amp: float, noise: float) -> np.ndarray: rng = np.random.default_rng(int(seed)) t = np.arange(int(n), dtype=np.float32) y = (trend * t + season_amp * np.sin(2 * np.pi * t / max(1, int(season_period))) + rng.normal(0, noise, size=int(n))).astype(np.float32) if float(np.min(y)) < 0: y -= float(np.min(y)) return y def load_series_from_csv(csv_path: str, column: Optional[str]) -> Tuple[np.ndarray, str]: df = pd.read_csv(csv_path) col = (column or "").strip() if not col: numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])] if not numeric_cols: # try coercion for c in df.columns: coerced = pd.to_numeric(df[c], errors="coerce") if coerced.notna().sum() > 0: numeric_cols.append(c) if not numeric_cols: raise ValueError("Non trovo colonne numeriche nel CSV.") col = numeric_cols[0] if col not in df.columns: raise ValueError(f"Colonna '{col}' non trovata. Disponibili: {list(df.columns)}") y = pd.to_numeric(df[col], errors="coerce").dropna().astype(np.float32).to_numpy() if len(y) < 10: raise ValueError("Serie troppo corta.") return y, col # ------------------------- # Model cache # ------------------------- _PIPE = None _META = {"model_id": None, "device": None} def get_pipeline(model_id: str, device: str) -> Chronos2Pipeline: global _PIPE, _META model_id = (model_id or MODEL_ID_DEFAULT).strip() device = "cuda" if device == "cuda" and torch.cuda.is_available() else "cpu" if _PIPE is None or _META["model_id"] != model_id or _META["device"] != device: _PIPE = Chronos2Pipeline.from_pretrained(model_id, device_map=device) _META = {"model_id": model_id, "device": device} return _PIPE # ------------------------- # Predict (STABLE) # ------------------------- def _to_numpy(x: Any) -> np.ndarray: if isinstance(x, np.ndarray): return x if torch.is_tensor(x): return x.detach().cpu().numpy() return np.asarray(x) def _extract_samples(raw: Any) -> np.ndarray: if isinstance(raw, dict): for k in ["samples", "predictions", "prediction", "output"]: if k in raw: return _to_numpy(raw[k]) if len(raw) > 0: return _to_numpy(next(iter(raw.values()))) return np.asarray([], dtype=np.float32) return _to_numpy(raw) def chronos2_predict(pipe: Chronos2Pipeline, y: np.ndarray, horizon: int, requested_samples: int) -> Tuple[np.ndarray, bool, str]: """ Returns: samples: (S, H) multi: whether S>1 is real (not replicated) note: debug note """ sig = inspect.signature(pipe.predict) params = sig.parameters # input format: ALWAYS batch = [series] inputs = [y.tolist()] # kw for horizon horizon_kw = None for cand in ["prediction_length", "horizon", "steps", "n_steps", "pred_len"]: if cand in params: horizon_kw = cand break # kw for samples count (many versions don't have it!) sample_kw = None for cand in ["n_samples", "num_return_sequences", "num_samples"]: if cand in params: sample_kw = cand break kwargs: Dict[str, Any] = {} if horizon_kw: kwargs[horizon_kw] = int(horizon) else: # worst case: try positional horizon if supported (rare) kwargs["prediction_length"] = int(horizon) if sample_kw: kwargs[sample_kw] = int(requested_samples) # call raw = pipe.predict(inputs=inputs, **kwargs) if "inputs" in params else pipe.predict(inputs, **kwargs) arr = _extract_samples(raw).astype(np.float32, copy=False) # normalize shape -> (S,H) arr = np.squeeze(arr) if arr.ndim == 1: # could be (H,) or (S,) - assume horizon if length == H arr = arr[None, :] # Sometimes output is (B,S,H) or (B,H). If batch dim exists, take first if arr.ndim == 3: # assume (B,S,H) or (S,B,H); safest: pick first on axis=0 arr = arr[0] if arr.ndim == 1: arr = arr[None, :] # ensure horizon length if arr.shape[-1] != horizon: if arr.shape[-1] > horizon: arr = arr[..., :horizon] else: pad = horizon - arr.shape[-1] last = arr[..., -1:] arr = np.concatenate([arr, np.repeat(last, pad, axis=-1)], axis=-1) # If we got only 1 sample, we can still plot median but band is not meaningful real_multi = arr.shape[0] > 1 note = f"predict_signature={sig} | used_horizon_kw={horizon_kw} | used_sample_kw={sample_kw} | got_shape={tuple(arr.shape)}" return arr, real_multi, note # ------------------------- # Plotly # ------------------------- def plot_forecast(y, median, low, high, title, show_band: bool, band_label: str) -> go.Figure: t_hist = np.arange(len(y)) t_fcst = np.arange(len(y), len(y) + len(median)) fig = go.Figure() fig.add_trace(go.Scatter(x=t_hist, y=y, mode="lines", name="History")) fig.add_trace(go.Scatter(x=t_fcst, y=median, mode="lines", name="Forecast (median)")) fig.add_vline(x=len(y) - 1, line_width=1, line_dash="dash", opacity=0.6) if show_band: fig.add_trace(go.Scatter(x=t_fcst, y=high, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip")) fig.add_trace(go.Scatter( x=t_fcst, y=low, mode="lines", fill="tonexty", line=dict(width=0), name=band_label )) fig.update_layout( title=title, hovermode="x unified", margin=dict(l=10, r=10, t=55, b=10), legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), xaxis_title="t", yaxis_title="value", ) return fig def kpi_card(label: str, value: str, hint: str = "") -> str: hint_html = f"
{hint}
" if hint else "" return f"""
{label}
{value}
{hint_html}
""" def kpi_grid(cards: List[str]) -> str: return f"
{''.join(cards)}
" def explain(y, median, low, high, band_enabled: bool, q_low: float, q_high: float, extra: str) -> str: horizon = len(median) base = float(np.mean(y)) delta = float(median[-1] - median[0]) pct = (delta / max(1e-6, base)) * 100.0 if abs(pct) < 2: trend_txt = "sostanzialmente stabile" elif pct > 0: trend_txt = "in crescita" else: trend_txt = "in calo" txt = f""" ### 🧠 Spiegazione Nei prossimi **{horizon} step** la previsione mediana è **{trend_txt}** (variazione ≈ **{pct:+.1f}%** rispetto al livello medio storico). - **Ultimo valore mediano previsto:** **{median[-1]:.2f}** """ if band_enabled: txt += f"- **Banda [{q_low:.0%}–{q_high:.0%}] (ultimo step):** **[{low[-1]:.2f} – {high[-1]:.2f}]**\n" else: txt += "- **Banda di incertezza:** disattivata (questa versione di Chronos2 non restituisce campioni multipli con i parametri disponibili).\n" txt += f"\n
Debug\n\n`{extra}`\n\n
\n" return txt # ------------------------- # Run # ------------------------- def run_all( input_mode, test_csv_name, upload_csv, csv_column, n, seed, trend, season_period, season_amp, noise, prediction_length, requested_samples, q_low, q_high, device_ui, model_id, ): if q_low >= q_high: raise gr.Error("Quantile low deve essere < quantile high.") device = pick_device(device_ui) pipe = get_pipeline(model_id, device) # data if input_mode == "Test CSV": if not test_csv_name: raise gr.Error("Seleziona un Test CSV.") path = os.path.join(DATA_DIR, test_csv_name) y, used_col = load_series_from_csv(path, csv_column) source = f"Test CSV: {test_csv_name} • col={used_col}" elif input_mode == "Upload CSV": if upload_csv is None: raise gr.Error("Carica un CSV.") y, used_col = load_series_from_csv(upload_csv.name, csv_column) source = f"Upload CSV • col={used_col}" else: y = make_sample_series(n, seed, trend, season_period, season_amp, noise) source = "Sample series" t0 = time.time() samples, real_multi, note = chronos2_predict(pipe, y, int(prediction_length), int(requested_samples)) latency = time.time() - t0 median = np.quantile(samples, 0.50, axis=0) band_enabled = real_multi and samples.shape[0] > 2 if band_enabled: low = np.quantile(samples, float(q_low), axis=0) high = np.quantile(samples, float(q_high), axis=0) else: low = median.copy() high = median.copy() # KPI cards = [ kpi_card("Device", device.upper(), f"cuda_available={torch.cuda.is_available()}"), kpi_card("Latency", f"{latency:.2f}s", "predict()"), kpi_card("Samples", str(samples.shape[0]), "returned by model"), kpi_card("Band", "ON" if band_enabled else "OFF", "needs multi-samples"), kpi_card("Horizon", str(prediction_length)), kpi_card("Model", (model_id or MODEL_ID_DEFAULT)), ] kpis_html = kpi_grid(cards) # Plot fig = plot_forecast( y=y, median=median, low=low, high=high, title=f"Forecast — {source}", show_band=band_enabled, band_label=f"Band [{q_low:.2f}, {q_high:.2f}]", ) # Table + export t_fcst = np.arange(len(y), len(y) + int(prediction_length)) out_df = pd.DataFrame({ "t": t_fcst, "median": median, }) if band_enabled: out_df[f"q{q_low:.2f}"] = low out_df[f"q{q_high:.2f}"] = high out_path = os.path.join(OUT_DIR, "chronos2_forecast.csv") out_df.to_csv(out_path, index=False) explanation_md = explain(y, median, low, high, band_enabled, q_low, q_high, note) info = { "source": source, "history_points": int(len(y)), "prediction_length": int(prediction_length), "requested_samples": int(requested_samples), "returned_samples": int(samples.shape[0]), "band_enabled": bool(band_enabled), "predict_signature": str(inspect.signature(pipe.predict)), "debug_note": note, } return kpis_html, explanation_md, fig, out_df, out_path, info # ------------------------- # UI # ------------------------- css = """.gradio-container { max-width: 1200px !important; }""" with gr.Blocks(title="Chronos-2 • Pro Dashboard (Stable)", css=css) as demo: gr.Markdown("# ⏱️ Chronos-2 Forecast Dashboard — Stable Edition") with gr.Row(): with gr.Column(scale=1, min_width=360): input_mode = gr.Radio(["Sample", "Test CSV", "Upload CSV"], value="Sample", label="Input") test_csv_name = gr.Dropdown(choices=available_test_csv(), label="Test CSV (data/)") upload_csv = gr.File(label="Upload CSV", file_types=[".csv"]) csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value") device_ui = gr.Dropdown( ["cpu", "cuda (se disponibile)"], value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu", label="Device", ) model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID") with gr.Accordion("Sample generator", open=False): n = gr.Slider(60, 2000, value=300, step=10, label="History length") seed = gr.Number(value=42, precision=0, label="Seed") trend = gr.Slider(0.0, 0.2, value=0.03, step=0.005, label="Trend") season_period = gr.Slider(2, 240, value=14, step=1, label="Season period") season_amp = gr.Slider(0.0, 12.0, value=3.0, step=0.1, label="Season amplitude") noise = gr.Slider(0.0, 6.0, value=0.8, step=0.05, label="Noise") prediction_length = gr.Slider(1, 365, value=30, step=1, label="Prediction length") requested_samples = gr.Slider(1, 800, value=200, step=25, label="Requested samples (best effort)") q_low = gr.Slider(0.01, 0.49, value=0.10, step=0.01, label="Quantile low") q_high = gr.Slider(0.51, 0.99, value=0.90, step=0.01, label="Quantile high") run_btn = gr.Button("Run", variant="primary") with gr.Column(scale=2): kpis = gr.HTML() with gr.Tabs(): with gr.Tab("Forecast"): forecast_plot = gr.Plot() forecast_table = gr.Dataframe(interactive=False) with gr.Tab("Spiegazione"): explanation = gr.Markdown() with gr.Tab("Export"): download = gr.File() with gr.Tab("Info"): info = gr.JSON() run_btn.click( fn=run_all, inputs=[ input_mode, test_csv_name, upload_csv, csv_column, n, seed, trend, season_period, season_amp, noise, prediction_length, requested_samples, q_low, q_high, device_ui, model_id, ], outputs=[kpis, explanation, forecast_plot, forecast_table, download, info], ) demo.queue() demo.launch(ssr_mode=False)