import gradio as gr import pandas as pd import requests from prophet import Prophet import logging import plotly.graph_objs as go import math import numpy as np logging.basicConfig(level=logging.INFO) OKX_TICKERS_ENDPOINT = "https://www.okx.com/api/v5/market/tickers?instType=SPOT" OKX_CANDLE_ENDPOINT = "https://www.okx.com/api/v5/market/candles" TIMEFRAME_MAPPING = { "1m": "1m", "5m": "5m", "15m": "15m", "30m": "30m", "1h": "1H", "2h": "2H", "4h": "4H", "6h": "6H", "12h": "12H", "1d": "1D", "1w": "1W", } def calculate_technical_indicators(df): # Calculate RSI delta = df['close'].diff() gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() rs = gain / loss df['RSI'] = 100 - (100 / (1 + rs)) # Calculate MACD exp1 = df['close'].ewm(span=12, adjust=False).mean() exp2 = df['close'].ewm(span=26, adjust=False).mean() df['MACD'] = exp1 - exp2 df['Signal_Line'] = df['MACD'].ewm(span=9, adjust=False).mean() # Calculate Bollinger Bands df['MA20'] = df['close'].rolling(window=20).mean() df['BB_upper'] = df['MA20'] + 2 * df['close'].rolling(window=20).std() df['BB_lower'] = df['MA20'] - 2 * df['close'].rolling(window=20).std() return df def create_technical_charts(df): # Price and Bollinger Bands fig1 = go.Figure() fig1.add_trace(go.Candlestick( x=df['timestamp'], open=df['open'], high=df['high'], low=df['low'], close=df['close'], name='Price' )) fig1.add_trace(go.Scatter(x=df['timestamp'], y=df['BB_upper'], name='Upper BB', line=dict(color='gray', dash='dash'))) fig1.add_trace(go.Scatter(x=df['timestamp'], y=df['BB_lower'], name='Lower BB', line=dict(color='gray', dash='dash'))) fig1.update_layout(title='Price and Bollinger Bands', xaxis_title='Date', yaxis_title='Price') # RSI fig2 = go.Figure() fig2.add_trace(go.Scatter(x=df['timestamp'], y=df['RSI'], name='RSI')) fig2.add_hline(y=70, line_dash="dash", line_color="red") fig2.add_hline(y=30, line_dash="dash", line_color="green") fig2.update_layout(title='RSI Indicator', xaxis_title='Date', yaxis_title='RSI') # MACD fig3 = go.Figure() fig3.add_trace(go.Scatter(x=df['timestamp'], y=df['MACD'], name='MACD')) fig3.add_trace(go.Scatter(x=df['timestamp'], y=df['Signal_Line'], name='Signal Line')) fig3.update_layout(title='MACD', xaxis_title='Date', yaxis_title='Value') return fig1, fig2, fig3 def fetch_okx_symbols(): """ Fetch spot symbols from OKX. """ logging.info("Fetching symbols from OKX Spot tickers...") try: resp = requests.get(OKX_TICKERS_ENDPOINT, timeout=30) resp.raise_for_status() json_data = resp.json() if json_data.get("code") != "0": logging.error(f"Non-zero code returned: {json_data}") return ["BTC-USDT"] # Default fallback data = json_data.get("data", []) symbols = [item["instId"] for item in data if item.get("instType") == "SPOT"] if not symbols: return ["BTC-USDT"] # Ensure BTC-USDT is first in the list if "BTC-USDT" in symbols: symbols.remove("BTC-USDT") symbols.insert(0, "BTC-USDT") logging.info(f"Fetched {len(symbols)} OKX spot symbols.") return symbols except Exception as e: logging.error(f"Error fetching OKX symbols: {e}") return ["BTC-USDT"] def fetch_okx_candles_chunk(symbol, timeframe, limit=300, after=None, before=None): params = { "instId": symbol, "bar": timeframe, "limit": limit } if after is not None: params["after"] = str(after) if before is not None: params["before"] = str(before) logging.info(f"Fetching chunk: symbol={symbol}, bar={timeframe}, limit={limit}") try: resp = requests.get(OKX_CANDLE_ENDPOINT, params=params, timeout=30) resp.raise_for_status() json_data = resp.json() if json_data.get("code") != "0": msg = f"OKX returned code={json_data.get('code')}, msg={json_data.get('msg')}" logging.error(msg) return pd.DataFrame(), msg items = json_data.get("data", []) if not items: return pd.DataFrame(), "" columns = ["ts", "o", "h", "l", "c", "vol", "volCcy", "volCcyQuote", "confirm"] df = pd.DataFrame(items, columns=columns) df.rename(columns={ "ts": "timestamp", "o": "open", "h": "high", "l": "low", "c": "close" }, inplace=True) df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms") numeric_cols = ["open", "high", "low", "close", "vol", "volCcy", "volCcyQuote", "confirm"] df[numeric_cols] = df[numeric_cols].astype(float) return df, "" except Exception as e: err_msg = f"Error fetching candles chunk for {symbol}: {e}" logging.error(err_msg) return pd.DataFrame(), err_msg def fetch_okx_candles(symbol, timeframe="1H", total=2000): """ Fetch historical candle data """ logging.info(f"Fetching ~{total} candles for {symbol} @ {timeframe}") calls_needed = math.ceil(total / 300.0) all_data = [] after_ts = None for _ in range(calls_needed): df_chunk, err = fetch_okx_candles_chunk( symbol, timeframe, limit=300, after=after_ts ) if err: return pd.DataFrame(), err if df_chunk.empty: break earliest_ts = df_chunk["timestamp"].iloc[-1] after_ts = int(earliest_ts.timestamp() * 1000 - 1) all_data.append(df_chunk) if len(df_chunk) < 300: break if not all_data: return pd.DataFrame(), "No data returned." df_all = pd.concat(all_data, ignore_index=True) df_all.sort_values(by="timestamp", inplace=True) df_all.reset_index(drop=True, inplace=True) # Calculate technical indicators df_all = calculate_technical_indicators(df_all) logging.info(f"Fetched {len(df_all)} rows for {symbol}.") return df_all, "" def prepare_data_for_prophet(df): if df.empty: return pd.DataFrame(columns=["ds", "y"]) df_prophet = df.rename(columns={"timestamp": "ds", "close": "y"}) return df_prophet[["ds", "y"]] def prophet_forecast( df_prophet, periods=10, freq="h", daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=False, seasonality_mode="additive", changepoint_prior_scale=0.05, ): if df_prophet.empty: return pd.DataFrame(), "No data for Prophet." try: model = Prophet( daily_seasonality=daily_seasonality, weekly_seasonality=weekly_seasonality, yearly_seasonality=yearly_seasonality, seasonality_mode=seasonality_mode, changepoint_prior_scale=changepoint_prior_scale, ) model.fit(df_prophet) future = model.make_future_dataframe(periods=periods, freq=freq) forecast = model.predict(future) return forecast, "" except Exception as e: logging.error(f"Forecast error: {e}") return pd.DataFrame(), f"Forecast error: {e}" def prophet_wrapper( df_prophet, forecast_steps, freq, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale, ): if len(df_prophet) < 10: return pd.DataFrame(), "Not enough data for forecasting (need >=10 rows)." full_forecast, err = prophet_forecast( df_prophet, periods=forecast_steps, freq=freq, daily_seasonality=daily_seasonality, weekly_seasonality=weekly_seasonality, yearly_seasonality=yearly_seasonality, seasonality_mode=seasonality_mode, changepoint_prior_scale=changepoint_prior_scale, ) if err: return pd.DataFrame(), err future_only = full_forecast.loc[len(df_prophet):, ["ds", "yhat", "yhat_lower", "yhat_upper"]] return future_only, "" def create_forecast_plot(forecast_df): if forecast_df.empty: return go.Figure() fig = go.Figure() fig.add_trace(go.Scatter( x=forecast_df["ds"], y=forecast_df["yhat"], mode="lines", name="Forecast", line=dict(color="blue", width=2) )) fig.add_trace(go.Scatter( x=forecast_df["ds"], y=forecast_df["yhat_lower"], fill=None, mode="lines", line=dict(width=0), showlegend=True, name="Lower Bound" )) fig.add_trace(go.Scatter( x=forecast_df["ds"], y=forecast_df["yhat_upper"], fill="tonexty", mode="lines", line=dict(width=0), name="Upper Bound" )) fig.update_layout( title="Price Forecast", xaxis_title="Time", yaxis_title="Price", hovermode="x unified", template="plotly_white", legend=dict( yanchor="top", y=0.99, xanchor="left", x=0.01 ) ) return fig def predict( symbol, timeframe, forecast_steps, total_candles, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale, ): okx_bar = TIMEFRAME_MAPPING.get(timeframe, "1H") df_raw, err = fetch_okx_candles(symbol, timeframe=okx_bar, total=total_candles) if err: return pd.DataFrame(), pd.DataFrame(), err df_prophet = prepare_data_for_prophet(df_raw) freq = "h" if "h" in timeframe.lower() else "d" future_df, err2 = prophet_wrapper( df_prophet, forecast_steps, freq, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale, ) if err2: return pd.DataFrame(), pd.DataFrame(), err2 return df_raw, future_df, "" def display_forecast( symbol, timeframe, forecast_steps, total_candles, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale, ): logging.info(f"Processing forecast request for {symbol}") df_raw, forecast_df, error = predict( symbol, timeframe, forecast_steps, total_candles, daily_seasonality, weekly_seasonality, yearly_seasonality, seasonality_mode, changepoint_prior_scale, ) if error: return None, None, None, None, f"Error: {error}" forecast_plot = create_forecast_plot(forecast_df) tech_plot, rsi_plot, macd_plot = create_technical_charts(df_raw) return forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df def main(): symbols = fetch_okx_symbols() with gr.Blocks(theme=gr.themes.Base()) as demo: with gr.Row(): gr.Markdown("# CryptoVision") gr.HTML(""" """) with gr.Row(): with gr.Column(scale=1): with gr.Group(): gr.Markdown("### Market Selection") symbol_dd = gr.Dropdown( label="Trading Pair", choices=symbols, value="BTC-USDT" ) timeframe_dd = gr.Dropdown( label="Timeframe", choices=list(TIMEFRAME_MAPPING.keys()), value="1h" ) with gr.Column(scale=1): with gr.Group(): gr.Markdown("### Forecast Parameters") forecast_steps_slider = gr.Slider( label="Forecast Steps", minimum=1, maximum=100, value=24, step=1 ) total_candles_slider = gr.Slider( label="Historical Candles", minimum=300, maximum=3000, value=2000, step=100 ) with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown("### Advanced Settings") with gr.Row(): daily_box = gr.Checkbox(label="Daily Seasonality", value=True) weekly_box = gr.Checkbox(label="Weekly Seasonality", value=True) yearly_box = gr.Checkbox(label="Yearly Seasonality", value=False) seasonality_mode_dd = gr.Dropdown( label="Seasonality Mode", choices=["additive", "multiplicative"], value="additive" ) changepoint_scale_slider = gr.Slider( label="Changepoint Prior Scale", minimum=0.01, maximum=1.0, step=0.01, value=0.05 ) with gr.Row(): forecast_btn = gr.Button("Generate Forecast", variant="primary", size="lg") with gr.Row(): forecast_plot = gr.Plot(label="Price Forecast") with gr.Row(): tech_plot = gr.Plot(label="Technical Analysis") rsi_plot = gr.Plot(label="RSI Indicator") with gr.Row(): macd_plot = gr.Plot(label="MACD") with gr.Row(): forecast_df = gr.Dataframe( label="Forecast Data", headers=["Date", "Forecast", "Lower Bound", "Upper Bound"] ) forecast_btn.click( fn=display_forecast, inputs=[ symbol_dd, timeframe_dd, forecast_steps_slider, total_candles_slider, daily_box, weekly_box, yearly_box, seasonality_mode_dd, changepoint_scale_slider, ], outputs=[forecast_plot, tech_plot, rsi_plot, macd_plot, forecast_df] ) return demo if __name__ == "__main__": app = main() app.launch()