File size: 2,935 Bytes
9b240de
 
 
48b8d70
9b240de
 
 
6608c0e
 
 
 
 
 
 
 
9b240de
6608c0e
 
cf07dbe
6608c0e
 
 
9b240de
6608c0e
 
9b240de
65e06a7
6608c0e
 
 
 
 
 
9b240de
6608c0e
 
 
 
cf07dbe
 
 
48b8d70
6608c0e
cf07dbe
6608c0e
48b8d70
6608c0e
9b240de
6608c0e
 
9b240de
6608c0e
9b240de
6608c0e
 
 
9b240de
 
 
6608c0e
9b240de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe0093d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import numpy as np
import pandas as pd
from chronos import BaseChronosPipeline # <--- PERBAIKAN: Impor langsung dari 'chronos'
import warnings
warnings.filterwarnings('ignore')

# Global pipeline instance
_pipeline = None

def load_timesfm_model():
    """Load and return the Chronos Pipeline (function name kept for compatibility)"""
    global _pipeline
    model_name = "amazon/chronos-bolt-base"
    if _pipeline is None:
        try:
            print(f"Loading Chronos Pipeline: {model_name}")
            # ChronosPipeline handles model, tokenizer, and device placement automatically
            _pipeline = BaseChronosPipeline.from_pretrained(
                model_name,
                # Setting device_map="auto" to automatically select available device
                device_map="auto"
            )
            print(f"Model loaded successfully.")
            return _pipeline
        except Exception as e:
            print(f"Error loading model: {e}. Please ensure the 'chronos-forecasting' library is installed.")
            _pipeline = None
            return None
    return _pipeline

def predict_stock_prices(model_pipeline, data, forecast_horizon):
    """Predict stock prices using the loaded Chronos pipeline"""
    
    if model_pipeline is None:
        return _simple_forecast(data, forecast_horizon)

    try:
        # Konversi data numpy mentah ke tensor float32 (standar PyTorch)
        context_tensor = torch.tensor(data, dtype=torch.float32)
        
        # Chronos Pipeline expects context as a list of tensors
        raw_forecasts = model_pipeline.predict(
            context=[context_tensor], 
            prediction_length=forecast_horizon,
            num_samples=20 
        )
        
        # Ambil median dari semua sampel (axis=0) untuk mendapatkan point forecast
        point_forecast = np.median(raw_forecasts[0].cpu().numpy(), axis=0)
        
        return point_forecast
        
    except Exception as e:
        print(f"Prediction error with Chronos: {e}")
        # Fallback to simple forecast
        return _simple_forecast(data, forecast_horizon)

def _simple_forecast(data, forecast_horizon):
    """Simple forecasting fallback (used if Chronos fails)"""
    if len(data) < 2:
        return np.full(forecast_horizon, data[-1] if len(data) > 0 else 0)
    
    # Use exponential smoothing
    alpha = 0.3
    predictions = []
    last_value = data[-1]
    
    for _ in range(forecast_horizon):
        # Simple exponential smoothing with trend
        if len(data) >= 2:
            trend = data[-1] - data[-2]
            next_value = last_value + alpha * trend
        else:
            next_value = last_value
        
        # Add some randomness
        noise = np.random.normal(0, np.std(data) * 0.05)
        next_value += noise
        
        predictions.append(next_value)
        last_value = next_value
    
    return np.array(predictions)