added missing imports
Browse files- import statements were missing
- max_length parameter in generate function was set to lower value than the input tokens, which was producing error, set it to little high than input tokens (from 10 to 64 as 48 as input tokens).
README.md
CHANGED
@@ -33,6 +33,8 @@ pip install git+https://github.com/huggingface/transformers.git@refs/pull/33410/
|
|
33 |
```
|
34 |
And then load the model :
|
35 |
```python
|
|
|
|
|
36 |
|
37 |
model = AutoModelForCausalLM.from_pretrained("HF1BitLLM/Llama3-8B-1.58-100B-tokens", device_map="cuda", torch_dtype=torch.bfloat16)
|
38 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
@@ -40,7 +42,7 @@ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
40 |
input_text = "Daniel went back to the the the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:"
|
41 |
|
42 |
input_ids = tokenizer.encode(input_text, return_tensors="pt").cuda()
|
43 |
-
output = model.generate(input_ids, max_length=
|
44 |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
45 |
print(generated_text)
|
46 |
```
|
|
|
33 |
```
|
34 |
And then load the model :
|
35 |
```python
|
36 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
37 |
+
import torch
|
38 |
|
39 |
model = AutoModelForCausalLM.from_pretrained("HF1BitLLM/Llama3-8B-1.58-100B-tokens", device_map="cuda", torch_dtype=torch.bfloat16)
|
40 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
|
|
42 |
input_text = "Daniel went back to the the the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:"
|
43 |
|
44 |
input_ids = tokenizer.encode(input_text, return_tensors="pt").cuda()
|
45 |
+
output = model.generate(input_ids, max_length=64, do_sample=False)
|
46 |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
47 |
print(generated_text)
|
48 |
```
|