philschmid HF staff commited on
Commit
8c68191
1 Parent(s): 2dd9b98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -24,9 +24,13 @@ if HF_TOKEN:
24
 
25
 
26
  # Load peft config for pre-trained checkpoint etc.
 
27
  torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
28
  model_id = "philschmid/instruct-igel-001"
29
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, device_map="auto")
 
 
 
30
  tokenizer = AutoTokenizer.from_pretrained(model_id)
31
 
32
  prompt_template = f"""### Anweisung:
@@ -65,7 +69,7 @@ def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
65
  streamer = IteratorStreamer(tokenizer)
66
  model_inputs = tokenizer(formatted_instruction, return_tensors="pt", truncation=True, max_length=2048)
67
  # move to gpu
68
- model_inputs = {k: v.cuda() for k, v in model_inputs.items()}
69
 
70
  generate_kwargs = dict(
71
  top_p=top_p,
@@ -186,4 +190,4 @@ with gr.Blocks(theme=theme) as demo:
186
  )
187
 
188
  demo.queue()
189
- demo.launch(share=True)
 
24
 
25
 
26
  # Load peft config for pre-trained checkpoint etc.
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
  torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
29
  model_id = "philschmid/instruct-igel-001"
30
+ if device == "cpu":
31
+ model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True)
32
+ else:
33
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, device_map="auto")
34
  tokenizer = AutoTokenizer.from_pretrained(model_id)
35
 
36
  prompt_template = f"""### Anweisung:
 
69
  streamer = IteratorStreamer(tokenizer)
70
  model_inputs = tokenizer(formatted_instruction, return_tensors="pt", truncation=True, max_length=2048)
71
  # move to gpu
72
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
73
 
74
  generate_kwargs = dict(
75
  top_p=top_p,
 
190
  )
191
 
192
  demo.queue()
193
+ demo.launch()