shaheerawan3 commited on
Commit
d8baa98
ยท
verified ยท
1 Parent(s): 77eb99e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -118
app.py CHANGED
@@ -16,11 +16,17 @@ st.set_page_config(
16
  layout="wide"
17
  )
18
 
19
- # Constants and configurations
20
  START = "2015-01-01"
21
  TODAY = date.today().strftime("%Y-%m-%d")
22
 
23
- # Custom CSS for better styling
 
 
 
 
 
 
24
  st.markdown("""
25
  <style>
26
  .stButton>button {
@@ -32,79 +38,116 @@ st.markdown("""
32
  </style>
33
  """, unsafe_allow_html=True)
34
 
35
- class AssetPredictor:
36
- def __init__(self):
37
- self.assets = {
38
- 'Stocks': ['GOOG', 'AAPL', 'MSFT', 'GME'],
39
- 'Cryptocurrencies': ['BTC-USD', 'ETH-USD', 'DOGE-USD', 'ADA-USD']
40
- }
 
41
 
42
- @st.cache_data(ttl=3600) # Cache data for 1 hour
43
- def load_data(self, ticker):
44
- """Load and validate financial data."""
45
- try:
46
- data = yf.download(ticker, START, TODAY)
47
- if data.empty:
48
- raise ValueError(f"No data found for {ticker}")
49
-
50
- data.reset_index(inplace=True)
51
- # Ensure all required columns exist and are numeric
52
- required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
53
- for col in required_columns:
54
- if col not in data.columns:
55
- raise ValueError(f"Missing required column: {col}")
56
- if col != 'Date':
57
- data[col] = pd.to_numeric(data[col], errors='coerce')
58
-
59
- data.dropna(inplace=True)
60
- return data
61
- except Exception as e:
62
- st.error(f"Error loading data: {str(e)}")
63
- return None
64
-
65
- def prepare_prophet_data(self, data):
66
- """Prepare data for Prophet model."""
67
- df_prophet = data[['Date', 'Close']].copy()
68
- df_prophet.columns = ['ds', 'y']
69
- return df_prophet
70
-
71
- def train_prophet_model(self, data, period):
72
- """Train and return Prophet model with customized parameters."""
73
- model = Prophet(
74
- yearly_seasonality=True,
75
- weekly_seasonality=True,
76
- daily_seasonality=True,
77
- changepoint_prior_scale=0.05,
78
- seasonality_prior_scale=10.0,
79
- changepoint_range=0.9
80
- )
81
 
82
- # Add custom seasonalities
83
- model.add_seasonality(
84
- name='monthly',
85
- period=30.5,
86
- fourier_order=5
87
- )
88
 
89
- model.fit(data)
90
- future = model.make_future_dataframe(periods=period)
91
- return model, future
 
 
92
 
93
- def main():
94
- predictor = AssetPredictor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- # Sidebar for user inputs
97
- st.sidebar.title("โš™๏ธ Configuration")
98
- asset_type = st.sidebar.radio("Select Asset Type", list(predictor.assets.keys()))
99
- selected_asset = st.sidebar.selectbox(
100
- 'Select Asset',
101
- predictor.assets[asset_type]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  )
103
 
104
- # Main content
 
 
 
 
 
 
 
105
  st.title('๐Ÿ“ˆ Advanced Stock & Cryptocurrency Forecast')
106
 
107
- # Time range selection
 
 
 
 
 
108
  col1, col2 = st.columns(2)
109
  with col1:
110
  n_years = st.slider('Forecast Period (Years):', 1, 4)
@@ -115,79 +158,51 @@ def main():
115
 
116
  # Load and process data
117
  with st.spinner('Loading data...'):
118
- data = predictor.load_data(selected_asset)
119
 
120
  if data is not None:
121
- # Display technical indicators
122
- st.subheader('๐Ÿ“Š Technical Analysis')
123
-
124
  # Calculate technical indicators
125
  data['SMA_20'] = data['Close'].rolling(window=20).mean()
126
  data['SMA_50'] = data['Close'].rolling(window=50).mean()
127
  data['RSI'] = calculate_rsi(data['Close'])
128
 
129
- # Create technical analysis plot
130
- fig_technical = go.Figure()
131
- fig_technical.add_trace(go.Candlestick(
132
- x=data['Date'],
133
- open=data['Open'],
134
- high=data['High'],
135
- low=data['Low'],
136
- close=data['Close'],
137
- name='Price'
138
- ))
139
- fig_technical.add_trace(go.Scatter(
140
- x=data['Date'],
141
- y=data['SMA_20'],
142
- name='SMA 20',
143
- line=dict(color='orange')
144
- ))
145
- fig_technical.add_trace(go.Scatter(
146
- x=data['Date'],
147
- y=data['SMA_50'],
148
- name='SMA 50',
149
- line=dict(color='blue')
150
- ))
151
-
152
- fig_technical.update_layout(
153
- title=f'{selected_asset} Technical Analysis',
154
- yaxis_title='Price',
155
- template='plotly_dark'
156
- )
157
  st.plotly_chart(fig_technical, use_container_width=True)
158
 
159
  # Prepare and train Prophet model
160
- df_prophet = predictor.prepare_prophet_data(data)
161
 
162
  try:
163
- model, future = predictor.train_prophet_model(df_prophet, period)
164
  forecast = model.predict(future)
165
 
166
- # Calculate performance metrics
167
  historical_predictions = forecast[forecast['ds'].isin(df_prophet['ds'])]
168
  mae = mean_absolute_error(df_prophet['y'], historical_predictions['yhat'])
169
  rmse = np.sqrt(mean_squared_error(df_prophet['y'], historical_predictions['yhat']))
170
  mape = np.mean(np.abs((df_prophet['y'] - historical_predictions['yhat']) / df_prophet['y'])) * 100
171
 
172
- # Display metrics in columns
173
  st.subheader('๐Ÿ“‰ Model Performance Metrics')
174
  col1, col2, col3 = st.columns(3)
175
  col1.metric("MAE", f"${mae:.2f}")
176
  col2.metric("RMSE", f"${rmse:.2f}")
177
  col3.metric("MAPE", f"{mape:.2f}%")
178
 
179
- # Forecast visualization
180
  st.subheader('๐Ÿ”ฎ Price Forecast')
181
  fig_forecast = plot_plotly(model, forecast)
182
  fig_forecast.update_layout(template='plotly_dark')
183
  st.plotly_chart(fig_forecast, use_container_width=True)
184
 
185
- # Show forecast components
186
  st.subheader("๐Ÿ“Š Forecast Components")
187
  fig_components = model.plot_components(forecast)
188
  st.plotly_chart(fig_components, use_container_width=True)
189
 
190
- # Download forecast data
191
  csv = convert_df_to_csv(forecast)
192
  st.download_button(
193
  label="Download Forecast Data",
@@ -198,19 +213,7 @@ def main():
198
 
199
  except Exception as e:
200
  st.error(f"Error in prediction: {str(e)}")
201
-
202
- def calculate_rsi(prices, period=14):
203
- """Calculate Relative Strength Index."""
204
- delta = prices.diff()
205
- gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
206
- loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
207
- rs = gain / loss
208
- return 100 - (100 / (1 + rs))
209
-
210
- @st.cache_data
211
- def convert_df_to_csv(df):
212
- """Convert dataframe to CSV for download."""
213
- return df.to_csv(index=False).encode('utf-8')
214
 
215
  if __name__ == "__main__":
216
  main()
 
16
  layout="wide"
17
  )
18
 
19
+ # Constants
20
  START = "2015-01-01"
21
  TODAY = date.today().strftime("%Y-%m-%d")
22
 
23
+ # Asset categories
24
+ ASSETS = {
25
+ 'Stocks': ['GOOG', 'AAPL', 'MSFT', 'GME'],
26
+ 'Cryptocurrencies': ['BTC-USD', 'ETH-USD', 'DOGE-USD', 'ADA-USD']
27
+ }
28
+
29
+ # Custom CSS
30
  st.markdown("""
31
  <style>
32
  .stButton>button {
 
38
  </style>
39
  """, unsafe_allow_html=True)
40
 
41
+ @st.cache_data(ttl=3600)
42
+ def load_data(ticker):
43
+ """Load and validate financial data."""
44
+ try:
45
+ data = yf.download(ticker, START, TODAY)
46
+ if data.empty:
47
+ raise ValueError(f"No data found for {ticker}")
48
 
49
+ data.reset_index(inplace=True)
50
+ required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ for col in required_columns:
53
+ if col not in data.columns:
54
+ raise ValueError(f"Missing required column: {col}")
55
+ if col != 'Date':
56
+ data[col] = pd.to_numeric(data[col], errors='coerce')
 
57
 
58
+ data.dropna(inplace=True)
59
+ return data
60
+ except Exception as e:
61
+ st.error(f"Error loading data: {str(e)}")
62
+ return None
63
 
64
+ def calculate_rsi(prices, period=14):
65
+ """Calculate Relative Strength Index."""
66
+ delta = prices.diff()
67
+ gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
68
+ loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
69
+ rs = gain / loss
70
+ return 100 - (100 / (1 + rs))
71
+
72
+ def prepare_prophet_data(data):
73
+ """Prepare data for Prophet model."""
74
+ df_prophet = data[['Date', 'Close']].copy()
75
+ df_prophet.columns = ['ds', 'y']
76
+ return df_prophet
77
+
78
+ def train_prophet_model(data, period):
79
+ """Train and return Prophet model with customized parameters."""
80
+ model = Prophet(
81
+ yearly_seasonality=True,
82
+ weekly_seasonality=True,
83
+ daily_seasonality=True,
84
+ changepoint_prior_scale=0.05,
85
+ seasonality_prior_scale=10.0,
86
+ changepoint_range=0.9
87
+ )
88
 
89
+ # Add custom seasonalities
90
+ model.add_seasonality(
91
+ name='monthly',
92
+ period=30.5,
93
+ fourier_order=5
94
+ )
95
+
96
+ model.fit(data)
97
+ future = model.make_future_dataframe(periods=period)
98
+ return model, future
99
+
100
+ def plot_technical_analysis(data, selected_asset):
101
+ """Create technical analysis plot."""
102
+ fig = go.Figure()
103
+
104
+ # Add candlestick chart
105
+ fig.add_trace(go.Candlestick(
106
+ x=data['Date'],
107
+ open=data['Open'],
108
+ high=data['High'],
109
+ low=data['Low'],
110
+ close=data['Close'],
111
+ name='Price'
112
+ ))
113
+
114
+ # Add moving averages
115
+ fig.add_trace(go.Scatter(
116
+ x=data['Date'],
117
+ y=data['SMA_20'],
118
+ name='SMA 20',
119
+ line=dict(color='orange')
120
+ ))
121
+
122
+ fig.add_trace(go.Scatter(
123
+ x=data['Date'],
124
+ y=data['SMA_50'],
125
+ name='SMA 50',
126
+ line=dict(color='blue')
127
+ ))
128
+
129
+ fig.update_layout(
130
+ title=f'{selected_asset} Technical Analysis',
131
+ yaxis_title='Price',
132
+ template='plotly_dark'
133
  )
134
 
135
+ return fig
136
+
137
+ @st.cache_data
138
+ def convert_df_to_csv(df):
139
+ """Convert dataframe to CSV for download."""
140
+ return df.to_csv(index=False).encode('utf-8')
141
+
142
+ def main():
143
  st.title('๐Ÿ“ˆ Advanced Stock & Cryptocurrency Forecast')
144
 
145
+ # Sidebar configuration
146
+ st.sidebar.title("โš™๏ธ Configuration")
147
+ asset_type = st.sidebar.radio("Select Asset Type", list(ASSETS.keys()))
148
+ selected_asset = st.sidebar.selectbox('Select Asset', ASSETS[asset_type])
149
+
150
+ # Main content layout
151
  col1, col2 = st.columns(2)
152
  with col1:
153
  n_years = st.slider('Forecast Period (Years):', 1, 4)
 
158
 
159
  # Load and process data
160
  with st.spinner('Loading data...'):
161
+ data = load_data(selected_asset)
162
 
163
  if data is not None:
 
 
 
164
  # Calculate technical indicators
165
  data['SMA_20'] = data['Close'].rolling(window=20).mean()
166
  data['SMA_50'] = data['Close'].rolling(window=50).mean()
167
  data['RSI'] = calculate_rsi(data['Close'])
168
 
169
+ # Display technical analysis
170
+ st.subheader('๐Ÿ“Š Technical Analysis')
171
+ fig_technical = plot_technical_analysis(data, selected_asset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  st.plotly_chart(fig_technical, use_container_width=True)
173
 
174
  # Prepare and train Prophet model
175
+ df_prophet = prepare_prophet_data(data)
176
 
177
  try:
178
+ model, future = train_prophet_model(df_prophet, period)
179
  forecast = model.predict(future)
180
 
181
+ # Calculate metrics
182
  historical_predictions = forecast[forecast['ds'].isin(df_prophet['ds'])]
183
  mae = mean_absolute_error(df_prophet['y'], historical_predictions['yhat'])
184
  rmse = np.sqrt(mean_squared_error(df_prophet['y'], historical_predictions['yhat']))
185
  mape = np.mean(np.abs((df_prophet['y'] - historical_predictions['yhat']) / df_prophet['y'])) * 100
186
 
187
+ # Display metrics
188
  st.subheader('๐Ÿ“‰ Model Performance Metrics')
189
  col1, col2, col3 = st.columns(3)
190
  col1.metric("MAE", f"${mae:.2f}")
191
  col2.metric("RMSE", f"${rmse:.2f}")
192
  col3.metric("MAPE", f"{mape:.2f}%")
193
 
194
+ # Display forecast
195
  st.subheader('๐Ÿ”ฎ Price Forecast')
196
  fig_forecast = plot_plotly(model, forecast)
197
  fig_forecast.update_layout(template='plotly_dark')
198
  st.plotly_chart(fig_forecast, use_container_width=True)
199
 
200
+ # Display components
201
  st.subheader("๐Ÿ“Š Forecast Components")
202
  fig_components = model.plot_components(forecast)
203
  st.plotly_chart(fig_components, use_container_width=True)
204
 
205
+ # Add download button
206
  csv = convert_df_to_csv(forecast)
207
  st.download_button(
208
  label="Download Forecast Data",
 
213
 
214
  except Exception as e:
215
  st.error(f"Error in prediction: {str(e)}")
216
+ st.exception(e) # This will show the full traceback
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  if __name__ == "__main__":
219
  main()