sthenno commited on
Commit
e35e3bc
·
1 Parent(s): 9d6d7b1

update(core): fix code

Browse files
Files changed (3) hide show
  1. __pycache__/utils.cpython-312.pyc +0 -0
  2. app.py +37 -23
  3. utils.py +51 -67
__pycache__/utils.cpython-312.pyc ADDED
Binary file (4.1 kB). View file
 
app.py CHANGED
@@ -1,69 +1,83 @@
1
  import gradio as gr
2
 
3
- from utils import checkpoints, load_model, log_perplexity
4
 
5
 
6
- class ModelManager:
7
  """Class to manage model loading and perplexity calculation state."""
8
 
9
  def __init__(self):
10
- self.loaded_models = None
11
 
12
  def load_models(self, checkpoint_input_str: str) -> str:
13
  """Load models from a comma-separated string of checkpoint names."""
14
- checkpoint_list = [
15
- c.strip() for c in checkpoint_input_str.split(",") if c.strip()
16
  ]
17
 
18
- if not checkpoint_list:
19
  return "Please enter at least one model checkpoint name."
20
 
21
  try:
22
- self.loaded_models = load_model(checkpoint_list)
23
  return "Models loaded successfully!"
24
  except Exception as e:
25
  return f"Model loading failed: {e}"
26
 
27
- def calculate_perplexity(self) -> dict | str:
 
 
 
 
28
  """Calculate perplexity using the loaded models."""
29
- if self.loaded_models is None:
30
  return "Please load models first."
31
-
 
32
  try:
33
- result = log_perplexity()
34
- return result
35
  except Exception as e:
36
  return f"Perplexity calculation failed: {e}"
37
 
38
 
39
- def create_interface() -> gr.Blocks:
40
  """Create and return the Gradio interface."""
41
- manager = ModelManager()
42
 
43
  with gr.Blocks() as demo:
44
- gr.Markdown("# LLM PPL")
45
 
46
- checkpoint_input = gr.Textbox(
47
- label="Checkpoints",
48
- value=", ".join(checkpoints),
49
  )
50
 
51
  load_btn = gr.Button("Load Models", variant="primary")
52
- perplexity_btn = gr.Button("Compute PPL")
53
 
54
- load_output = gr.Textbox(label="Model Loading Status", interactive=False)
 
 
 
 
 
 
55
  perplexity_output = gr.JSON(label="PPL Results")
56
 
57
  # Connect event handlers
58
  load_btn.click(
59
- fn=manager.load_models, inputs=checkpoint_input, outputs=load_output
 
 
60
  )
61
 
62
- perplexity_btn.click(fn=manager.calculate_perplexity, outputs=perplexity_output)
 
 
 
 
63
 
64
  return demo
65
 
66
 
67
  if __name__ == "__main__":
68
- demo = create_interface()
69
  demo.launch()
 
1
  import gradio as gr
2
 
3
+ from utils import load_model, log_perplexity
4
 
5
 
6
+ class Manager:
7
  """Class to manage model loading and perplexity calculation state."""
8
 
9
  def __init__(self):
10
+ self.loaded = None
11
 
12
  def load_models(self, checkpoint_input_str: str) -> str:
13
  """Load models from a comma-separated string of checkpoint names."""
14
+ checkpoints = [
15
+ ckpt.strip() for ckpt in checkpoint_input_str.split(",") if ckpt.strip()
16
  ]
17
 
18
+ if not checkpoints:
19
  return "Please enter at least one model checkpoint name."
20
 
21
  try:
22
+ self.loaded = load_model(checkpoints)
23
  return "Models loaded successfully!"
24
  except Exception as e:
25
  return f"Model loading failed: {e}"
26
 
27
+ def perplexity(
28
+ self,
29
+ num_samples: int | None = None,
30
+ sample_length: int | None = None,
31
+ ) -> dict | str:
32
  """Calculate perplexity using the loaded models."""
33
+ if self.loaded is None:
34
  return "Please load models first."
35
+ if num_samples is None or sample_length is None:
36
+ return "Please set the number of samples and sample length."
37
  try:
38
+ return log_perplexity(self.loaded, num_samples, sample_length)
 
39
  except Exception as e:
40
  return f"Perplexity calculation failed: {e}"
41
 
42
 
43
+ def make_interface() -> gr.Blocks:
44
  """Create and return the Gradio interface."""
45
+ manager = Manager()
46
 
47
  with gr.Blocks() as demo:
48
+ gr.Markdown("# LLM PPLs")
49
 
50
+ checkpoints = gr.Textbox(
51
+ label="Checkpoints", value="HuggingFaceTB/SmolLM2-135M"
 
52
  )
53
 
54
  load_btn = gr.Button("Load Models", variant="primary")
 
55
 
56
+ with gr.Row():
57
+ num_samples = gr.Number(label="Number of Samples", value=1500)
58
+ sample_length = gr.Number(label="Sample Length", value=128)
59
+
60
+ perplexity_btn = gr.Button("Compute PPLs")
61
+
62
+ load_output = gr.Textbox(label="Model Loading Status")
63
  perplexity_output = gr.JSON(label="PPL Results")
64
 
65
  # Connect event handlers
66
  load_btn.click(
67
+ fn=manager.load_models,
68
+ inputs=checkpoints,
69
+ outputs=load_output,
70
  )
71
 
72
+ perplexity_btn.click(
73
+ fn=manager.perplexity,
74
+ inputs=[num_samples, sample_length],
75
+ outputs=perplexity_output,
76
+ )
77
 
78
  return demo
79
 
80
 
81
  if __name__ == "__main__":
82
+ demo = make_interface()
83
  demo.launch()
utils.py CHANGED
@@ -1,94 +1,78 @@
1
  from typing import Final
2
 
 
3
  import numpy as np
4
  import torch
5
  import ujson as json
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
- dev: Final = "cuda" if torch.cuda.is_available() else "cpu"
9
- texts: Final = json.load(open("texts.json", "r"))
10
 
11
- checkpoints = ["HuggingFaceTB/SmolLM2-135M"] # Inputs
12
 
 
 
 
 
 
 
 
13
 
14
- def load_model(checkpoints: list[str]) -> dict:
15
- tokenizers = [
16
- AutoTokenizer.from_pretrained(checkpoint) for checkpoint in checkpoints
17
- ]
18
 
 
 
19
  models = [
20
- AutoModelForCausalLM.from_pretrained(
21
- checkpoint,
22
- device_map="auto",
23
- torch_dtype=torch.bfloat16,
24
- )
25
- .to(dev)
26
- .eval()
27
- for checkpoint in checkpoints
28
  ]
29
 
30
  # Load the models and tokenizers into a dictionary
31
  return {
32
- checkpoint: {"model": model, "tokenizer": tokenizer}
33
- for checkpoint, model, tokenizer in zip(checkpoints, models, tokenizers)
34
  }
35
 
36
 
37
- def _perplexity(model, tokenizer, text):
38
- encodings = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
39
- input_ids = encodings.input_ids.to(dev)
40
- with torch.no_grad():
41
- outputs = model(input_ids, labels=input_ids)
42
- loss = outputs.loss.item()
43
- return torch.exp(torch.tensor(loss)).item()
44
-
45
-
46
- num_samples: Final[int] = 500 # Sample size for perplexity calculation
47
- sample_length: Final[int] = 100 # Maximum length of text to consider for perplexity
48
-
49
- loaded = load_model(checkpoints)
50
 
 
 
 
 
 
 
51
 
52
- def log_perplexity() -> dict:
53
- # Initialize a dictionary to store perplexity
54
- ppls = {checkpoint: [] for checkpoint in loaded.keys()}
55
  for i in range(num_samples):
56
- text = texts[i]
57
- if len(text.strip()) == 0:
58
- continue
59
-
60
- text = text.strip()[:sample_length]
61
-
62
- # Calculate perplexity for each model
63
- current_ppls = {}
64
- for checkpoint, info in loaded.items():
65
- ppl = _perplexity(
66
- info["model"],
67
- info["tokenizer"],
68
- text,
69
- )
70
- current_ppls[checkpoint] = ppl
71
-
72
- # Filter out outliers
73
- if all(1 < ppl < 1e4 for ppl in current_ppls.values()):
74
- for checkpoint, ppl in current_ppls.items():
75
- ppls[checkpoint].append(ppl)
76
-
77
- # Convert perplexity into log scale
78
- log_ppls: dict = {checkpoint: np.log(ppl) for checkpoint, ppl in ppls.items()}
79
 
80
  # Calculate the mean perplexity for each model
81
- mean_log_ppls: dict = {
82
- checkpoint: np.mean(ppl) for checkpoint, ppl in log_ppls.items()
83
- }
84
 
85
  # Calculate the standard deviation of perplexity for each model
86
- std_log_ppls: dict = {
87
- checkpoint: np.std(ppl) for checkpoint, ppl in log_ppls.items()
88
- }
89
 
90
- return {
91
- "ppls": ppls,
92
- "mean_ppls": mean_log_ppls,
93
- "std_ppls": std_log_ppls,
94
- }
 
 
 
 
 
 
 
 
1
  from typing import Final
2
 
3
+ import gradio as gr
4
  import numpy as np
5
  import torch
6
  import ujson as json
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
+ _dev: Final = "cuda" if torch.cuda.is_available() else "cpu"
10
+ _dtype: Final = torch.bfloat16
11
 
 
12
 
13
+ def _perplexity(model, tokenizer, text) -> float:
14
+ encodings = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
15
+ input_ids = encodings.input_ids.to(_dev)
16
+ with torch.no_grad():
17
+ outputs = model(input_ids, labels=input_ids)
18
+ loss = outputs.loss.item()
19
+ return np.log(torch.exp(torch.tensor(loss)).item())
20
 
 
 
 
 
21
 
22
+ def load_model(checkpoints: list[str]) -> dict:
23
+ tokenizers = [AutoTokenizer.from_pretrained(c) for c in checkpoints]
24
  models = [
25
+ AutoModelForCausalLM.from_pretrained(c, device_map="auto", torch_dtype=_dtype)
26
+ for c in checkpoints
 
 
 
 
 
 
27
  ]
28
 
29
  # Load the models and tokenizers into a dictionary
30
  return {
31
+ ckpt: {"model": model.to(_dev).eval(), "tokenizer": tokenizer}
32
+ for ckpt, model, tokenizer in zip(checkpoints, models, tokenizers)
33
  }
34
 
35
 
36
+ def log_perplexity(
37
+ loaded: dict,
38
+ num_samples: int,
39
+ sample_length: int,
40
+ progress=gr.Progress(),
41
+ ) -> dict:
42
+ # Initialize a dictionary to store perplexity
43
+ ppls: dict[str, list] = {ckpt: [] for ckpt in loaded.keys()}
 
 
 
 
 
44
 
45
+ # Initialize samples
46
+ texts: Final[list[str]] = [
47
+ text.strip()[:sample_length]
48
+ for text in json.load(open("texts.json", "r"))
49
+ if text.strip()
50
+ ]
51
 
52
+ # Start the iteration
53
+ progress(0, desc="Starting")
 
54
  for i in range(num_samples):
55
+ progress(i / num_samples, desc="Processing samples")
56
+ for ckpt, info in loaded.items(): # Calculate perplexity for each model
57
+ ppl: float = _perplexity(info["model"], info["tokenizer"], texts[i])
58
+ if 1 < ppl < 1e4: # Filter out outliers
59
+ ppls[ckpt].append(ppl)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # Calculate the mean perplexity for each model
62
+ means: dict = {ckpt: np.mean(ppl) for ckpt, ppl in ppls.items()}
 
 
63
 
64
  # Calculate the standard deviation of perplexity for each model
65
+ stds: dict = {ckpt: np.std(ppl) for ckpt, ppl in ppls.items()}
 
 
66
 
67
+ return {"ppls": ppls, "means": means, "stds": stds}
68
+
69
+
70
+ if __name__ == "__main__":
71
+ from pprint import pprint
72
+
73
+ # Example usage
74
+ checkpoints = ["HuggingFaceTB/SmolLM2-135M"]
75
+ loaded = load_model(checkpoints)
76
+ num_samples = 500
77
+ sample_length = 128
78
+ pprint(log_perplexity(loaded, num_samples, sample_length))