Correct prompt padding side

#1
by ylacombe - opened
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -29,7 +29,8 @@ model = ParlerTTSForConditionalGeneration.from_pretrained(
29
 
30
  client = InferenceClient()
31
 
32
- tokenizer = AutoTokenizer.from_pretrained(repo_id)
 
33
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
34
 
35
  SAMPLE_RATE = feature_extractor.sampling_rate
@@ -87,8 +88,8 @@ def generate_base(subject, setting):
87
 
88
  gr.Info("Generating Audio")
89
  description = "Jenny speaks at an average pace with a calm delivery in a very confined sounding environment with clear audio quality."
90
- story_tokens = tokenizer(model_input_tokens, return_tensors="pt", padding=True).input_ids.to(device)
91
- description_tokens = tokenizer([description for _ in range(len(model_input_tokens))], return_tensors="pt").input_ids.to(device)
92
  speech_output = model.generate(input_ids=description_tokens, prompt_input_ids=story_tokens)
93
  speech_output = [output.cpu().numpy() for output in speech_output]
94
  gr.Info("Generated Audio")
 
29
 
30
  client = InferenceClient()
31
 
32
+ description_tokenizer = AutoTokenizer.from_pretrained(repo_id)
33
+ prompt_tokenizer = AutoTokenizer.from_pretrained(repo_id, padding_side="left")
34
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
35
 
36
  SAMPLE_RATE = feature_extractor.sampling_rate
 
88
 
89
  gr.Info("Generating Audio")
90
  description = "Jenny speaks at an average pace with a calm delivery in a very confined sounding environment with clear audio quality."
91
+ story_tokens = prompt_tokenizer(model_input_tokens, return_tensors="pt", padding=True).input_ids.to(device)
92
+ description_tokens = description_tokenizer([description for _ in range(len(model_input_tokens))], return_tensors="pt").input_ids.to(device)
93
  speech_output = model.generate(input_ids=description_tokens, prompt_input_ids=story_tokens)
94
  speech_output = [output.cpu().numpy() for output in speech_output]
95
  gr.Info("Generated Audio")