Spaces:
Sleeping
Sleeping
File size: 2,935 Bytes
9b240de 48b8d70 9b240de 6608c0e 9b240de 6608c0e cf07dbe 6608c0e 9b240de 6608c0e 9b240de 65e06a7 6608c0e 9b240de 6608c0e cf07dbe 48b8d70 6608c0e cf07dbe 6608c0e 48b8d70 6608c0e 9b240de 6608c0e 9b240de 6608c0e 9b240de 6608c0e 9b240de 6608c0e 9b240de fe0093d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch
import numpy as np
import pandas as pd
from chronos import BaseChronosPipeline # <--- PERBAIKAN: Impor langsung dari 'chronos'
import warnings
warnings.filterwarnings('ignore')
# Global pipeline instance
_pipeline = None
def load_timesfm_model():
"""Load and return the Chronos Pipeline (function name kept for compatibility)"""
global _pipeline
model_name = "amazon/chronos-bolt-base"
if _pipeline is None:
try:
print(f"Loading Chronos Pipeline: {model_name}")
# ChronosPipeline handles model, tokenizer, and device placement automatically
_pipeline = BaseChronosPipeline.from_pretrained(
model_name,
# Setting device_map="auto" to automatically select available device
device_map="auto"
)
print(f"Model loaded successfully.")
return _pipeline
except Exception as e:
print(f"Error loading model: {e}. Please ensure the 'chronos-forecasting' library is installed.")
_pipeline = None
return None
return _pipeline
def predict_stock_prices(model_pipeline, data, forecast_horizon):
"""Predict stock prices using the loaded Chronos pipeline"""
if model_pipeline is None:
return _simple_forecast(data, forecast_horizon)
try:
# Konversi data numpy mentah ke tensor float32 (standar PyTorch)
context_tensor = torch.tensor(data, dtype=torch.float32)
# Chronos Pipeline expects context as a list of tensors
raw_forecasts = model_pipeline.predict(
context=[context_tensor],
prediction_length=forecast_horizon,
num_samples=20
)
# Ambil median dari semua sampel (axis=0) untuk mendapatkan point forecast
point_forecast = np.median(raw_forecasts[0].cpu().numpy(), axis=0)
return point_forecast
except Exception as e:
print(f"Prediction error with Chronos: {e}")
# Fallback to simple forecast
return _simple_forecast(data, forecast_horizon)
def _simple_forecast(data, forecast_horizon):
"""Simple forecasting fallback (used if Chronos fails)"""
if len(data) < 2:
return np.full(forecast_horizon, data[-1] if len(data) > 0 else 0)
# Use exponential smoothing
alpha = 0.3
predictions = []
last_value = data[-1]
for _ in range(forecast_horizon):
# Simple exponential smoothing with trend
if len(data) >= 2:
trend = data[-1] - data[-2]
next_value = last_value + alpha * trend
else:
next_value = last_value
# Add some randomness
noise = np.random.normal(0, np.std(data) * 0.05)
next_value += noise
predictions.append(next_value)
last_value = next_value
return np.array(predictions) |