Skip to content

pocket_tts.models.tts_model.TTSModel

Bases: Module

Source code in pocket_tts/models/tts_model.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
class TTSModel(nn.Module):
    _TOKENS_PER_SECOND_ESTIMATE = 3.0
    _GEN_SECONDS_PADDING = 2.0

    def __init__(
        self,
        flow_lm: FlowLMModel,
        temp: float,
        lsd_decode_steps: int,
        noise_clamp: float | None,
        eos_threshold,
        config: Config,
    ):
        super().__init__()
        self.flow_lm = flow_lm
        self.temp = temp
        self.lsd_decode_steps = lsd_decode_steps
        self.noise_clamp = noise_clamp
        self.eos_threshold = eos_threshold
        self.config = config
        self.has_voice_cloning = True

    @property
    def device(self) -> str:
        return next(self.parameters()).device.type

    @property
    def sample_rate(self) -> int:
        return self.config.mimi.sample_rate

    @classmethod
    def _from_pydantic_config(
        cls, config: Config, temp, lsd_decode_steps, noise_clamp: float | None, eos_threshold
    ) -> Self:
        flow_lm = FlowLMModel.from_pydantic_config(
            config.flow_lm, latent_dim=config.mimi.quantizer.dimension
        )
        tts_model = cls(flow_lm, temp, lsd_decode_steps, noise_clamp, eos_threshold, config)
        return tts_model

    @classmethod
    def _from_pydantic_config_with_weights(
        cls, config: Config, temp, lsd_decode_steps, noise_clamp: float | None, eos_threshold
    ) -> Self:
        tts_model = cls._from_pydantic_config(
            config, temp, lsd_decode_steps, noise_clamp, eos_threshold
        )
        tts_model.flow_lm.speaker_proj_weight = torch.nn.Parameter(
            torch.zeros((1024, 512), dtype=torch.float32)
        )
        if config.flow_lm.weights_path is not None:
            if config.mimi.weights_path is None:
                raise ValueError(
                    "If you specify flow_lm.weights_path you should specify mimi.weights_path"
                )
            logger.info(f"Loading FlowLM weights from {config.flow_lm.weights_path}")
            state_dict_flowlm = get_flow_lm_state_dict(
                download_if_necessary(config.flow_lm.weights_path)
            )
            tts_model.flow_lm.load_state_dict(state_dict_flowlm, strict=True)

        # safetensors.torch.save_file(tts_model.state_dict(), "7442637a.safetensors")
        # Create mimi config directly from the provided config using model_dump
        mimi_config = config.mimi.model_dump()

        # Build mimi model from config
        encoder = SEANetEncoder(**mimi_config["seanet"])
        decoder = SEANetDecoder(**mimi_config["seanet"])

        encoder_transformer = mimi_transformer.ProjectedTransformer(**mimi_config["transformer"])
        decoder_transformer = mimi_transformer.ProjectedTransformer(**mimi_config["transformer"])
        quantizer = DummyQuantizer(**mimi_config["quantizer"])

        tts_model.mimi = MimiModel(
            encoder,
            decoder,
            quantizer,
            channels=mimi_config["channels"],
            sample_rate=mimi_config["sample_rate"],
            frame_rate=mimi_config["frame_rate"],
            encoder_frame_rate=mimi_config["sample_rate"] / encoder.hop_length,
            encoder_transformer=encoder_transformer,
            decoder_transformer=decoder_transformer,
        ).to(device="cpu")

        # Load mimi weights from the config safetensors file with complete mapping for strict loading

        if config.mimi.weights_path is not None:
            if config.flow_lm.weights_path is None:
                raise ValueError(
                    "If you specify mimi.weights_path you should specify flow_lm.weights_path"
                )
            logger.info(f"Loading Mimi weights from {config.mimi.weights_path}")
            mimi_state = get_mimi_state_dict(download_if_necessary(config.mimi.weights_path))
            tts_model.mimi.load_state_dict(mimi_state, strict=True)

        tts_model.mimi.eval()
        # tts_model.to(dtype=torch.float32)

        # uncomment to save the weights
        # tts_model = tts_model.to(dtype=torch.bfloat16)
        # safetensors.torch.save_file(tts_model.state_dict(), "tts_b6369a24.safetensors")
        if config.weights_path is not None:
            logger.info(f"Loading TTSModel weights from {config.weights_path}")
            try:
                weights_file = download_if_necessary(config.weights_path)
            except Exception:
                tts_model.has_voice_cloning = False
                weights_file = download_if_necessary(config.weights_path_without_voice_cloning)

            state_dict = safetensors.torch.load_file(weights_file)
            tts_model.load_state_dict(state_dict, strict=True)

        if config.flow_lm.weights_path is None and config.weights_path is None:
            logger.warning(
                "No weights_path specified for FlowLM or TTSModel, model is uninitialized!"
            )
        size_in_mb = size_of_dict(tts_model.state_dict()) // 1e6
        logging.info(f"TTS Model loaded successfully. Its size is {size_in_mb} MB")

        # TODO: move this in the __init__ and make self.mimi in __init__
        for top_module in (tts_model.flow_lm, tts_model.mimi):
            for module_name, module in top_module.named_modules():
                if not isinstance(module, StatefulModule):
                    continue
                module._module_absolute_name = module_name

        return tts_model

    @classmethod
    def load_model(
        cls,
        config: str | Path = DEFAULT_VARIANT,
        temp: float | int = DEFAULT_TEMPERATURE,
        lsd_decode_steps: int = DEFAULT_LSD_DECODE_STEPS,
        noise_clamp: float | int | None = DEFAULT_NOISE_CLAMP,
        eos_threshold: float = DEFAULT_EOS_THRESHOLD,
    ) -> Self:
        """Load a pre-trained TTS model with specified configuration.

        This class method loads a complete TTS model including the flow language model
        and Mimi compression model from pre-trained weights. The model is initialized
        with the specified generation parameters and ready for inference.

        Args:
            config: a path to a custom YAML config file saved locally (e.g., C://pocket_tts/pocket_tts_config.yaml)
                or a model variant identifier (e.g., '610b0b2c'; must match a YAML file in the config directory).
            temp: Sampling temperature for generation. Higher values produce more
                diverse but potentially lower quality output.
            lsd_decode_steps: Number of steps for Lagrangian Self Distillation
                decoding. More steps can improve quality but increase computation.
            noise_clamp: Maximum value for noise sampling. If None, no clamping
                is applied. Helps prevent extreme values in generation.
            eos_threshold: Threshold for end-of-sequence detection. Higher values
                make the model more likely to continue generating.

        Returns:
            TTSModel: Fully initialized model with loaded weights on cpu, ready for
                text-to-speech generation.

        Raises:
            FileNotFoundError: If the specified config file or model weights
                are not found.
            ValueError: If the configuration is invalid or incompatible.

        Example:
            ```python
            from pocket_tts import TTSModel

            # Load with default settings
            model = TTSModel.load_model()

            # Load with custom parameters
            model = TTSModel.load_model(variant="b6369a24", temp=0.5, lsd_decode_steps=5, eos_threshold=-3.0)
            ```
        """
        if str(config).endswith(".yaml"):
            config_path = Path(config)
            config = load_config(config_path)
            logger.info(f"Loading model from config at {config_path}...")
        else:
            config = load_config(Path(__file__).parents[1] / f"config/{config}.yaml")

        tts_model = TTSModel._from_pydantic_config_with_weights(
            config, temp, lsd_decode_steps, noise_clamp, eos_threshold
        )
        return tts_model

    def _run_flow_lm_and_increment_step(
        self,
        model_state: dict,
        text_tokens: torch.Tensor | None = None,
        backbone_input_latents: torch.Tensor | None = None,
        audio_conditioning: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """First one is the backbone output, second one is the audio decoding output."""
        if text_tokens is None:
            text_tokens = torch.zeros((1, 0), dtype=torch.int64, device=self.flow_lm.device)
        if backbone_input_latents is None:
            backbone_input_latents = torch.empty(
                (1, 0, self.flow_lm.ldim), dtype=self.flow_lm.dtype, device=self.flow_lm.device
            )
        if audio_conditioning is None:
            audio_conditioning = torch.empty(
                (1, 0, self.flow_lm.dim), dtype=self.flow_lm.dtype, device=self.flow_lm.device
            )

        output = self._run_flow_lm(
            text_tokens=text_tokens,
            backbone_input_latents=backbone_input_latents,
            model_state=model_state,
            audio_conditioning=audio_conditioning,
        )
        increment_by = (
            text_tokens.shape[1] + backbone_input_latents.shape[1] + audio_conditioning.shape[1]
        )
        increment_steps(self.flow_lm, model_state, increment=increment_by)
        return output

    def _run_flow_lm(
        self,
        model_state: dict,
        text_tokens: torch.Tensor,
        backbone_input_latents: torch.Tensor,
        audio_conditioning: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        text_embeddings = self.flow_lm.conditioner(TokenizedText(text_tokens))
        text_embeddings = torch.cat([text_embeddings, audio_conditioning], dim=1)

        output_embeddings, is_eos = self.flow_lm._sample_next_latent(
            backbone_input_latents,
            text_embeddings,
            model_state=model_state,
            lsd_decode_steps=self.lsd_decode_steps,
            temp=self.temp,
            noise_clamp=self.noise_clamp,
            eos_threshold=self.eos_threshold,
        )
        return output_embeddings[:, None, :], is_eos

    def _encode_audio(self, audio: torch.Tensor) -> torch.Tensor:
        encoded = self.mimi.encode_to_latent(audio)
        latents = encoded.transpose(-1, -2).to(torch.float32)
        conditioning = F.linear(latents, self.flow_lm.speaker_proj_weight)
        return conditioning

    def _expand_kv_cache(self, model_state: dict, sequence_length: int) -> None:
        """Expand KV cache back to full sequence_length for generation.

        When a model state is retrieved from cache with sliced KV caches,
        this method expands them back to the full size needed for generation.

        Args:
            model_state: The model state dict containing potentially sliced KV caches
            sequence_length: Target sequence length to expand caches to
        """
        for module_name, module_state in model_state.items():
            if "cache" in module_state:
                cache = module_state["cache"]
                # KV cache has shape [2, batch_size, current_length, num_heads, dim_per_head]
                current_length = cache.shape[2]
                if current_length < sequence_length:
                    # Create expanded cache filled with NaN for unused positions
                    expanded_cache = torch.full(
                        (
                            cache.shape[0],
                            cache.shape[1],
                            sequence_length,
                            cache.shape[3],
                            cache.shape[4],
                        ),
                        float("NaN"),
                        device=cache.device,
                        dtype=cache.dtype,
                    )
                    # Copy existing data to the beginning
                    expanded_cache[:, :, :current_length, :, :] = cache
                    module_state["cache"] = expanded_cache

    def _flow_lm_current_end(self, model_state: dict) -> int:
        for module_state in model_state.values():
            current_end = module_state.get("current_end")
            if current_end is not None:
                return int(current_end.shape[0])
        raise ValueError(
            "Could not find current_end in model state, please open an issue "
            "at https://github.com/kyutai-labs/pocket-tts/issues"
        )

    @torch.no_grad
    def _decode_audio_worker(self, latents_queue: queue.Queue, result_queue: queue.Queue):
        """Worker thread function for decoding audio latents from queue with immediate streaming."""
        try:
            audio_chunks = []
            mimi_context = self.config.mimi.transformer.context
            mimi_state = init_states(self.mimi, batch_size=1, sequence_length=mimi_context)
            while True:
                latent = latents_queue.get()
                if latent is None:
                    break
                mimi_decoding_input = latent * self.flow_lm.emb_std + self.flow_lm.emb_mean
                transposed = mimi_decoding_input.transpose(-1, -2)
                quantized = self.mimi.quantizer(transposed)

                t = time.monotonic()
                audio_frame = self.mimi.decode_from_latent(quantized, mimi_state)
                increment_steps(self.mimi, mimi_state, increment=16)
                audio_frame_duration = audio_frame.shape[2] / self.config.mimi.sample_rate
                # We could log the timings here.
                logger.debug(
                    " " * 30 + "Decoded %d ms of audio with mimi in %d ms",
                    int(audio_frame_duration * 1000),
                    int((time.monotonic() - t) * 1000),
                )
                audio_chunks.append(audio_frame)

                result_queue.put(("chunk", audio_frame))

                latents_queue.task_done()

            # Signal completion
            result_queue.put(("done", None))

        except Exception as e:
            # Put error in result queue
            result_queue.put(("error", e))

    @torch.no_grad
    def generate_audio(
        self,
        model_state: dict,
        text_to_generate: str,
        max_tokens: int = MAX_TOKEN_PER_CHUNK,
        frames_after_eos: int | None = None,
        copy_state: bool = True,
    ) -> torch.Tensor:
        """Generate complete audio tensor from text input.

        This method generates the full audio output for the given text prompt
        and returns it as a single tensor. It internally uses the streaming
        generation method but collects all chunks before returning.

        This method is NOT thread-safe; separate model instances should be used
        for concurrent generation.

        Args:
            model_state: Model state dictionary containing hidden states and
                positional information. Can be obtained from get_state_for_audio_prompt()
                or init_states(). The state may be modified during generation.
            text_to_generate: Input text to convert to speech. The text will be
                automatically formatted (capitalization, punctuation) for optimal
                generation quality.
            frames_after_eos: Number of additional frames to generate after
                detecting end-of-sequence. If None, automatically determined
                based on text length (1-3 frames).
            copy_state: Whether to create a deep copy of the model state before
                generation. If True, preserves the original state for reuse.
                If False, modifies the input state in-place. Defaults to True.

        Returns:
            torch.Tensor: Generated audio tensor with shape [channels, samples]
                at the model's sample rate (typically 24kHz). The audio is
                normalized and ready for playback or saving.
                You can get the sample rate from the `sample_rate` attribute.

        Raises:
            ValueError: If text_to_generate is empty or invalid.
            RuntimeError: If generation fails due to model errors.

        Example:
            ```python
            from pocket_tts import TTSModel

            model = TTSModel.load_model()

            voice_state = model.get_state_for_audio_prompt("hf://kyutai/tts-voices/alba-mackenna/casual.wav")

            # Generate audio
            audio = model.generate_audio(voice_state, "Hello world!", frames_after_eos=2, copy_state=True)

            print(f"Generated audio shape: {audio.shape}")
            print(f"Audio duration: {audio.shape[-1] / model.sample_rate:.2f} seconds")
            ```
        """
        audio_chunks = []
        for chunk in self.generate_audio_stream(
            model_state=model_state,
            text_to_generate=text_to_generate,
            frames_after_eos=frames_after_eos,
            copy_state=copy_state,
            max_tokens=max_tokens,
        ):
            audio_chunks.append(chunk)
        return torch.cat(audio_chunks, dim=0)

    @torch.no_grad
    def generate_audio_stream(
        self,
        model_state: dict,
        text_to_generate: str,
        max_tokens: int = MAX_TOKEN_PER_CHUNK,
        frames_after_eos: int | None = None,
        copy_state: bool = True,
    ):
        """Generate audio streaming chunks from text input.

        This method generates audio from text and yields chunks as they become
        available, enabling real-time playback or processing. It uses multithreading
        to parallelize generation and decoding for optimal performance.
        This method is NOT thread-safe; separate model instances should be used
        for concurrent generation.

        Args:
            model_state: Model state dictionary containing hidden states and
                positional information. Can be obtained from get_state_for_audio_prompt()
                or init_states(). The state may be modified during generation.
            text_to_generate: Input text to convert to speech. The text will be
                automatically formatted (capitalization, punctuation) for optimal
                generation quality.
            frames_after_eos: Number of additional frames to generate after
                detecting end-of-sequence. If None, automatically determined
                based on text length (1-3 frames). Defaults to None.
            copy_state: Whether to create a deep copy of the model state before
                generation. If True, preserves the original state for reuse.
                If False, modifies the input state in-place. Defaults to True.

        Yields:
            torch.Tensor: Audio chunks with shape [samples] at the model's
                sample rate (typically 24kHz). Chunks are yielded as soon as
                they are decoded, enabling real-time streaming.

        Raises:
            ValueError: If text_to_generate is empty or invalid.
            RuntimeError: If generation fails due to model errors or threading issues.

        Example:
            ```python
            from pocket_tts import TTSModel

            model = TTSModel.load_model()

            voice_state = model.get_state_for_audio_prompt("hf://kyutai/tts-voices/alba-mackenna/casual.wav")
            # Stream generation
            for chunk in model.generate_audio_stream(voice_state, "Long text content..."):
                # Process each chunk as it's generated
                print(f"Generated chunk: {chunk.shape[0]} samples")
                # Could save chunks to file or play in real-time
            ```

        Note:
            This method uses multithreading to parallelize latent generation
            and audio decoding. Generation performance is logged including
            real-time factor (RTF) metrics.
        """

        # This is a very simplistic way of handling long texts. We could do much better
        # by using teacher forcing, but it would be a bit slower.
        # TODO: add the teacher forcing method for long texts where we use the audio of one chunk
        # as conditioning for the next chunk.
        chunks = split_into_best_sentences(
            self.flow_lm.conditioner.tokenizer, text_to_generate, max_tokens
        )

        for chunk in chunks:
            text_to_generate, frames_after_eos_guess = prepare_text_prompt(chunk)
            frames_after_eos_guess += 2
            effective_frames = (
                frames_after_eos if frames_after_eos is not None else frames_after_eos_guess
            )
            yield from self._generate_audio_stream_short_text(
                model_state=model_state,
                text_to_generate=chunk,
                frames_after_eos=effective_frames,
                copy_state=copy_state,
            )

    @torch.no_grad
    def _generate_audio_stream_short_text(
        self, model_state: dict, text_to_generate: str, frames_after_eos: int, copy_state: bool
    ):
        if copy_state:
            model_state = copy.deepcopy(model_state)

        # Set up multithreaded generation and decoding
        latents_queue = queue.Queue()
        result_queue = queue.Queue()

        # Start decoder worker thread
        decoder_thread = threading.Thread(
            target=self._decode_audio_worker, args=(latents_queue, result_queue), daemon=True
        )
        logger.info("starting timer now!")
        t_generating = time.monotonic()
        decoder_thread.start()

        # Generate latents and add them to queue (decoder processes them in parallel)
        self._generate(
            model_state=model_state,
            text_to_generate=text_to_generate,
            frames_after_eos=frames_after_eos,
            latents_queue=latents_queue,
            result_queue=result_queue,
        )

        # Stream audio chunks as they become available
        total_generated_samples = 0
        while True:
            result = result_queue.get()
            if result[0] == "chunk":
                # Audio chunk available immediately for streaming/playback
                audio_chunk = result[1]
                total_generated_samples += audio_chunk.shape[-1]
                yield audio_chunk[0, 0]  # Remove batch, channel
            elif result[0] == "done":
                # Generation complete
                break
            elif result[0] == "error":
                # Wait for decoder thread to finish cleanly before propagating error
                with display_execution_time("Waiting for mimi decoder to finish"):
                    decoder_thread.join()
                # Propagate error
                raise result[1]

        # Wait for decoder thread to finish cleanly
        with display_execution_time("Waiting for mimi decoder to finish"):
            decoder_thread.join()

        # Print timing information
        duration_generated_audio = int(
            total_generated_samples * 1000 / self.config.mimi.sample_rate
        )
        generation_time = int((time.monotonic() - t_generating) * 1000)
        real_time_factor = duration_generated_audio / generation_time

        logger.info(
            "Generated: %d ms of audio in %d ms so %.2fx faster than real-time",
            duration_generated_audio,
            generation_time,
            real_time_factor,
        )

    @torch.no_grad
    def _generate(
        self,
        model_state: dict,
        text_to_generate: str,
        frames_after_eos: int,
        latents_queue: queue.Queue,
        result_queue: queue.Queue,
    ):
        prepared = self.flow_lm.conditioner.prepare(text_to_generate)
        token_count = prepared.tokens.shape[1]
        max_gen_len = self._estimate_max_gen_len(token_count)
        current_end = self._flow_lm_current_end(model_state)
        required_len = current_end + token_count + max_gen_len
        self._expand_kv_cache(model_state, sequence_length=required_len)

        with display_execution_time("Prompting text"):
            self._run_flow_lm_and_increment_step(
                model_state=model_state, text_tokens=prepared.tokens
            )

        def run_generation():
            try:
                self._autoregressive_generation(
                    model_state, max_gen_len, frames_after_eos, latents_queue
                )
            except Exception as e:
                logger.error(f"Error in autoregressive generation: {e}")
                # Signal decoder to stop by putting None (completion sentinel)
                if latents_queue is not None:
                    latents_queue.put(None)
                # Report error to main thread
                if result_queue is not None:
                    result_queue.put(("error", e))

        generation_thread = threading.Thread(target=run_generation, daemon=True)
        generation_thread.start()

    @torch.no_grad
    def _autoregressive_generation(
        self, model_state: dict, max_gen_len: int, frames_after_eos: int, latents_queue: queue.Queue
    ):
        backbone_input = torch.full(
            (1, 1, self.flow_lm.ldim),
            fill_value=float("NaN"),
            device=next(iter(self.flow_lm.parameters())).device,
            dtype=self.flow_lm.dtype,
        )
        steps_times = []
        eos_step = None
        for generation_step in range(max_gen_len):
            with display_execution_time("Generating latent", print_output=False) as timer:
                next_latent, is_eos = self._run_flow_lm_and_increment_step(
                    model_state=model_state, backbone_input_latents=backbone_input
                )
                if is_eos.item() and eos_step is None:
                    eos_step = generation_step
                if eos_step is not None and generation_step >= eos_step + frames_after_eos:
                    break

                # Add generated latent to queue for immediate decoding
                latents_queue.put(next_latent)
                backbone_input = next_latent
            steps_times.append(timer.elapsed_time_ms)
        else:
            if os.environ.get("KPOCKET_TTS_ERROR_WITHOUT_EOS", "0") == "1":
                raise RuntimeError("Generation reached maximum length without EOS!")
            logger.warning(
                "Maximum generation length reached without EOS, this very often indicates an error."
            )

        # Add sentinel value to signal end of generation
        latents_queue.put(None)
        logger.info("Average generation step time: %d ms", int(statistics.mean(steps_times)))

    @lru_cache(maxsize=2)
    def _cached_get_state_for_audio_prompt(
        self, audio_conditioning: Path | str | torch.Tensor, truncate: bool = False
    ) -> dict:
        return self.get_state_for_audio_prompt(audio_conditioning, truncate)

    @torch.no_grad
    def get_state_for_audio_prompt(
        self, audio_conditioning: Path | str | torch.Tensor, truncate: bool = False
    ) -> dict:
        """Create model state conditioned on audio prompt for continuation.

        This method processes an audio prompt and creates a model state that
        captures the acoustic characteristics (speaker voice, style, prosody)
        for use in subsequent text-to-speech generation. The resulting state
        enables voice cloning and audio continuation with speaker consistency.

        Args:
            audio_conditioning: Audio prompt to condition (or .safetensors to load). Can be:
                - Path: Local file path to audio file (or .safetensors)
                - str: URL to download audio file (or .safetensors) from
                - torch.Tensor: Pre-loaded audio tensor with shape [channels, samples]
            truncate: Whether to truncate long audio prompts to 30 seconds.
                Helps prevent memory issues with very long inputs. Defaults to False.

        Returns:
            dict: Model state dictionary containing hidden states and positional
                information conditioned on the audio prompt. This state can be
                passed to `generate_audio()` or `generate_audio_stream()` for
                voice-consistent generation.

        Raises:
            FileNotFoundError: If audio file path doesn't exist.
            ValueError: If audio tensor is invalid or empty.
            RuntimeError: If audio processing or encoding fails.

        Example:
            ```python
            from pocket_tts import TTSModel

            model = TTSModel.load_model()
            # From HuggingFace URL
            voice_state = model.get_state_for_audio_prompt("hf://kyutai/tts-voices/alba-mackenna/casual.wav")

            # From local file
            voice_state = model.get_state_for_audio_prompt("./my_voice.wav")

            # Reload state from a .safetensors file (much faster than extracting from an audio file)
            voice_state = model.get_state_for_audio_prompt("./my_voices.safetensors")

            # From HTTP URL
            voice_state = model.get_state_for_audio_prompt(
                "https://huggingface.co/kyutai/tts-voices/resolve"
                "/main/expresso/ex01-ex02_default_001_channel1_168s.wav"
            )
            ```

        Note:
            - Audio is automatically resampled to the model's sample rate (24kHz)
            - The audio is encoded using the Mimi compression model and projected
              to the flow model's latent space
            - Processing time is logged for performance monitoring
            - The state preserves speaker characteristics for voice cloning
        """
        if isinstance(audio_conditioning, (str, Path)) and str(audio_conditioning).endswith(
            ".safetensors"
        ):
            if isinstance(audio_conditioning, str):
                audio_conditioning = download_if_necessary(audio_conditioning)

            return _import_model_state(audio_conditioning)

        elif isinstance(audio_conditioning, str) and audio_conditioning in PREDEFINED_VOICES:
            # We get the audio conditioning directly from the safetensors file.
            return _import_model_state(download_if_necessary(PREDEFINED_VOICES[audio_conditioning]))

        if not self.has_voice_cloning and isinstance(audio_conditioning, (str, Path)):
            raise ValueError(VOICE_CLONING_UNSUPPORTED)

        if isinstance(audio_conditioning, str):
            audio_conditioning = download_if_necessary(audio_conditioning)

        if isinstance(audio_conditioning, Path):
            audio, conditioning_sample_rate = audio_read(audio_conditioning)

            if truncate:
                max_samples = int(30 * conditioning_sample_rate)  # 30 seconds of audio
                if audio.shape[-1] > max_samples:
                    audio = audio[..., :max_samples]
                    logger.info(f"Audio truncated to first 30 seconds ({max_samples} samples)")

            audio_conditioning = convert_audio(
                audio, conditioning_sample_rate, self.config.mimi.sample_rate, 1
            )

        with display_execution_time("Encoding audio prompt"):
            prompt = self._encode_audio(audio_conditioning.unsqueeze(0).to(self.device))

        model_state = init_states(self.flow_lm, batch_size=1, sequence_length=prompt.shape[1])

        with display_execution_time("Prompting audio"):
            self._run_flow_lm_and_increment_step(model_state=model_state, audio_conditioning=prompt)

        logger.info(
            "Size of the model state for audio prompt: %d MB", size_of_dict(model_state) // 1e6
        )

        return model_state

    def _estimate_max_gen_len(self, token_count: int) -> int:
        gen_len_sec = token_count / self._TOKENS_PER_SECOND_ESTIMATE + self._GEN_SECONDS_PADDING
        frame_rate = self.config.mimi.frame_rate
        return math.ceil(gen_len_sec * frame_rate)

generate_audio(model_state, text_to_generate, max_tokens=MAX_TOKEN_PER_CHUNK, frames_after_eos=None, copy_state=True)

Generate complete audio tensor from text input.

This method generates the full audio output for the given text prompt and returns it as a single tensor. It internally uses the streaming generation method but collects all chunks before returning.

This method is NOT thread-safe; separate model instances should be used for concurrent generation.

Parameters:

Name Type Description Default
model_state dict

Model state dictionary containing hidden states and positional information. Can be obtained from get_state_for_audio_prompt() or init_states(). The state may be modified during generation.

required
text_to_generate str

Input text to convert to speech. The text will be automatically formatted (capitalization, punctuation) for optimal generation quality.

required
frames_after_eos int | None

Number of additional frames to generate after detecting end-of-sequence. If None, automatically determined based on text length (1-3 frames).

None
copy_state bool

Whether to create a deep copy of the model state before generation. If True, preserves the original state for reuse. If False, modifies the input state in-place. Defaults to True.

True

Returns:

Type Description
Tensor

torch.Tensor: Generated audio tensor with shape [channels, samples] at the model's sample rate (typically 24kHz). The audio is normalized and ready for playback or saving. You can get the sample rate from the sample_rate attribute.

Raises:

Type Description
ValueError

If text_to_generate is empty or invalid.

RuntimeError

If generation fails due to model errors.

Example
from pocket_tts import TTSModel

model = TTSModel.load_model()

voice_state = model.get_state_for_audio_prompt("hf://kyutai/tts-voices/alba-mackenna/casual.wav")

# Generate audio
audio = model.generate_audio(voice_state, "Hello world!", frames_after_eos=2, copy_state=True)

print(f"Generated audio shape: {audio.shape}")
print(f"Audio duration: {audio.shape[-1] / model.sample_rate:.2f} seconds")
Source code in pocket_tts/models/tts_model.py
@torch.no_grad
def generate_audio(
    self,
    model_state: dict,
    text_to_generate: str,
    max_tokens: int = MAX_TOKEN_PER_CHUNK,
    frames_after_eos: int | None = None,
    copy_state: bool = True,
) -> torch.Tensor:
    """Generate complete audio tensor from text input.

    This method generates the full audio output for the given text prompt
    and returns it as a single tensor. It internally uses the streaming
    generation method but collects all chunks before returning.

    This method is NOT thread-safe; separate model instances should be used
    for concurrent generation.

    Args:
        model_state: Model state dictionary containing hidden states and
            positional information. Can be obtained from get_state_for_audio_prompt()
            or init_states(). The state may be modified during generation.
        text_to_generate: Input text to convert to speech. The text will be
            automatically formatted (capitalization, punctuation) for optimal
            generation quality.
        frames_after_eos: Number of additional frames to generate after
            detecting end-of-sequence. If None, automatically determined
            based on text length (1-3 frames).
        copy_state: Whether to create a deep copy of the model state before
            generation. If True, preserves the original state for reuse.
            If False, modifies the input state in-place. Defaults to True.

    Returns:
        torch.Tensor: Generated audio tensor with shape [channels, samples]
            at the model's sample rate (typically 24kHz). The audio is
            normalized and ready for playback or saving.
            You can get the sample rate from the `sample_rate` attribute.

    Raises:
        ValueError: If text_to_generate is empty or invalid.
        RuntimeError: If generation fails due to model errors.

    Example:
        ```python
        from pocket_tts import TTSModel

        model = TTSModel.load_model()

        voice_state = model.get_state_for_audio_prompt("hf://kyutai/tts-voices/alba-mackenna/casual.wav")

        # Generate audio
        audio = model.generate_audio(voice_state, "Hello world!", frames_after_eos=2, copy_state=True)

        print(f"Generated audio shape: {audio.shape}")
        print(f"Audio duration: {audio.shape[-1] / model.sample_rate:.2f} seconds")
        ```
    """
    audio_chunks = []
    for chunk in self.generate_audio_stream(
        model_state=model_state,
        text_to_generate=text_to_generate,
        frames_after_eos=frames_after_eos,
        copy_state=copy_state,
        max_tokens=max_tokens,
    ):
        audio_chunks.append(chunk)
    return torch.cat(audio_chunks, dim=0)

generate_audio_stream(model_state, text_to_generate, max_tokens=MAX_TOKEN_PER_CHUNK, frames_after_eos=None, copy_state=True)

Generate audio streaming chunks from text input.

This method generates audio from text and yields chunks as they become available, enabling real-time playback or processing. It uses multithreading to parallelize generation and decoding for optimal performance. This method is NOT thread-safe; separate model instances should be used for concurrent generation.

Parameters:

Name Type Description Default
model_state dict

Model state dictionary containing hidden states and positional information. Can be obtained from get_state_for_audio_prompt() or init_states(). The state may be modified during generation.

required
text_to_generate str

Input text to convert to speech. The text will be automatically formatted (capitalization, punctuation) for optimal generation quality.

required
frames_after_eos int | None

Number of additional frames to generate after detecting end-of-sequence. If None, automatically determined based on text length (1-3 frames). Defaults to None.

None
copy_state bool

Whether to create a deep copy of the model state before generation. If True, preserves the original state for reuse. If False, modifies the input state in-place. Defaults to True.

True

Yields:

Type Description

torch.Tensor: Audio chunks with shape [samples] at the model's sample rate (typically 24kHz). Chunks are yielded as soon as they are decoded, enabling real-time streaming.

Raises:

Type Description
ValueError

If text_to_generate is empty or invalid.

RuntimeError

If generation fails due to model errors or threading issues.

Example
from pocket_tts import TTSModel

model = TTSModel.load_model()

voice_state = model.get_state_for_audio_prompt("hf://kyutai/tts-voices/alba-mackenna/casual.wav")
# Stream generation
for chunk in model.generate_audio_stream(voice_state, "Long text content..."):
    # Process each chunk as it's generated
    print(f"Generated chunk: {chunk.shape[0]} samples")
    # Could save chunks to file or play in real-time
Note

This method uses multithreading to parallelize latent generation and audio decoding. Generation performance is logged including real-time factor (RTF) metrics.

Source code in pocket_tts/models/tts_model.py
@torch.no_grad
def generate_audio_stream(
    self,
    model_state: dict,
    text_to_generate: str,
    max_tokens: int = MAX_TOKEN_PER_CHUNK,
    frames_after_eos: int | None = None,
    copy_state: bool = True,
):
    """Generate audio streaming chunks from text input.

    This method generates audio from text and yields chunks as they become
    available, enabling real-time playback or processing. It uses multithreading
    to parallelize generation and decoding for optimal performance.
    This method is NOT thread-safe; separate model instances should be used
    for concurrent generation.

    Args:
        model_state: Model state dictionary containing hidden states and
            positional information. Can be obtained from get_state_for_audio_prompt()
            or init_states(). The state may be modified during generation.
        text_to_generate: Input text to convert to speech. The text will be
            automatically formatted (capitalization, punctuation) for optimal
            generation quality.
        frames_after_eos: Number of additional frames to generate after
            detecting end-of-sequence. If None, automatically determined
            based on text length (1-3 frames). Defaults to None.
        copy_state: Whether to create a deep copy of the model state before
            generation. If True, preserves the original state for reuse.
            If False, modifies the input state in-place. Defaults to True.

    Yields:
        torch.Tensor: Audio chunks with shape [samples] at the model's
            sample rate (typically 24kHz). Chunks are yielded as soon as
            they are decoded, enabling real-time streaming.

    Raises:
        ValueError: If text_to_generate is empty or invalid.
        RuntimeError: If generation fails due to model errors or threading issues.

    Example:
        ```python
        from pocket_tts import TTSModel

        model = TTSModel.load_model()

        voice_state = model.get_state_for_audio_prompt("hf://kyutai/tts-voices/alba-mackenna/casual.wav")
        # Stream generation
        for chunk in model.generate_audio_stream(voice_state, "Long text content..."):
            # Process each chunk as it's generated
            print(f"Generated chunk: {chunk.shape[0]} samples")
            # Could save chunks to file or play in real-time
        ```

    Note:
        This method uses multithreading to parallelize latent generation
        and audio decoding. Generation performance is logged including
        real-time factor (RTF) metrics.
    """

    # This is a very simplistic way of handling long texts. We could do much better
    # by using teacher forcing, but it would be a bit slower.
    # TODO: add the teacher forcing method for long texts where we use the audio of one chunk
    # as conditioning for the next chunk.
    chunks = split_into_best_sentences(
        self.flow_lm.conditioner.tokenizer, text_to_generate, max_tokens
    )

    for chunk in chunks:
        text_to_generate, frames_after_eos_guess = prepare_text_prompt(chunk)
        frames_after_eos_guess += 2
        effective_frames = (
            frames_after_eos if frames_after_eos is not None else frames_after_eos_guess
        )
        yield from self._generate_audio_stream_short_text(
            model_state=model_state,
            text_to_generate=chunk,
            frames_after_eos=effective_frames,
            copy_state=copy_state,
        )

get_state_for_audio_prompt(audio_conditioning, truncate=False)

Create model state conditioned on audio prompt for continuation.

This method processes an audio prompt and creates a model state that captures the acoustic characteristics (speaker voice, style, prosody) for use in subsequent text-to-speech generation. The resulting state enables voice cloning and audio continuation with speaker consistency.

Parameters:

Name Type Description Default
audio_conditioning Path | str | Tensor

Audio prompt to condition (or .safetensors to load). Can be: - Path: Local file path to audio file (or .safetensors) - str: URL to download audio file (or .safetensors) from - torch.Tensor: Pre-loaded audio tensor with shape [channels, samples]

required
truncate bool

Whether to truncate long audio prompts to 30 seconds. Helps prevent memory issues with very long inputs. Defaults to False.

False

Returns:

Name Type Description
dict dict

Model state dictionary containing hidden states and positional information conditioned on the audio prompt. This state can be passed to generate_audio() or generate_audio_stream() for voice-consistent generation.

Raises:

Type Description
FileNotFoundError

If audio file path doesn't exist.

ValueError

If audio tensor is invalid or empty.

RuntimeError

If audio processing or encoding fails.

Example
from pocket_tts import TTSModel

model = TTSModel.load_model()
# From HuggingFace URL
voice_state = model.get_state_for_audio_prompt("hf://kyutai/tts-voices/alba-mackenna/casual.wav")

# From local file
voice_state = model.get_state_for_audio_prompt("./my_voice.wav")

# Reload state from a .safetensors file (much faster than extracting from an audio file)
voice_state = model.get_state_for_audio_prompt("./my_voices.safetensors")

# From HTTP URL
voice_state = model.get_state_for_audio_prompt(
    "https://huggingface.co/kyutai/tts-voices/resolve"
    "/main/expresso/ex01-ex02_default_001_channel1_168s.wav"
)
Note
  • Audio is automatically resampled to the model's sample rate (24kHz)
  • The audio is encoded using the Mimi compression model and projected to the flow model's latent space
  • Processing time is logged for performance monitoring
  • The state preserves speaker characteristics for voice cloning
Source code in pocket_tts/models/tts_model.py
@torch.no_grad
def get_state_for_audio_prompt(
    self, audio_conditioning: Path | str | torch.Tensor, truncate: bool = False
) -> dict:
    """Create model state conditioned on audio prompt for continuation.

    This method processes an audio prompt and creates a model state that
    captures the acoustic characteristics (speaker voice, style, prosody)
    for use in subsequent text-to-speech generation. The resulting state
    enables voice cloning and audio continuation with speaker consistency.

    Args:
        audio_conditioning: Audio prompt to condition (or .safetensors to load). Can be:
            - Path: Local file path to audio file (or .safetensors)
            - str: URL to download audio file (or .safetensors) from
            - torch.Tensor: Pre-loaded audio tensor with shape [channels, samples]
        truncate: Whether to truncate long audio prompts to 30 seconds.
            Helps prevent memory issues with very long inputs. Defaults to False.

    Returns:
        dict: Model state dictionary containing hidden states and positional
            information conditioned on the audio prompt. This state can be
            passed to `generate_audio()` or `generate_audio_stream()` for
            voice-consistent generation.

    Raises:
        FileNotFoundError: If audio file path doesn't exist.
        ValueError: If audio tensor is invalid or empty.
        RuntimeError: If audio processing or encoding fails.

    Example:
        ```python
        from pocket_tts import TTSModel

        model = TTSModel.load_model()
        # From HuggingFace URL
        voice_state = model.get_state_for_audio_prompt("hf://kyutai/tts-voices/alba-mackenna/casual.wav")

        # From local file
        voice_state = model.get_state_for_audio_prompt("./my_voice.wav")

        # Reload state from a .safetensors file (much faster than extracting from an audio file)
        voice_state = model.get_state_for_audio_prompt("./my_voices.safetensors")

        # From HTTP URL
        voice_state = model.get_state_for_audio_prompt(
            "https://huggingface.co/kyutai/tts-voices/resolve"
            "/main/expresso/ex01-ex02_default_001_channel1_168s.wav"
        )
        ```

    Note:
        - Audio is automatically resampled to the model's sample rate (24kHz)
        - The audio is encoded using the Mimi compression model and projected
          to the flow model's latent space
        - Processing time is logged for performance monitoring
        - The state preserves speaker characteristics for voice cloning
    """
    if isinstance(audio_conditioning, (str, Path)) and str(audio_conditioning).endswith(
        ".safetensors"
    ):
        if isinstance(audio_conditioning, str):
            audio_conditioning = download_if_necessary(audio_conditioning)

        return _import_model_state(audio_conditioning)

    elif isinstance(audio_conditioning, str) and audio_conditioning in PREDEFINED_VOICES:
        # We get the audio conditioning directly from the safetensors file.
        return _import_model_state(download_if_necessary(PREDEFINED_VOICES[audio_conditioning]))

    if not self.has_voice_cloning and isinstance(audio_conditioning, (str, Path)):
        raise ValueError(VOICE_CLONING_UNSUPPORTED)

    if isinstance(audio_conditioning, str):
        audio_conditioning = download_if_necessary(audio_conditioning)

    if isinstance(audio_conditioning, Path):
        audio, conditioning_sample_rate = audio_read(audio_conditioning)

        if truncate:
            max_samples = int(30 * conditioning_sample_rate)  # 30 seconds of audio
            if audio.shape[-1] > max_samples:
                audio = audio[..., :max_samples]
                logger.info(f"Audio truncated to first 30 seconds ({max_samples} samples)")

        audio_conditioning = convert_audio(
            audio, conditioning_sample_rate, self.config.mimi.sample_rate, 1
        )

    with display_execution_time("Encoding audio prompt"):
        prompt = self._encode_audio(audio_conditioning.unsqueeze(0).to(self.device))

    model_state = init_states(self.flow_lm, batch_size=1, sequence_length=prompt.shape[1])

    with display_execution_time("Prompting audio"):
        self._run_flow_lm_and_increment_step(model_state=model_state, audio_conditioning=prompt)

    logger.info(
        "Size of the model state for audio prompt: %d MB", size_of_dict(model_state) // 1e6
    )

    return model_state

load_model(config=DEFAULT_VARIANT, temp=DEFAULT_TEMPERATURE, lsd_decode_steps=DEFAULT_LSD_DECODE_STEPS, noise_clamp=DEFAULT_NOISE_CLAMP, eos_threshold=DEFAULT_EOS_THRESHOLD) classmethod

Load a pre-trained TTS model with specified configuration.

This class method loads a complete TTS model including the flow language model and Mimi compression model from pre-trained weights. The model is initialized with the specified generation parameters and ready for inference.

Parameters:

Name Type Description Default
config str | Path

a path to a custom YAML config file saved locally (e.g., C://pocket_tts/pocket_tts_config.yaml) or a model variant identifier (e.g., '610b0b2c'; must match a YAML file in the config directory).

DEFAULT_VARIANT
temp float | int

Sampling temperature for generation. Higher values produce more diverse but potentially lower quality output.

DEFAULT_TEMPERATURE
lsd_decode_steps int

Number of steps for Lagrangian Self Distillation decoding. More steps can improve quality but increase computation.

DEFAULT_LSD_DECODE_STEPS
noise_clamp float | int | None

Maximum value for noise sampling. If None, no clamping is applied. Helps prevent extreme values in generation.

DEFAULT_NOISE_CLAMP
eos_threshold float

Threshold for end-of-sequence detection. Higher values make the model more likely to continue generating.

DEFAULT_EOS_THRESHOLD

Returns:

Name Type Description
TTSModel Self

Fully initialized model with loaded weights on cpu, ready for text-to-speech generation.

Raises:

Type Description
FileNotFoundError

If the specified config file or model weights are not found.

ValueError

If the configuration is invalid or incompatible.

Example
from pocket_tts import TTSModel

# Load with default settings
model = TTSModel.load_model()

# Load with custom parameters
model = TTSModel.load_model(variant="b6369a24", temp=0.5, lsd_decode_steps=5, eos_threshold=-3.0)
Source code in pocket_tts/models/tts_model.py
@classmethod
def load_model(
    cls,
    config: str | Path = DEFAULT_VARIANT,
    temp: float | int = DEFAULT_TEMPERATURE,
    lsd_decode_steps: int = DEFAULT_LSD_DECODE_STEPS,
    noise_clamp: float | int | None = DEFAULT_NOISE_CLAMP,
    eos_threshold: float = DEFAULT_EOS_THRESHOLD,
) -> Self:
    """Load a pre-trained TTS model with specified configuration.

    This class method loads a complete TTS model including the flow language model
    and Mimi compression model from pre-trained weights. The model is initialized
    with the specified generation parameters and ready for inference.

    Args:
        config: a path to a custom YAML config file saved locally (e.g., C://pocket_tts/pocket_tts_config.yaml)
            or a model variant identifier (e.g., '610b0b2c'; must match a YAML file in the config directory).
        temp: Sampling temperature for generation. Higher values produce more
            diverse but potentially lower quality output.
        lsd_decode_steps: Number of steps for Lagrangian Self Distillation
            decoding. More steps can improve quality but increase computation.
        noise_clamp: Maximum value for noise sampling. If None, no clamping
            is applied. Helps prevent extreme values in generation.
        eos_threshold: Threshold for end-of-sequence detection. Higher values
            make the model more likely to continue generating.

    Returns:
        TTSModel: Fully initialized model with loaded weights on cpu, ready for
            text-to-speech generation.

    Raises:
        FileNotFoundError: If the specified config file or model weights
            are not found.
        ValueError: If the configuration is invalid or incompatible.

    Example:
        ```python
        from pocket_tts import TTSModel

        # Load with default settings
        model = TTSModel.load_model()

        # Load with custom parameters
        model = TTSModel.load_model(variant="b6369a24", temp=0.5, lsd_decode_steps=5, eos_threshold=-3.0)
        ```
    """
    if str(config).endswith(".yaml"):
        config_path = Path(config)
        config = load_config(config_path)
        logger.info(f"Loading model from config at {config_path}...")
    else:
        config = load_config(Path(__file__).parents[1] / f"config/{config}.yaml")

    tts_model = TTSModel._from_pydantic_config_with_weights(
        config, temp, lsd_decode_steps, noise_clamp, eos_threshold
    )
    return tts_model