Ray Leung commited on
Commit
b9aea98
·
verified ·
1 Parent(s): 53f357f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -8,8 +8,8 @@ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
8
  pipeline = transformers.pipeline(
9
  "text-generation",
10
  model=model_id,
11
- model_kwargs={"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32},
12
- device_map="auto"
13
  )
14
 
15
  def respond(message, history, system_message, max_tokens, temperature, top_p):
@@ -45,7 +45,7 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
45
  return response_text, history
46
 
47
  # Define the Gradio interface
48
- demo = gr.ChatInterface(
49
  respond,
50
  additional_inputs=[
51
  gr.Textbox(value="You are a pirate chatbot who always responds in pirate speak!", label="System message"),
@@ -55,6 +55,5 @@ demo = gr.ChatInterface(
55
  ]
56
  )
57
 
58
-
59
  if __name__ == "__main__":
60
- demo.launch()
 
8
  pipeline = transformers.pipeline(
9
  "text-generation",
10
  model=model_id,
11
+ model_kwargs={"torch_dtype": torch.bfloat16},
12
+ device_map="auto",
13
  )
14
 
15
  def respond(message, history, system_message, max_tokens, temperature, top_p):
 
45
  return response_text, history
46
 
47
  # Define the Gradio interface
48
+ chatiface = gr.ChatInterface(
49
  respond,
50
  additional_inputs=[
51
  gr.Textbox(value="You are a pirate chatbot who always responds in pirate speak!", label="System message"),
 
55
  ]
56
  )
57
 
 
58
  if __name__ == "__main__":
59
+ chatiface.launch()