candlefusion / README.md
tuankg1028's picture
Update README.md
0fcb63e verified
metadata
license: apache-2.0
tags:
  - pytorch
  - candlestick
  - financial-analysis
  - multimodal
  - bert
  - vit
  - cross-attention
  - trading
  - forecasting
datasets:
  - tuankg1028/btc-candlestick-dataset

CandleFusion Model

A multimodal financial analysis model that combines textual market sentiment with visual candlestick patterns for enhanced trading signal prediction and price forecasting.

Links

Training Results

  • Best Epoch: 18
  • Best Validation Loss: 316165.5985
  • Training Epochs: 23
  • Early Stopping: Yes

Architecture Overview

Core Components

  • Text Encoder: BERT (bert-base-uncased) for processing market sentiment and news
  • Vision Encoder: Vision Transformer (ViT-base-patch16-224) for candlestick pattern recognition
  • Cross-Attention Fusion: Multi-head attention mechanism (8 heads, 768 dim) for text-image integration
  • Dual Task Heads:
    • Classification head for trading signals (bullish/bearish)
    • Regression head for next closing price prediction

Data Flow

  1. Text Processing: Market sentiment -> BERT -> CLS token (768-dim)
  2. Image Processing: Candlestick charts -> ViT -> Patch embeddings (197 tokens, 768-dim each)
  3. Cross-Modal Fusion: Text CLS as query, Image patches as keys/values -> Fused representation
  4. Dual Predictions:
    • Fused features -> Classification head -> Trading signal logits
    • Fused features -> Regression head -> Price forecast

Model Specifications

  • Input Text: Tokenized to max 64 tokens
  • Input Images: Resized to 224x224 RGB
  • Hidden Dimension: 768 (consistent across encoders)
  • Output Classes: 2 (bullish/bearish)
  • Dropout: 0.3 in both heads

Training Details

  • Epochs: 100
  • Learning Rate: 2e-05
  • Loss Function: CrossEntropy (classification) + MSE (regression)
  • Loss Weight (alpha): 0.5 for regression term
  • Optimizer: AdamW with linear scheduling
  • Early Stopping Patience: 5

Usage

from model import CrossAttentionModel
import torch

# Load model
model = CrossAttentionModel()
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()

# Inference
outputs = model(input_ids, attention_mask, pixel_values)
trading_signals = outputs["logits"]
price_forecast = outputs["forecast"]

Performance

The model simultaneously optimizes for:

  • Classification Task: Trading signal accuracy
  • Regression Task: Price prediction MSE

This dual-task approach enables the model to learn both categorical market direction and continuous price movements.