splade-code-06B / splade.py
sclincha's picture
Support Sentence Transformers via SparseEncoder (#1)
b8b90b5
"""
Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
with `is_causal=False` in the config.
This file supports two loading paths:
1. Sentence Transformers: `SparseEncoder("naver/splade-code-06B", trust_remote_code=True)` via AutoModelForMaskedLM -> Qwen3ForCausalLM
2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-06B", trust_remote_code=True)` -> Splade
"""
import torch
import os
from transformers import Qwen3ForCausalLM as TransformersQwen3ForCausalLM
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig
from transformers.utils import is_flash_attn_2_available
from .utils import prepare_tokenizer, splade_max, similarity, encode
class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
def tie_weights(self, *args, **kwargs):
"""Explicitly re-tie lm_head to embed_tokens to hopefully avoid meta tensor errors."""
if (
self.config.tie_word_embeddings
and hasattr(self, "lm_head")
and hasattr(self, "model")
):
self.lm_head.weight = self.model.embed_tokens.weight
missing_keys = kwargs.get("missing_keys")
if missing_keys is not None:
missing_keys.discard("lm_head.weight")
else:
super().tie_weights(*args, **kwargs)
def _init_weights(self, module):
"""Skip lm_head init when it will be tied to embed_tokens later."""
if module is getattr(self, "lm_head", None) and self.config.tie_word_embeddings:
return
super()._init_weights(module)
class SpladeConfig(PretrainedConfig):
model_type = "qwen3"
def __init__(
self,
model_name_or_path: str = "Qwen/Qwen3-0.6B",
attn_implementation: str = "flash_attention_2",
bidirectional: bool = True, # only for decoder models
padding_side: str = "left",
**kwargs,
):
super().__init__(**kwargs)
self.model_name_or_path = model_name_or_path
self.attn_implementation = attn_implementation
self.bidirectional = bidirectional
self.padding_side = padding_side
class Splade(PreTrainedModel):
config_class = SpladeConfig
# methods for MTEB's interface
similarity = similarity
encode = encode
def __init__(self, config, weights_path=None, token=None):
super().__init__(config)
self.name = "splade"
base_cfg = AutoConfig.from_pretrained(
weights_path,
attn_implementation=config.attn_implementation,
torch_dtype="auto",
)
self.tokenizer = prepare_tokenizer(
weights_path, padding_side=config.padding_side
)
if is_flash_attn_2_available():
config.attn_implementation = "flash_attention_2"
else:
config.attn_implementation = "sdpa"
self.model = Qwen3ForCausalLM.from_pretrained(
weights_path,
config=base_cfg,
torch_dtype=torch.bfloat16,
attn_implementation=config.attn_implementation,
token=token,
)
def save_pretrained(self, save_directory, *args, **kwargs):
self.model.save_pretrained(os.path.join(save_directory, "lora"))
self.config.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, model_name_or_path, *args, **kwargs):
token = kwargs.get("token", None)
config = SpladeConfig.from_pretrained(
model_name_or_path,
token=token,
)
model = cls(config, weights_path=model_name_or_path, token=token)
model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
return model
def forward(self, **tokens):
output = self.model(**tokens)
splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
return (splade_reps,)
def get_width(self):
return self.model.config.vocab_size
def create_batch_dict(self, input_texts, max_length):
return self.tokenizer(
input_texts,
add_special_tokens=True,
padding="longest",
truncation=True,
max_length=max_length,
return_attention_mask=True,
return_tensors="pt",
)
__all__ = ["Qwen3ForCausalLM", "Splade"]