CandleFusion Model
A multimodal financial analysis model that combines textual market sentiment with visual candlestick patterns for enhanced trading signal prediction and price forecasting.
Links
- π GitHub Repository: https://github.com/tuankg1028/CandleFusion
- π Demo on Hugging Face Spaces: https://huggingface.co/spaces/tuankg1028/candlefusion
Training Results
- Best Epoch: 18
- Best Validation Loss: 316165.5985
- Training Epochs: 23
- Early Stopping: Yes
Architecture Overview
Core Components
- Text Encoder: BERT (bert-base-uncased) for processing market sentiment and news
- Vision Encoder: Vision Transformer (ViT-base-patch16-224) for candlestick pattern recognition
- Cross-Attention Fusion: Multi-head attention mechanism (8 heads, 768 dim) for text-image integration
- Dual Task Heads:
- Classification head for trading signals (bullish/bearish)
- Regression head for next closing price prediction
Data Flow
- Text Processing: Market sentiment -> BERT -> CLS token (768-dim)
- Image Processing: Candlestick charts -> ViT -> Patch embeddings (197 tokens, 768-dim each)
- Cross-Modal Fusion: Text CLS as query, Image patches as keys/values -> Fused representation
- Dual Predictions:
- Fused features -> Classification head -> Trading signal logits
- Fused features -> Regression head -> Price forecast
Model Specifications
- Input Text: Tokenized to max 64 tokens
- Input Images: Resized to 224x224 RGB
- Hidden Dimension: 768 (consistent across encoders)
- Output Classes: 2 (bullish/bearish)
- Dropout: 0.3 in both heads
Training Details
- Epochs: 100
- Learning Rate: 2e-05
- Loss Function: CrossEntropy (classification) + MSE (regression)
- Loss Weight (alpha): 0.5 for regression term
- Optimizer: AdamW with linear scheduling
- Early Stopping Patience: 5
Usage
from model import CrossAttentionModel
import torch
# Load model
model = CrossAttentionModel()
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()
# Inference
outputs = model(input_ids, attention_mask, pixel_values)
trading_signals = outputs["logits"]
price_forecast = outputs["forecast"]
Performance
The model simultaneously optimizes for:
- Classification Task: Trading signal accuracy
- Regression Task: Price prediction MSE
This dual-task approach enables the model to learn both categorical market direction and continuous price movements.
- Downloads last month
- 11
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
π
Ask for provider support