openaudio-s1-mini / fish_speech /content_sequence.py
Stardust-minus's picture
Upload folder using huggingface_hub
a26769d verified
from dataclasses import dataclass, field
from typing import List, Literal, Union
import numpy as np
import torch
from fish_speech.tokenizer import (
IM_END_TOKEN,
MODALITY_TOKENS,
FishTokenizer,
)
def restore_ndarray(obj, to_tensor: bool = False):
if isinstance(obj, dict) and "__ndarray__" in obj:
obj = np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"])
if to_tensor and isinstance(obj, np.ndarray):
obj = torch.from_numpy(obj.copy())
return obj
@dataclass
class BasePart:
type: Literal["text", "vq", "audio"] | None = None
cal_loss: bool = False
@dataclass(kw_only=True)
class VQPart(BasePart):
type = "vq"
codes: torch.Tensor
def __post_init__(self: "VQPart"):
self.type = "vq"
self.codes = restore_ndarray(self.codes, to_tensor=True)
@dataclass(kw_only=True)
class TextPart(BasePart):
type = "text"
text: str | None = None
tokens: list[int] | None = None
def __post_init__(self: "TextPart"):
self.type = "text"
if self.text is None and self.tokens is None:
raise ValueError("Either text or tokens must be provided")
@dataclass(kw_only=True)
class AudioPart(BasePart):
type = "audio"
features: torch.Tensor
def __post_init__(self: "AudioPart"):
self.type = "audio"
self.features = restore_ndarray(self.features, to_tensor=True)
@dataclass(kw_only=True)
class EncodedMessage:
tokens: torch.Tensor
labels: torch.Tensor
vq_mask_tokens: torch.Tensor | None = None
vq_mask_labels: torch.Tensor | None = None
vq_parts: list[torch.Tensor]
vq_require_losses: torch.Tensor | None = None
audio_parts: list[torch.Tensor]
audio_masks: torch.Tensor | None = None
metadata: dict | None = None
@dataclass
class ContentSequence:
"""
Flexible sequence of content parts that supports interleaved multimodal format.
Example format: <|interleave|><|speaker:1|> TEXT AUDIO <|im_end|><|speaker:2|> TEXT AUDIO <|im_end|>
"""
parts: list[BasePart] = field(default_factory=list)
modality: Literal["text", "voice", "interleave"] | None = None
metadata: dict | None = None
def __init__(
self: "ContentSequence",
parts: list[BasePart | dict] | None = None,
modality: Literal["text", "voice", "interleave"] | None = None,
metadata: dict | None = None,
):
self.modality = modality
self.metadata = metadata or {}
fixed_parts = []
for part in parts or []:
if isinstance(part, dict):
if part["type"] == "vq":
part = VQPart(**part)
elif part["type"] == "audio":
part = AudioPart(**part)
elif part["type"] == "text":
part = TextPart(**part)
else:
raise ValueError(f"Unsupported part type: {part['type']}")
fixed_parts.append(part)
self.parts = fixed_parts
# If modality is specified, add it at the beginning if it's not already there
if self.modality and not (
len(self.parts) > 0
and isinstance(self.parts[0], dict) is False
and isinstance(self.parts[0], TextPart)
and self.parts[0].text is not None
and self.parts[0].text.startswith(MODALITY_TOKENS[self.modality])
):
modality_token = MODALITY_TOKENS[self.modality]
self.parts.insert(0, TextPart(text=modality_token))
def append(
self: "ContentSequence",
part_or_parts: Union[BasePart, List[BasePart]],
add_end: bool = False,
speaker: Union[str, int] | None = None,
):
"""
Append a part or list of parts to the sequence.
Args:
part_or_parts: A single part or list of parts to add
add_end: Whether to add the IM_END_TOKEN after these parts
speaker: Optional speaker identifier (name or ID) to add before the parts
"""
# Convert single part to list
parts_to_add = (
[part_or_parts] if not isinstance(part_or_parts, list) else part_or_parts
)
# Add speaker token if specified
if speaker is not None:
speaker_token = f"<|speaker:{speaker}|>"
self.parts.append(TextPart(text=speaker_token))
# Add all the parts
self.parts.extend(parts_to_add)
# Add end token if requested
if add_end:
self.parts.append(
TextPart(text=IM_END_TOKEN, cal_loss=self.parts[-1].cal_loss)
)
def encode(
self: "ContentSequence",
tokenizer: FishTokenizer,
add_shift: bool = True,
ignore_loss_tokens: list[str] = [],
) -> EncodedMessage:
"""
Encode the sequence parts into tokens for the model.
Args:
tokenizer: The tokenizer to use
add_shift: Whether to shift tokens for next-token prediction
ignore_loss_tokens: List of token strings to ignore when calculating loss
Returns:
EncodedMessage with tensors ready for the model
"""
all_tokens = []
all_labels = []
# Multi-modal elements
vq_parts = []
vq_masks = []
vq_require_losses = []
audio_parts = []
audio_masks = []
ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
for part in self.parts:
if isinstance(part, TextPart):
if part.tokens is None:
assert part.text is not None
tokens = tokenizer.encode(part.text)
else:
tokens = part.tokens
tokens = torch.tensor(tokens, dtype=torch.int)
elif isinstance(part, VQPart):
curr_codes = part.codes.clone().to(torch.int)
tokens = torch.tensor(
[
tokenizer.semantic_id_to_token_id[int(i.item())]
for i in curr_codes[0].int()
],
dtype=torch.int,
)
vq_parts.append(curr_codes)
vq_require_losses.append(part.cal_loss)
else:
raise ValueError(f"Unsupported part type: {type(part)}")
all_tokens.append(tokens)
# Set masks for different part types
if isinstance(part, VQPart):
vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
elif isinstance(part, AudioPart):
vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
audio_mask = torch.ones_like(tokens, dtype=torch.bool)
audio_mask[0] = False # Skip start token
audio_mask[-1] = False # Skip end token
audio_masks.append(audio_mask)
else:
vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
audio_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
# Set labels based on whether we want to calculate loss for this part
if part.cal_loss and not isinstance(part, AudioPart):
all_labels.append(tokens.clone())
else:
all_labels.append(torch.full_like(tokens, -100))
# Concatenate all tensors
tokens = torch.cat(all_tokens, dim=0)
labels = torch.cat(all_labels, dim=0)
vq_masks = torch.cat(vq_masks, dim=0)
audio_masks = torch.cat(audio_masks, dim=0)
vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
# Apply shift if needed for next-token prediction
vq_mask_tokens = vq_masks
vq_mask_labels = vq_masks
if add_shift:
tokens = tokens[:-1]
labels = labels[1:]
vq_masks = vq_masks[:-1]
vq_mask_tokens = vq_mask_tokens[:-1]
vq_mask_labels = vq_mask_labels[1:]
audio_masks = audio_masks[:-1]
# Ignore specified tokens
for i in ignore_loss_token_ids:
assert i != -100 and i is not None
labels[labels == i] = -100
assert tokens.dtype in [
torch.int,
torch.long,
], f"Invalid dtype: {tokens.dtype}"
return EncodedMessage(
tokens=tokens,
labels=labels,
vq_parts=vq_parts,
vq_mask_tokens=vq_mask_tokens,
vq_mask_labels=vq_mask_labels,
vq_require_losses=vq_require_losses,
audio_parts=audio_parts,
audio_masks=audio_masks,
metadata=self.metadata,
)
def encode_for_inference(
self: "ContentSequence",
tokenizer: FishTokenizer,
num_codebooks: int,
) -> torch.Tensor:
encoded = self.encode(tokenizer, add_shift=False)
tokens = encoded.tokens
values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
values[0] = tokens
if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and (
encoded.audio_parts is None or len(encoded.audio_parts) == 0
):
return values
if encoded.vq_parts is not None and len(encoded.vq_parts) > 0:
vq_parts = encoded.vq_parts
vq_parts = torch.cat(vq_parts, dim=1)
values[0, encoded.vq_mask_tokens] = (
vq_parts[0] + tokenizer.semantic_begin_id
)
values[1:, encoded.vq_mask_tokens] = vq_parts
return values
def visualize(
self: "ContentSequence",
tokenizer: FishTokenizer,
ignore_loss_tokens: list[str] = [],
merge_semantic_tokens: bool = False,
):
"""
Visualize the encoded sequence with color-coded tokens.
Blue/cyan tokens contribute to loss, green tokens do not.
"""
encoded = self.encode(
tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
)
# Colors for alternating tokens
colors = {
"blue": "\033[94m", # Light blue
"cyan": "\033[96m", # Cyan
"green": "\033[92m", # Light green
"dark_green": "\033[32m", # Dark green
}
blue_idx = 0
green_idx = 0
def print_in_blue(x):
nonlocal blue_idx
color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
print(f"{color}{x}\033[0m", end="")
blue_idx += 1
def print_in_green(x):
nonlocal green_idx
color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
print(f"{color}{x}\033[0m", end="")
green_idx += 1
def print_semantic_token(x, count):
val = f"[<|semantic|>x{count}]"
if x == -100:
print_in_green(val)
else:
print_in_blue(val)
count_semantic_tokens = 0
semantic_label = None
for tok, lab in zip(encoded.tokens, encoded.labels):
token_id = int(tok.item())
if merge_semantic_tokens:
if (
tokenizer.semantic_begin_id <= token_id <= tokenizer.semantic_end_id
and (semantic_label is None or semantic_label == lab)
):
count_semantic_tokens += 1
semantic_label = lab
continue
elif count_semantic_tokens > 0:
print_semantic_token(semantic_label, count_semantic_tokens)
count_semantic_tokens = 0
semantic_label = None
val = tokenizer.decode([int(tok.item())])
if lab == -100:
print_in_green(val)
else:
print_in_blue(val)
if merge_semantic_tokens and count_semantic_tokens > 0:
print_semantic_token(semantic_label, count_semantic_tokens)
print()