brendanm12345 commited on
Commit
8f352d6
·
verified ·
1 Parent(s): 707c5c2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +66 -64
README.md CHANGED
@@ -1,96 +1,98 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
 
 
 
 
3
  ---
4
 
5
- # Weaver Distilled - MATH500 (ModernBERT-large)
6
 
7
- This is a distilled cross-encoder model based on ModernBERT-large, trained to predict the correctness of answers on MATH500. This specialized verifier was trained on Weaver scores aggregated over 35 different verifiers and reward models.
8
 
9
  ## Model Details
10
 
11
- - **Base Model**: [answerdotai/ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large)
12
  - **Architecture**: Cross-encoder with MLP head (1024 → 512 → 256 → 1)
13
- - **Max Sequence Length**: 4096
14
- - **Training Data**: [MATH500](https://huggingface.co/datasets/HuggingFaceH4/MATH-500) scored by 35 different LM Judges and reward models, aggregated to form sample-level scores with Weaver
15
- - **Training Objective**: Binary classification (correct/incorrect answer prediction)
16
 
17
- ## Usage
18
- TODO: ADD POINTER TO CUSTOM_CROSSENCODER.PY SCRIPT
19
 
20
- ```python
21
- import torch
22
- import logging
23
- from custom_crossencoder import CustomCrossEncoder, TrainingConfig
24
 
25
- # Setup logging
26
- logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
27
- logger = logging.getLogger(__name__)
28
 
29
- # Model configuration
30
- config = TrainingConfig(
31
- model_name="answerdotai/ModernBERT-large", # Base model to use
32
- max_length=4096,
33
- mlp_hidden_dims=[1024, 512, 256], # Default for ModernBERT
34
- dropout_rate=0.1,
35
- dataset_path="hazyresearch/MATH500_with_Llama_3.1_70B_Instruct_v1",
36
- )
37
 
38
- # Model path - using HuggingFace model repository
39
- checkpoint_path = "hazyresearch/Weaver_Distilled_ModernBERT_Large_for_MATH500"
 
40
 
41
- # Load model
42
- logger.info(f"Loading model from checkpoint: {checkpoint_path}")
43
- model = CustomCrossEncoder(config)
44
- model.load_finetuned_checkpoint(checkpoint_path)
45
- model.eval() # Set to evaluation mode
46
 
47
- # Dummy example
48
- instruction = "Solve the following math problem: What is 2 + 2?"
49
- response = "The answer is 4. This is because when we add 2 and 2 together, we get 4."
50
 
51
- # Tokenize input
52
- encoded = model.tokenizer(
53
- text=instruction,
54
- text_pair=response,
55
  truncation=True,
56
- max_length=config.max_length,
57
- padding="max_length",
58
  return_tensors="pt"
59
  )
60
 
61
- # Get prediction
62
- logger.info("\nMaking prediction on dummy example:")
63
- logger.info(f"Instruction: {instruction}")
64
- logger.info(f"Response: {response}")
65
-
66
- # Move tensors to the same device as model
67
- device = next(model.parameters()).device
68
- input_ids = encoded["input_ids"].to(device)
69
- attention_mask = encoded["attention_mask"].to(device)
70
-
71
- # Get raw score
72
  with torch.no_grad():
73
- score = model(input_ids, attention_mask).item()
74
-
75
- logger.info(f"\nRaw prediction score: {score:.4f}")
76
-
77
- # Get binary prediction (using 0.5 threshold)
78
- binary_prediction = "Correct" if score >= 0.5 else "Incorrect"
79
- logger.info(f"Binary prediction (threshold 0.5): {binary_prediction}")
80
  ```
81
 
82
- ## Running Evaluation
83
 
84
- TODO: ADD EVALUATION_SIMPLE COMMAND HERE
85
 
86
- ## License
87
 
88
- [Your chosen license]
89
 
90
- ## Citation
 
 
 
 
 
 
 
 
91
 
92
- If you use this model in your research, please cite:
93
 
94
  ```bibtex
95
- TODO
 
 
 
 
 
96
  ```
 
1
  ---
2
  license: mit
3
+ pipeline_tag: text-classification
4
+ library_name: transformers
5
+ base_model: answerdotai/ModernBERT-large
6
+ tags:
7
+ - math
8
+ - reasoning
9
+ - verification
10
+ - weaver
11
+ - cross-encoder
12
+ language:
13
+ - en
14
  ---
15
 
16
+ # Weaver Distilled for MATH500
17
 
18
+ A distilled cross-encoder model that captures 98.7% of Weaver's accuracy while reducing verification compute by 99.97%. This model is fine-tuned from ModernBERT-large to predict the correctness of mathematical reasoning responses, trained on Weaver ensemble scores from 35 different verifiers.
19
 
20
  ## Model Details
21
 
22
+ - **Base Model**: [answerdotai/ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large) (395M parameters)
23
  - **Architecture**: Cross-encoder with MLP head (1024 → 512 → 256 → 1)
24
+ - **Max Sequence Length**: 4096 tokens
25
+ - **Training Data**: MATH500 problems with Weaver scores from 35 LM judges and reward models
26
+ - **Task**: Binary classification for answer correctness prediction
27
 
28
+ ## Performance
 
29
 
30
+ On MATH500 with Llama 3.1 70B generations:
31
+ - **Weaver (Full)**: 93.4% accuracy, high compute cost
32
+ - **Weaver (Distilled)**: 92.2% accuracy, 99.97% compute reduction
33
+ - **Majority Voting**: 83.0% accuracy
34
 
35
+ TODO: replace these with the actual numbers
 
 
36
 
37
+ ## Quick Start
 
 
 
 
 
 
 
38
 
39
+ ```python
40
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
41
+ import torch
42
 
43
+ # Load model and tokenizer
44
+ model_name = "hazyresearch/Weaver_Distilled_for_MATH500"
45
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
46
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
47
 
48
+ # Example usage
49
+ instruction = "Solve: What is the derivative of x^2 + 3x + 2?"
50
+ response = "The derivative is 2x + 3. Using the power rule..."
51
 
52
+ # Tokenize input pair
53
+ inputs = tokenizer(
54
+ instruction,
55
+ response,
56
  truncation=True,
57
+ max_length=4096,
58
+ padding=True,
59
  return_tensors="pt"
60
  )
61
 
62
+ # Get correctness score
 
 
 
 
 
 
 
 
 
 
63
  with torch.no_grad():
64
+ outputs = model(**inputs)
65
+ score = torch.sigmoid(outputs.logits).item()
66
+
67
+ print(f"Correctness score: {score:.3f}")
68
+ print(f"Prediction: {'Correct' if score > 0.5 else 'Incorrect'}")
 
 
69
  ```
70
 
71
+ ## Training Details
72
 
73
+ This model was trained using the [Weaver distillation pipeline](https://github.com/ScalingIntelligence/scaling-verification/tree/main/distillation). For training your own distilled models, see the [distillation README](https://github.com/ScalingIntelligence/scaling-verification/blob/main/distillation/README.md).
74
 
75
+ ## Evaluation
76
 
77
+ Evaluate this model using:
78
 
79
+ ```bash
80
+ python evaluate_crossencoder.py \
81
+ --model_name "answerdotai/ModernBERT-large" \
82
+ --checkpoint_path "hazyresearch/Weaver_Distilled_for_MATH500" \
83
+ --dataset_path "hazyresearch/MATH500_with_Llama_3.1_70B_Instruct_v1" \
84
+ --dataset_split "data" \
85
+ --max_length 4096 \
86
+ --batch_size 64
87
+ ```
88
 
89
+ ## Citation
90
 
91
  ```bibtex
92
+ @article{weaver2025,
93
+ title={Weaver: Shrinking the Generation-Verification Gap with Weak Verifiers},
94
+ author={},
95
+ journal={arXiv preprint},
96
+ year={2025}
97
+ }
98
  ```