File size: 2,725 Bytes
d04c64c
 
 
 
 
 
 
 
 
 
 
 
0fcb63e
 
d04c64c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fcb63e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
---
license: apache-2.0
tags:
- pytorch
- candlestick
- financial-analysis
- multimodal
- bert
- vit
- cross-attention
- trading
- forecasting
datasets:
- tuankg1028/btc-candlestick-dataset
---

# CandleFusion Model

A multimodal financial analysis model that combines textual market sentiment with visual candlestick patterns for enhanced trading signal prediction and price forecasting.

## Links
- ๐Ÿ”— **GitHub Repository**: https://github.com/tuankg1028/CandleFusion
- ๐Ÿš€ **Demo on Hugging Face Spaces**: https://huggingface.co/spaces/tuankg1028/candlefusion

## Training Results
- **Best Epoch**: 18
- **Best Validation Loss**: 316165.5985
- **Training Epochs**: 23
- **Early Stopping**: Yes

## Architecture Overview

### Core Components
- **Text Encoder**: BERT (bert-base-uncased) for processing market sentiment and news
- **Vision Encoder**: Vision Transformer (ViT-base-patch16-224) for candlestick pattern recognition
- **Cross-Attention Fusion**: Multi-head attention mechanism (8 heads, 768 dim) for text-image integration
- **Dual Task Heads**: 
  - Classification head for trading signals (bullish/bearish)
  - Regression head for next closing price prediction

### Data Flow
1. **Text Processing**: Market sentiment -> BERT -> CLS token (768-dim)
2. **Image Processing**: Candlestick charts -> ViT -> Patch embeddings (197 tokens, 768-dim each)
3. **Cross-Modal Fusion**: Text CLS as query, Image patches as keys/values -> Fused representation
4. **Dual Predictions**: 
   - Fused features -> Classification head -> Trading signal logits
   - Fused features -> Regression head -> Price forecast

### Model Specifications
- **Input Text**: Tokenized to max 64 tokens
- **Input Images**: Resized to 224x224 RGB
- **Hidden Dimension**: 768 (consistent across encoders)
- **Output Classes**: 2 (bullish/bearish)
- **Dropout**: 0.3 in both heads

## Training Details
- **Epochs**: 100
- **Learning Rate**: 2e-05
- **Loss Function**: CrossEntropy (classification) + MSE (regression)
- **Loss Weight (alpha)**: 0.5 for regression term
- **Optimizer**: AdamW with linear scheduling
- **Early Stopping Patience**: 5

## Usage
```python
from model import CrossAttentionModel
import torch

# Load model
model = CrossAttentionModel()
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()

# Inference
outputs = model(input_ids, attention_mask, pixel_values)
trading_signals = outputs["logits"]
price_forecast = outputs["forecast"]
```

## Performance
The model simultaneously optimizes for:
- **Classification Task**: Trading signal accuracy
- **Regression Task**: Price prediction MSE

This dual-task approach enables the model to learn both categorical market direction and continuous price movements.