tuankg1028 commited on
Commit
d04c64c
Β·
verified Β·
1 Parent(s): d079a44

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +84 -84
README.md CHANGED
@@ -1,84 +1,84 @@
1
-
2
- ---
3
- license: apache-2.0
4
- tags:
5
- - pytorch
6
- - candlestick
7
- - financial-analysis
8
- - multimodal
9
- - bert
10
- - vit
11
- - cross-attention
12
- - trading
13
- - forecasting
14
- ---
15
-
16
- # CandleFusion Model
17
-
18
- A multimodal financial analysis model that combines textual market sentiment with visual candlestick patterns for enhanced trading signal prediction and price forecasting.
19
-
20
- ## Links
21
- - πŸ”— **GitHub Repository**: https://github.com/tuankg1028/CandleFusion
22
- - πŸš€ **Demo on Hugging Face Spaces**: https://huggingface.co/spaces/tuankg1028/candlefusion
23
-
24
- ## Training Results
25
- - **Best Epoch**: 18
26
- - **Best Validation Loss**: 316165.5985
27
- - **Training Epochs**: 23
28
- - **Early Stopping**: Yes
29
-
30
- ## Architecture Overview
31
-
32
- ### Core Components
33
- - **Text Encoder**: BERT (bert-base-uncased) for processing market sentiment and news
34
- - **Vision Encoder**: Vision Transformer (ViT-base-patch16-224) for candlestick pattern recognition
35
- - **Cross-Attention Fusion**: Multi-head attention mechanism (8 heads, 768 dim) for text-image integration
36
- - **Dual Task Heads**:
37
- - Classification head for trading signals (buy/sell/hold)
38
- - Regression head for next closing price prediction
39
-
40
- ### Data Flow
41
- 1. **Text Processing**: Market sentiment -> BERT -> CLS token (768-dim)
42
- 2. **Image Processing**: Candlestick charts -> ViT -> Patch embeddings (197 tokens, 768-dim each)
43
- 3. **Cross-Modal Fusion**: Text CLS as query, Image patches as keys/values -> Fused representation
44
- 4. **Dual Predictions**:
45
- - Fused features -> Classification head -> Trading signal logits
46
- - Fused features -> Regression head -> Price forecast
47
-
48
- ### Model Specifications
49
- - **Input Text**: Tokenized to max 64 tokens
50
- - **Input Images**: Resized to 224x224 RGB
51
- - **Hidden Dimension**: 768 (consistent across encoders)
52
- - **Output Classes**: 3 (buy/sell/hold)
53
- - **Dropout**: 0.3 in both heads
54
-
55
- ## Training Details
56
- - **Epochs**: 100
57
- - **Learning Rate**: 2e-05
58
- - **Loss Function**: CrossEntropy (classification) + MSE (regression)
59
- - **Loss Weight (alpha)**: 0.5 for regression term
60
- - **Optimizer**: AdamW with linear scheduling
61
- - **Early Stopping Patience**: 5
62
-
63
- ## Usage
64
- ```python
65
- from model import CrossAttentionModel
66
- import torch
67
-
68
- # Load model
69
- model = CrossAttentionModel()
70
- model.load_state_dict(torch.load("pytorch_model.bin"))
71
- model.eval()
72
-
73
- # Inference
74
- outputs = model(input_ids, attention_mask, pixel_values)
75
- trading_signals = outputs["logits"]
76
- price_forecast = outputs["forecast"]
77
- ```
78
-
79
- ## Performance
80
- The model simultaneously optimizes for:
81
- - **Classification Task**: Trading signal accuracy
82
- - **Regression Task**: Price prediction MSE
83
-
84
- This dual-task approach enables the model to learn both categorical market direction and continuous price movements.
 
1
+
2
+ ---
3
+ license: apache-2.0
4
+ tags:
5
+ - pytorch
6
+ - candlestick
7
+ - financial-analysis
8
+ - multimodal
9
+ - bert
10
+ - vit
11
+ - cross-attention
12
+ - trading
13
+ - forecasting
14
+ ---
15
+
16
+ # CandleFusion Model
17
+
18
+ A multimodal financial analysis model that combines textual market sentiment with visual candlestick patterns for enhanced trading signal prediction and price forecasting.
19
+
20
+ ## Links
21
+ - πŸ”— **GitHub Repository**: https://github.com/tuankg1028/CandleFusion
22
+ - πŸš€ **Demo on Hugging Face Spaces**: https://huggingface.co/spaces/tuankg1028/candlefusion
23
+
24
+ ## Training Results
25
+ - **Best Epoch**: 18
26
+ - **Best Validation Loss**: 316165.5985
27
+ - **Training Epochs**: 23
28
+ - **Early Stopping**: Yes
29
+
30
+ ## Architecture Overview
31
+
32
+ ### Core Components
33
+ - **Text Encoder**: BERT (bert-base-uncased) for processing market sentiment and news
34
+ - **Vision Encoder**: Vision Transformer (ViT-base-patch16-224) for candlestick pattern recognition
35
+ - **Cross-Attention Fusion**: Multi-head attention mechanism (8 heads, 768 dim) for text-image integration
36
+ - **Dual Task Heads**:
37
+ - Classification head for trading signals (bullish/bearish)
38
+ - Regression head for next closing price prediction
39
+
40
+ ### Data Flow
41
+ 1. **Text Processing**: Market sentiment -> BERT -> CLS token (768-dim)
42
+ 2. **Image Processing**: Candlestick charts -> ViT -> Patch embeddings (197 tokens, 768-dim each)
43
+ 3. **Cross-Modal Fusion**: Text CLS as query, Image patches as keys/values -> Fused representation
44
+ 4. **Dual Predictions**:
45
+ - Fused features -> Classification head -> Trading signal logits
46
+ - Fused features -> Regression head -> Price forecast
47
+
48
+ ### Model Specifications
49
+ - **Input Text**: Tokenized to max 64 tokens
50
+ - **Input Images**: Resized to 224x224 RGB
51
+ - **Hidden Dimension**: 768 (consistent across encoders)
52
+ - **Output Classes**: 2 (bullish/bearish)
53
+ - **Dropout**: 0.3 in both heads
54
+
55
+ ## Training Details
56
+ - **Epochs**: 100
57
+ - **Learning Rate**: 2e-05
58
+ - **Loss Function**: CrossEntropy (classification) + MSE (regression)
59
+ - **Loss Weight (alpha)**: 0.5 for regression term
60
+ - **Optimizer**: AdamW with linear scheduling
61
+ - **Early Stopping Patience**: 5
62
+
63
+ ## Usage
64
+ ```python
65
+ from model import CrossAttentionModel
66
+ import torch
67
+
68
+ # Load model
69
+ model = CrossAttentionModel()
70
+ model.load_state_dict(torch.load("pytorch_model.bin"))
71
+ model.eval()
72
+
73
+ # Inference
74
+ outputs = model(input_ids, attention_mask, pixel_values)
75
+ trading_signals = outputs["logits"]
76
+ price_forecast = outputs["forecast"]
77
+ ```
78
+
79
+ ## Performance
80
+ The model simultaneously optimizes for:
81
+ - **Classification Task**: Trading signal accuracy
82
+ - **Regression Task**: Price prediction MSE
83
+
84
+ This dual-task approach enables the model to learn both categorical market direction and continuous price movements.