import os
import time
import inspect
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import gradio as gr
import torch
import plotly.graph_objects as go
from chronos import Chronos2Pipeline
MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
DATA_DIR = "data"
OUT_DIR = "/tmp"
# -------------------------
# Data
# -------------------------
def available_test_csv() -> List[str]:
if not os.path.isdir(DATA_DIR):
return []
return sorted([f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv")])
def pick_device(ui_choice: str) -> str:
return "cuda" if (ui_choice or "").startswith("cuda") and torch.cuda.is_available() else "cpu"
def make_sample_series(n: int, seed: int, trend: float, season_period: int, season_amp: float, noise: float) -> np.ndarray:
rng = np.random.default_rng(int(seed))
t = np.arange(int(n), dtype=np.float32)
y = (trend * t + season_amp * np.sin(2 * np.pi * t / max(1, int(season_period))) + rng.normal(0, noise, size=int(n))).astype(np.float32)
if float(np.min(y)) < 0:
y -= float(np.min(y))
return y
def load_series_from_csv(csv_path: str, column: Optional[str]) -> Tuple[np.ndarray, str]:
df = pd.read_csv(csv_path)
col = (column or "").strip()
if not col:
numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
if not numeric_cols:
# try coercion
for c in df.columns:
coerced = pd.to_numeric(df[c], errors="coerce")
if coerced.notna().sum() > 0:
numeric_cols.append(c)
if not numeric_cols:
raise ValueError("Non trovo colonne numeriche nel CSV.")
col = numeric_cols[0]
if col not in df.columns:
raise ValueError(f"Colonna '{col}' non trovata. Disponibili: {list(df.columns)}")
y = pd.to_numeric(df[col], errors="coerce").dropna().astype(np.float32).to_numpy()
if len(y) < 10:
raise ValueError("Serie troppo corta.")
return y, col
# -------------------------
# Model cache
# -------------------------
_PIPE = None
_META = {"model_id": None, "device": None}
def get_pipeline(model_id: str, device: str) -> Chronos2Pipeline:
global _PIPE, _META
model_id = (model_id or MODEL_ID_DEFAULT).strip()
device = "cuda" if device == "cuda" and torch.cuda.is_available() else "cpu"
if _PIPE is None or _META["model_id"] != model_id or _META["device"] != device:
_PIPE = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
_META = {"model_id": model_id, "device": device}
return _PIPE
# -------------------------
# Predict (STABLE)
# -------------------------
def _to_numpy(x: Any) -> np.ndarray:
if isinstance(x, np.ndarray):
return x
if torch.is_tensor(x):
return x.detach().cpu().numpy()
return np.asarray(x)
def _extract_samples(raw: Any) -> np.ndarray:
if isinstance(raw, dict):
for k in ["samples", "predictions", "prediction", "output"]:
if k in raw:
return _to_numpy(raw[k])
if len(raw) > 0:
return _to_numpy(next(iter(raw.values())))
return np.asarray([], dtype=np.float32)
return _to_numpy(raw)
def chronos2_predict(pipe: Chronos2Pipeline, y: np.ndarray, horizon: int, requested_samples: int) -> Tuple[np.ndarray, bool, str]:
"""
Returns:
samples: (S, H)
multi: whether S>1 is real (not replicated)
note: debug note
"""
sig = inspect.signature(pipe.predict)
params = sig.parameters
# input format: ALWAYS batch = [series]
inputs = [y.tolist()]
# kw for horizon
horizon_kw = None
for cand in ["prediction_length", "horizon", "steps", "n_steps", "pred_len"]:
if cand in params:
horizon_kw = cand
break
# kw for samples count (many versions don't have it!)
sample_kw = None
for cand in ["n_samples", "num_return_sequences", "num_samples"]:
if cand in params:
sample_kw = cand
break
kwargs: Dict[str, Any] = {}
if horizon_kw:
kwargs[horizon_kw] = int(horizon)
else:
# worst case: try positional horizon if supported (rare)
kwargs["prediction_length"] = int(horizon)
if sample_kw:
kwargs[sample_kw] = int(requested_samples)
# call
raw = pipe.predict(inputs=inputs, **kwargs) if "inputs" in params else pipe.predict(inputs, **kwargs)
arr = _extract_samples(raw).astype(np.float32, copy=False)
# normalize shape -> (S,H)
arr = np.squeeze(arr)
if arr.ndim == 1:
# could be (H,) or (S,) - assume horizon if length == H
arr = arr[None, :]
# Sometimes output is (B,S,H) or (B,H). If batch dim exists, take first
if arr.ndim == 3:
# assume (B,S,H) or (S,B,H); safest: pick first on axis=0
arr = arr[0]
if arr.ndim == 1:
arr = arr[None, :]
# ensure horizon length
if arr.shape[-1] != horizon:
if arr.shape[-1] > horizon:
arr = arr[..., :horizon]
else:
pad = horizon - arr.shape[-1]
last = arr[..., -1:]
arr = np.concatenate([arr, np.repeat(last, pad, axis=-1)], axis=-1)
# If we got only 1 sample, we can still plot median but band is not meaningful
real_multi = arr.shape[0] > 1
note = f"predict_signature={sig} | used_horizon_kw={horizon_kw} | used_sample_kw={sample_kw} | got_shape={tuple(arr.shape)}"
return arr, real_multi, note
# -------------------------
# Plotly
# -------------------------
def plot_forecast(y, median, low, high, title, show_band: bool, band_label: str) -> go.Figure:
t_hist = np.arange(len(y))
t_fcst = np.arange(len(y), len(y) + len(median))
fig = go.Figure()
fig.add_trace(go.Scatter(x=t_hist, y=y, mode="lines", name="History"))
fig.add_trace(go.Scatter(x=t_fcst, y=median, mode="lines", name="Forecast (median)"))
fig.add_vline(x=len(y) - 1, line_width=1, line_dash="dash", opacity=0.6)
if show_band:
fig.add_trace(go.Scatter(x=t_fcst, y=high, mode="lines", line=dict(width=0),
showlegend=False, hoverinfo="skip"))
fig.add_trace(go.Scatter(
x=t_fcst, y=low, mode="lines", fill="tonexty",
line=dict(width=0), name=band_label
))
fig.update_layout(
title=title,
hovermode="x unified",
margin=dict(l=10, r=10, t=55, b=10),
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
xaxis_title="t",
yaxis_title="value",
)
return fig
def kpi_card(label: str, value: str, hint: str = "") -> str:
hint_html = f"
{hint}
" if hint else ""
return f"""
{label}
{value}
{hint_html}
"""
def kpi_grid(cards: List[str]) -> str:
return f"{''.join(cards)}
"
def explain(y, median, low, high, band_enabled: bool, q_low: float, q_high: float, extra: str) -> str:
horizon = len(median)
base = float(np.mean(y))
delta = float(median[-1] - median[0])
pct = (delta / max(1e-6, base)) * 100.0
if abs(pct) < 2:
trend_txt = "sostanzialmente stabile"
elif pct > 0:
trend_txt = "in crescita"
else:
trend_txt = "in calo"
txt = f"""
### 🧠 Spiegazione
Nei prossimi **{horizon} step** la previsione mediana è **{trend_txt}** (variazione ≈ **{pct:+.1f}%** rispetto al livello medio storico).
- **Ultimo valore mediano previsto:** **{median[-1]:.2f}**
"""
if band_enabled:
txt += f"- **Banda [{q_low:.0%}–{q_high:.0%}] (ultimo step):** **[{low[-1]:.2f} – {high[-1]:.2f}]**\n"
else:
txt += "- **Banda di incertezza:** disattivata (questa versione di Chronos2 non restituisce campioni multipli con i parametri disponibili).\n"
txt += f"\nDebug
\n\n`{extra}`\n\n \n"
return txt
# -------------------------
# Run
# -------------------------
def run_all(
input_mode, test_csv_name, upload_csv, csv_column,
n, seed, trend, season_period, season_amp, noise,
prediction_length, requested_samples, q_low, q_high,
device_ui, model_id,
):
if q_low >= q_high:
raise gr.Error("Quantile low deve essere < quantile high.")
device = pick_device(device_ui)
pipe = get_pipeline(model_id, device)
# data
if input_mode == "Test CSV":
if not test_csv_name:
raise gr.Error("Seleziona un Test CSV.")
path = os.path.join(DATA_DIR, test_csv_name)
y, used_col = load_series_from_csv(path, csv_column)
source = f"Test CSV: {test_csv_name} • col={used_col}"
elif input_mode == "Upload CSV":
if upload_csv is None:
raise gr.Error("Carica un CSV.")
y, used_col = load_series_from_csv(upload_csv.name, csv_column)
source = f"Upload CSV • col={used_col}"
else:
y = make_sample_series(n, seed, trend, season_period, season_amp, noise)
source = "Sample series"
t0 = time.time()
samples, real_multi, note = chronos2_predict(pipe, y, int(prediction_length), int(requested_samples))
latency = time.time() - t0
median = np.quantile(samples, 0.50, axis=0)
band_enabled = real_multi and samples.shape[0] > 2
if band_enabled:
low = np.quantile(samples, float(q_low), axis=0)
high = np.quantile(samples, float(q_high), axis=0)
else:
low = median.copy()
high = median.copy()
# KPI
cards = [
kpi_card("Device", device.upper(), f"cuda_available={torch.cuda.is_available()}"),
kpi_card("Latency", f"{latency:.2f}s", "predict()"),
kpi_card("Samples", str(samples.shape[0]), "returned by model"),
kpi_card("Band", "ON" if band_enabled else "OFF", "needs multi-samples"),
kpi_card("Horizon", str(prediction_length)),
kpi_card("Model", (model_id or MODEL_ID_DEFAULT)),
]
kpis_html = kpi_grid(cards)
# Plot
fig = plot_forecast(
y=y,
median=median,
low=low,
high=high,
title=f"Forecast — {source}",
show_band=band_enabled,
band_label=f"Band [{q_low:.2f}, {q_high:.2f}]",
)
# Table + export
t_fcst = np.arange(len(y), len(y) + int(prediction_length))
out_df = pd.DataFrame({
"t": t_fcst,
"median": median,
})
if band_enabled:
out_df[f"q{q_low:.2f}"] = low
out_df[f"q{q_high:.2f}"] = high
out_path = os.path.join(OUT_DIR, "chronos2_forecast.csv")
out_df.to_csv(out_path, index=False)
explanation_md = explain(y, median, low, high, band_enabled, q_low, q_high, note)
info = {
"source": source,
"history_points": int(len(y)),
"prediction_length": int(prediction_length),
"requested_samples": int(requested_samples),
"returned_samples": int(samples.shape[0]),
"band_enabled": bool(band_enabled),
"predict_signature": str(inspect.signature(pipe.predict)),
"debug_note": note,
}
return kpis_html, explanation_md, fig, out_df, out_path, info
# -------------------------
# UI
# -------------------------
css = """.gradio-container { max-width: 1200px !important; }"""
with gr.Blocks(title="Chronos-2 • Pro Dashboard (Stable)", css=css) as demo:
gr.Markdown("# ⏱️ Chronos-2 Forecast Dashboard — Stable Edition")
with gr.Row():
with gr.Column(scale=1, min_width=360):
input_mode = gr.Radio(["Sample", "Test CSV", "Upload CSV"], value="Sample", label="Input")
test_csv_name = gr.Dropdown(choices=available_test_csv(), label="Test CSV (data/)")
upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value")
device_ui = gr.Dropdown(
["cpu", "cuda (se disponibile)"],
value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu",
label="Device",
)
model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID")
with gr.Accordion("Sample generator", open=False):
n = gr.Slider(60, 2000, value=300, step=10, label="History length")
seed = gr.Number(value=42, precision=0, label="Seed")
trend = gr.Slider(0.0, 0.2, value=0.03, step=0.005, label="Trend")
season_period = gr.Slider(2, 240, value=14, step=1, label="Season period")
season_amp = gr.Slider(0.0, 12.0, value=3.0, step=0.1, label="Season amplitude")
noise = gr.Slider(0.0, 6.0, value=0.8, step=0.05, label="Noise")
prediction_length = gr.Slider(1, 365, value=30, step=1, label="Prediction length")
requested_samples = gr.Slider(1, 800, value=200, step=25, label="Requested samples (best effort)")
q_low = gr.Slider(0.01, 0.49, value=0.10, step=0.01, label="Quantile low")
q_high = gr.Slider(0.51, 0.99, value=0.90, step=0.01, label="Quantile high")
run_btn = gr.Button("Run", variant="primary")
with gr.Column(scale=2):
kpis = gr.HTML()
with gr.Tabs():
with gr.Tab("Forecast"):
forecast_plot = gr.Plot()
forecast_table = gr.Dataframe(interactive=False)
with gr.Tab("Spiegazione"):
explanation = gr.Markdown()
with gr.Tab("Export"):
download = gr.File()
with gr.Tab("Info"):
info = gr.JSON()
run_btn.click(
fn=run_all,
inputs=[
input_mode, test_csv_name, upload_csv, csv_column,
n, seed, trend, season_period, season_amp, noise,
prediction_length, requested_samples, q_low, q_high,
device_ui, model_id,
],
outputs=[kpis, explanation, forecast_plot, forecast_table, download, info],
)
demo.queue()
demo.launch(ssr_mode=False)