ZetangForward commited on
Commit
5d4a357
·
verified ·
1 Parent(s): 06f1988

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +25 -0
pipeline.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, Pipeline, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
2
+ from transformers.pipelines import PIPELINE_REGISTRY
3
+ import torch
4
+
5
+
6
+ class SpanClassificationPipeline(Pipeline):
7
+ def __init__(self, model, tokenizer, device="cpu", **kwargs):
8
+ super().__init__(model=model, tokenizer=tokenizer, device=device, **kwargs)
9
+ self.model.to(self.device)
10
+ self.model.eval()
11
+
12
+ def _sanitize_parameters(self, **kwargs):
13
+ return {}, kwargs, {}
14
+
15
+ def preprocess(self, inputs):
16
+ return self.tokenizer(inputs, return_tensors="pt").to(self.device)
17
+
18
+ def _forward(self, model_inputs):
19
+ with torch.no_grad():
20
+ outputs = self.model(**model_inputs)
21
+ return outputs
22
+
23
+ def postprocess(self, model_outputs):
24
+ logits = model_outputs.logits
25
+ return int(torch.argmax(logits, dim=1).item())