root commited on
Commit
b8cb7fe
·
1 Parent(s): 80676ae

Mise à jour du fichier handler.py avec le code adapté

Browse files
Files changed (1) hide show
  1. handler.py +17 -7
handler.py CHANGED
@@ -1,26 +1,36 @@
1
  import json
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
- # Charger le modèle et le tokenizer
6
  model_name = "AIDC-AI/Ovis1.6-Gemma2-9B" # Remplacez par le nom de votre modèle
7
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
9
 
10
- # Fonction d'initialisation
 
 
 
 
11
  def init():
12
  global tokenizer, model
13
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
  model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
15
  model.eval() # Définir le modèle en mode évaluation
16
 
17
- # Fonction de prédiction
18
  def predict(data):
 
 
 
19
  inputs = data.get("inputs")
 
 
20
  if isinstance(inputs, str):
21
  inputs = tokenizer(inputs, return_tensors="pt")
22
 
 
23
  outputs = model.generate(**inputs)
 
 
24
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
 
 
26
  return json.dumps({"result": result})
 
1
  import json
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
 
4
+ # Nom du modèle que vous souhaitez utiliser
5
  model_name = "AIDC-AI/Ovis1.6-Gemma2-9B" # Remplacez par le nom de votre modèle
 
 
6
 
7
+ # Initialisation globale des variables
8
+ tokenizer = None
9
+ model = None
10
+
11
+ # Fonction d'initialisation qui sera appelée lors du démarrage du service
12
  def init():
13
  global tokenizer, model
14
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
  model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
16
  model.eval() # Définir le modèle en mode évaluation
17
 
18
+ # Fonction de prédiction qui sera appelée pour traiter les requêtes d'inférence
19
  def predict(data):
20
+ global tokenizer, model
21
+
22
+ # Extraire les données d'entrée
23
  inputs = data.get("inputs")
24
+
25
+ # Vérifier si les données d'entrée sont une chaîne de caractères
26
  if isinstance(inputs, str):
27
  inputs = tokenizer(inputs, return_tensors="pt")
28
 
29
+ # Générer les prédictions à partir du modèle
30
  outputs = model.generate(**inputs)
31
+
32
+ # Convertir les résultats en texte
33
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
 
35
+ # Retourner le résultat au format JSON
36
  return json.dumps({"result": result})