farzadab commited on
Commit
3867238
·
verified ·
1 Parent(s): 15893c0

Upload ultravox_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ultravox_model.py +6 -1
ultravox_model.py CHANGED
@@ -76,6 +76,11 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
76
  return model
77
 
78
  def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
 
 
 
 
 
79
  if (
80
  self.config.text_model_id is not None
81
  and self.language_model.device.type == "meta"
@@ -850,4 +855,4 @@ UltravoxModel.register_for_auto_class()
850
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
851
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
852
 
853
- transformers.activations.ACT2FN["swiglu"] = SwiGLU
 
76
  return model
77
 
78
  def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
79
+ if "torch_dtype" in kwargs:
80
+ self.config.torch_dtype = kwargs.pop("torch_dtype")
81
+
82
+ kwargs.pop("config", None)
83
+
84
  if (
85
  self.config.text_model_id is not None
86
  and self.language_model.device.type == "meta"
 
855
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
856
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
857
 
858
+ transformers.activations.ACT2FN["swiglu"] = SwiGLU