satyaiyer commited on
Commit
1368f34
·
verified ·
1 Parent(s): 77a5ebe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -1,20 +1,19 @@
1
  import pandas as pd
2
  import gradio as gr
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import os
6
-
7
  model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
8
  hf_token = os.environ.get("HF_TOKEN")
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
11
- model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=hf_token)
12
-
13
- # model = AutoModelForCausalLM.from_pretrained(
14
- # model_name,
15
- # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
16
- # device_map="auto"
17
- # )
18
 
19
  def generate_prompt(original, translation):
20
  return f"### Task: Machine Translation Quality Estimation\n\nSource: {original}\nTranslation: {translation}\n\nScore (0-1):"
 
1
  import pandas as pd
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  import torch
5
  import os
6
+ bnb_config = BitsAndBytesConfig(load_in_4bit=True)
7
  model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
8
  hf_token = os.environ.get("HF_TOKEN")
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_name,
13
+ quantization_config=bnb_config,
14
+ device_map="auto",
15
+ use_auth_token=hf_token
16
+ )
 
17
 
18
  def generate_prompt(original, translation):
19
  return f"### Task: Machine Translation Quality Estimation\n\nSource: {original}\nTranslation: {translation}\n\nScore (0-1):"