tuankg1028 commited on
Commit
8f1d3f9
·
verified ·
1 Parent(s): 021dc20

Upload folder using huggingface_hub

Browse files
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CandleFusion Demo App for Hugging Face Spaces
3
+ Entry point for the Gradio demo
4
+ """
5
+
6
+ import os
7
+ import sys
8
+
9
+ # Since HF Spaces runs from the demo directory, we need to add the parent directory
10
+ # to access the training modules
11
+ current_dir = os.path.dirname(os.path.abspath(__file__))
12
+ parent_dir = os.path.dirname(current_dir)
13
+ sys.path.append(parent_dir)
14
+
15
+ # Import and run the demo
16
+ try:
17
+ from gradio_demo import main
18
+
19
+ if __name__ == "__main__":
20
+ main()
21
+ except Exception as e:
22
+ print(f"Error launching demo: {e}")
23
+ # Fallback: create a simple error page
24
+ import gradio as gr
25
+
26
+ def error_interface():
27
+ return gr.Interface(
28
+ fn=lambda x: f"Demo temporarily unavailable. Error: {str(e)}",
29
+ inputs=gr.Textbox(label="Input"),
30
+ outputs=gr.Textbox(label="Output"),
31
+ title="CandleFusion Demo - Error"
32
+ )
33
+
34
+ error_demo = error_interface()
35
+ error_demo.launch(server_name="0.0.0.0")
gradio_demo.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import sys
4
+ import os
5
+ from PIL import Image
6
+ import numpy as np
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ # Import spaces for GPU support on Hugging Face Spaces
10
+ try:
11
+ import spaces
12
+ HF_SPACES = True
13
+ except ImportError:
14
+ HF_SPACES = False
15
+ # Create a dummy decorator if not on Spaces
16
+ def spaces_gpu_decorator(func):
17
+ return func
18
+ spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})()
19
+
20
+ # Add parent directory to path to import our modules
21
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
22
+
23
+ from training.model import CrossAttentionModel
24
+ from transformers import BertTokenizer, ViTImageProcessor
25
+
26
+ class CandleFusionDemo:
27
+ def __init__(self, model_path=None):
28
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ # Load model from Hugging Face
31
+ self.model = CrossAttentionModel()
32
+
33
+ try:
34
+ # Download model from Hugging Face Hub
35
+ print("📥 Downloading model from Hugging Face...")
36
+ model_file = hf_hub_download(
37
+ repo_id="tuankg1028/candlefusion",
38
+ filename="pytorch_model.bin",
39
+ cache_dir="./model_cache"
40
+ )
41
+
42
+ # Load the downloaded model
43
+ self.model.load_state_dict(torch.load(model_file, map_location=self.device))
44
+ print(f"✅ Model loaded from Hugging Face: tuankg1028/candlefusion")
45
+
46
+ except Exception as e:
47
+ print(f"❌ Error loading model from Hugging Face: {str(e)}")
48
+ print("⚠️ Using untrained model instead.")
49
+
50
+ self.model.to(self.device)
51
+ self.model.eval()
52
+
53
+ # Initialize processors
54
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
55
+ self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
56
+
57
+ # Class labels
58
+ self.class_labels = ["Bearish", "Bullish"]
59
+
60
+ def preprocess_inputs(self, image, text):
61
+ """Preprocess image and text inputs for the model"""
62
+ # Process image
63
+ if image is None:
64
+ raise ValueError("Please upload a candlestick chart image")
65
+
66
+ image = Image.fromarray(image).convert("RGB")
67
+ image_inputs = self.processor(images=image, return_tensors="pt")
68
+ pixel_values = image_inputs["pixel_values"].to(self.device)
69
+
70
+ # Process text
71
+ if not text.strip():
72
+ text = "Market analysis" # Default text if empty
73
+
74
+ text_inputs = self.tokenizer(
75
+ text,
76
+ return_tensors="pt",
77
+ truncation=True,
78
+ padding="max_length",
79
+ max_length=64
80
+ )
81
+ input_ids = text_inputs["input_ids"].to(self.device)
82
+ attention_mask = text_inputs["attention_mask"].to(self.device)
83
+
84
+ return pixel_values, input_ids, attention_mask
85
+
86
+ @spaces.GPU
87
+ def predict(self, image, text):
88
+ """Make prediction using the model"""
89
+ try:
90
+ # Preprocess inputs
91
+ pixel_values, input_ids, attention_mask = self.preprocess_inputs(image, text)
92
+
93
+ # Model prediction
94
+ with torch.no_grad():
95
+ outputs = self.model(
96
+ input_ids=input_ids,
97
+ attention_mask=attention_mask,
98
+ pixel_values=pixel_values
99
+ )
100
+
101
+ logits = outputs["logits"]
102
+ forecast = outputs["forecast"]
103
+
104
+ # Get classification results
105
+ probabilities = torch.softmax(logits, dim=1)
106
+ predicted_class = torch.argmax(logits, dim=1).item()
107
+ confidence = probabilities[0][predicted_class].item()
108
+
109
+ # Get price forecast
110
+ predicted_price = forecast.squeeze().item()
111
+
112
+ # Format results
113
+ classification_result = f"**Prediction:** {self.class_labels[predicted_class]}\n"
114
+ classification_result += f"**Confidence:** {confidence:.2%}\n\n"
115
+ classification_result += "**Class Probabilities:**\n"
116
+ for i, (label, prob) in enumerate(zip(self.class_labels, probabilities[0])):
117
+ classification_result += f"- {label}: {prob:.2%}\n"
118
+
119
+ forecast_result = f"**Predicted Next Close Price:** ${predicted_price:.2f}"
120
+
121
+ return classification_result, forecast_result
122
+
123
+ except Exception as e:
124
+ error_msg = f"Error during prediction: {str(e)}"
125
+ return error_msg, error_msg
126
+
127
+ def create_demo():
128
+ """Create and launch the Gradio demo"""
129
+ demo_instance = CandleFusionDemo()
130
+
131
+ # Create Gradio interface
132
+ with gr.Blocks(title="CandleFusion - Candlestick Chart Analysis", theme=gr.themes.Soft()) as demo:
133
+ gr.Markdown("""
134
+ # 🕯️ CandleFusion Demo
135
+
136
+ Upload a candlestick chart image and provide market context to get:
137
+ - **Market Direction Prediction** (Bullish/Bearish)
138
+ - **Next Close Price Forecast**
139
+
140
+ This model combines visual analysis of candlestick charts with textual market context using BERT + ViT architecture.
141
+ """)
142
+
143
+ with gr.Row():
144
+ with gr.Column(scale=1):
145
+ gr.Markdown("### 📊 Input")
146
+
147
+ image_input = gr.Image(
148
+ label="Candlestick Chart",
149
+ type="numpy",
150
+ height=300
151
+ )
152
+
153
+ text_input = gr.Textbox(
154
+ label="Market Context",
155
+ placeholder="Enter market analysis, news, or context (e.g., 'Strong volume with positive earnings report')",
156
+ lines=3,
157
+ value="Technical analysis of price action"
158
+ )
159
+
160
+ predict_btn = gr.Button("🔮 Analyze Chart", variant="primary")
161
+
162
+ gr.Markdown("""
163
+ ### 💡 Tips:
164
+ - Upload clear candlestick chart images
165
+ - Provide relevant market context
166
+ - Charts should show recent price action
167
+ """)
168
+
169
+ with gr.Column(scale=1):
170
+ gr.Markdown("### 📈 Results")
171
+
172
+ classification_output = gr.Markdown(
173
+ value="Upload an image and click 'Analyze Chart' to see prediction"
174
+ )
175
+
176
+ forecast_output = gr.Markdown(
177
+ value=""
178
+ )
179
+
180
+ # Example section
181
+ gr.Markdown("### 📚 Example")
182
+ gr.Examples(
183
+ examples=[
184
+ ["example_chart.png", "Strong bullish momentum with high volume"],
185
+ ["example_chart2.png", "Bearish reversal pattern forming"]
186
+ ],
187
+ inputs=[image_input, text_input],
188
+ label="Try these examples:"
189
+ )
190
+
191
+ # Connect the prediction function
192
+ predict_btn.click(
193
+ fn=demo_instance.predict,
194
+ inputs=[image_input, text_input],
195
+ outputs=[classification_output, forecast_output]
196
+ )
197
+
198
+ gr.Markdown("""
199
+ ---
200
+ **Note:** This is a demo model. For production trading decisions, always consult with financial professionals and use additional analysis tools.
201
+ """)
202
+
203
+ return demo
204
+
205
+ def main():
206
+ """Main function to launch the demo"""
207
+ try:
208
+ demo = create_demo()
209
+ # Launch with server_name for compatibility on HF Spaces
210
+ demo.launch(server_name="0.0.0.0")
211
+ except Exception as e:
212
+ print(f"Failed to launch Gradio demo: {e}")
213
+ # Fallback launch with minimal configuration
214
+ demo = create_demo()
215
+ demo.launch(server_name="0.0.0.0")
216
+
217
+ if __name__ == "__main__":
218
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=5.32.0
2
+ torch>=1.9.0
3
+ transformers>=4.20.0
4
+ Pillow>=8.0.0
5
+ numpy>=1.21.0
6
+ pandas>=1.3.0
7
+ huggingface_hub>=0.16.0
training/__pycache__/model.cpython-311.pyc ADDED
Binary file (2.99 kB). View file
 
training/dataset.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset.py
2
+
3
+ import os
4
+ import sys
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from PIL import Image
8
+ import pandas as pd
9
+ from transformers import BertTokenizer, ViTImageProcessor
10
+
11
+ class CandlestickDataset(Dataset):
12
+ def __init__(self, csv_path: str, image_size: int = 224):
13
+ """
14
+ Args:
15
+ csv_path (str): Path to CSV with image_path, text, label
16
+ image_size (int): Size to resize chart images to (default 224)
17
+ """
18
+ self.data = pd.read_csv(csv_path)
19
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
20
+ self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
21
+ self.image_size = image_size
22
+
23
+ def __len__(self):
24
+ return len(self.data)
25
+
26
+ def __getitem__(self, idx):
27
+ row = self.data.iloc[idx]
28
+
29
+ # === Load and preprocess image ===
30
+ image_path = row["image_path"]
31
+ image = Image.open(image_path).convert("RGB")
32
+ image_inputs = self.processor(images=image, return_tensors="pt")
33
+ pixel_values = image_inputs["pixel_values"].squeeze(0) # (3, 224, 224)
34
+
35
+ # === Tokenize text ===
36
+ text = row["text"]
37
+ text_inputs = self.tokenizer(
38
+ text,
39
+ return_tensors="pt",
40
+ truncation=True,
41
+ padding="max_length",
42
+ max_length=64 # can be adjusted
43
+ )
44
+ input_ids = text_inputs["input_ids"].squeeze(0)
45
+ attention_mask = text_inputs["attention_mask"].squeeze(0)
46
+
47
+ # === Label ===
48
+ label = torch.tensor(row["label"], dtype=torch.long)
49
+ next_close = torch.tensor(row["next_close"], dtype=torch.float)
50
+
51
+ return {
52
+ "pixel_values": pixel_values,
53
+ "input_ids": input_ids,
54
+ "attention_mask": attention_mask,
55
+ "label": label,
56
+ "next_close": next_close
57
+ }
training/main.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pandas as pd
3
+ import os
4
+
5
+ from dataset import CandlestickDataset
6
+ from model import CrossAttentionModel
7
+ from train import train
8
+
9
+ from torch.utils.data import DataLoader
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(description="Train candlestick classifier using BERT + ViT")
13
+ parser.add_argument("--data_dir", type=str, default="../data", help="Directory containing dataset")
14
+ parser.add_argument("--batch_size", type=int, default=8)
15
+ parser.add_argument("--epochs", type=int, default=3)
16
+ parser.add_argument("--lr", type=float, default=2e-5)
17
+ parser.add_argument("--device", type=str, default="cuda")
18
+ parser.add_argument("--push_to_hub", action="store_true", help="Push model to Hugging Face Hub")
19
+ parser.add_argument("--hub_model_id", type=str, help="Hugging Face model ID (e.g., 'username/candlefusion')")
20
+ parser.add_argument("--hub_token", type=str, help="Hugging Face token (or set HF_TOKEN env var)")
21
+ args = parser.parse_args()
22
+
23
+ # === Paths
24
+ index_csv = os.path.join(args.data_dir, "dataset_index.csv")
25
+
26
+ if not os.path.exists(index_csv):
27
+ print(f"❌ Dataset index not found at {index_csv}")
28
+ print("Please run the build_dataset script first.")
29
+ return
30
+
31
+ # === Create checkpoints directory
32
+ os.makedirs("./checkpoints", exist_ok=True)
33
+
34
+ # === Dataset & Loader
35
+ dataset = CandlestickDataset(csv_path=index_csv)
36
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
37
+
38
+ # === Model
39
+ model = CrossAttentionModel()
40
+
41
+ # === Get HF token from env if not provided
42
+ hub_token = args.hub_token or os.getenv("HF_TOKEN")
43
+
44
+ # === Train
45
+ train(
46
+ model,
47
+ dataloader,
48
+ epochs=args.epochs,
49
+ lr=args.lr,
50
+ device=args.device,
51
+ push_to_hub=args.push_to_hub,
52
+ hub_model_id=args.hub_model_id,
53
+ hub_token=hub_token
54
+ )
55
+
56
+ if __name__ == "__main__":
57
+ main()
training/model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import BertModel, ViTModel
6
+
7
+ class CrossAttentionModel(nn.Module):
8
+ def __init__(self,
9
+ text_model_name="bert-base-uncased",
10
+ image_model_name="google/vit-base-patch16-224",
11
+ hidden_dim=768,
12
+ num_classes=2):
13
+ super().__init__()
14
+
15
+ # Encoders
16
+ self.bert = BertModel.from_pretrained(text_model_name)
17
+ self.vit = ViTModel.from_pretrained(image_model_name)
18
+
19
+ # Cross-Attention layer
20
+ self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=8, batch_first=True)
21
+
22
+ # Classification Head
23
+ self.classifier = nn.Sequential(
24
+ nn.Linear(hidden_dim, hidden_dim),
25
+ nn.ReLU(),
26
+ nn.Dropout(0.3),
27
+ nn.Linear(hidden_dim, num_classes)
28
+ )
29
+
30
+ # Forecasting Head (regression)
31
+ self.regressor = nn.Sequential(
32
+ nn.Linear(hidden_dim, hidden_dim),
33
+ nn.ReLU(),
34
+ nn.Dropout(0.3),
35
+ nn.Linear(hidden_dim, 1) # Predict next closing price
36
+ )
37
+
38
+ def forward(self, input_ids, attention_mask, pixel_values):
39
+ # === Text Encoding ===
40
+ text_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
41
+ text_cls = text_outputs.last_hidden_state[:, 0:1, :] # (B, 1, H)
42
+
43
+ # === Image Encoding ===
44
+ image_outputs = self.vit(pixel_values=pixel_values)
45
+ image_tokens = image_outputs.last_hidden_state[:, 1:, :] # skip CLS token
46
+
47
+ # === Cross-Attention ===
48
+ fused_cls, _ = self.cross_attention(
49
+ query=text_cls,
50
+ key=image_tokens,
51
+ value=image_tokens
52
+ ) # (B, 1, H)
53
+
54
+ fused_cls = fused_cls.squeeze(1) # (B, H)
55
+
56
+ # === Dual Heads ===
57
+ logits = self.classifier(fused_cls) # Classification
58
+ forecast = self.regressor(fused_cls) # Regression (next price)
59
+
60
+ return {"logits": logits, "forecast": forecast}
training/train.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import DataLoader
6
+ from transformers import get_scheduler
7
+ from tqdm import tqdm
8
+ import os
9
+
10
+ from dataset import CandlestickDataset
11
+ from model import CrossAttentionModel
12
+
13
+ def train(model, dataloader, val_loader=None, epochs=5, lr=2e-5, alpha=0.5, device="cuda",
14
+ push_to_hub=False, hub_model_id=None, hub_token=None):
15
+ device = torch.device(device if torch.cuda.is_available() else "cpu")
16
+ model.to(device)
17
+
18
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
19
+ total_steps = len(dataloader) * epochs
20
+ scheduler = get_scheduler("linear", optimizer, num_warmup_steps=0, num_training_steps=total_steps)
21
+
22
+ loss_fn_cls = nn.CrossEntropyLoss()
23
+ loss_fn_reg = nn.MSELoss()
24
+
25
+ for epoch in range(epochs):
26
+ model.train()
27
+ total_loss = 0
28
+ total_cls_loss = 0
29
+ total_reg_loss = 0
30
+
31
+ progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
32
+
33
+ for batch in progress_bar:
34
+ input_ids = batch["input_ids"].to(device)
35
+ attention_mask = batch["attention_mask"].to(device)
36
+ pixel_values = batch["pixel_values"].to(device)
37
+ labels = batch["label"].to(device)
38
+ target_price = batch["next_close"].to(device) # shape: (B,)
39
+
40
+ optimizer.zero_grad()
41
+
42
+ outputs = model(
43
+ input_ids=input_ids,
44
+ attention_mask=attention_mask,
45
+ pixel_values=pixel_values
46
+ )
47
+
48
+ logits = outputs["logits"]
49
+ forecast = outputs["forecast"].squeeze(1) # shape: (B,)
50
+
51
+ loss_cls = loss_fn_cls(logits, labels)
52
+ loss_reg = loss_fn_reg(forecast, target_price)
53
+ loss = loss_cls + alpha * loss_reg
54
+
55
+ loss.backward()
56
+ optimizer.step()
57
+ scheduler.step()
58
+
59
+ total_loss += loss.item()
60
+ total_cls_loss += loss_cls.item()
61
+ total_reg_loss += loss_reg.item()
62
+
63
+ progress_bar.set_postfix(loss=loss.item(), cls=loss_cls.item(), reg=loss_reg.item())
64
+
65
+ avg_loss = total_loss / len(dataloader)
66
+ print(f"✅ Epoch {epoch+1} done | Total Loss: {avg_loss:.4f} | CLS: {total_cls_loss/len(dataloader):.4f} | REG: {total_reg_loss/len(dataloader):.4f}")
67
+
68
+ if val_loader:
69
+ evaluate(model, val_loader, device)
70
+
71
+ torch.save(model.state_dict(), "./checkpoints/candlefusion_model.pt")
72
+ print("✅ Model saved to ./checkpoints/candlefusion_model.pt")
73
+
74
+ # Push to Hugging Face Hub if requested
75
+ if push_to_hub and hub_model_id:
76
+ try:
77
+ from huggingface_hub import HfApi, Repository
78
+ import json
79
+
80
+ # Login to HF Hub
81
+ if hub_token:
82
+ from huggingface_hub import login
83
+ login(token=hub_token)
84
+
85
+ # Create model card and config
86
+ model_card_content = f"""
87
+ ---
88
+ license: apache-2.0
89
+ tags:
90
+ - pytorch
91
+ - candlestick
92
+ - financial-analysis
93
+ - multimodal
94
+ - bert
95
+ - vit
96
+ - cross-attention
97
+ - trading
98
+ - forecasting
99
+ ---
100
+
101
+ # CandleFusion Model
102
+
103
+ A multimodal financial analysis model that combines textual market sentiment with visual candlestick patterns for enhanced trading signal prediction and price forecasting.
104
+
105
+ ## Architecture Overview
106
+
107
+ ### Core Components
108
+ - **Text Encoder**: BERT (bert-base-uncased) for processing market sentiment and news
109
+ - **Vision Encoder**: Vision Transformer (ViT-base-patch16-224) for candlestick pattern recognition
110
+ - **Cross-Attention Fusion**: Multi-head attention mechanism (8 heads, 768 dim) for text-image integration
111
+ - **Dual Task Heads**:
112
+ - Classification head for trading signals (buy/sell/hold)
113
+ - Regression head for next closing price prediction
114
+
115
+ ### Data Flow
116
+ 1. **Text Processing**: Market sentiment -> BERT -> CLS token (768-dim)
117
+ 2. **Image Processing**: Candlestick charts -> ViT -> Patch embeddings (197 tokens, 768-dim each)
118
+ 3. **Cross-Modal Fusion**: Text CLS as query, Image patches as keys/values -> Fused representation
119
+ 4. **Dual Predictions**:
120
+ - Fused features -> Classification head -> Trading signal logits
121
+ - Fused features -> Regression head -> Price forecast
122
+
123
+ ### Model Specifications
124
+ - **Input Text**: Tokenized to max 64 tokens
125
+ - **Input Images**: Resized to 224x224 RGB
126
+ - **Hidden Dimension**: 768 (consistent across encoders)
127
+ - **Output Classes**: {2} (binary: bullish/bearish)
128
+ - **Dropout**: 0.3 in both heads
129
+
130
+ ## Training Details
131
+ - **Epochs**: {epochs}
132
+ - **Learning Rate**: {lr}
133
+ - **Loss Function**: CrossEntropy (classification) + MSE (regression)
134
+ - **Loss Weight (alpha)**: {alpha} for regression term
135
+ - **Optimizer**: AdamW with linear scheduling
136
+
137
+ ## Usage
138
+ ```python
139
+ from model import CrossAttentionModel
140
+ import torch
141
+
142
+ # Load model
143
+ model = CrossAttentionModel()
144
+ model.load_state_dict(torch.load("pytorch_model.bin"))
145
+ model.eval()
146
+
147
+ # Inference
148
+ outputs = model(input_ids, attention_mask, pixel_values)
149
+ trading_signals = outputs["logits"]
150
+ price_forecast = outputs["forecast"]
151
+ ```
152
+
153
+ ## Performance
154
+ The model simultaneously optimizes for:
155
+ - **Classification Task**: Trading signal accuracy
156
+ - **Regression Task**: Price prediction MSE
157
+
158
+ This dual-task approach enables the model to learn both categorical market direction and continuous price movements.
159
+ """
160
+
161
+ config = {
162
+ "model_type": "candlefusion",
163
+ "architecture": "bert+vit+cross_attention",
164
+ "num_labels": 3,
165
+ "epochs": epochs,
166
+ "learning_rate": lr,
167
+ "alpha": alpha
168
+ }
169
+
170
+ # Create repository
171
+ api = HfApi()
172
+ api.create_repo(repo_id=hub_model_id, exist_ok=True)
173
+
174
+ # Upload files
175
+ api.upload_file(
176
+ path_or_fileobj="./checkpoints/candlefusion_model.pt",
177
+ path_in_repo="pytorch_model.bin",
178
+ repo_id=hub_model_id,
179
+ )
180
+
181
+ # Upload model card
182
+ with open("./checkpoints/README.md", "w", encoding="utf-8") as f:
183
+ f.write(model_card_content)
184
+ api.upload_file(
185
+ path_or_fileobj="./checkpoints/README.md",
186
+ path_in_repo="README.md",
187
+ repo_id=hub_model_id,
188
+ )
189
+
190
+ # Upload config
191
+ with open("./checkpoints/config.json", "w") as f:
192
+ json.dump(config, f, indent=2)
193
+ api.upload_file(
194
+ path_or_fileobj="./checkpoints/config.json",
195
+ path_in_repo="config.json",
196
+ repo_id=hub_model_id,
197
+ )
198
+
199
+ print(f"✅ Model pushed to Hugging Face Hub: https://huggingface.co/{hub_model_id}")
200
+
201
+ except ImportError:
202
+ print("❌ huggingface_hub not installed. Install with: pip install huggingface_hub")
203
+ except Exception as e:
204
+ print(f"❌ Error pushing to Hub: {e}")
205
+
206
+ def evaluate(model, dataloader, device="cuda"):
207
+ device = torch.device(device if torch.cuda.is_available() else "cpu")
208
+ model.eval()
209
+
210
+ correct = 0
211
+ total = 0
212
+ all_preds = []
213
+ all_labels = []
214
+ all_forecasts = []
215
+ all_targets = []
216
+
217
+ with torch.no_grad():
218
+ for batch in dataloader:
219
+ input_ids = batch["input_ids"].to(device)
220
+ attention_mask = batch["attention_mask"].to(device)
221
+ pixel_values = batch["pixel_values"].to(device)
222
+ labels = batch["label"].to(device)
223
+ target_price = batch["next_close"].to(device)
224
+
225
+ outputs = model(
226
+ input_ids=input_ids,
227
+ attention_mask=attention_mask,
228
+ pixel_values=pixel_values
229
+ )
230
+
231
+ logits = outputs["logits"]
232
+ forecast = outputs["forecast"].squeeze(1)
233
+
234
+ preds = torch.argmax(logits, dim=1)
235
+ correct += (preds == labels).sum().item()
236
+ total += labels.size(0)
237
+
238
+ all_preds.extend(preds.tolist())
239
+ all_labels.extend(labels.tolist())
240
+ all_forecasts.extend(forecast.tolist())
241
+ all_targets.extend(target_price.tolist())
242
+
243
+ acc = correct / total
244
+ print(f"📊 Evaluation Accuracy: {acc*100:.2f}%")
245
+
246
+ # Optional: print forecasting MSE
247
+ forecast_mse = nn.MSELoss()(torch.tensor(all_forecasts), torch.tensor(all_targets)).item()
248
+ print(f"📈 Forecast MSE: {forecast_mse:.4f}")