Skip to content

Add VibeVoice Acoustic Tokenizer#43400

Merged
ebezzam merged 41 commits intohuggingface:mainfrom
ebezzam:vibevoice_acoustic_tokenizer
Feb 6, 2026
Merged

Add VibeVoice Acoustic Tokenizer#43400
ebezzam merged 41 commits intohuggingface:mainfrom
ebezzam:vibevoice_acoustic_tokenizer

Conversation

@ebezzam
Copy link
Contributor

@ebezzam ebezzam commented Jan 22, 2026

What does this PR do?

Splitting off acoustic tokenizer from #40546
Such that VibeVoice ASR can be done in a separate / independent PR

Model card: https://huggingface.co/bezzam/VibeVoice-AcousticTokenizer

cc @eustlb

@ebezzam
Copy link
Contributor Author

ebezzam commented Jan 22, 2026

run-slow: vibevoice_acoustic_tokenizer

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/vibevoice_acoustic_tokenizer"]
quantizations: []

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

Model CI Report

❌ Failed tests

  • vibevoice_acoustic_tokenizer:
    tests/models/vibevoice_acoustic_tokenizer/test_modeling_vibevoice_acoustic_tokenizer.py::VibeVoiceAcousticTokenizerIntegrationTest::test_batch_integration

@ebezzam
Copy link
Contributor Author

ebezzam commented Jan 22, 2026

run-slow: vibevoice_acoustic_tokenizer

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/vibevoice_acoustic_tokenizer"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@ebezzam
Copy link
Contributor Author

ebezzam commented Jan 22, 2026

run-slow: vibevoice_acoustic_tokenizer

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/vibevoice_acoustic_tokenizer"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

Model CI Report

❌ Failed tests

  • vibevoice_acoustic_tokenizer:
    tests/models/vibevoice_acoustic_tokenizer/test_modeling_vibevoice_acoustic_tokenizer.py::VibeVoiceAcousticTokenizerIntegrationTest::test_batch_integration

@ebezzam
Copy link
Contributor Author

ebezzam commented Jan 22, 2026

run-slow: vibevoice_acoustic_tokenizer

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/vibevoice_acoustic_tokenizer"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

Copy link
Contributor Author

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eustlb a self-review with pointers to related discussions on the tokenizer from the (original) main model PR!


One key feature of VibeVoice is the use of two continuous speech tokenizers, one for extracting acoustic features and another for semantic features.

A model checkpoint is available at [bezzam/VibeVoice-AcousticTokenizer](https://huggingface.co/bezzam/VibeVoice-AcousticTokenizer)
Copy link
Contributor Author

@ebezzam ebezzam Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO update to official, current draft: https://huggingface.co/bezzam/VibeVoice-AcousticTokenizer

from transformers.audio_utils import load_audio_librosa


model_id = "bezzam/VibeVoice-AcousticTokenizer"
Copy link
Contributor Author

@ebezzam ebezzam Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO update to official


@can_return_tuple
@auto_docstring
def sample(self, latents):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wdyt of this new method? Related to discussion here

return hidden_states


class VibeVoiceAcousticTokenizerDecoder(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wdyt of refactoring of Decoder (and Encoder)? Related to this discussion


# Ensure torch tensors and mono
for idx, example in enumerate(audio):
example = torch.tensor(example, dtype=torch.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Direct casting to torch tensors. Related to this discussion

FYI I moved the feature extractor to the tokenizer, as it actually makes more sense here (needed by the tokenizer rather than the main model, which needs it because of the tokenizer)

updated_state_dict = {}

for key, value in state_dict.items():
new_key = key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that ideally we would prefer a key mapping which is much cleaner and clearer but that's ok for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I'll do for next models!

ebezzam and others added 6 commits January 22, 2026 19:59
…ibevoice_acoustic_tokenizer.py

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
…ibevoice_acoustic_tokenizer.py

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
…ibevoice_acoustic_tokenizer.py

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
…ibevoice_acoustic_tokenizer.py

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
@ebezzam ebezzam mentioned this pull request Jan 30, 2026
6 tasks
@ebezzam
Copy link
Contributor Author

ebezzam commented Feb 3, 2026

run-slow: vibevoice_acoustic_tokenizer

@github-actions
Copy link
Contributor

github-actions bot commented Feb 3, 2026

This comment contains run-slow, running the specified jobs:

models: ["models/vibevoice_acoustic_tokenizer"]
quantizations: []

@github-actions
Copy link
Contributor

github-actions bot commented Feb 3, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 9b208eeb merge commit
PR 9eb54f3c branch commit
main b6a202f8 base commit

✅ No failing test specific to this PR 🎉 👏 !

@ebezzam
Copy link
Contributor Author

ebezzam commented Feb 3, 2026

run-slow: vibevoice_acoustic_tokenizer

@github-actions
Copy link
Contributor

github-actions bot commented Feb 3, 2026

This comment contains run-slow, running the specified jobs:

models: ["models/vibevoice_acoustic_tokenizer"]
quantizations: []

Comment on lines +43 to +45
audio: torch.FloatTensor | None = None
latents: torch.FloatTensor | None = None
padding_cache: Optional["VibeVoiceAcousticTokenizerConv1dPaddingCache"] = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we compute a loss for training?

  • Encodec doesn't have
  • DAC "has" one, but not on the decoder output
  • Xcodec doesn't have

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there is no obvious way to do it, we can leave it for later :)

@github-actions
Copy link
Contributor

github-actions bot commented Feb 3, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN a40bccef merge commit
PR 1465b606 branch commit
main 36ec3bfa base commit

✅ No failing test specific to this PR 🎉 👏 !

Copy link
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Thanks a lot for iterating and being patient with my reviews 🤗
ready for a final review @ArthurZucker

Comment on lines +43 to +45
audio: torch.FloatTensor | None = None
latents: torch.FloatTensor | None = None
padding_cache: Optional["VibeVoiceAcousticTokenizerConv1dPaddingCache"] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there is no obvious way to do it, we can leave it for later :)

Comment on lines +470 to +476
if use_cache and padding_cache is None:
padding_cache = VibeVoiceAcousticTokenizerConv1dPaddingCache(
num_layers=self.encoder.num_conv_layers,
per_layer_padding=self.encoder.per_conv_layer_padding,
per_layer_padding_mode=self.encoder.per_conv_layer_padding_mode,
per_layer_in_channels=self.encoder.per_conv_layer_in_channels,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree!

Comment on lines +66 to +71
self,
audio: AudioInput,
sampling_rate: int | None = None,
padding: bool | str | PaddingStrategy | None = True,
pad_to_multiple_of: int | None = None,
return_attention_mask: bool | None = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes ok for me

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! LGTM i am a bit bothered by the way cache is handled, but this is something that will be refactored later on!

Main comment is if we can find a way to put stuff that the cache requires with the cache!

Otherwise LGTM 🤗

Comment on lines +298 to +313
# Parameters for cache creation
self.num_conv_layers = sum(depth + 1 for depth in config.depths) + 1
self.per_conv_layer_padding = [self.stem.conv.causal_padding]
self.per_conv_layer_in_channels = [self.stem.conv.conv.in_channels]
self.per_conv_layer_padding.extend([block.mixer.causal_padding for block in self.stem.stage])
self.per_conv_layer_in_channels.extend([block.mixer.conv.in_channels for block in self.stem.stage])

for layer in self.conv_layers:
self.per_conv_layer_padding.append(layer.conv.causal_padding)
self.per_conv_layer_in_channels.append(layer.conv.conv.in_channels)
self.per_conv_layer_padding.extend([block.mixer.causal_padding for block in layer.stage])
self.per_conv_layer_in_channels.extend([block.mixer.conv.in_channels for block in layer.stage])

self.per_conv_layer_padding.append(self.head.causal_padding)
self.per_conv_layer_in_channels.append(self.head.conv.in_channels)
self.per_conv_layer_padding_mode = ["constant" for _ in self.per_conv_layer_padding]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is TBH a bit weird to have!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

especially because its completely unused

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I've moved the creation of these variables to when the cache is created, like in Mimi: https://github.com/huggingface/transformers/blob/main/src%2Ftransformers%2Fmodels%2Fmimi%2Fmodeling_mimi.py#L1589-L1599

however, as discussed offline, we should think of a way to refactor MimiConv1dPaddingCache so it doesn't require so much overhead when creating it, cc @eustlb

@github-actions
Copy link
Contributor

github-actions bot commented Feb 5, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, vibevoice_acoustic_tokenizer

@ebezzam ebezzam merged commit 281eeef into huggingface:main Feb 6, 2026
25 checks passed
jiosephlee pushed a commit to jiosephlee/transformers_latest that referenced this pull request Feb 11, 2026
* Add vibevoice tokenizer files.

* Address style tests.

* Revert to expected outputs previously computed on runner.

* Enable encoder output test.

* Update expected output from runner

* Add note on expected outputs

* remove code link and better init

* Update src/transformers/models/vibevoice_acoustic_tokenizer/modular_vibevoice_acoustic_tokenizer.py

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>

* Update src/transformers/models/vibevoice_acoustic_tokenizer/modular_vibevoice_acoustic_tokenizer.py

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>

* Update src/transformers/models/vibevoice_acoustic_tokenizer/modular_vibevoice_acoustic_tokenizer.py

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>

* Update src/transformers/models/vibevoice_acoustic_tokenizer/modular_vibevoice_acoustic_tokenizer.py

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>

* modular

* Same changes to decoder layers.

* Update src/transformers/models/vibevoice_acoustic_tokenizer/modular_vibevoice_acoustic_tokenizer.py

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>

* doc nits

* Use decoder_depths for decoder!

* Doc nits

* Nits

* Trim feature extraction for tensor only usage.

* Add cache logic to encoder.

* Nit

* Revert to previous sampling approach.

* Nits

* Better logic for vae sampling?

* More standard conversion script.

* Revert to sample flag

* Nits

* Docs, cleanup, nits.

* Nit

* Nit

* Skip parallelism

* Shift cache creation to when it's used.

* Updated checkpoint path

---------

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants