Anvarbekkk commited on
Commit
d414f92
·
verified ·
1 Parent(s): 846ed80

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +206 -0
  2. requirements.txt +80 -0
  3. stock_price_model.h5 +3 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import yfinance as yf
4
+ import datetime as dt
5
+ import plotly.graph_objects as go
6
+ from sklearn.preprocessing import MinMaxScaler
7
+ from tensorflow.keras.models import load_model
8
+ import gradio as gr
9
+ import warnings
10
+ import os
11
+
12
+ # Suppress warnings
13
+ warnings.filterwarnings("ignore")
14
+
15
+ # Constants
16
+ PREDICTION_DAYS = 30
17
+ TIME_STEP = 60
18
+ DATA_YEARS = 3
19
+
20
+ # Load model
21
+ model = load_model('stock_price_model.h5')
22
+ model.make_predict_function() # For faster inference
23
+
24
+
25
+ def preprocess_data(df):
26
+ """Process yfinance data"""
27
+ df.columns = [col[0] if isinstance(col, tuple) else col for col in df.columns]
28
+ df = df.reset_index().rename(columns={'index': 'Date'})
29
+ df = df[['Date', 'High', 'Low', 'Open', 'Close', 'Volume']]
30
+ df['Date'] = pd.to_datetime(df['Date'])
31
+ df.set_index('Date', inplace=True)
32
+ return df
33
+
34
+
35
+ def get_stock_data(stock_symbol):
36
+ """Fetch stock data with caching"""
37
+ end_date = dt.datetime.now()
38
+ start_date = end_date - dt.timedelta(days=365 * DATA_YEARS)
39
+ df = yf.download(stock_symbol, start=start_date, end=end_date)
40
+ return preprocess_data(df)
41
+
42
+
43
+ def prepare_data(df):
44
+ """Prepare data for LSTM prediction"""
45
+ scaler = MinMaxScaler()
46
+ scaled_data = scaler.fit_transform(df['Close'].values.reshape(-1, 1))
47
+
48
+ # Create dataset using sliding window
49
+ X = np.array([scaled_data[i:i + TIME_STEP, 0]
50
+ for i in range(len(scaled_data) - TIME_STEP - 1)])
51
+ y = scaled_data[TIME_STEP + 1:, 0]
52
+
53
+ return X.reshape(X.shape[0], TIME_STEP, 1), y, scaler
54
+
55
+
56
+ def predict_future(model, data, scaler):
57
+ """Generate future predictions"""
58
+ last_data = data[-TIME_STEP:].reshape(1, TIME_STEP, 1)
59
+ future_preds = np.zeros(PREDICTION_DAYS, dtype='float32')
60
+
61
+ for i in range(PREDICTION_DAYS):
62
+ next_pred = model.predict(last_data, verbose=0)[0, 0]
63
+ future_preds[i] = next_pred
64
+ last_data = np.roll(last_data, -1, axis=1)
65
+ last_data[0, -1, 0] = next_pred
66
+
67
+ return scaler.inverse_transform(future_preds.reshape(-1, 1))
68
+
69
+
70
+ def create_plot(df, pred_data=None, future_data=None, title=""):
71
+ """Create interactive Plotly figure"""
72
+ fig = go.Figure()
73
+
74
+ # Main price line
75
+ fig.add_trace(go.Scatter(
76
+ x=df.index,
77
+ y=df['Close'],
78
+ name='Actual Price',
79
+ line=dict(color='blue')
80
+ ))
81
+
82
+ # Prediction line
83
+ if pred_data is not None:
84
+ fig.add_trace(go.Scatter(
85
+ x=df.index[TIME_STEP + 1:],
86
+ y=pred_data[:, 0],
87
+ name='Predicted',
88
+ line=dict(color='orange')
89
+ ))
90
+
91
+ # Future prediction
92
+ if future_data is not None:
93
+ future_dates = pd.date_range(
94
+ start=df.index[-1],
95
+ periods=PREDICTION_DAYS + 1
96
+ )[1:]
97
+ fig.add_trace(go.Scatter(
98
+ x=future_dates,
99
+ y=future_data[:, 0],
100
+ name='30-Day Forecast',
101
+ line=dict(color='green')
102
+ ))
103
+
104
+ fig.update_layout(
105
+ title=title,
106
+ template='plotly_dark',
107
+ margin=dict(l=20, r=20, t=40, b=20)
108
+ )
109
+ return fig
110
+
111
+
112
+ def predict_stock(stock_symbol):
113
+ """Main prediction function for Gradio"""
114
+ try:
115
+ df = get_stock_data(stock_symbol)
116
+ X, y, scaler = prepare_data(df)
117
+
118
+ # Make predictions
119
+ y_pred = model.predict(X)
120
+ y_pred = scaler.inverse_transform(y_pred)
121
+
122
+ # Future prediction
123
+ future_prices = predict_future(
124
+ model,
125
+ scaler.transform(df['Close'].values.reshape(-1, 1)),
126
+ scaler
127
+ )
128
+
129
+ # Create plots
130
+ main_plot = create_plot(
131
+ df,
132
+ pred_data=y_pred,
133
+ title=f"{stock_symbol} Price Prediction"
134
+ )
135
+
136
+ future_plot = create_plot(
137
+ df,
138
+ future_data=future_prices,
139
+ title=f"{stock_symbol} 30-Day Forecast"
140
+ )
141
+
142
+ # Technical indicators
143
+ df['SMA_50'] = df['Close'].rolling(50).mean()
144
+ df['SMA_200'] = df['Close'].rolling(200).mean()
145
+
146
+ tech_fig = go.Figure()
147
+ tech_fig.add_trace(go.Scatter(
148
+ x=df.index, y=df['Close'],
149
+ name='Price', line=dict(color='blue')))
150
+ tech_fig.add_trace(go.Scatter(
151
+ x=df.index, y=df['SMA_50'],
152
+ name='50-Day SMA', line=dict(color='orange')))
153
+ tech_fig.add_trace(go.Scatter(
154
+ x=df.index, y=df['SMA_200'],
155
+ name='200-Day SMA', line=dict(color='red')))
156
+ tech_fig.update_layout(
157
+ title=f"{stock_symbol} Technical Indicators",
158
+ template='plotly_dark'
159
+ )
160
+
161
+ return (
162
+ f"${df['Close'].iloc[-1]:.2f}",
163
+ df.index[-1].strftime('%Y-%m-%d'),
164
+ main_plot,
165
+ future_plot,
166
+ tech_fig
167
+ )
168
+
169
+ except Exception as e:
170
+ raise gr.Error(f"Prediction failed: {str(e)}")
171
+
172
+
173
+ # Gradio Interface
174
+ with gr.Blocks(title="Stock Prediction", theme=gr.themes.Default()) as demo:
175
+ gr.Markdown("# 📈 Real-Time Stock Predictor")
176
+ gr.Markdown("Predict stock prices using LSTM neural networks")
177
+
178
+ with gr.Row():
179
+ stock_input = gr.Textbox(
180
+ label="Stock Symbol (Examples: TSLA, AAPL, MSFT, AMZN, GOOG, AEP)",
181
+ value="TSLA",
182
+ placeholder="Enter stock symbol (e.g. AAPL, MSFT)"
183
+ )
184
+ submit_btn = gr.Button("Predict", variant="primary")
185
+
186
+ with gr.Row():
187
+ with gr.Column():
188
+ last_price = gr.Textbox(label="Last Price")
189
+ last_date = gr.Textbox(label="Last Date")
190
+
191
+ with gr.Tabs():
192
+ with gr.Tab("Price Prediction"):
193
+ main_plot = gr.Plot(label="Price Prediction")
194
+ with gr.Tab("30-Day Forecast"):
195
+ future_plot = gr.Plot(label="Future Prediction")
196
+ with gr.Tab("Technical Indicators"):
197
+ tech_plot = gr.Plot(label="Technical Analysis")
198
+
199
+ submit_btn.click(
200
+ fn=predict_stock,
201
+ inputs=stock_input,
202
+ outputs=[last_price, last_date, main_plot, future_plot, tech_plot]
203
+ )
204
+
205
+ # For Hugging Face Spaces
206
+ demo.launch(debug=False)
requirements.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.2.2
2
+ astunparse==1.6.3
3
+ beautifulsoup4==4.13.4
4
+ blinker==1.9.0
5
+ certifi==2025.4.26
6
+ cffi==1.17.1
7
+ charset-normalizer==3.4.2
8
+ click==8.1.8
9
+ colorama==0.4.6
10
+ contourpy==1.3.2
11
+ curl_cffi==0.10.0
12
+ cycler==0.12.1
13
+ Flask==3.1.0
14
+ flatbuffers==25.2.10
15
+ fonttools==4.57.0
16
+ frozendict==2.4.6
17
+ gast==0.6.0
18
+ google-pasta==0.2.0
19
+ grpcio==1.71.0
20
+ h5py==3.13.0
21
+ idna==3.10
22
+ itsdangerous==2.2.0
23
+ Jinja2==3.1.6
24
+ joblib==1.5.0
25
+ keras==3.9.2
26
+ kiwisolver==1.4.8
27
+ libclang==18.1.1
28
+ lxml==5.4.0
29
+ Markdown==3.8
30
+ markdown-it-py==3.0.0
31
+ MarkupSafe==3.0.2
32
+ matplotlib==3.10.1
33
+ mdurl==0.1.2
34
+ ml_dtypes==0.5.1
35
+ multitasking==0.0.11
36
+ namex==0.0.9
37
+ narwhals==1.38.0
38
+ numpy==2.1.3
39
+ opt_einsum==3.4.0
40
+ optree==0.15.0
41
+ packaging==25.0
42
+ pandas==2.2.3
43
+ pandas-datareader==0.10.0
44
+ peewee==3.18.1
45
+ pillow==11.2.1
46
+ platformdirs==4.3.7
47
+ plotly==6.0.1
48
+ protobuf==5.29.4
49
+ pycparser==2.22
50
+ Pygments==2.19.1
51
+ pyparsing==3.2.3
52
+ PyQt6==6.7.1
53
+ PyQt6-Qt6==6.7.3
54
+ PyQt6_sip==13.10.0
55
+ python-dateutil==2.9.0.post0
56
+ pytz==2025.2
57
+ pywin32==308
58
+ requests==2.32.3
59
+ rich==14.0.0
60
+ scikit-learn==1.6.1
61
+ scipy==1.15.2
62
+ setuptools==80.3.1
63
+ six==1.17.0
64
+ soupsieve==2.7
65
+ tensorboard==2.19.0
66
+ tensorboard-data-server==0.7.2
67
+ tensorflow==2.19.0
68
+ tensorflow-io-gcs-filesystem==0.31.0
69
+ termcolor==3.1.0
70
+ threadpoolctl==3.6.0
71
+ typing_extensions==4.13.2
72
+ tzdata==2025.2
73
+ urllib3==2.4.0
74
+ Werkzeug==3.1.3
75
+ wheel==0.45.1
76
+ yfinance==0.2.58
77
+ wrapt==1.17.2
78
+ beautifulsoup4==4.13.4
79
+ blinker==1.9.0
80
+ gradio>=3.0
stock_price_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f88398ddd0d648c749790aefbabfaa5c90b383a5b557011228e300274d797f6
3
+ size 682368