Upload ultravox_model.py with huggingface_hub
Browse files- ultravox_model.py +6 -1
ultravox_model.py
CHANGED
@@ -76,6 +76,11 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
76 |
return model
|
77 |
|
78 |
def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
|
|
|
|
|
|
|
|
|
|
|
79 |
if (
|
80 |
self.config.text_model_id is not None
|
81 |
and self.language_model.device.type == "meta"
|
@@ -850,4 +855,4 @@ UltravoxModel.register_for_auto_class()
|
|
850 |
transformers.AutoConfig.register("ultravox", UltravoxConfig)
|
851 |
transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
|
852 |
|
853 |
-
transformers.activations.ACT2FN["swiglu"] = SwiGLU
|
|
|
76 |
return model
|
77 |
|
78 |
def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
|
79 |
+
if "torch_dtype" in kwargs:
|
80 |
+
self.config.torch_dtype = kwargs.pop("torch_dtype")
|
81 |
+
|
82 |
+
kwargs.pop("config", None)
|
83 |
+
|
84 |
if (
|
85 |
self.config.text_model_id is not None
|
86 |
and self.language_model.device.type == "meta"
|
|
|
855 |
transformers.AutoConfig.register("ultravox", UltravoxConfig)
|
856 |
transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
|
857 |
|
858 |
+
transformers.activations.ACT2FN["swiglu"] = SwiGLU
|