velocity-ai commited on
Commit
a38f107
·
verified ·
1 Parent(s): 8e90d33

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +2 -2
code/inference.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import json
3
  import torch
4
  import torch.nn as nn
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
@@ -43,7 +43,7 @@ def model_fn(model_dir, context=None):
43
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
44
 
45
  # Load base model
46
- base_model = AutoModelForCausalLM.from_pretrained(
47
  model_id,
48
  torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
49
  trust_remote_code=True
 
2
  import json
3
  import torch
4
  import torch.nn as nn
5
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
 
43
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
44
 
45
  # Load base model
46
+ base_model = AutoModelForSequenceClassification.from_pretrained(
47
  model_id,
48
  torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
49
  trust_remote_code=True