jonsaadfalcon commited on
Commit
0e40d5d
·
verified ·
1 Parent(s): 8ed81f4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +45 -12
README.md CHANGED
@@ -15,35 +15,68 @@ This is a distilled cross-encoder model based on ModernBERT-large, trained to pr
15
  - **Training Objective**: Binary classification (correct/incorrect answer prediction)
16
 
17
  ## Usage
 
18
 
19
  ```python
 
 
20
  from custom_crossencoder import CustomCrossEncoder, TrainingConfig
21
 
22
- # Initialize model
 
 
 
 
23
  config = TrainingConfig(
24
- model_name="answerdotai/ModernBERT-large",
25
  max_length=4096,
26
- mlp_hidden_dims=[1024, 512, 256]
 
 
27
  )
 
 
 
 
 
 
28
  model = CustomCrossEncoder(config)
 
 
29
 
30
- # Load checkpoint
31
- model.load_state_dict(torch.load("hazyresearch/Weaver_Distilled_ModernBERT_Large_for_MATH500"))
32
- model.eval()
33
 
34
- # Get prediction
35
- instruction = "Your instruction here"
36
- answer = "Your answer here"
37
  encoded = model.tokenizer(
38
  text=instruction,
39
- text_pair=answer,
40
  truncation=True,
41
- max_length=4096,
42
  padding="max_length",
43
  return_tensors="pt"
44
  )
 
 
 
 
 
 
 
 
 
 
 
 
45
  with torch.no_grad():
46
- prediction = model(encoded["input_ids"], encoded["attention_mask"])
 
 
 
 
 
 
47
  ```
48
 
49
  ## Running Evaluation
 
15
  - **Training Objective**: Binary classification (correct/incorrect answer prediction)
16
 
17
  ## Usage
18
+ TODO: ADD POINTER TO CUSTOM_CROSSENCODER.PY SCRIPT
19
 
20
  ```python
21
+ import torch
22
+ import logging
23
  from custom_crossencoder import CustomCrossEncoder, TrainingConfig
24
 
25
+ # Setup logging
26
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Model configuration
30
  config = TrainingConfig(
31
+ model_name="answerdotai/ModernBERT-large", # Base model to use
32
  max_length=4096,
33
+ mlp_hidden_dims=[1024, 512, 256], # Default for ModernBERT
34
+ dropout_rate=0.1,
35
+ dataset_path="hazyresearch/MATH500_with_Llama_3.1_70B_Instruct_v1",
36
  )
37
+
38
+ # Model path - using HuggingFace model repository
39
+ checkpoint_path = "hazyresearch/Weaver_Distilled_ModernBERT_Large_for_MATH500"
40
+
41
+ # Load model
42
+ logger.info(f"Loading model from checkpoint: {checkpoint_path}")
43
  model = CustomCrossEncoder(config)
44
+ model.load_finetuned_checkpoint(checkpoint_path)
45
+ model.eval() # Set to evaluation mode
46
 
47
+ # Dummy example
48
+ instruction = "Solve the following math problem: What is 2 + 2?"
49
+ response = "The answer is 4. This is because when we add 2 and 2 together, we get 4."
50
 
51
+ # Tokenize input
 
 
52
  encoded = model.tokenizer(
53
  text=instruction,
54
+ text_pair=response,
55
  truncation=True,
56
+ max_length=config.max_length,
57
  padding="max_length",
58
  return_tensors="pt"
59
  )
60
+
61
+ # Get prediction
62
+ logger.info("\nMaking prediction on dummy example:")
63
+ logger.info(f"Instruction: {instruction}")
64
+ logger.info(f"Response: {response}")
65
+
66
+ # Move tensors to the same device as model
67
+ device = next(model.parameters()).device
68
+ input_ids = encoded["input_ids"].to(device)
69
+ attention_mask = encoded["attention_mask"].to(device)
70
+
71
+ # Get raw score
72
  with torch.no_grad():
73
+ score = model(input_ids, attention_mask).item()
74
+
75
+ logger.info(f"\nRaw prediction score: {score:.4f}")
76
+
77
+ # Get binary prediction (using 0.5 threshold)
78
+ binary_prediction = "Correct" if score >= 0.5 else "Incorrect"
79
+ logger.info(f"Binary prediction (threshold 0.5): {binary_prediction}")
80
  ```
81
 
82
  ## Running Evaluation