o
    
j_                  	   @   sr  d dl Z d dlZd dlmZ d dlm  mZ d dlmZm	Z	m
Z
 d dlmZ d dlmZmZ dd Zdd ZG d	d
 d
ejZG dd de	ZG dd dejZG dd dejZdd ZG dd dejZG dd dejZedkreddddddZeeddddejd d!d"ed#d gejd$d%d"ed&d'gZe edddejd(d)d"ed#dg dS dS )*    N)
GPT2ConfigGPT2PreTrainedModelLogitsProcessorList)!CausalLMOutputWithCrossAttentions)AttentionBlockTypicalLogitsWarperc                 C   s"   t j| jd | jd |f| jdS )Nr      device)torchzerosshaper
   )rangedim r   Y/home/kuhnn/.local/lib/python3.10/site-packages/TTS/tts/layers/tortoise/autoregressive.pynull_position_embeddings      "r   c                 C   s$   | ot | t | d | d d jfS )Nr   )lenr   )tr   r   r   _p   s   $r   c                       s(   e Zd ZdZ fddZdd Z  ZS )ResBlockzA
    Basic residual convolutional block that uses GroupNorm.
    c                    sZ   t    ttj||dddt|d |t tj||dddt|d || _d S )N   r   kernel_sizepadding   )super__init__nn
SequentialConv1d	GroupNormReLUnet)selfchan	__class__r   r   r      s   

zResBlock.__init__c                 C   s   t | || S N)Frelur$   )r%   xr   r   r   forward$   s   zResBlock.forward)__name__
__module____qualname____doc__r   r-   __classcell__r   r   r'   r   r      s    
r   c                       s`   e Zd Z fddZdd ZdddZ														ddd	Zed
d Z  Z	S )GPT2InferenceModelc                    s6   t  | || _|| _|| _t||| _|| _d S r)   )	r   r   transformertext_pos_embedding
embeddingsr   r    lm_headkv_cache)r%   configgpttext_pos_embr6   normlinearr8   r'   r   r   r   )   s   
zGPT2InferenceModel.__init__c                 C   s
   || _ d S r)   )cached_mel_emb)r%   mel_embr   r   r   store_mel_emb1   s   
z GPT2InferenceModel.store_mel_embNc                 K   s   | dd }| jsd }|r'|d d df d}|d ur'|d d df d}| dd }| dd }|d urZ|d u rZ| dd }||dkd |rY|d d df d}nd }||| d|||dS )	Ntoken_type_idsattention_maskposition_idsr   r   	use_cache)	input_idspast_key_valuesrE   rD   rC   rA   )getr8   	unsqueezelongcumsummasked_fill_)r%   rF   rG   kwargsrA   rC   rD   r   r   r   prepare_inputs_for_generation4   s.   z0GPT2InferenceModel.prepare_inputs_for_generationc                 C   s`  | j d usJ |d u sJ |
d u sJ |d ur|n| jj}| j jd }|jd dkri|d d |d f }| |}|| | }| j jd |jd kr\| j |jd | j jd  d}n| j }tj||gdd}n| |}|| j	|jd | |j
 }| j||||||||	||||d}|d }| |}|s|f|dd   S td ||j|j|j|jdS )Nr   r   r   )inputs_embedsrG   rC   rA   rD   	head_maskencoder_hidden_statesencoder_attention_maskrE   output_attentionsoutput_hidden_statesreturn_dict)losslogitsrG   hidden_states
attentionscross_attentions)r>   r9   use_return_dictr   r6   r5   repeat_interleaver   catget_fixed_embeddingr
   r4   r7   r   rG   rY   rZ   r[   )r%   rF   rG   rC   rA   rD   rQ   rP   rR   rS   labelsrE   rT   rU   rV   mel_lentext_inputstext_embr?   embtransformer_outputsrY   	lm_logitsr   r   r   r-   R   sV   
"

zGPT2InferenceModel.forwardc                    s   t  fdd| D S )a>  
        This function is used to re-order the :obj:`past_key_values` cache if
        :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
        called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
        c                 3   s&    | ]}t  fd d|D V  qdS )c                 3   s$    | ]}| d  |jV  qdS )r   N)index_selecttor
   ).0
past_statebeam_idxr   r   	<genexpr>   s   " z>GPT2InferenceModel._reorder_cache.<locals>.<genexpr>.<genexpr>Ntuple)ri   
layer_pastrk   r   r   rm      s
    
z4GPT2InferenceModel._reorder_cache.<locals>.<genexpr>rn   )pastrl   r   rk   r   _reorder_cache   s   z!GPT2InferenceModel._reorder_cacher)   )NNNNNNNNNNNNNN)
r.   r/   r0   r   r@   rN   r-   staticmethodrr   r2   r   r   r'   r   r3   (   s*    
 
Dr3   c                       s.   e Zd Z				d fdd	Zdd Z  ZS )	ConditioningEncoder      Fc           	         s`   t    g }tj||dd| _t|D ]
}|t|| qtj| | _	|| _
|| _|| _d S )Nr   )r   )r   r   r   r!   initr   appendr   r    attnr   do_checkpointingmean)	r%   spec_dimembedding_dimattn_blocksnum_attn_headsrz   r{   ry   ar'   r   r   r      s   
	
zConditioningEncoder.__init__c                 C   s<   |  |}| |}| jr|jddS |d d d d df S )N   rO   r   )rw   ry   r{   )r%   r,   hr   r   r   r-      s
   

zConditioningEncoder.forward)ru   rv   FFr.   r/   r0   r   r-   r2   r   r   r'   r   rt      s    rt   c                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	LearnedPositionEmbeddings{Gz?c                    s0   t    t||| _| jjjjd|d d S )N        r{   std)r   r   r   	Embeddingrd   weightdatanormal_)r%   seq_len	model_dimrw   r'   r   r   r      s   
z"LearnedPositionEmbeddings.__init__c                 C   s"   |j d }| tjd||jdS )Nr   r   r	   )r   rd   r   aranger
   )r%   r,   slr   r   r   r-      s   
z!LearnedPositionEmbeddings.forwardc                 C   s"   |  tjd||d|d | S )Nr   r	   r   )rd   r   r   )r%   inddevr   r   r   r_      r   z-LearnedPositionEmbeddings.get_fixed_embedding)r   )r.   r/   r0   r   r-   r_   r2   r   r   r'   r   r      s    r   c           
   
   C   sl   ddl m}m} |d|| || || ||| d}||}	|	`tjt|d|	_|	`|	t||t||ddfS )z7
    GPT-2 implemented by the HuggingFace library.
    r   )r   	GPT2Model   
vocab_sizen_positionsn_ctxn_embdn_layern_headgradient_checkpointingrE   rO   N)	transformersr   r   wpe	functoolspartialr   wter   )
layersr   headsmax_mel_seq_lenmax_text_seq_lencheckpointingr   r   
gpt_configr:   r   r   r   build_hf_gpt_transformer   s*   
r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )
MelEncoderP   r   c                    s   t     | _ttj| d dddtj fddt|D  tj d  d ddddt d	  d t tj fd
dt|D  tj d  ddddt d  t tj fddt|D  
| _	d| _
d S )Nrv   r   r   r   c                       g | ]}t  d  qS )rv   r   ri   _channelsr   r   
<listcomp>       z'MelEncoder.__init__.<locals>.<listcomp>r   )r   strider      c                    r   )r   r   r   r   r   r   r      r   r   c                    s   g | ]}t  qS r   r   r   r   r   r   r      s    )r   r   r   r   r    r!   r   r"   r#   encoder	reduction)r%   r   mel_channelsresblocks_per_reductionr'   r   r   r      s   

zMelEncoder.__init__c                 C   s"   | j D ]}||}q|dddS )Nr   r   r   )r   permute)r%   r,   er   r   r   r-      s   

zMelEncoder.forward)r   r   r   r   r   r'   r   r      s    r   c                       s   e Zd Z												
					d fdd	Zd ddZdd Zdd Z				d!ddZdd Z						d"ddZ						d#ddZ
  ZS )$UnifiedVoicer      x      r      r   N          FTc                    s  t    || _|	du r|| n|	| _d| _|
| _|| _|| _|| _|| _	|| _
|| _|| _|| _|| _td||d| _t| j| d || _|rSt| j|| _nt|dd| _t|||| j
d | j | jd |\| _| _| _| _| _|rtjtdd|d d	d
| _tjtdd|d d	d
| _ nd| _d| _ t!|| _"t#|| j| d | _$t#|| j| _%| jg}|r|&| j |D ]}|j'j(j)ddd qdS )a  
        Args:
            layers: Number of layers in transformer stack.
            model_dim: Operating dimensions of the transformer
            heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
            max_text_tokens: Maximum number of text tokens that will be encountered by model.
            max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
            max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
            mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
            number_text_tokens:
            start_text_token:
            stop_text_token:
            number_mel_codes:
            start_mel_token:
            stop_mel_token:
            train_solo_embeddings:
            use_mel_codes_as_input:
            checkpointing:
        Nr   r   )r   r   )r   r   r   T)requires_gradr   r   )*r   r   number_text_tokensstart_text_tokenstop_text_tokennumber_mel_codesstart_mel_tokenstop_mel_tokenr   r   max_mel_tokensmax_text_tokensr   max_conditioning_inputsmel_length_compressionrt   conditioning_encoderr   r   text_embeddingmel_embeddingr   r   r:   mel_pos_embeddingr5   mel_layer_pos_embeddingtext_layer_pos_embedding	Parameterr   randnmel_solo_embeddingtext_solo_embedding	LayerNorm
final_normLinear	text_headmel_headrx   r   r   r   )r%   r   r   r   r   r   r   r   r   r   r   r   r   train_solo_embeddingsuse_mel_codes_as_inputr   typesr6   moduler'   r   r   r     s^   
&	 zUnifiedVoice.__init__c              
   C   s`   | j | j d }t| j ||| j| j| jddd}t|| j| j| j	| j
| j|d| _| j	| j_d S )Nr   FTr   )r8   )r   r   r   r   r   r   r3   r:   r   r   r   r   inference_modelr   )r%   r8   
seq_lengthr   r   r   r   post_init_gpt2_configb  s*   

z"UnifiedVoice.post_init_gpt2_configc                 C   s(   t j|d|d}t j|d|d}||fS )N)r   r   valuer   r   )r*   pad)r%   inputstart_token
stop_tokeninptarr   r   r    build_aligned_inputs_and_targets{  s   z-UnifiedVoice.build_aligned_inputs_and_targetsc                 C   sT   t j|| jdd}tt|D ]}|| d }||jd k r'| j|||df< q|S )a"  
        Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
        that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
        preformatting to create a working TTS model.
        trunc)rounding_moder   rB   N)r   divr   r   r   r   r   )r%   mel_input_tokenswav_lengthsmel_lengthsb
actual_endr   r   r   set_mel_padding  s   
zUnifiedVoice.set_mel_paddingc                 C   s*  |d urt j|||gdd}n	t j||gdd}| j|d|d}	|r%|	jS |	jd d dd f }
| |
}
|rZ|
d d |jd |jd |jd  f |
d d |jd  d f fS |
d d d |jd f }||}|ddd}|d ur|
d d |jd  d f }||}|ddd}||fS |S )Nr   rO   T)rP   rV   rT   r   r   )r   r^   r:   rZ   last_hidden_stater   r   r   )r%   speech_conditioning_inputsfirst_inputs
first_headsecond_inputssecond_head	get_attnsreturn_latentrd   gpt_outencfirst_logitssecond_logitsr   r   r   
get_logits  s2   

zUnifiedVoice.get_logitsc                 C   sn   t |jdkr|dn|}g }t|jd D ]}|| |d d |f  qtj|dd}|jdd}|S )Nr   r   rO   )	r   r   rI   r   rx   r   r   stackr{   )r%   speech_conditioning_inputcondsjr   r   r   get_conditioning  s   
zUnifiedVoice.get_conditioningc              	   C   s  |dur|d|  d }|rA| }|ddd|f }| | j }|ddd|f }|durA|ddddd|d f }| ||}tj|d| jd}tj|d| jd}| d}| || j	| j\}}| 
|| | }| || j| j\}}|durt|d}n|}| |}|| | }|r| j||| j|| j|	|
d\}}|
r|dddd	f S n| j||| j|| j|	|
d\}}|
r|dddd	f S |	r|S t|| }t|| }| | |fS )
a  
        Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
        (actuated by `text_first`).

        speech_conditioning_input: MEL float tensor, (b,1024)
        text_inputs: long tensor, (b,t)
        text_lengths: long tensor, (b,)
        mel_inputs:  long tensor, (b,m)
        wav_lengths: long tensor, (b,)
        raw_mels: MEL float tensor (b,80,s)

        If return_attentions is specified, only logits are returned.
        If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
        If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
        Nr   rB   rv   r   r   )r   r   )r   r   )rI   maxr   r   r*   r   r   r   r   r   r   r5   r   r   r   r  r   r   cross_entropyrJ   r{   )r%   speech_conditioning_latentrb   text_lengths	mel_codesr   r   
text_firstraw_melsreturn_attentionsr   clip_inputsmax_text_lenmax_mel_lenr  text_targetsrc   mel_targetsmel_inpr?   text_logits
mel_logits	loss_textloss_melr   r   r   r-     st   




	
	zUnifiedVoice.forward?c              	   K   s  t j|d| jd}| || j| j\}}	| || | }
|d}tj	||
gdd}| j
| tj|jd |jd |jd  fdtj|jd}| j|d d df< |jd }|d u r_|}n'||jd  dkslJ d||d}|||jd  d}tj	||gdd}|rtt|d	gnt }|d u r|| j d n|| }| j
j|f| j| j| j|||d
|}|d d |d f S )Nr   r   r   rO   r   )
fill_valuedtyper
   rB   zQThe number of return sequences must be divisible by the number of input sequences)mass)bos_token_idpad_token_ideos_token_id
max_lengthlogits_processornum_return_sequences)r*   r   r   r   r   r   r5   rI   r   r^   r   r@   fullr   rJ   r
   r   repeatr   r   r   generater   )r%   r  rb   input_tokensr%  max_generate_lengthtypical_samplingtypical_masshf_generate_kwargsr  rc   r  rd   fake_inputstrunc_indexinputsr$  r#  genr   r   r   inference_speech'  sV   

	

zUnifiedVoice.inference_speech)r   r   r   r   r   r   r   r   Nr   r   r   FTTr   )T)NNFF)NTNFFT)Nr   NFr  )r.   r/   r0   r   r   r   r   r  r  r-   r2  r2   r   r   r'   r   r     sP    
[
*
dr   __main__r   rv   T)r   r   r   r   r   r   r   r   i   r   )r   r   )highsize    r   )r   r   i   i   2   )r   r   )!r   r   torch.nnr   torch.nn.functional
functionalr*   r   r   r   r   transformers.modeling_outputsr   "TTS.tts.layers.tortoise.arch_utilsr   r   r   r   Moduler   r3   rt   r   r   r   r   r.   r:   r   randinttensorltext_forwardr   r   r   r   <module>   sN   {  `