Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -16,11 +16,17 @@ st.set_page_config(
|
|
16 |
layout="wide"
|
17 |
)
|
18 |
|
19 |
-
# Constants
|
20 |
START = "2015-01-01"
|
21 |
TODAY = date.today().strftime("%Y-%m-%d")
|
22 |
|
23 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
st.markdown("""
|
25 |
<style>
|
26 |
.stButton>button {
|
@@ -32,79 +38,116 @@ st.markdown("""
|
|
32 |
</style>
|
33 |
""", unsafe_allow_html=True)
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
"""Load and validate financial data."""
|
45 |
-
try:
|
46 |
-
data = yf.download(ticker, START, TODAY)
|
47 |
-
if data.empty:
|
48 |
-
raise ValueError(f"No data found for {ticker}")
|
49 |
-
|
50 |
-
data.reset_index(inplace=True)
|
51 |
-
# Ensure all required columns exist and are numeric
|
52 |
-
required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
|
53 |
-
for col in required_columns:
|
54 |
-
if col not in data.columns:
|
55 |
-
raise ValueError(f"Missing required column: {col}")
|
56 |
-
if col != 'Date':
|
57 |
-
data[col] = pd.to_numeric(data[col], errors='coerce')
|
58 |
-
|
59 |
-
data.dropna(inplace=True)
|
60 |
-
return data
|
61 |
-
except Exception as e:
|
62 |
-
st.error(f"Error loading data: {str(e)}")
|
63 |
-
return None
|
64 |
-
|
65 |
-
def prepare_prophet_data(self, data):
|
66 |
-
"""Prepare data for Prophet model."""
|
67 |
-
df_prophet = data[['Date', 'Close']].copy()
|
68 |
-
df_prophet.columns = ['ds', 'y']
|
69 |
-
return df_prophet
|
70 |
-
|
71 |
-
def train_prophet_model(self, data, period):
|
72 |
-
"""Train and return Prophet model with customized parameters."""
|
73 |
-
model = Prophet(
|
74 |
-
yearly_seasonality=True,
|
75 |
-
weekly_seasonality=True,
|
76 |
-
daily_seasonality=True,
|
77 |
-
changepoint_prior_scale=0.05,
|
78 |
-
seasonality_prior_scale=10.0,
|
79 |
-
changepoint_range=0.9
|
80 |
-
)
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
)
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
92 |
|
93 |
-
def
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
#
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
)
|
103 |
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
st.title('๐ Advanced Stock & Cryptocurrency Forecast')
|
106 |
|
107 |
-
#
|
|
|
|
|
|
|
|
|
|
|
108 |
col1, col2 = st.columns(2)
|
109 |
with col1:
|
110 |
n_years = st.slider('Forecast Period (Years):', 1, 4)
|
@@ -115,79 +158,51 @@ def main():
|
|
115 |
|
116 |
# Load and process data
|
117 |
with st.spinner('Loading data...'):
|
118 |
-
data =
|
119 |
|
120 |
if data is not None:
|
121 |
-
# Display technical indicators
|
122 |
-
st.subheader('๐ Technical Analysis')
|
123 |
-
|
124 |
# Calculate technical indicators
|
125 |
data['SMA_20'] = data['Close'].rolling(window=20).mean()
|
126 |
data['SMA_50'] = data['Close'].rolling(window=50).mean()
|
127 |
data['RSI'] = calculate_rsi(data['Close'])
|
128 |
|
129 |
-
#
|
130 |
-
|
131 |
-
fig_technical
|
132 |
-
x=data['Date'],
|
133 |
-
open=data['Open'],
|
134 |
-
high=data['High'],
|
135 |
-
low=data['Low'],
|
136 |
-
close=data['Close'],
|
137 |
-
name='Price'
|
138 |
-
))
|
139 |
-
fig_technical.add_trace(go.Scatter(
|
140 |
-
x=data['Date'],
|
141 |
-
y=data['SMA_20'],
|
142 |
-
name='SMA 20',
|
143 |
-
line=dict(color='orange')
|
144 |
-
))
|
145 |
-
fig_technical.add_trace(go.Scatter(
|
146 |
-
x=data['Date'],
|
147 |
-
y=data['SMA_50'],
|
148 |
-
name='SMA 50',
|
149 |
-
line=dict(color='blue')
|
150 |
-
))
|
151 |
-
|
152 |
-
fig_technical.update_layout(
|
153 |
-
title=f'{selected_asset} Technical Analysis',
|
154 |
-
yaxis_title='Price',
|
155 |
-
template='plotly_dark'
|
156 |
-
)
|
157 |
st.plotly_chart(fig_technical, use_container_width=True)
|
158 |
|
159 |
# Prepare and train Prophet model
|
160 |
-
df_prophet =
|
161 |
|
162 |
try:
|
163 |
-
model, future =
|
164 |
forecast = model.predict(future)
|
165 |
|
166 |
-
# Calculate
|
167 |
historical_predictions = forecast[forecast['ds'].isin(df_prophet['ds'])]
|
168 |
mae = mean_absolute_error(df_prophet['y'], historical_predictions['yhat'])
|
169 |
rmse = np.sqrt(mean_squared_error(df_prophet['y'], historical_predictions['yhat']))
|
170 |
mape = np.mean(np.abs((df_prophet['y'] - historical_predictions['yhat']) / df_prophet['y'])) * 100
|
171 |
|
172 |
-
# Display metrics
|
173 |
st.subheader('๐ Model Performance Metrics')
|
174 |
col1, col2, col3 = st.columns(3)
|
175 |
col1.metric("MAE", f"${mae:.2f}")
|
176 |
col2.metric("RMSE", f"${rmse:.2f}")
|
177 |
col3.metric("MAPE", f"{mape:.2f}%")
|
178 |
|
179 |
-
#
|
180 |
st.subheader('๐ฎ Price Forecast')
|
181 |
fig_forecast = plot_plotly(model, forecast)
|
182 |
fig_forecast.update_layout(template='plotly_dark')
|
183 |
st.plotly_chart(fig_forecast, use_container_width=True)
|
184 |
|
185 |
-
#
|
186 |
st.subheader("๐ Forecast Components")
|
187 |
fig_components = model.plot_components(forecast)
|
188 |
st.plotly_chart(fig_components, use_container_width=True)
|
189 |
|
190 |
-
#
|
191 |
csv = convert_df_to_csv(forecast)
|
192 |
st.download_button(
|
193 |
label="Download Forecast Data",
|
@@ -198,19 +213,7 @@ def main():
|
|
198 |
|
199 |
except Exception as e:
|
200 |
st.error(f"Error in prediction: {str(e)}")
|
201 |
-
|
202 |
-
def calculate_rsi(prices, period=14):
|
203 |
-
"""Calculate Relative Strength Index."""
|
204 |
-
delta = prices.diff()
|
205 |
-
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
206 |
-
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
207 |
-
rs = gain / loss
|
208 |
-
return 100 - (100 / (1 + rs))
|
209 |
-
|
210 |
-
@st.cache_data
|
211 |
-
def convert_df_to_csv(df):
|
212 |
-
"""Convert dataframe to CSV for download."""
|
213 |
-
return df.to_csv(index=False).encode('utf-8')
|
214 |
|
215 |
if __name__ == "__main__":
|
216 |
main()
|
|
|
16 |
layout="wide"
|
17 |
)
|
18 |
|
19 |
+
# Constants
|
20 |
START = "2015-01-01"
|
21 |
TODAY = date.today().strftime("%Y-%m-%d")
|
22 |
|
23 |
+
# Asset categories
|
24 |
+
ASSETS = {
|
25 |
+
'Stocks': ['GOOG', 'AAPL', 'MSFT', 'GME'],
|
26 |
+
'Cryptocurrencies': ['BTC-USD', 'ETH-USD', 'DOGE-USD', 'ADA-USD']
|
27 |
+
}
|
28 |
+
|
29 |
+
# Custom CSS
|
30 |
st.markdown("""
|
31 |
<style>
|
32 |
.stButton>button {
|
|
|
38 |
</style>
|
39 |
""", unsafe_allow_html=True)
|
40 |
|
41 |
+
@st.cache_data(ttl=3600)
|
42 |
+
def load_data(ticker):
|
43 |
+
"""Load and validate financial data."""
|
44 |
+
try:
|
45 |
+
data = yf.download(ticker, START, TODAY)
|
46 |
+
if data.empty:
|
47 |
+
raise ValueError(f"No data found for {ticker}")
|
48 |
|
49 |
+
data.reset_index(inplace=True)
|
50 |
+
required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
for col in required_columns:
|
53 |
+
if col not in data.columns:
|
54 |
+
raise ValueError(f"Missing required column: {col}")
|
55 |
+
if col != 'Date':
|
56 |
+
data[col] = pd.to_numeric(data[col], errors='coerce')
|
|
|
57 |
|
58 |
+
data.dropna(inplace=True)
|
59 |
+
return data
|
60 |
+
except Exception as e:
|
61 |
+
st.error(f"Error loading data: {str(e)}")
|
62 |
+
return None
|
63 |
|
64 |
+
def calculate_rsi(prices, period=14):
|
65 |
+
"""Calculate Relative Strength Index."""
|
66 |
+
delta = prices.diff()
|
67 |
+
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
68 |
+
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
69 |
+
rs = gain / loss
|
70 |
+
return 100 - (100 / (1 + rs))
|
71 |
+
|
72 |
+
def prepare_prophet_data(data):
|
73 |
+
"""Prepare data for Prophet model."""
|
74 |
+
df_prophet = data[['Date', 'Close']].copy()
|
75 |
+
df_prophet.columns = ['ds', 'y']
|
76 |
+
return df_prophet
|
77 |
+
|
78 |
+
def train_prophet_model(data, period):
|
79 |
+
"""Train and return Prophet model with customized parameters."""
|
80 |
+
model = Prophet(
|
81 |
+
yearly_seasonality=True,
|
82 |
+
weekly_seasonality=True,
|
83 |
+
daily_seasonality=True,
|
84 |
+
changepoint_prior_scale=0.05,
|
85 |
+
seasonality_prior_scale=10.0,
|
86 |
+
changepoint_range=0.9
|
87 |
+
)
|
88 |
|
89 |
+
# Add custom seasonalities
|
90 |
+
model.add_seasonality(
|
91 |
+
name='monthly',
|
92 |
+
period=30.5,
|
93 |
+
fourier_order=5
|
94 |
+
)
|
95 |
+
|
96 |
+
model.fit(data)
|
97 |
+
future = model.make_future_dataframe(periods=period)
|
98 |
+
return model, future
|
99 |
+
|
100 |
+
def plot_technical_analysis(data, selected_asset):
|
101 |
+
"""Create technical analysis plot."""
|
102 |
+
fig = go.Figure()
|
103 |
+
|
104 |
+
# Add candlestick chart
|
105 |
+
fig.add_trace(go.Candlestick(
|
106 |
+
x=data['Date'],
|
107 |
+
open=data['Open'],
|
108 |
+
high=data['High'],
|
109 |
+
low=data['Low'],
|
110 |
+
close=data['Close'],
|
111 |
+
name='Price'
|
112 |
+
))
|
113 |
+
|
114 |
+
# Add moving averages
|
115 |
+
fig.add_trace(go.Scatter(
|
116 |
+
x=data['Date'],
|
117 |
+
y=data['SMA_20'],
|
118 |
+
name='SMA 20',
|
119 |
+
line=dict(color='orange')
|
120 |
+
))
|
121 |
+
|
122 |
+
fig.add_trace(go.Scatter(
|
123 |
+
x=data['Date'],
|
124 |
+
y=data['SMA_50'],
|
125 |
+
name='SMA 50',
|
126 |
+
line=dict(color='blue')
|
127 |
+
))
|
128 |
+
|
129 |
+
fig.update_layout(
|
130 |
+
title=f'{selected_asset} Technical Analysis',
|
131 |
+
yaxis_title='Price',
|
132 |
+
template='plotly_dark'
|
133 |
)
|
134 |
|
135 |
+
return fig
|
136 |
+
|
137 |
+
@st.cache_data
|
138 |
+
def convert_df_to_csv(df):
|
139 |
+
"""Convert dataframe to CSV for download."""
|
140 |
+
return df.to_csv(index=False).encode('utf-8')
|
141 |
+
|
142 |
+
def main():
|
143 |
st.title('๐ Advanced Stock & Cryptocurrency Forecast')
|
144 |
|
145 |
+
# Sidebar configuration
|
146 |
+
st.sidebar.title("โ๏ธ Configuration")
|
147 |
+
asset_type = st.sidebar.radio("Select Asset Type", list(ASSETS.keys()))
|
148 |
+
selected_asset = st.sidebar.selectbox('Select Asset', ASSETS[asset_type])
|
149 |
+
|
150 |
+
# Main content layout
|
151 |
col1, col2 = st.columns(2)
|
152 |
with col1:
|
153 |
n_years = st.slider('Forecast Period (Years):', 1, 4)
|
|
|
158 |
|
159 |
# Load and process data
|
160 |
with st.spinner('Loading data...'):
|
161 |
+
data = load_data(selected_asset)
|
162 |
|
163 |
if data is not None:
|
|
|
|
|
|
|
164 |
# Calculate technical indicators
|
165 |
data['SMA_20'] = data['Close'].rolling(window=20).mean()
|
166 |
data['SMA_50'] = data['Close'].rolling(window=50).mean()
|
167 |
data['RSI'] = calculate_rsi(data['Close'])
|
168 |
|
169 |
+
# Display technical analysis
|
170 |
+
st.subheader('๐ Technical Analysis')
|
171 |
+
fig_technical = plot_technical_analysis(data, selected_asset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
st.plotly_chart(fig_technical, use_container_width=True)
|
173 |
|
174 |
# Prepare and train Prophet model
|
175 |
+
df_prophet = prepare_prophet_data(data)
|
176 |
|
177 |
try:
|
178 |
+
model, future = train_prophet_model(df_prophet, period)
|
179 |
forecast = model.predict(future)
|
180 |
|
181 |
+
# Calculate metrics
|
182 |
historical_predictions = forecast[forecast['ds'].isin(df_prophet['ds'])]
|
183 |
mae = mean_absolute_error(df_prophet['y'], historical_predictions['yhat'])
|
184 |
rmse = np.sqrt(mean_squared_error(df_prophet['y'], historical_predictions['yhat']))
|
185 |
mape = np.mean(np.abs((df_prophet['y'] - historical_predictions['yhat']) / df_prophet['y'])) * 100
|
186 |
|
187 |
+
# Display metrics
|
188 |
st.subheader('๐ Model Performance Metrics')
|
189 |
col1, col2, col3 = st.columns(3)
|
190 |
col1.metric("MAE", f"${mae:.2f}")
|
191 |
col2.metric("RMSE", f"${rmse:.2f}")
|
192 |
col3.metric("MAPE", f"{mape:.2f}%")
|
193 |
|
194 |
+
# Display forecast
|
195 |
st.subheader('๐ฎ Price Forecast')
|
196 |
fig_forecast = plot_plotly(model, forecast)
|
197 |
fig_forecast.update_layout(template='plotly_dark')
|
198 |
st.plotly_chart(fig_forecast, use_container_width=True)
|
199 |
|
200 |
+
# Display components
|
201 |
st.subheader("๐ Forecast Components")
|
202 |
fig_components = model.plot_components(forecast)
|
203 |
st.plotly_chart(fig_components, use_container_width=True)
|
204 |
|
205 |
+
# Add download button
|
206 |
csv = convert_df_to_csv(forecast)
|
207 |
st.download_button(
|
208 |
label="Download Forecast Data",
|
|
|
213 |
|
214 |
except Exception as e:
|
215 |
st.error(f"Error in prediction: {str(e)}")
|
216 |
+
st.exception(e) # This will show the full traceback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
if __name__ == "__main__":
|
219 |
main()
|