velocity-ai commited on
Commit
febdd76
·
verified ·
1 Parent(s): a38f107

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +10 -2
code/inference.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import json
3
  import torch
4
  import torch.nn as nn
5
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
@@ -42,9 +42,17 @@ def model_fn(model_dir, context=None):
42
  # Load tokenizer
43
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
44
 
45
- # Load base model
 
 
 
 
 
 
 
46
  base_model = AutoModelForSequenceClassification.from_pretrained(
47
  model_id,
 
48
  torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
49
  trust_remote_code=True
50
  )
 
2
  import json
3
  import torch
4
  import torch.nn as nn
5
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
 
42
  # Load tokenizer
43
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
44
 
45
+ # Configure model as RoBERTa
46
+ config = AutoConfig.from_pretrained(model_id,
47
+ num_labels=2,
48
+ architectures=["RobertaForSequenceClassification"],
49
+ model_type="roberta",
50
+ trust_remote_code=True)
51
+
52
+ # Load base model with RoBERTa config
53
  base_model = AutoModelForSequenceClassification.from_pretrained(
54
  model_id,
55
+ config=config,
56
  torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
57
  trust_remote_code=True
58
  )