shaheerawan3 commited on
Commit
b64d6f6
Β·
verified Β·
1 Parent(s): d8baa98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -8
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from datetime import date, datetime
3
  import yfinance as yf
4
  import pandas as pd
5
  import numpy as np
@@ -35,6 +35,12 @@ st.markdown("""
35
  .reportview-container {
36
  background: #f0f2f6
37
  }
 
 
 
 
 
 
38
  </style>
39
  """, unsafe_allow_html=True)
40
 
@@ -101,7 +107,6 @@ def plot_technical_analysis(data, selected_asset):
101
  """Create technical analysis plot."""
102
  fig = go.Figure()
103
 
104
- # Add candlestick chart
105
  fig.add_trace(go.Candlestick(
106
  x=data['Date'],
107
  open=data['Open'],
@@ -111,7 +116,6 @@ def plot_technical_analysis(data, selected_asset):
111
  name='Price'
112
  ))
113
 
114
- # Add moving averages
115
  fig.add_trace(go.Scatter(
116
  x=data['Date'],
117
  y=data['SMA_20'],
@@ -134,18 +138,80 @@ def plot_technical_analysis(data, selected_asset):
134
 
135
  return fig
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  @st.cache_data
138
  def convert_df_to_csv(df):
139
  """Convert dataframe to CSV for download."""
140
  return df.to_csv(index=False).encode('utf-8')
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def main():
143
  st.title('πŸ“ˆ Advanced Stock & Cryptocurrency Forecast')
144
 
 
 
 
 
 
 
 
 
 
 
145
  # Sidebar configuration
146
  st.sidebar.title("βš™οΈ Configuration")
147
- asset_type = st.sidebar.radio("Select Asset Type", list(ASSETS.keys()))
148
- selected_asset = st.sidebar.selectbox('Select Asset', ASSETS[asset_type])
149
 
150
  # Main content layout
151
  col1, col2 = st.columns(2)
@@ -156,6 +222,15 @@ def main():
156
 
157
  period = n_years * 365
158
 
 
 
 
 
 
 
 
 
 
159
  # Load and process data
160
  with st.spinner('Loading data...'):
161
  data = load_data(selected_asset)
@@ -178,6 +253,19 @@ def main():
178
  model, future = train_prophet_model(df_prophet, period)
179
  forecast = model.predict(future)
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  # Calculate metrics
182
  historical_predictions = forecast[forecast['ds'].isin(df_prophet['ds'])]
183
  mae = mean_absolute_error(df_prophet['y'], historical_predictions['yhat'])
@@ -197,9 +285,9 @@ def main():
197
  fig_forecast.update_layout(template='plotly_dark')
198
  st.plotly_chart(fig_forecast, use_container_width=True)
199
 
200
- # Display components
201
  st.subheader("πŸ“Š Forecast Components")
202
- fig_components = model.plot_components(forecast)
203
  st.plotly_chart(fig_components, use_container_width=True)
204
 
205
  # Add download button
@@ -213,7 +301,7 @@ def main():
213
 
214
  except Exception as e:
215
  st.error(f"Error in prediction: {str(e)}")
216
- st.exception(e) # This will show the full traceback
217
 
218
  if __name__ == "__main__":
219
  main()
 
1
  import streamlit as st
2
+ from datetime import date, datetime, timedelta
3
  import yfinance as yf
4
  import pandas as pd
5
  import numpy as np
 
35
  .reportview-container {
36
  background: #f0f2f6
37
  }
38
+ .custom-date {
39
+ margin-top: 1rem;
40
+ padding: 1rem;
41
+ background-color: #f8f9fa;
42
+ border-radius: 0.5rem;
43
+ }
44
  </style>
45
  """, unsafe_allow_html=True)
46
 
 
107
  """Create technical analysis plot."""
108
  fig = go.Figure()
109
 
 
110
  fig.add_trace(go.Candlestick(
111
  x=data['Date'],
112
  open=data['Open'],
 
116
  name='Price'
117
  ))
118
 
 
119
  fig.add_trace(go.Scatter(
120
  x=data['Date'],
121
  y=data['SMA_20'],
 
138
 
139
  return fig
140
 
141
+ def plot_forecast_components(model, forecast):
142
+ """Create custom forecast components plot."""
143
+ fig = go.Figure()
144
+
145
+ # Trend
146
+ fig.add_trace(go.Scatter(
147
+ x=forecast['ds'],
148
+ y=forecast['trend'],
149
+ name='Trend',
150
+ line=dict(color='blue')
151
+ ))
152
+
153
+ # Yearly seasonality
154
+ if 'yearly' in forecast.columns:
155
+ fig.add_trace(go.Scatter(
156
+ x=forecast['ds'],
157
+ y=forecast['yearly'],
158
+ name='Yearly Seasonality',
159
+ line=dict(color='green')
160
+ ))
161
+
162
+ # Weekly seasonality
163
+ if 'weekly' in forecast.columns:
164
+ fig.add_trace(go.Scatter(
165
+ x=forecast['ds'],
166
+ y=forecast['weekly'],
167
+ name='Weekly Seasonality',
168
+ line=dict(color='red')
169
+ ))
170
+
171
+ fig.update_layout(
172
+ title='Forecast Components',
173
+ template='plotly_dark',
174
+ height=800,
175
+ showlegend=True
176
+ )
177
+
178
+ return fig
179
+
180
  @st.cache_data
181
  def convert_df_to_csv(df):
182
  """Convert dataframe to CSV for download."""
183
  return df.to_csv(index=False).encode('utf-8')
184
 
185
+ def get_specific_date_prediction(model, date_input, forecast):
186
+ """Get prediction for a specific date."""
187
+ try:
188
+ date_prediction = forecast[forecast['ds'] == pd.to_datetime(date_input)].iloc[0]
189
+ return {
190
+ 'Predicted Value': f"${date_prediction['yhat']:.2f}",
191
+ 'Lower Bound': f"${date_prediction['yhat_lower']:.2f}",
192
+ 'Upper Bound': f"${date_prediction['yhat_upper']:.2f}",
193
+ 'Trend': f"${date_prediction['trend']:.2f}"
194
+ }
195
+ except IndexError:
196
+ return None
197
+
198
  def main():
199
  st.title('πŸ“ˆ Advanced Stock & Cryptocurrency Forecast')
200
 
201
+ # Search bar for assets
202
+ search_term = st.text_input('πŸ” Search for assets (e.g., "AAPL" for Apple Inc.)', '')
203
+
204
+ # Filter assets based on search
205
+ filtered_assets = {
206
+ category: [asset for asset in assets
207
+ if search_term.upper() in asset.upper()]
208
+ for category, assets in ASSETS.items()
209
+ }
210
+
211
  # Sidebar configuration
212
  st.sidebar.title("βš™οΈ Configuration")
213
+ asset_type = st.sidebar.radio("Select Asset Type", list(filtered_assets.keys()))
214
+ selected_asset = st.sidebar.selectbox('Select Asset', filtered_assets[asset_type])
215
 
216
  # Main content layout
217
  col1, col2 = st.columns(2)
 
222
 
223
  period = n_years * 365
224
 
225
+ # Date-specific prediction section
226
+ st.subheader('🎯 Get Prediction for Specific Date')
227
+ prediction_date = st.date_input(
228
+ "Select a date for prediction",
229
+ min_value=date.today(),
230
+ max_value=date.today() + timedelta(days=period),
231
+ value=date.today() + timedelta(days=30)
232
+ )
233
+
234
  # Load and process data
235
  with st.spinner('Loading data...'):
236
  data = load_data(selected_asset)
 
253
  model, future = train_prophet_model(df_prophet, period)
254
  forecast = model.predict(future)
255
 
256
+ # Get specific date prediction
257
+ specific_prediction = get_specific_date_prediction(
258
+ model,
259
+ prediction_date,
260
+ forecast
261
+ )
262
+
263
+ if specific_prediction:
264
+ st.subheader(f"Prediction for {prediction_date}")
265
+ cols = st.columns(4)
266
+ for i, (metric, value) in enumerate(specific_prediction.items()):
267
+ cols[i].metric(metric, value)
268
+
269
  # Calculate metrics
270
  historical_predictions = forecast[forecast['ds'].isin(df_prophet['ds'])]
271
  mae = mean_absolute_error(df_prophet['y'], historical_predictions['yhat'])
 
285
  fig_forecast.update_layout(template='plotly_dark')
286
  st.plotly_chart(fig_forecast, use_container_width=True)
287
 
288
+ # Display components using custom plotting function
289
  st.subheader("πŸ“Š Forecast Components")
290
+ fig_components = plot_forecast_components(model, forecast)
291
  st.plotly_chart(fig_components, use_container_width=True)
292
 
293
  # Add download button
 
301
 
302
  except Exception as e:
303
  st.error(f"Error in prediction: {str(e)}")
304
+ st.exception(e)
305
 
306
  if __name__ == "__main__":
307
  main()