Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Upload 4 files
Browse files- f5-tts/model/backbones/README.md +20 -0
- f5-tts/model/backbones/dit.py +163 -0
- f5-tts/model/backbones/mmdit.py +146 -0
- f5-tts/model/backbones/unett.py +219 -0
    	
        f5-tts/model/backbones/README.md
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ## Backbones quick introduction
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            ### unett.py
         | 
| 5 | 
            +
            - flat unet transformer
         | 
| 6 | 
            +
            - structure same as in e2-tts & voicebox paper except using rotary pos emb
         | 
| 7 | 
            +
            - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            ### dit.py
         | 
| 10 | 
            +
            - adaln-zero dit
         | 
| 11 | 
            +
            - embedded timestep as condition
         | 
| 12 | 
            +
            - concatted noised_input + masked_cond + embedded_text, linear proj in
         | 
| 13 | 
            +
            - possible abs pos emb & convnextv2 blocks for embedded text before concat
         | 
| 14 | 
            +
            - possible long skip connection (first layer to last layer)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            ### mmdit.py
         | 
| 17 | 
            +
            - sd3 structure
         | 
| 18 | 
            +
            - timestep as condition
         | 
| 19 | 
            +
            - left stream: text embedded and applied a abs pos emb
         | 
| 20 | 
            +
            - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
         | 
    	
        f5-tts/model/backbones/dit.py
    ADDED
    
    | @@ -0,0 +1,163 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            ein notation:
         | 
| 3 | 
            +
            b - batch
         | 
| 4 | 
            +
            n - sequence
         | 
| 5 | 
            +
            nt - text sequence
         | 
| 6 | 
            +
            nw - raw wave length
         | 
| 7 | 
            +
            d - dimension
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from __future__ import annotations
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from torch import nn
         | 
| 14 | 
            +
            import torch.nn.functional as F
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from x_transformers.x_transformers import RotaryEmbedding
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from f5_tts.model.modules import (
         | 
| 19 | 
            +
                TimestepEmbedding,
         | 
| 20 | 
            +
                ConvNeXtV2Block,
         | 
| 21 | 
            +
                ConvPositionEmbedding,
         | 
| 22 | 
            +
                DiTBlock,
         | 
| 23 | 
            +
                AdaLayerNormZero_Final,
         | 
| 24 | 
            +
                precompute_freqs_cis,
         | 
| 25 | 
            +
                get_pos_embed_indices,
         | 
| 26 | 
            +
            )
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            # Text embedding
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class TextEmbedding(nn.Module):
         | 
| 33 | 
            +
                def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
         | 
| 34 | 
            +
                    super().__init__()
         | 
| 35 | 
            +
                    self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim)  # use 0 as filler token
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    if conv_layers > 0:
         | 
| 38 | 
            +
                        self.extra_modeling = True
         | 
| 39 | 
            +
                        self.precompute_max_pos = 4096  # ~44s of 24khz audio
         | 
| 40 | 
            +
                        self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
         | 
| 41 | 
            +
                        self.text_blocks = nn.Sequential(
         | 
| 42 | 
            +
                            *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
         | 
| 43 | 
            +
                        )
         | 
| 44 | 
            +
                    else:
         | 
| 45 | 
            +
                        self.extra_modeling = False
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def forward(self, text: int["b nt"], seq_len, drop_text=False):  # noqa: F722
         | 
| 48 | 
            +
                    text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
         | 
| 49 | 
            +
                    text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
         | 
| 50 | 
            +
                    batch, text_len = text.shape[0], text.shape[1]
         | 
| 51 | 
            +
                    text = F.pad(text, (0, seq_len - text_len), value=0)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    if drop_text:  # cfg for text
         | 
| 54 | 
            +
                        text = torch.zeros_like(text)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    text = self.text_embed(text)  # b n -> b n d
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    # possible extra modeling
         | 
| 59 | 
            +
                    if self.extra_modeling:
         | 
| 60 | 
            +
                        # sinus pos emb
         | 
| 61 | 
            +
                        batch_start = torch.zeros((batch,), dtype=torch.long)
         | 
| 62 | 
            +
                        pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
         | 
| 63 | 
            +
                        text_pos_embed = self.freqs_cis[pos_idx]
         | 
| 64 | 
            +
                        text = text + text_pos_embed
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                        # convnextv2 blocks
         | 
| 67 | 
            +
                        text = self.text_blocks(text)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    return text
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            # noised input audio and context mixing embedding
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            class InputEmbedding(nn.Module):
         | 
| 76 | 
            +
                def __init__(self, mel_dim, text_dim, out_dim):
         | 
| 77 | 
            +
                    super().__init__()
         | 
| 78 | 
            +
                    self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
         | 
| 79 | 
            +
                    self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False):  # noqa: F722
         | 
| 82 | 
            +
                    if drop_audio_cond:  # cfg for cond audio
         | 
| 83 | 
            +
                        cond = torch.zeros_like(cond)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
         | 
| 86 | 
            +
                    x = self.conv_pos_embed(x) + x
         | 
| 87 | 
            +
                    return x
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            # Transformer backbone using DiT blocks
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            class DiT(nn.Module):
         | 
| 94 | 
            +
                def __init__(
         | 
| 95 | 
            +
                    self,
         | 
| 96 | 
            +
                    *,
         | 
| 97 | 
            +
                    dim,
         | 
| 98 | 
            +
                    depth=8,
         | 
| 99 | 
            +
                    heads=8,
         | 
| 100 | 
            +
                    dim_head=64,
         | 
| 101 | 
            +
                    dropout=0.1,
         | 
| 102 | 
            +
                    ff_mult=4,
         | 
| 103 | 
            +
                    mel_dim=100,
         | 
| 104 | 
            +
                    text_num_embeds=256,
         | 
| 105 | 
            +
                    text_dim=None,
         | 
| 106 | 
            +
                    conv_layers=0,
         | 
| 107 | 
            +
                    long_skip_connection=False,
         | 
| 108 | 
            +
                ):
         | 
| 109 | 
            +
                    super().__init__()
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    self.time_embed = TimestepEmbedding(dim)
         | 
| 112 | 
            +
                    if text_dim is None:
         | 
| 113 | 
            +
                        text_dim = mel_dim
         | 
| 114 | 
            +
                    self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
         | 
| 115 | 
            +
                    self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    self.rotary_embed = RotaryEmbedding(dim_head)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    self.dim = dim
         | 
| 120 | 
            +
                    self.depth = depth
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 123 | 
            +
                        [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
         | 
| 124 | 
            +
                    )
         | 
| 125 | 
            +
                    self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    self.norm_out = AdaLayerNormZero_Final(dim)  # final modulation
         | 
| 128 | 
            +
                    self.proj_out = nn.Linear(dim, mel_dim)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                def forward(
         | 
| 131 | 
            +
                    self,
         | 
| 132 | 
            +
                    x: float["b n d"],  # nosied input audio  # noqa: F722
         | 
| 133 | 
            +
                    cond: float["b n d"],  # masked cond audio  # noqa: F722
         | 
| 134 | 
            +
                    text: int["b nt"],  # text  # noqa: F722
         | 
| 135 | 
            +
                    time: float["b"] | float[""],  # time step  # noqa: F821 F722
         | 
| 136 | 
            +
                    drop_audio_cond,  # cfg for cond audio
         | 
| 137 | 
            +
                    drop_text,  # cfg for text
         | 
| 138 | 
            +
                    mask: bool["b n"] | None = None,  # noqa: F722
         | 
| 139 | 
            +
                ):
         | 
| 140 | 
            +
                    batch, seq_len = x.shape[0], x.shape[1]
         | 
| 141 | 
            +
                    if time.ndim == 0:
         | 
| 142 | 
            +
                        time = time.repeat(batch)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
         | 
| 145 | 
            +
                    t = self.time_embed(time)
         | 
| 146 | 
            +
                    text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
         | 
| 147 | 
            +
                    x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    rope = self.rotary_embed.forward_from_seq_len(seq_len)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    if self.long_skip_connection is not None:
         | 
| 152 | 
            +
                        residual = x
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    for block in self.transformer_blocks:
         | 
| 155 | 
            +
                        x = block(x, t, mask=mask, rope=rope)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    if self.long_skip_connection is not None:
         | 
| 158 | 
            +
                        x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    x = self.norm_out(x, t)
         | 
| 161 | 
            +
                    output = self.proj_out(x)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    return output
         | 
    	
        f5-tts/model/backbones/mmdit.py
    ADDED
    
    | @@ -0,0 +1,146 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            ein notation:
         | 
| 3 | 
            +
            b - batch
         | 
| 4 | 
            +
            n - sequence
         | 
| 5 | 
            +
            nt - text sequence
         | 
| 6 | 
            +
            nw - raw wave length
         | 
| 7 | 
            +
            d - dimension
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from __future__ import annotations
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from torch import nn
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from x_transformers.x_transformers import RotaryEmbedding
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from f5_tts.model.modules import (
         | 
| 18 | 
            +
                TimestepEmbedding,
         | 
| 19 | 
            +
                ConvPositionEmbedding,
         | 
| 20 | 
            +
                MMDiTBlock,
         | 
| 21 | 
            +
                AdaLayerNormZero_Final,
         | 
| 22 | 
            +
                precompute_freqs_cis,
         | 
| 23 | 
            +
                get_pos_embed_indices,
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            # text embedding
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class TextEmbedding(nn.Module):
         | 
| 31 | 
            +
                def __init__(self, out_dim, text_num_embeds):
         | 
| 32 | 
            +
                    super().__init__()
         | 
| 33 | 
            +
                    self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim)  # will use 0 as filler token
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    self.precompute_max_pos = 1024
         | 
| 36 | 
            +
                    self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]:  # noqa: F722
         | 
| 39 | 
            +
                    text = text + 1
         | 
| 40 | 
            +
                    if drop_text:
         | 
| 41 | 
            +
                        text = torch.zeros_like(text)
         | 
| 42 | 
            +
                    text = self.text_embed(text)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    # sinus pos emb
         | 
| 45 | 
            +
                    batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
         | 
| 46 | 
            +
                    batch_text_len = text.shape[1]
         | 
| 47 | 
            +
                    pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
         | 
| 48 | 
            +
                    text_pos_embed = self.freqs_cis[pos_idx]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    text = text + text_pos_embed
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    return text
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            # noised input & masked cond audio embedding
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            class AudioEmbedding(nn.Module):
         | 
| 59 | 
            +
                def __init__(self, in_dim, out_dim):
         | 
| 60 | 
            +
                    super().__init__()
         | 
| 61 | 
            +
                    self.linear = nn.Linear(2 * in_dim, out_dim)
         | 
| 62 | 
            +
                    self.conv_pos_embed = ConvPositionEmbedding(out_dim)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False):  # noqa: F722
         | 
| 65 | 
            +
                    if drop_audio_cond:
         | 
| 66 | 
            +
                        cond = torch.zeros_like(cond)
         | 
| 67 | 
            +
                    x = torch.cat((x, cond), dim=-1)
         | 
| 68 | 
            +
                    x = self.linear(x)
         | 
| 69 | 
            +
                    x = self.conv_pos_embed(x) + x
         | 
| 70 | 
            +
                    return x
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            # Transformer backbone using MM-DiT blocks
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            class MMDiT(nn.Module):
         | 
| 77 | 
            +
                def __init__(
         | 
| 78 | 
            +
                    self,
         | 
| 79 | 
            +
                    *,
         | 
| 80 | 
            +
                    dim,
         | 
| 81 | 
            +
                    depth=8,
         | 
| 82 | 
            +
                    heads=8,
         | 
| 83 | 
            +
                    dim_head=64,
         | 
| 84 | 
            +
                    dropout=0.1,
         | 
| 85 | 
            +
                    ff_mult=4,
         | 
| 86 | 
            +
                    text_num_embeds=256,
         | 
| 87 | 
            +
                    mel_dim=100,
         | 
| 88 | 
            +
                ):
         | 
| 89 | 
            +
                    super().__init__()
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    self.time_embed = TimestepEmbedding(dim)
         | 
| 92 | 
            +
                    self.text_embed = TextEmbedding(dim, text_num_embeds)
         | 
| 93 | 
            +
                    self.audio_embed = AudioEmbedding(mel_dim, dim)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    self.rotary_embed = RotaryEmbedding(dim_head)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.dim = dim
         | 
| 98 | 
            +
                    self.depth = depth
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 101 | 
            +
                        [
         | 
| 102 | 
            +
                            MMDiTBlock(
         | 
| 103 | 
            +
                                dim=dim,
         | 
| 104 | 
            +
                                heads=heads,
         | 
| 105 | 
            +
                                dim_head=dim_head,
         | 
| 106 | 
            +
                                dropout=dropout,
         | 
| 107 | 
            +
                                ff_mult=ff_mult,
         | 
| 108 | 
            +
                                context_pre_only=i == depth - 1,
         | 
| 109 | 
            +
                            )
         | 
| 110 | 
            +
                            for i in range(depth)
         | 
| 111 | 
            +
                        ]
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
                    self.norm_out = AdaLayerNormZero_Final(dim)  # final modulation
         | 
| 114 | 
            +
                    self.proj_out = nn.Linear(dim, mel_dim)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def forward(
         | 
| 117 | 
            +
                    self,
         | 
| 118 | 
            +
                    x: float["b n d"],  # nosied input audio  # noqa: F722
         | 
| 119 | 
            +
                    cond: float["b n d"],  # masked cond audio  # noqa: F722
         | 
| 120 | 
            +
                    text: int["b nt"],  # text  # noqa: F722
         | 
| 121 | 
            +
                    time: float["b"] | float[""],  # time step  # noqa: F821 F722
         | 
| 122 | 
            +
                    drop_audio_cond,  # cfg for cond audio
         | 
| 123 | 
            +
                    drop_text,  # cfg for text
         | 
| 124 | 
            +
                    mask: bool["b n"] | None = None,  # noqa: F722
         | 
| 125 | 
            +
                ):
         | 
| 126 | 
            +
                    batch = x.shape[0]
         | 
| 127 | 
            +
                    if time.ndim == 0:
         | 
| 128 | 
            +
                        time = time.repeat(batch)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
         | 
| 131 | 
            +
                    t = self.time_embed(time)
         | 
| 132 | 
            +
                    c = self.text_embed(text, drop_text=drop_text)
         | 
| 133 | 
            +
                    x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    seq_len = x.shape[1]
         | 
| 136 | 
            +
                    text_len = text.shape[1]
         | 
| 137 | 
            +
                    rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
         | 
| 138 | 
            +
                    rope_text = self.rotary_embed.forward_from_seq_len(text_len)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    for block in self.transformer_blocks:
         | 
| 141 | 
            +
                        c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    x = self.norm_out(x, t)
         | 
| 144 | 
            +
                    output = self.proj_out(x)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    return output
         | 
    	
        f5-tts/model/backbones/unett.py
    ADDED
    
    | @@ -0,0 +1,219 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            ein notation:
         | 
| 3 | 
            +
            b - batch
         | 
| 4 | 
            +
            n - sequence
         | 
| 5 | 
            +
            nt - text sequence
         | 
| 6 | 
            +
            nw - raw wave length
         | 
| 7 | 
            +
            d - dimension
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from __future__ import annotations
         | 
| 11 | 
            +
            from typing import Literal
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            from torch import nn
         | 
| 15 | 
            +
            import torch.nn.functional as F
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from x_transformers import RMSNorm
         | 
| 18 | 
            +
            from x_transformers.x_transformers import RotaryEmbedding
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from f5_tts.model.modules import (
         | 
| 21 | 
            +
                TimestepEmbedding,
         | 
| 22 | 
            +
                ConvNeXtV2Block,
         | 
| 23 | 
            +
                ConvPositionEmbedding,
         | 
| 24 | 
            +
                Attention,
         | 
| 25 | 
            +
                AttnProcessor,
         | 
| 26 | 
            +
                FeedForward,
         | 
| 27 | 
            +
                precompute_freqs_cis,
         | 
| 28 | 
            +
                get_pos_embed_indices,
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            # Text embedding
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class TextEmbedding(nn.Module):
         | 
| 36 | 
            +
                def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
         | 
| 37 | 
            +
                    super().__init__()
         | 
| 38 | 
            +
                    self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim)  # use 0 as filler token
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    if conv_layers > 0:
         | 
| 41 | 
            +
                        self.extra_modeling = True
         | 
| 42 | 
            +
                        self.precompute_max_pos = 4096  # ~44s of 24khz audio
         | 
| 43 | 
            +
                        self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
         | 
| 44 | 
            +
                        self.text_blocks = nn.Sequential(
         | 
| 45 | 
            +
                            *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
         | 
| 46 | 
            +
                        )
         | 
| 47 | 
            +
                    else:
         | 
| 48 | 
            +
                        self.extra_modeling = False
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def forward(self, text: int["b nt"], seq_len, drop_text=False):  # noqa: F722
         | 
| 51 | 
            +
                    text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
         | 
| 52 | 
            +
                    text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
         | 
| 53 | 
            +
                    batch, text_len = text.shape[0], text.shape[1]
         | 
| 54 | 
            +
                    text = F.pad(text, (0, seq_len - text_len), value=0)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    if drop_text:  # cfg for text
         | 
| 57 | 
            +
                        text = torch.zeros_like(text)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    text = self.text_embed(text)  # b n -> b n d
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    # possible extra modeling
         | 
| 62 | 
            +
                    if self.extra_modeling:
         | 
| 63 | 
            +
                        # sinus pos emb
         | 
| 64 | 
            +
                        batch_start = torch.zeros((batch,), dtype=torch.long)
         | 
| 65 | 
            +
                        pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
         | 
| 66 | 
            +
                        text_pos_embed = self.freqs_cis[pos_idx]
         | 
| 67 | 
            +
                        text = text + text_pos_embed
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                        # convnextv2 blocks
         | 
| 70 | 
            +
                        text = self.text_blocks(text)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    return text
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            # noised input audio and context mixing embedding
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            class InputEmbedding(nn.Module):
         | 
| 79 | 
            +
                def __init__(self, mel_dim, text_dim, out_dim):
         | 
| 80 | 
            +
                    super().__init__()
         | 
| 81 | 
            +
                    self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
         | 
| 82 | 
            +
                    self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False):  # noqa: F722
         | 
| 85 | 
            +
                    if drop_audio_cond:  # cfg for cond audio
         | 
| 86 | 
            +
                        cond = torch.zeros_like(cond)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
         | 
| 89 | 
            +
                    x = self.conv_pos_embed(x) + x
         | 
| 90 | 
            +
                    return x
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            # Flat UNet Transformer backbone
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            class UNetT(nn.Module):
         | 
| 97 | 
            +
                def __init__(
         | 
| 98 | 
            +
                    self,
         | 
| 99 | 
            +
                    *,
         | 
| 100 | 
            +
                    dim,
         | 
| 101 | 
            +
                    depth=8,
         | 
| 102 | 
            +
                    heads=8,
         | 
| 103 | 
            +
                    dim_head=64,
         | 
| 104 | 
            +
                    dropout=0.1,
         | 
| 105 | 
            +
                    ff_mult=4,
         | 
| 106 | 
            +
                    mel_dim=100,
         | 
| 107 | 
            +
                    text_num_embeds=256,
         | 
| 108 | 
            +
                    text_dim=None,
         | 
| 109 | 
            +
                    conv_layers=0,
         | 
| 110 | 
            +
                    skip_connect_type: Literal["add", "concat", "none"] = "concat",
         | 
| 111 | 
            +
                ):
         | 
| 112 | 
            +
                    super().__init__()
         | 
| 113 | 
            +
                    assert depth % 2 == 0, "UNet-Transformer's depth should be even."
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    self.time_embed = TimestepEmbedding(dim)
         | 
| 116 | 
            +
                    if text_dim is None:
         | 
| 117 | 
            +
                        text_dim = mel_dim
         | 
| 118 | 
            +
                    self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
         | 
| 119 | 
            +
                    self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    self.rotary_embed = RotaryEmbedding(dim_head)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # transformer layers & skip connections
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    self.dim = dim
         | 
| 126 | 
            +
                    self.skip_connect_type = skip_connect_type
         | 
| 127 | 
            +
                    needs_skip_proj = skip_connect_type == "concat"
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    self.depth = depth
         | 
| 130 | 
            +
                    self.layers = nn.ModuleList([])
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    for idx in range(depth):
         | 
| 133 | 
            +
                        is_later_half = idx >= (depth // 2)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                        attn_norm = RMSNorm(dim)
         | 
| 136 | 
            +
                        attn = Attention(
         | 
| 137 | 
            +
                            processor=AttnProcessor(),
         | 
| 138 | 
            +
                            dim=dim,
         | 
| 139 | 
            +
                            heads=heads,
         | 
| 140 | 
            +
                            dim_head=dim_head,
         | 
| 141 | 
            +
                            dropout=dropout,
         | 
| 142 | 
            +
                        )
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                        ff_norm = RMSNorm(dim)
         | 
| 145 | 
            +
                        ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        self.layers.append(
         | 
| 150 | 
            +
                            nn.ModuleList(
         | 
| 151 | 
            +
                                [
         | 
| 152 | 
            +
                                    skip_proj,
         | 
| 153 | 
            +
                                    attn_norm,
         | 
| 154 | 
            +
                                    attn,
         | 
| 155 | 
            +
                                    ff_norm,
         | 
| 156 | 
            +
                                    ff,
         | 
| 157 | 
            +
                                ]
         | 
| 158 | 
            +
                            )
         | 
| 159 | 
            +
                        )
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    self.norm_out = RMSNorm(dim)
         | 
| 162 | 
            +
                    self.proj_out = nn.Linear(dim, mel_dim)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def forward(
         | 
| 165 | 
            +
                    self,
         | 
| 166 | 
            +
                    x: float["b n d"],  # nosied input audio  # noqa: F722
         | 
| 167 | 
            +
                    cond: float["b n d"],  # masked cond audio  # noqa: F722
         | 
| 168 | 
            +
                    text: int["b nt"],  # text  # noqa: F722
         | 
| 169 | 
            +
                    time: float["b"] | float[""],  # time step  # noqa: F821 F722
         | 
| 170 | 
            +
                    drop_audio_cond,  # cfg for cond audio
         | 
| 171 | 
            +
                    drop_text,  # cfg for text
         | 
| 172 | 
            +
                    mask: bool["b n"] | None = None,  # noqa: F722
         | 
| 173 | 
            +
                ):
         | 
| 174 | 
            +
                    batch, seq_len = x.shape[0], x.shape[1]
         | 
| 175 | 
            +
                    if time.ndim == 0:
         | 
| 176 | 
            +
                        time = time.repeat(batch)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
         | 
| 179 | 
            +
                    t = self.time_embed(time)
         | 
| 180 | 
            +
                    text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
         | 
| 181 | 
            +
                    x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # postfix time t to input x, [b n d] -> [b n+1 d]
         | 
| 184 | 
            +
                    x = torch.cat([t.unsqueeze(1), x], dim=1)  # pack t to x
         | 
| 185 | 
            +
                    if mask is not None:
         | 
| 186 | 
            +
                        mask = F.pad(mask, (1, 0), value=1)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    # flat unet transformer
         | 
| 191 | 
            +
                    skip_connect_type = self.skip_connect_type
         | 
| 192 | 
            +
                    skips = []
         | 
| 193 | 
            +
                    for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
         | 
| 194 | 
            +
                        layer = idx + 1
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                        # skip connection logic
         | 
| 197 | 
            +
                        is_first_half = layer <= (self.depth // 2)
         | 
| 198 | 
            +
                        is_later_half = not is_first_half
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                        if is_first_half:
         | 
| 201 | 
            +
                            skips.append(x)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                        if is_later_half:
         | 
| 204 | 
            +
                            skip = skips.pop()
         | 
| 205 | 
            +
                            if skip_connect_type == "concat":
         | 
| 206 | 
            +
                                x = torch.cat((x, skip), dim=-1)
         | 
| 207 | 
            +
                                x = maybe_skip_proj(x)
         | 
| 208 | 
            +
                            elif skip_connect_type == "add":
         | 
| 209 | 
            +
                                x = x + skip
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                        # attention and feedforward blocks
         | 
| 212 | 
            +
                        x = attn(attn_norm(x), rope=rope, mask=mask) + x
         | 
| 213 | 
            +
                        x = ff(ff_norm(x)) + x
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    assert len(skips) == 0
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    x = self.norm_out(x)[:, 1:, :]  # unpack t from x
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    return self.proj_out(x)
         |