Update README.md
Browse files
README.md
CHANGED
@@ -15,35 +15,68 @@ This is a distilled cross-encoder model based on ModernBERT-large, trained to pr
|
|
15 |
- **Training Objective**: Binary classification (correct/incorrect answer prediction)
|
16 |
|
17 |
## Usage
|
|
|
18 |
|
19 |
```python
|
|
|
|
|
20 |
from custom_crossencoder import CustomCrossEncoder, TrainingConfig
|
21 |
|
22 |
-
#
|
|
|
|
|
|
|
|
|
23 |
config = TrainingConfig(
|
24 |
-
model_name="answerdotai/ModernBERT-large",
|
25 |
max_length=4096,
|
26 |
-
mlp_hidden_dims=[1024, 512, 256]
|
|
|
|
|
27 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
model = CustomCrossEncoder(config)
|
|
|
|
|
29 |
|
30 |
-
#
|
31 |
-
|
32 |
-
|
33 |
|
34 |
-
#
|
35 |
-
instruction = "Your instruction here"
|
36 |
-
answer = "Your answer here"
|
37 |
encoded = model.tokenizer(
|
38 |
text=instruction,
|
39 |
-
text_pair=
|
40 |
truncation=True,
|
41 |
-
max_length=
|
42 |
padding="max_length",
|
43 |
return_tensors="pt"
|
44 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
with torch.no_grad():
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
```
|
48 |
|
49 |
## Running Evaluation
|
|
|
15 |
- **Training Objective**: Binary classification (correct/incorrect answer prediction)
|
16 |
|
17 |
## Usage
|
18 |
+
TODO: ADD POINTER TO CUSTOM_CROSSENCODER.PY SCRIPT
|
19 |
|
20 |
```python
|
21 |
+
import torch
|
22 |
+
import logging
|
23 |
from custom_crossencoder import CustomCrossEncoder, TrainingConfig
|
24 |
|
25 |
+
# Setup logging
|
26 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
# Model configuration
|
30 |
config = TrainingConfig(
|
31 |
+
model_name="answerdotai/ModernBERT-large", # Base model to use
|
32 |
max_length=4096,
|
33 |
+
mlp_hidden_dims=[1024, 512, 256], # Default for ModernBERT
|
34 |
+
dropout_rate=0.1,
|
35 |
+
dataset_path="hazyresearch/MATH500_with_Llama_3.1_70B_Instruct_v1",
|
36 |
)
|
37 |
+
|
38 |
+
# Model path - using HuggingFace model repository
|
39 |
+
checkpoint_path = "hazyresearch/Weaver_Distilled_ModernBERT_Large_for_MATH500"
|
40 |
+
|
41 |
+
# Load model
|
42 |
+
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
|
43 |
model = CustomCrossEncoder(config)
|
44 |
+
model.load_finetuned_checkpoint(checkpoint_path)
|
45 |
+
model.eval() # Set to evaluation mode
|
46 |
|
47 |
+
# Dummy example
|
48 |
+
instruction = "Solve the following math problem: What is 2 + 2?"
|
49 |
+
response = "The answer is 4. This is because when we add 2 and 2 together, we get 4."
|
50 |
|
51 |
+
# Tokenize input
|
|
|
|
|
52 |
encoded = model.tokenizer(
|
53 |
text=instruction,
|
54 |
+
text_pair=response,
|
55 |
truncation=True,
|
56 |
+
max_length=config.max_length,
|
57 |
padding="max_length",
|
58 |
return_tensors="pt"
|
59 |
)
|
60 |
+
|
61 |
+
# Get prediction
|
62 |
+
logger.info("\nMaking prediction on dummy example:")
|
63 |
+
logger.info(f"Instruction: {instruction}")
|
64 |
+
logger.info(f"Response: {response}")
|
65 |
+
|
66 |
+
# Move tensors to the same device as model
|
67 |
+
device = next(model.parameters()).device
|
68 |
+
input_ids = encoded["input_ids"].to(device)
|
69 |
+
attention_mask = encoded["attention_mask"].to(device)
|
70 |
+
|
71 |
+
# Get raw score
|
72 |
with torch.no_grad():
|
73 |
+
score = model(input_ids, attention_mask).item()
|
74 |
+
|
75 |
+
logger.info(f"\nRaw prediction score: {score:.4f}")
|
76 |
+
|
77 |
+
# Get binary prediction (using 0.5 threshold)
|
78 |
+
binary_prediction = "Correct" if score >= 0.5 else "Incorrect"
|
79 |
+
logger.info(f"Binary prediction (threshold 0.5): {binary_prediction}")
|
80 |
```
|
81 |
|
82 |
## Running Evaluation
|