Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- app.py +35 -0
- gradio_demo.py +218 -0
- requirements.txt +7 -0
- training/__pycache__/model.cpython-311.pyc +0 -0
- training/dataset.py +57 -0
- training/main.py +57 -0
- training/model.py +60 -0
- training/train.py +248 -0
app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
CandleFusion Demo App for Hugging Face Spaces
|
3 |
+
Entry point for the Gradio demo
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
|
9 |
+
# Since HF Spaces runs from the demo directory, we need to add the parent directory
|
10 |
+
# to access the training modules
|
11 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
12 |
+
parent_dir = os.path.dirname(current_dir)
|
13 |
+
sys.path.append(parent_dir)
|
14 |
+
|
15 |
+
# Import and run the demo
|
16 |
+
try:
|
17 |
+
from gradio_demo import main
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
main()
|
21 |
+
except Exception as e:
|
22 |
+
print(f"Error launching demo: {e}")
|
23 |
+
# Fallback: create a simple error page
|
24 |
+
import gradio as gr
|
25 |
+
|
26 |
+
def error_interface():
|
27 |
+
return gr.Interface(
|
28 |
+
fn=lambda x: f"Demo temporarily unavailable. Error: {str(e)}",
|
29 |
+
inputs=gr.Textbox(label="Input"),
|
30 |
+
outputs=gr.Textbox(label="Output"),
|
31 |
+
title="CandleFusion Demo - Error"
|
32 |
+
)
|
33 |
+
|
34 |
+
error_demo = error_interface()
|
35 |
+
error_demo.launch(server_name="0.0.0.0")
|
gradio_demo.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
|
9 |
+
# Import spaces for GPU support on Hugging Face Spaces
|
10 |
+
try:
|
11 |
+
import spaces
|
12 |
+
HF_SPACES = True
|
13 |
+
except ImportError:
|
14 |
+
HF_SPACES = False
|
15 |
+
# Create a dummy decorator if not on Spaces
|
16 |
+
def spaces_gpu_decorator(func):
|
17 |
+
return func
|
18 |
+
spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})()
|
19 |
+
|
20 |
+
# Add parent directory to path to import our modules
|
21 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
22 |
+
|
23 |
+
from training.model import CrossAttentionModel
|
24 |
+
from transformers import BertTokenizer, ViTImageProcessor
|
25 |
+
|
26 |
+
class CandleFusionDemo:
|
27 |
+
def __init__(self, model_path=None):
|
28 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
|
30 |
+
# Load model from Hugging Face
|
31 |
+
self.model = CrossAttentionModel()
|
32 |
+
|
33 |
+
try:
|
34 |
+
# Download model from Hugging Face Hub
|
35 |
+
print("📥 Downloading model from Hugging Face...")
|
36 |
+
model_file = hf_hub_download(
|
37 |
+
repo_id="tuankg1028/candlefusion",
|
38 |
+
filename="pytorch_model.bin",
|
39 |
+
cache_dir="./model_cache"
|
40 |
+
)
|
41 |
+
|
42 |
+
# Load the downloaded model
|
43 |
+
self.model.load_state_dict(torch.load(model_file, map_location=self.device))
|
44 |
+
print(f"✅ Model loaded from Hugging Face: tuankg1028/candlefusion")
|
45 |
+
|
46 |
+
except Exception as e:
|
47 |
+
print(f"❌ Error loading model from Hugging Face: {str(e)}")
|
48 |
+
print("⚠️ Using untrained model instead.")
|
49 |
+
|
50 |
+
self.model.to(self.device)
|
51 |
+
self.model.eval()
|
52 |
+
|
53 |
+
# Initialize processors
|
54 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
55 |
+
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
56 |
+
|
57 |
+
# Class labels
|
58 |
+
self.class_labels = ["Bearish", "Bullish"]
|
59 |
+
|
60 |
+
def preprocess_inputs(self, image, text):
|
61 |
+
"""Preprocess image and text inputs for the model"""
|
62 |
+
# Process image
|
63 |
+
if image is None:
|
64 |
+
raise ValueError("Please upload a candlestick chart image")
|
65 |
+
|
66 |
+
image = Image.fromarray(image).convert("RGB")
|
67 |
+
image_inputs = self.processor(images=image, return_tensors="pt")
|
68 |
+
pixel_values = image_inputs["pixel_values"].to(self.device)
|
69 |
+
|
70 |
+
# Process text
|
71 |
+
if not text.strip():
|
72 |
+
text = "Market analysis" # Default text if empty
|
73 |
+
|
74 |
+
text_inputs = self.tokenizer(
|
75 |
+
text,
|
76 |
+
return_tensors="pt",
|
77 |
+
truncation=True,
|
78 |
+
padding="max_length",
|
79 |
+
max_length=64
|
80 |
+
)
|
81 |
+
input_ids = text_inputs["input_ids"].to(self.device)
|
82 |
+
attention_mask = text_inputs["attention_mask"].to(self.device)
|
83 |
+
|
84 |
+
return pixel_values, input_ids, attention_mask
|
85 |
+
|
86 |
+
@spaces.GPU
|
87 |
+
def predict(self, image, text):
|
88 |
+
"""Make prediction using the model"""
|
89 |
+
try:
|
90 |
+
# Preprocess inputs
|
91 |
+
pixel_values, input_ids, attention_mask = self.preprocess_inputs(image, text)
|
92 |
+
|
93 |
+
# Model prediction
|
94 |
+
with torch.no_grad():
|
95 |
+
outputs = self.model(
|
96 |
+
input_ids=input_ids,
|
97 |
+
attention_mask=attention_mask,
|
98 |
+
pixel_values=pixel_values
|
99 |
+
)
|
100 |
+
|
101 |
+
logits = outputs["logits"]
|
102 |
+
forecast = outputs["forecast"]
|
103 |
+
|
104 |
+
# Get classification results
|
105 |
+
probabilities = torch.softmax(logits, dim=1)
|
106 |
+
predicted_class = torch.argmax(logits, dim=1).item()
|
107 |
+
confidence = probabilities[0][predicted_class].item()
|
108 |
+
|
109 |
+
# Get price forecast
|
110 |
+
predicted_price = forecast.squeeze().item()
|
111 |
+
|
112 |
+
# Format results
|
113 |
+
classification_result = f"**Prediction:** {self.class_labels[predicted_class]}\n"
|
114 |
+
classification_result += f"**Confidence:** {confidence:.2%}\n\n"
|
115 |
+
classification_result += "**Class Probabilities:**\n"
|
116 |
+
for i, (label, prob) in enumerate(zip(self.class_labels, probabilities[0])):
|
117 |
+
classification_result += f"- {label}: {prob:.2%}\n"
|
118 |
+
|
119 |
+
forecast_result = f"**Predicted Next Close Price:** ${predicted_price:.2f}"
|
120 |
+
|
121 |
+
return classification_result, forecast_result
|
122 |
+
|
123 |
+
except Exception as e:
|
124 |
+
error_msg = f"Error during prediction: {str(e)}"
|
125 |
+
return error_msg, error_msg
|
126 |
+
|
127 |
+
def create_demo():
|
128 |
+
"""Create and launch the Gradio demo"""
|
129 |
+
demo_instance = CandleFusionDemo()
|
130 |
+
|
131 |
+
# Create Gradio interface
|
132 |
+
with gr.Blocks(title="CandleFusion - Candlestick Chart Analysis", theme=gr.themes.Soft()) as demo:
|
133 |
+
gr.Markdown("""
|
134 |
+
# 🕯️ CandleFusion Demo
|
135 |
+
|
136 |
+
Upload a candlestick chart image and provide market context to get:
|
137 |
+
- **Market Direction Prediction** (Bullish/Bearish)
|
138 |
+
- **Next Close Price Forecast**
|
139 |
+
|
140 |
+
This model combines visual analysis of candlestick charts with textual market context using BERT + ViT architecture.
|
141 |
+
""")
|
142 |
+
|
143 |
+
with gr.Row():
|
144 |
+
with gr.Column(scale=1):
|
145 |
+
gr.Markdown("### 📊 Input")
|
146 |
+
|
147 |
+
image_input = gr.Image(
|
148 |
+
label="Candlestick Chart",
|
149 |
+
type="numpy",
|
150 |
+
height=300
|
151 |
+
)
|
152 |
+
|
153 |
+
text_input = gr.Textbox(
|
154 |
+
label="Market Context",
|
155 |
+
placeholder="Enter market analysis, news, or context (e.g., 'Strong volume with positive earnings report')",
|
156 |
+
lines=3,
|
157 |
+
value="Technical analysis of price action"
|
158 |
+
)
|
159 |
+
|
160 |
+
predict_btn = gr.Button("🔮 Analyze Chart", variant="primary")
|
161 |
+
|
162 |
+
gr.Markdown("""
|
163 |
+
### 💡 Tips:
|
164 |
+
- Upload clear candlestick chart images
|
165 |
+
- Provide relevant market context
|
166 |
+
- Charts should show recent price action
|
167 |
+
""")
|
168 |
+
|
169 |
+
with gr.Column(scale=1):
|
170 |
+
gr.Markdown("### 📈 Results")
|
171 |
+
|
172 |
+
classification_output = gr.Markdown(
|
173 |
+
value="Upload an image and click 'Analyze Chart' to see prediction"
|
174 |
+
)
|
175 |
+
|
176 |
+
forecast_output = gr.Markdown(
|
177 |
+
value=""
|
178 |
+
)
|
179 |
+
|
180 |
+
# Example section
|
181 |
+
gr.Markdown("### 📚 Example")
|
182 |
+
gr.Examples(
|
183 |
+
examples=[
|
184 |
+
["example_chart.png", "Strong bullish momentum with high volume"],
|
185 |
+
["example_chart2.png", "Bearish reversal pattern forming"]
|
186 |
+
],
|
187 |
+
inputs=[image_input, text_input],
|
188 |
+
label="Try these examples:"
|
189 |
+
)
|
190 |
+
|
191 |
+
# Connect the prediction function
|
192 |
+
predict_btn.click(
|
193 |
+
fn=demo_instance.predict,
|
194 |
+
inputs=[image_input, text_input],
|
195 |
+
outputs=[classification_output, forecast_output]
|
196 |
+
)
|
197 |
+
|
198 |
+
gr.Markdown("""
|
199 |
+
---
|
200 |
+
**Note:** This is a demo model. For production trading decisions, always consult with financial professionals and use additional analysis tools.
|
201 |
+
""")
|
202 |
+
|
203 |
+
return demo
|
204 |
+
|
205 |
+
def main():
|
206 |
+
"""Main function to launch the demo"""
|
207 |
+
try:
|
208 |
+
demo = create_demo()
|
209 |
+
# Launch with server_name for compatibility on HF Spaces
|
210 |
+
demo.launch(server_name="0.0.0.0")
|
211 |
+
except Exception as e:
|
212 |
+
print(f"Failed to launch Gradio demo: {e}")
|
213 |
+
# Fallback launch with minimal configuration
|
214 |
+
demo = create_demo()
|
215 |
+
demo.launch(server_name="0.0.0.0")
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=5.32.0
|
2 |
+
torch>=1.9.0
|
3 |
+
transformers>=4.20.0
|
4 |
+
Pillow>=8.0.0
|
5 |
+
numpy>=1.21.0
|
6 |
+
pandas>=1.3.0
|
7 |
+
huggingface_hub>=0.16.0
|
training/__pycache__/model.cpython-311.pyc
ADDED
Binary file (2.99 kB). View file
|
|
training/dataset.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dataset.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from PIL import Image
|
8 |
+
import pandas as pd
|
9 |
+
from transformers import BertTokenizer, ViTImageProcessor
|
10 |
+
|
11 |
+
class CandlestickDataset(Dataset):
|
12 |
+
def __init__(self, csv_path: str, image_size: int = 224):
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
csv_path (str): Path to CSV with image_path, text, label
|
16 |
+
image_size (int): Size to resize chart images to (default 224)
|
17 |
+
"""
|
18 |
+
self.data = pd.read_csv(csv_path)
|
19 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
20 |
+
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
21 |
+
self.image_size = image_size
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
return len(self.data)
|
25 |
+
|
26 |
+
def __getitem__(self, idx):
|
27 |
+
row = self.data.iloc[idx]
|
28 |
+
|
29 |
+
# === Load and preprocess image ===
|
30 |
+
image_path = row["image_path"]
|
31 |
+
image = Image.open(image_path).convert("RGB")
|
32 |
+
image_inputs = self.processor(images=image, return_tensors="pt")
|
33 |
+
pixel_values = image_inputs["pixel_values"].squeeze(0) # (3, 224, 224)
|
34 |
+
|
35 |
+
# === Tokenize text ===
|
36 |
+
text = row["text"]
|
37 |
+
text_inputs = self.tokenizer(
|
38 |
+
text,
|
39 |
+
return_tensors="pt",
|
40 |
+
truncation=True,
|
41 |
+
padding="max_length",
|
42 |
+
max_length=64 # can be adjusted
|
43 |
+
)
|
44 |
+
input_ids = text_inputs["input_ids"].squeeze(0)
|
45 |
+
attention_mask = text_inputs["attention_mask"].squeeze(0)
|
46 |
+
|
47 |
+
# === Label ===
|
48 |
+
label = torch.tensor(row["label"], dtype=torch.long)
|
49 |
+
next_close = torch.tensor(row["next_close"], dtype=torch.float)
|
50 |
+
|
51 |
+
return {
|
52 |
+
"pixel_values": pixel_values,
|
53 |
+
"input_ids": input_ids,
|
54 |
+
"attention_mask": attention_mask,
|
55 |
+
"label": label,
|
56 |
+
"next_close": next_close
|
57 |
+
}
|
training/main.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
|
5 |
+
from dataset import CandlestickDataset
|
6 |
+
from model import CrossAttentionModel
|
7 |
+
from train import train
|
8 |
+
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
|
11 |
+
def main():
|
12 |
+
parser = argparse.ArgumentParser(description="Train candlestick classifier using BERT + ViT")
|
13 |
+
parser.add_argument("--data_dir", type=str, default="../data", help="Directory containing dataset")
|
14 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
15 |
+
parser.add_argument("--epochs", type=int, default=3)
|
16 |
+
parser.add_argument("--lr", type=float, default=2e-5)
|
17 |
+
parser.add_argument("--device", type=str, default="cuda")
|
18 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Push model to Hugging Face Hub")
|
19 |
+
parser.add_argument("--hub_model_id", type=str, help="Hugging Face model ID (e.g., 'username/candlefusion')")
|
20 |
+
parser.add_argument("--hub_token", type=str, help="Hugging Face token (or set HF_TOKEN env var)")
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
# === Paths
|
24 |
+
index_csv = os.path.join(args.data_dir, "dataset_index.csv")
|
25 |
+
|
26 |
+
if not os.path.exists(index_csv):
|
27 |
+
print(f"❌ Dataset index not found at {index_csv}")
|
28 |
+
print("Please run the build_dataset script first.")
|
29 |
+
return
|
30 |
+
|
31 |
+
# === Create checkpoints directory
|
32 |
+
os.makedirs("./checkpoints", exist_ok=True)
|
33 |
+
|
34 |
+
# === Dataset & Loader
|
35 |
+
dataset = CandlestickDataset(csv_path=index_csv)
|
36 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
|
37 |
+
|
38 |
+
# === Model
|
39 |
+
model = CrossAttentionModel()
|
40 |
+
|
41 |
+
# === Get HF token from env if not provided
|
42 |
+
hub_token = args.hub_token or os.getenv("HF_TOKEN")
|
43 |
+
|
44 |
+
# === Train
|
45 |
+
train(
|
46 |
+
model,
|
47 |
+
dataloader,
|
48 |
+
epochs=args.epochs,
|
49 |
+
lr=args.lr,
|
50 |
+
device=args.device,
|
51 |
+
push_to_hub=args.push_to_hub,
|
52 |
+
hub_model_id=args.hub_model_id,
|
53 |
+
hub_token=hub_token
|
54 |
+
)
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
main()
|
training/model.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# model.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import BertModel, ViTModel
|
6 |
+
|
7 |
+
class CrossAttentionModel(nn.Module):
|
8 |
+
def __init__(self,
|
9 |
+
text_model_name="bert-base-uncased",
|
10 |
+
image_model_name="google/vit-base-patch16-224",
|
11 |
+
hidden_dim=768,
|
12 |
+
num_classes=2):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
# Encoders
|
16 |
+
self.bert = BertModel.from_pretrained(text_model_name)
|
17 |
+
self.vit = ViTModel.from_pretrained(image_model_name)
|
18 |
+
|
19 |
+
# Cross-Attention layer
|
20 |
+
self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=8, batch_first=True)
|
21 |
+
|
22 |
+
# Classification Head
|
23 |
+
self.classifier = nn.Sequential(
|
24 |
+
nn.Linear(hidden_dim, hidden_dim),
|
25 |
+
nn.ReLU(),
|
26 |
+
nn.Dropout(0.3),
|
27 |
+
nn.Linear(hidden_dim, num_classes)
|
28 |
+
)
|
29 |
+
|
30 |
+
# Forecasting Head (regression)
|
31 |
+
self.regressor = nn.Sequential(
|
32 |
+
nn.Linear(hidden_dim, hidden_dim),
|
33 |
+
nn.ReLU(),
|
34 |
+
nn.Dropout(0.3),
|
35 |
+
nn.Linear(hidden_dim, 1) # Predict next closing price
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, input_ids, attention_mask, pixel_values):
|
39 |
+
# === Text Encoding ===
|
40 |
+
text_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
41 |
+
text_cls = text_outputs.last_hidden_state[:, 0:1, :] # (B, 1, H)
|
42 |
+
|
43 |
+
# === Image Encoding ===
|
44 |
+
image_outputs = self.vit(pixel_values=pixel_values)
|
45 |
+
image_tokens = image_outputs.last_hidden_state[:, 1:, :] # skip CLS token
|
46 |
+
|
47 |
+
# === Cross-Attention ===
|
48 |
+
fused_cls, _ = self.cross_attention(
|
49 |
+
query=text_cls,
|
50 |
+
key=image_tokens,
|
51 |
+
value=image_tokens
|
52 |
+
) # (B, 1, H)
|
53 |
+
|
54 |
+
fused_cls = fused_cls.squeeze(1) # (B, H)
|
55 |
+
|
56 |
+
# === Dual Heads ===
|
57 |
+
logits = self.classifier(fused_cls) # Classification
|
58 |
+
forecast = self.regressor(fused_cls) # Regression (next price)
|
59 |
+
|
60 |
+
return {"logits": logits, "forecast": forecast}
|
training/train.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from transformers import get_scheduler
|
7 |
+
from tqdm import tqdm
|
8 |
+
import os
|
9 |
+
|
10 |
+
from dataset import CandlestickDataset
|
11 |
+
from model import CrossAttentionModel
|
12 |
+
|
13 |
+
def train(model, dataloader, val_loader=None, epochs=5, lr=2e-5, alpha=0.5, device="cuda",
|
14 |
+
push_to_hub=False, hub_model_id=None, hub_token=None):
|
15 |
+
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
16 |
+
model.to(device)
|
17 |
+
|
18 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
19 |
+
total_steps = len(dataloader) * epochs
|
20 |
+
scheduler = get_scheduler("linear", optimizer, num_warmup_steps=0, num_training_steps=total_steps)
|
21 |
+
|
22 |
+
loss_fn_cls = nn.CrossEntropyLoss()
|
23 |
+
loss_fn_reg = nn.MSELoss()
|
24 |
+
|
25 |
+
for epoch in range(epochs):
|
26 |
+
model.train()
|
27 |
+
total_loss = 0
|
28 |
+
total_cls_loss = 0
|
29 |
+
total_reg_loss = 0
|
30 |
+
|
31 |
+
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
|
32 |
+
|
33 |
+
for batch in progress_bar:
|
34 |
+
input_ids = batch["input_ids"].to(device)
|
35 |
+
attention_mask = batch["attention_mask"].to(device)
|
36 |
+
pixel_values = batch["pixel_values"].to(device)
|
37 |
+
labels = batch["label"].to(device)
|
38 |
+
target_price = batch["next_close"].to(device) # shape: (B,)
|
39 |
+
|
40 |
+
optimizer.zero_grad()
|
41 |
+
|
42 |
+
outputs = model(
|
43 |
+
input_ids=input_ids,
|
44 |
+
attention_mask=attention_mask,
|
45 |
+
pixel_values=pixel_values
|
46 |
+
)
|
47 |
+
|
48 |
+
logits = outputs["logits"]
|
49 |
+
forecast = outputs["forecast"].squeeze(1) # shape: (B,)
|
50 |
+
|
51 |
+
loss_cls = loss_fn_cls(logits, labels)
|
52 |
+
loss_reg = loss_fn_reg(forecast, target_price)
|
53 |
+
loss = loss_cls + alpha * loss_reg
|
54 |
+
|
55 |
+
loss.backward()
|
56 |
+
optimizer.step()
|
57 |
+
scheduler.step()
|
58 |
+
|
59 |
+
total_loss += loss.item()
|
60 |
+
total_cls_loss += loss_cls.item()
|
61 |
+
total_reg_loss += loss_reg.item()
|
62 |
+
|
63 |
+
progress_bar.set_postfix(loss=loss.item(), cls=loss_cls.item(), reg=loss_reg.item())
|
64 |
+
|
65 |
+
avg_loss = total_loss / len(dataloader)
|
66 |
+
print(f"✅ Epoch {epoch+1} done | Total Loss: {avg_loss:.4f} | CLS: {total_cls_loss/len(dataloader):.4f} | REG: {total_reg_loss/len(dataloader):.4f}")
|
67 |
+
|
68 |
+
if val_loader:
|
69 |
+
evaluate(model, val_loader, device)
|
70 |
+
|
71 |
+
torch.save(model.state_dict(), "./checkpoints/candlefusion_model.pt")
|
72 |
+
print("✅ Model saved to ./checkpoints/candlefusion_model.pt")
|
73 |
+
|
74 |
+
# Push to Hugging Face Hub if requested
|
75 |
+
if push_to_hub and hub_model_id:
|
76 |
+
try:
|
77 |
+
from huggingface_hub import HfApi, Repository
|
78 |
+
import json
|
79 |
+
|
80 |
+
# Login to HF Hub
|
81 |
+
if hub_token:
|
82 |
+
from huggingface_hub import login
|
83 |
+
login(token=hub_token)
|
84 |
+
|
85 |
+
# Create model card and config
|
86 |
+
model_card_content = f"""
|
87 |
+
---
|
88 |
+
license: apache-2.0
|
89 |
+
tags:
|
90 |
+
- pytorch
|
91 |
+
- candlestick
|
92 |
+
- financial-analysis
|
93 |
+
- multimodal
|
94 |
+
- bert
|
95 |
+
- vit
|
96 |
+
- cross-attention
|
97 |
+
- trading
|
98 |
+
- forecasting
|
99 |
+
---
|
100 |
+
|
101 |
+
# CandleFusion Model
|
102 |
+
|
103 |
+
A multimodal financial analysis model that combines textual market sentiment with visual candlestick patterns for enhanced trading signal prediction and price forecasting.
|
104 |
+
|
105 |
+
## Architecture Overview
|
106 |
+
|
107 |
+
### Core Components
|
108 |
+
- **Text Encoder**: BERT (bert-base-uncased) for processing market sentiment and news
|
109 |
+
- **Vision Encoder**: Vision Transformer (ViT-base-patch16-224) for candlestick pattern recognition
|
110 |
+
- **Cross-Attention Fusion**: Multi-head attention mechanism (8 heads, 768 dim) for text-image integration
|
111 |
+
- **Dual Task Heads**:
|
112 |
+
- Classification head for trading signals (buy/sell/hold)
|
113 |
+
- Regression head for next closing price prediction
|
114 |
+
|
115 |
+
### Data Flow
|
116 |
+
1. **Text Processing**: Market sentiment -> BERT -> CLS token (768-dim)
|
117 |
+
2. **Image Processing**: Candlestick charts -> ViT -> Patch embeddings (197 tokens, 768-dim each)
|
118 |
+
3. **Cross-Modal Fusion**: Text CLS as query, Image patches as keys/values -> Fused representation
|
119 |
+
4. **Dual Predictions**:
|
120 |
+
- Fused features -> Classification head -> Trading signal logits
|
121 |
+
- Fused features -> Regression head -> Price forecast
|
122 |
+
|
123 |
+
### Model Specifications
|
124 |
+
- **Input Text**: Tokenized to max 64 tokens
|
125 |
+
- **Input Images**: Resized to 224x224 RGB
|
126 |
+
- **Hidden Dimension**: 768 (consistent across encoders)
|
127 |
+
- **Output Classes**: {2} (binary: bullish/bearish)
|
128 |
+
- **Dropout**: 0.3 in both heads
|
129 |
+
|
130 |
+
## Training Details
|
131 |
+
- **Epochs**: {epochs}
|
132 |
+
- **Learning Rate**: {lr}
|
133 |
+
- **Loss Function**: CrossEntropy (classification) + MSE (regression)
|
134 |
+
- **Loss Weight (alpha)**: {alpha} for regression term
|
135 |
+
- **Optimizer**: AdamW with linear scheduling
|
136 |
+
|
137 |
+
## Usage
|
138 |
+
```python
|
139 |
+
from model import CrossAttentionModel
|
140 |
+
import torch
|
141 |
+
|
142 |
+
# Load model
|
143 |
+
model = CrossAttentionModel()
|
144 |
+
model.load_state_dict(torch.load("pytorch_model.bin"))
|
145 |
+
model.eval()
|
146 |
+
|
147 |
+
# Inference
|
148 |
+
outputs = model(input_ids, attention_mask, pixel_values)
|
149 |
+
trading_signals = outputs["logits"]
|
150 |
+
price_forecast = outputs["forecast"]
|
151 |
+
```
|
152 |
+
|
153 |
+
## Performance
|
154 |
+
The model simultaneously optimizes for:
|
155 |
+
- **Classification Task**: Trading signal accuracy
|
156 |
+
- **Regression Task**: Price prediction MSE
|
157 |
+
|
158 |
+
This dual-task approach enables the model to learn both categorical market direction and continuous price movements.
|
159 |
+
"""
|
160 |
+
|
161 |
+
config = {
|
162 |
+
"model_type": "candlefusion",
|
163 |
+
"architecture": "bert+vit+cross_attention",
|
164 |
+
"num_labels": 3,
|
165 |
+
"epochs": epochs,
|
166 |
+
"learning_rate": lr,
|
167 |
+
"alpha": alpha
|
168 |
+
}
|
169 |
+
|
170 |
+
# Create repository
|
171 |
+
api = HfApi()
|
172 |
+
api.create_repo(repo_id=hub_model_id, exist_ok=True)
|
173 |
+
|
174 |
+
# Upload files
|
175 |
+
api.upload_file(
|
176 |
+
path_or_fileobj="./checkpoints/candlefusion_model.pt",
|
177 |
+
path_in_repo="pytorch_model.bin",
|
178 |
+
repo_id=hub_model_id,
|
179 |
+
)
|
180 |
+
|
181 |
+
# Upload model card
|
182 |
+
with open("./checkpoints/README.md", "w", encoding="utf-8") as f:
|
183 |
+
f.write(model_card_content)
|
184 |
+
api.upload_file(
|
185 |
+
path_or_fileobj="./checkpoints/README.md",
|
186 |
+
path_in_repo="README.md",
|
187 |
+
repo_id=hub_model_id,
|
188 |
+
)
|
189 |
+
|
190 |
+
# Upload config
|
191 |
+
with open("./checkpoints/config.json", "w") as f:
|
192 |
+
json.dump(config, f, indent=2)
|
193 |
+
api.upload_file(
|
194 |
+
path_or_fileobj="./checkpoints/config.json",
|
195 |
+
path_in_repo="config.json",
|
196 |
+
repo_id=hub_model_id,
|
197 |
+
)
|
198 |
+
|
199 |
+
print(f"✅ Model pushed to Hugging Face Hub: https://huggingface.co/{hub_model_id}")
|
200 |
+
|
201 |
+
except ImportError:
|
202 |
+
print("❌ huggingface_hub not installed. Install with: pip install huggingface_hub")
|
203 |
+
except Exception as e:
|
204 |
+
print(f"❌ Error pushing to Hub: {e}")
|
205 |
+
|
206 |
+
def evaluate(model, dataloader, device="cuda"):
|
207 |
+
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
208 |
+
model.eval()
|
209 |
+
|
210 |
+
correct = 0
|
211 |
+
total = 0
|
212 |
+
all_preds = []
|
213 |
+
all_labels = []
|
214 |
+
all_forecasts = []
|
215 |
+
all_targets = []
|
216 |
+
|
217 |
+
with torch.no_grad():
|
218 |
+
for batch in dataloader:
|
219 |
+
input_ids = batch["input_ids"].to(device)
|
220 |
+
attention_mask = batch["attention_mask"].to(device)
|
221 |
+
pixel_values = batch["pixel_values"].to(device)
|
222 |
+
labels = batch["label"].to(device)
|
223 |
+
target_price = batch["next_close"].to(device)
|
224 |
+
|
225 |
+
outputs = model(
|
226 |
+
input_ids=input_ids,
|
227 |
+
attention_mask=attention_mask,
|
228 |
+
pixel_values=pixel_values
|
229 |
+
)
|
230 |
+
|
231 |
+
logits = outputs["logits"]
|
232 |
+
forecast = outputs["forecast"].squeeze(1)
|
233 |
+
|
234 |
+
preds = torch.argmax(logits, dim=1)
|
235 |
+
correct += (preds == labels).sum().item()
|
236 |
+
total += labels.size(0)
|
237 |
+
|
238 |
+
all_preds.extend(preds.tolist())
|
239 |
+
all_labels.extend(labels.tolist())
|
240 |
+
all_forecasts.extend(forecast.tolist())
|
241 |
+
all_targets.extend(target_price.tolist())
|
242 |
+
|
243 |
+
acc = correct / total
|
244 |
+
print(f"📊 Evaluation Accuracy: {acc*100:.2f}%")
|
245 |
+
|
246 |
+
# Optional: print forecasting MSE
|
247 |
+
forecast_mse = nn.MSELoss()(torch.tensor(all_forecasts), torch.tensor(all_targets)).item()
|
248 |
+
print(f"📈 Forecast MSE: {forecast_mse:.4f}")
|