Upload 3 files
Browse files- app.py +206 -0
- requirements.txt +80 -0
- stock_price_model.h5 +3 -0
app.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import yfinance as yf
|
4 |
+
import datetime as dt
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
from sklearn.preprocessing import MinMaxScaler
|
7 |
+
from tensorflow.keras.models import load_model
|
8 |
+
import gradio as gr
|
9 |
+
import warnings
|
10 |
+
import os
|
11 |
+
|
12 |
+
# Suppress warnings
|
13 |
+
warnings.filterwarnings("ignore")
|
14 |
+
|
15 |
+
# Constants
|
16 |
+
PREDICTION_DAYS = 30
|
17 |
+
TIME_STEP = 60
|
18 |
+
DATA_YEARS = 3
|
19 |
+
|
20 |
+
# Load model
|
21 |
+
model = load_model('stock_price_model.h5')
|
22 |
+
model.make_predict_function() # For faster inference
|
23 |
+
|
24 |
+
|
25 |
+
def preprocess_data(df):
|
26 |
+
"""Process yfinance data"""
|
27 |
+
df.columns = [col[0] if isinstance(col, tuple) else col for col in df.columns]
|
28 |
+
df = df.reset_index().rename(columns={'index': 'Date'})
|
29 |
+
df = df[['Date', 'High', 'Low', 'Open', 'Close', 'Volume']]
|
30 |
+
df['Date'] = pd.to_datetime(df['Date'])
|
31 |
+
df.set_index('Date', inplace=True)
|
32 |
+
return df
|
33 |
+
|
34 |
+
|
35 |
+
def get_stock_data(stock_symbol):
|
36 |
+
"""Fetch stock data with caching"""
|
37 |
+
end_date = dt.datetime.now()
|
38 |
+
start_date = end_date - dt.timedelta(days=365 * DATA_YEARS)
|
39 |
+
df = yf.download(stock_symbol, start=start_date, end=end_date)
|
40 |
+
return preprocess_data(df)
|
41 |
+
|
42 |
+
|
43 |
+
def prepare_data(df):
|
44 |
+
"""Prepare data for LSTM prediction"""
|
45 |
+
scaler = MinMaxScaler()
|
46 |
+
scaled_data = scaler.fit_transform(df['Close'].values.reshape(-1, 1))
|
47 |
+
|
48 |
+
# Create dataset using sliding window
|
49 |
+
X = np.array([scaled_data[i:i + TIME_STEP, 0]
|
50 |
+
for i in range(len(scaled_data) - TIME_STEP - 1)])
|
51 |
+
y = scaled_data[TIME_STEP + 1:, 0]
|
52 |
+
|
53 |
+
return X.reshape(X.shape[0], TIME_STEP, 1), y, scaler
|
54 |
+
|
55 |
+
|
56 |
+
def predict_future(model, data, scaler):
|
57 |
+
"""Generate future predictions"""
|
58 |
+
last_data = data[-TIME_STEP:].reshape(1, TIME_STEP, 1)
|
59 |
+
future_preds = np.zeros(PREDICTION_DAYS, dtype='float32')
|
60 |
+
|
61 |
+
for i in range(PREDICTION_DAYS):
|
62 |
+
next_pred = model.predict(last_data, verbose=0)[0, 0]
|
63 |
+
future_preds[i] = next_pred
|
64 |
+
last_data = np.roll(last_data, -1, axis=1)
|
65 |
+
last_data[0, -1, 0] = next_pred
|
66 |
+
|
67 |
+
return scaler.inverse_transform(future_preds.reshape(-1, 1))
|
68 |
+
|
69 |
+
|
70 |
+
def create_plot(df, pred_data=None, future_data=None, title=""):
|
71 |
+
"""Create interactive Plotly figure"""
|
72 |
+
fig = go.Figure()
|
73 |
+
|
74 |
+
# Main price line
|
75 |
+
fig.add_trace(go.Scatter(
|
76 |
+
x=df.index,
|
77 |
+
y=df['Close'],
|
78 |
+
name='Actual Price',
|
79 |
+
line=dict(color='blue')
|
80 |
+
))
|
81 |
+
|
82 |
+
# Prediction line
|
83 |
+
if pred_data is not None:
|
84 |
+
fig.add_trace(go.Scatter(
|
85 |
+
x=df.index[TIME_STEP + 1:],
|
86 |
+
y=pred_data[:, 0],
|
87 |
+
name='Predicted',
|
88 |
+
line=dict(color='orange')
|
89 |
+
))
|
90 |
+
|
91 |
+
# Future prediction
|
92 |
+
if future_data is not None:
|
93 |
+
future_dates = pd.date_range(
|
94 |
+
start=df.index[-1],
|
95 |
+
periods=PREDICTION_DAYS + 1
|
96 |
+
)[1:]
|
97 |
+
fig.add_trace(go.Scatter(
|
98 |
+
x=future_dates,
|
99 |
+
y=future_data[:, 0],
|
100 |
+
name='30-Day Forecast',
|
101 |
+
line=dict(color='green')
|
102 |
+
))
|
103 |
+
|
104 |
+
fig.update_layout(
|
105 |
+
title=title,
|
106 |
+
template='plotly_dark',
|
107 |
+
margin=dict(l=20, r=20, t=40, b=20)
|
108 |
+
)
|
109 |
+
return fig
|
110 |
+
|
111 |
+
|
112 |
+
def predict_stock(stock_symbol):
|
113 |
+
"""Main prediction function for Gradio"""
|
114 |
+
try:
|
115 |
+
df = get_stock_data(stock_symbol)
|
116 |
+
X, y, scaler = prepare_data(df)
|
117 |
+
|
118 |
+
# Make predictions
|
119 |
+
y_pred = model.predict(X)
|
120 |
+
y_pred = scaler.inverse_transform(y_pred)
|
121 |
+
|
122 |
+
# Future prediction
|
123 |
+
future_prices = predict_future(
|
124 |
+
model,
|
125 |
+
scaler.transform(df['Close'].values.reshape(-1, 1)),
|
126 |
+
scaler
|
127 |
+
)
|
128 |
+
|
129 |
+
# Create plots
|
130 |
+
main_plot = create_plot(
|
131 |
+
df,
|
132 |
+
pred_data=y_pred,
|
133 |
+
title=f"{stock_symbol} Price Prediction"
|
134 |
+
)
|
135 |
+
|
136 |
+
future_plot = create_plot(
|
137 |
+
df,
|
138 |
+
future_data=future_prices,
|
139 |
+
title=f"{stock_symbol} 30-Day Forecast"
|
140 |
+
)
|
141 |
+
|
142 |
+
# Technical indicators
|
143 |
+
df['SMA_50'] = df['Close'].rolling(50).mean()
|
144 |
+
df['SMA_200'] = df['Close'].rolling(200).mean()
|
145 |
+
|
146 |
+
tech_fig = go.Figure()
|
147 |
+
tech_fig.add_trace(go.Scatter(
|
148 |
+
x=df.index, y=df['Close'],
|
149 |
+
name='Price', line=dict(color='blue')))
|
150 |
+
tech_fig.add_trace(go.Scatter(
|
151 |
+
x=df.index, y=df['SMA_50'],
|
152 |
+
name='50-Day SMA', line=dict(color='orange')))
|
153 |
+
tech_fig.add_trace(go.Scatter(
|
154 |
+
x=df.index, y=df['SMA_200'],
|
155 |
+
name='200-Day SMA', line=dict(color='red')))
|
156 |
+
tech_fig.update_layout(
|
157 |
+
title=f"{stock_symbol} Technical Indicators",
|
158 |
+
template='plotly_dark'
|
159 |
+
)
|
160 |
+
|
161 |
+
return (
|
162 |
+
f"${df['Close'].iloc[-1]:.2f}",
|
163 |
+
df.index[-1].strftime('%Y-%m-%d'),
|
164 |
+
main_plot,
|
165 |
+
future_plot,
|
166 |
+
tech_fig
|
167 |
+
)
|
168 |
+
|
169 |
+
except Exception as e:
|
170 |
+
raise gr.Error(f"Prediction failed: {str(e)}")
|
171 |
+
|
172 |
+
|
173 |
+
# Gradio Interface
|
174 |
+
with gr.Blocks(title="Stock Prediction", theme=gr.themes.Default()) as demo:
|
175 |
+
gr.Markdown("# 📈 Real-Time Stock Predictor")
|
176 |
+
gr.Markdown("Predict stock prices using LSTM neural networks")
|
177 |
+
|
178 |
+
with gr.Row():
|
179 |
+
stock_input = gr.Textbox(
|
180 |
+
label="Stock Symbol (Examples: TSLA, AAPL, MSFT, AMZN, GOOG, AEP)",
|
181 |
+
value="TSLA",
|
182 |
+
placeholder="Enter stock symbol (e.g. AAPL, MSFT)"
|
183 |
+
)
|
184 |
+
submit_btn = gr.Button("Predict", variant="primary")
|
185 |
+
|
186 |
+
with gr.Row():
|
187 |
+
with gr.Column():
|
188 |
+
last_price = gr.Textbox(label="Last Price")
|
189 |
+
last_date = gr.Textbox(label="Last Date")
|
190 |
+
|
191 |
+
with gr.Tabs():
|
192 |
+
with gr.Tab("Price Prediction"):
|
193 |
+
main_plot = gr.Plot(label="Price Prediction")
|
194 |
+
with gr.Tab("30-Day Forecast"):
|
195 |
+
future_plot = gr.Plot(label="Future Prediction")
|
196 |
+
with gr.Tab("Technical Indicators"):
|
197 |
+
tech_plot = gr.Plot(label="Technical Analysis")
|
198 |
+
|
199 |
+
submit_btn.click(
|
200 |
+
fn=predict_stock,
|
201 |
+
inputs=stock_input,
|
202 |
+
outputs=[last_price, last_date, main_plot, future_plot, tech_plot]
|
203 |
+
)
|
204 |
+
|
205 |
+
# For Hugging Face Spaces
|
206 |
+
demo.launch(debug=False)
|
requirements.txt
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.2.2
|
2 |
+
astunparse==1.6.3
|
3 |
+
beautifulsoup4==4.13.4
|
4 |
+
blinker==1.9.0
|
5 |
+
certifi==2025.4.26
|
6 |
+
cffi==1.17.1
|
7 |
+
charset-normalizer==3.4.2
|
8 |
+
click==8.1.8
|
9 |
+
colorama==0.4.6
|
10 |
+
contourpy==1.3.2
|
11 |
+
curl_cffi==0.10.0
|
12 |
+
cycler==0.12.1
|
13 |
+
Flask==3.1.0
|
14 |
+
flatbuffers==25.2.10
|
15 |
+
fonttools==4.57.0
|
16 |
+
frozendict==2.4.6
|
17 |
+
gast==0.6.0
|
18 |
+
google-pasta==0.2.0
|
19 |
+
grpcio==1.71.0
|
20 |
+
h5py==3.13.0
|
21 |
+
idna==3.10
|
22 |
+
itsdangerous==2.2.0
|
23 |
+
Jinja2==3.1.6
|
24 |
+
joblib==1.5.0
|
25 |
+
keras==3.9.2
|
26 |
+
kiwisolver==1.4.8
|
27 |
+
libclang==18.1.1
|
28 |
+
lxml==5.4.0
|
29 |
+
Markdown==3.8
|
30 |
+
markdown-it-py==3.0.0
|
31 |
+
MarkupSafe==3.0.2
|
32 |
+
matplotlib==3.10.1
|
33 |
+
mdurl==0.1.2
|
34 |
+
ml_dtypes==0.5.1
|
35 |
+
multitasking==0.0.11
|
36 |
+
namex==0.0.9
|
37 |
+
narwhals==1.38.0
|
38 |
+
numpy==2.1.3
|
39 |
+
opt_einsum==3.4.0
|
40 |
+
optree==0.15.0
|
41 |
+
packaging==25.0
|
42 |
+
pandas==2.2.3
|
43 |
+
pandas-datareader==0.10.0
|
44 |
+
peewee==3.18.1
|
45 |
+
pillow==11.2.1
|
46 |
+
platformdirs==4.3.7
|
47 |
+
plotly==6.0.1
|
48 |
+
protobuf==5.29.4
|
49 |
+
pycparser==2.22
|
50 |
+
Pygments==2.19.1
|
51 |
+
pyparsing==3.2.3
|
52 |
+
PyQt6==6.7.1
|
53 |
+
PyQt6-Qt6==6.7.3
|
54 |
+
PyQt6_sip==13.10.0
|
55 |
+
python-dateutil==2.9.0.post0
|
56 |
+
pytz==2025.2
|
57 |
+
pywin32==308
|
58 |
+
requests==2.32.3
|
59 |
+
rich==14.0.0
|
60 |
+
scikit-learn==1.6.1
|
61 |
+
scipy==1.15.2
|
62 |
+
setuptools==80.3.1
|
63 |
+
six==1.17.0
|
64 |
+
soupsieve==2.7
|
65 |
+
tensorboard==2.19.0
|
66 |
+
tensorboard-data-server==0.7.2
|
67 |
+
tensorflow==2.19.0
|
68 |
+
tensorflow-io-gcs-filesystem==0.31.0
|
69 |
+
termcolor==3.1.0
|
70 |
+
threadpoolctl==3.6.0
|
71 |
+
typing_extensions==4.13.2
|
72 |
+
tzdata==2025.2
|
73 |
+
urllib3==2.4.0
|
74 |
+
Werkzeug==3.1.3
|
75 |
+
wheel==0.45.1
|
76 |
+
yfinance==0.2.58
|
77 |
+
wrapt==1.17.2
|
78 |
+
beautifulsoup4==4.13.4
|
79 |
+
blinker==1.9.0
|
80 |
+
gradio>=3.0
|
stock_price_model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7f88398ddd0d648c749790aefbabfaa5c90b383a5b557011228e300274d797f6
|
3 |
+
size 682368
|