T5 for Search Query Generation

class T5ForSQG:
  def __init__(self, model_path):
    self.model = T5ForConditionalGeneration.from_pretrained(model_path)
    self.tokenizer = T5Tokenizer.from_pretrained(model_path)

  def make_queries(self, topic, n=1, device='cpu', batch_size=16):
    ds = YourDataSetClass(pd.DataFrame({'topic': ['make queries: '+topic]*n, 'queries': [[]*n]}, index=range(n)), self.tokenizer, 64, 64, 'topic', 'queries')
    
    loader_params = {'batch_size': n if n < batch_size else batch_size, 'shuffle': False, 'num_workers': 0}

    loader = DataLoader(ds, **loader_params)

    self.model.eval()

    predictions = []
    with torch.no_grad():
        for _, data in enumerate(loader, 0):
            y = data['target_ids'].to(device, dtype = torch.long)
            ids = data['source_ids'].to(device, dtype = torch.long)
            mask = data['source_mask'].to(device, dtype = torch.long)

            generated_ids = self.model.generate(
                input_ids = ids,
                attention_mask = mask,
                max_length=64,
                num_beams=1,
                repetition_penalty=2.5,
                length_penalty=1.0,
                do_sample = True,
                temperature = 1.5,
                top_k = 10,
                top_p = 0.95
                )
            
            preds = list(set([self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]))
            predictions.extend(preds)
    
    return list(set(predictions))
Downloads last month
2
Safetensors
Model size
223M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Collection including 1rsh/t5-base-search-query-generation