israel commited on
Commit
5f95d4e
·
verified ·
1 Parent(s): 5f299c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -5
app.py CHANGED
@@ -6,18 +6,35 @@ from transformers import pipeline
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
 
8
 
9
- quantization_config = BitsAndBytesConfig(load_in_4bit=True)
10
 
11
 
12
- model_name = "masakhane/zephyr-7b-gemma-sft-african-alpaca"
13
 
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
16
 
17
 
18
- pipe = pipeline("text-generation", model=model,tokenizer=tokenizer, torch_dtype=torch.bfloat16, device_map="auto")
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  if 'messages' not in st.session_state:
23
  st.session_state.messages = [
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
 
8
 
9
+ # quantization_config = BitsAndBytesConfig(load_in_4bit=True)
10
 
11
 
12
+ # model_name = "masakhane/zephyr-7b-gemma-sft-african-alpaca"
13
 
14
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ # model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
16
 
17
 
18
+ # pipe = pipeline("text-generation", model=model,tokenizer=tokenizer, torch_dtype=torch.bfloat16, device_map="auto")
19
 
20
 
21
+ import torch
22
+ from transformers import pipeline
23
+
24
+ pipe = pipeline("text-generation", model="masakhane/zephyr-7b-gemma-sft-african-alpaca", torch_dtype=torch.bfloat16, device_map="auto")
25
+
26
+ # We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
27
+ # messages = [
28
+ # {
29
+ # "role": "system",
30
+ # "content": "You are a friendly chatbot who answewrs question in given language",
31
+ # },
32
+ # {"role": "user", "content": "what is the 3 biggest countrys in Africa?"},
33
+ # ]
34
+ # prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
35
+ # outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
36
+ # print(outputs[0]["generated_text"])
37
+
38
 
39
  if 'messages' not in st.session_state:
40
  st.session_state.messages = [