itamarlanger commited on
Commit
9d2cd90
·
1 Parent(s): 1b93d17

Add custom SageMaker inference script

Browse files
Files changed (1) hide show
  1. code/inference.py +22 -0
code/inference.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ from typing import List, Union
4
+
5
+
6
+ def input_fn(input_data, content_type):
7
+ data = json.loads(input_data)
8
+ return data['inputs']
9
+
10
+
11
+ def predict_fn(data: Union[List[str], str], model):
12
+ outputs = model(data, padding=False, truncation=True)
13
+ embeddings = [np.array(r[0]).mean(axis=0).tolist() for r in outputs]
14
+ return embeddings
15
+
16
+
17
+ def output_fn(prediction, accept):
18
+ return json.dumps(
19
+ obj={
20
+ "outputs": prediction
21
+ }
22
+ )