wiklif commited on
Commit
f7fc778
·
1 Parent(s): b7844b5

Zamiast używać InferenceClient, ładujemy model lokalnie za pomocą AutoModelForCausalLM i AutoTokenizer.

Browse files
Files changed (1) hide show
  1. app.py +22 -14
app.py CHANGED
@@ -1,29 +1,37 @@
 
1
  import spaces
2
- from huggingface_hub import InferenceClient
3
  import gradio as gr
4
- import os
 
5
 
6
- # Inicjalizacja klienta
7
- client = InferenceClient(
8
- model='meta-llama/Meta-Llama-3.1-8B',
9
- token=os.environ.get("MY_API_LLAMA_3_1")
10
- )
 
 
 
 
 
 
 
 
 
 
11
 
12
  @spaces.GPU(duration=60)
13
  def generate_response(chat, kwargs):
14
- output = ''
15
- stream = client.text_generation(chat, **kwargs, stream=True, details=True, return_full_text=False)
16
- for response in stream:
17
- output += response.token.text
18
- if output.endswith("</s>"): # Sprawdzamy, czy odpowiedź kończy się tagiem </s>
19
- output = output[:-4] # Usuwamy tag </s> z końca odpowiedzi
20
  return output
21
 
22
  def function(prompt, history=[]):
23
  chat = "<s>"
24
  for user_prompt, bot_response in history:
25
  chat += f"[INST] {user_prompt} [/INST] {bot_response}</s> <s>"
26
- chat += f"[INST] {prompt} [/INST]" # Zostawiamy tylko tag otwierający <s> na początku i kończymy ciąg zwykłym znacznikiem
27
  kwargs = dict(
28
  temperature=0.5,
29
  max_new_tokens=4096,
 
1
+ import os
2
  import spaces
 
3
  import gradio as gr
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
 
7
+ model_id = "meta-llama/Meta-Llama-3.1-8B"
8
+
9
+ @spaces.GPU(duration=60)
10
+ def load_model():
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("MY_API_LLAMA_3_1"))
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_id,
14
+ token=os.environ.get("MY_API_LLAMA_3_1"),
15
+ torch_dtype=torch.bfloat16,
16
+ device_map="auto",
17
+ low_cpu_mem_usage=True
18
+ )
19
+ return pipeline("text-generation", model=model, tokenizer=tokenizer)
20
+
21
+ pipe = load_model()
22
 
23
  @spaces.GPU(duration=60)
24
  def generate_response(chat, kwargs):
25
+ output = pipe(chat, **kwargs)[0]['generated_text']
26
+ if output.endswith("</s>"):
27
+ output = output[:-4]
 
 
 
28
  return output
29
 
30
  def function(prompt, history=[]):
31
  chat = "<s>"
32
  for user_prompt, bot_response in history:
33
  chat += f"[INST] {user_prompt} [/INST] {bot_response}</s> <s>"
34
+ chat += f"[INST] {prompt} [/INST]"
35
  kwargs = dict(
36
  temperature=0.5,
37
  max_new_tokens=4096,