o
    
ja                     @   s  d dl Z d dlZd dl mZ d dlmZmZ d dlmZ d dlm	Z	 d dl
mZ d dlmZ dZd	d
 ZG dd de jjZG dd de jjZG dd de jjZG dd dejZG dd dejZdd ZG dd dejZG dd dejZG dd de jjZdS )    N)nn)Conv1dConvTranspose1d)
functional)weight_norm)remove_parametrizations)load_fsspecg?c                 C   s   t | | | d S )N   )int)kd r   V/home/kuhnn/.local/lib/python3.10/site-packages/TTS/tts/layers/xtts/hifigan_decoder.pyget_padding   s   r   c                       2   e Zd ZdZd
 fdd	Zdd Zdd	 Z  ZS )	ResBlock1a;  Residual Block Type 1. It has 3 convolutional layers in each convolutional block.

    Network::

        x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
        |--------------------------------------------------------------------------------------------------|


    Args:
        channels (int): number of hidden channels for the convolutional layers.
        kernel_size (int): size of the convolution filter in each layer.
        dilations (list): list of dilation value for each conv layer in a block.
          r      c                    s   t    ttt|||d|d t||d dtt|||d|d t||d dtt|||d|d t||d dg| _ttt|||ddt|ddtt|||ddt|ddtt|||ddt|ddg| _d S )Nr   r   dilationpaddingr	   )	super__init__r   
ModuleListr   r   r   convs1convs2selfchannelskernel_sizer   	__class__r   r   r   !   s   


#


zResBlock1.__init__c                 C   sL   t | j| jD ]\}}t|t}||}t|t}||}|| }q|S )z
        Args:
            x (Tensor): input tensor.
        Returns:
            Tensor: output tensor.
        Shapes:
            x: [B, C, T]
        )zipr   r   F
leaky_reluLRELU_SLOPE)r   xc1c2xtr   r   r   forwardi   s   	
zResBlock1.forwardc                 C   s0   | j D ]}t|d q| jD ]}t|d qd S Nweight)r   r   r   r   lr   r   r   remove_weight_normz   s
   

zResBlock1.remove_weight_norm)r   r   __name__
__module____qualname____doc__r   r,   r1   __classcell__r   r   r"   r   r      s
    Hr   c                       r   )	ResBlock2a  Residual Block Type 2. It has 1 convolutional layers in each convolutional block.

    Network::

        x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
        |---------------------------------------------------|


    Args:
        channels (int): number of hidden channels for the convolutional layers.
        kernel_size (int): size of the convolution filter in each layer.
        dilations (list): list of dilation value for each conv layer in a block.
    r   r   r   c                    sb   t    ttt|||d|d t||d dtt|||d|d t||d dg| _d S )Nr   r   r   )r   r   r   r   r   r   r   convsr   r"   r   r   r      s0   


zResBlock2.__init__c                 C   s,   | j D ]}t|t}||}|| }q|S N)r:   r%   r&   r'   )r   r(   cr+   r   r   r   r,      s
   

zResBlock2.forwardc                 C   s   | j D ]}t|d qd S r-   )r:   r   r/   r   r   r   r1      s   
zResBlock2.remove_weight_norm)r   r9   r2   r   r   r"   r   r8      s
    r8   c                       sX   e Zd Z						d fdd	Zddd	Ze d
d Zdd Z	dddZ	  Z
S )HifiganGeneratorr   r   TFc                    s  t    |	| _t|| _t|| _|| _tt||dddd| _	|dkr't
nt}t | _tt||D ]#\}\}}| jtt|d|  |d|d   |||| d d q5t | _tt| jD ]"}|d|d   }tt||D ]\}\}}| j|||| qvqett||ddd|d| _|
dkrt|
|d| _|st| j	d	 |st| jd	 | jrt | _tt| jD ]}|d|d   }| jt|
|d qd
S d
S )a  HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)

        Network:
            x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
                                                 ..          -> zI ---|
                                              resblockN_kNx1 -> zN ---'

        Args:
            in_channels (int): number of input tensor channels.
            out_channels (int): number of output tensor channels.
            resblock_type (str): type of the `ResBlock`. '1' or '2'.
            resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
            resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
            upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
            upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
                for each consecutive upsampling layer.
            upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
            inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
           r   r   )r   1r	   )r   biasr   r.   N)r   r   inference_paddinglennum_kernelsnum_upsamplescond_in_each_up_layerr   r   conv_prer   r8   r   r   ups	enumerater$   appendr   	resblocksrange	conv_post
cond_layerr   conds)r   in_channelsout_channelsresblock_typeresblock_dilation_sizesresblock_kernel_sizesupsample_kernel_sizesupsample_initial_channelupsample_factorsrA   cond_channelsconv_pre_weight_normconv_post_weight_normconv_post_biasrE   resblockiur   ch_r   r"   r   r   r      sP   
$






zHifiganGenerator.__init__Nc                 C   s   |  |}t| dr|| | }t| jD ]I}t|t}| j| |}| j	r1|| j
| | }d}t| jD ]!}|du rK| j|| j |  |}q8|| j|| j |  |7 }q8|| j }qt|}| |}t|}|S )z
        Args:
            x (Tensor): feature input tensor.
            g (Tensor): global conditioning input tensor.

        Returns:
            Tensor: output waveform.

        Shapes:
            x: [B, C, T]
            Tensor: [B, 1, T]
        rM   N)rF   hasattrrM   rK   rD   r%   r&   r'   rG   rE   rN   rC   rJ   rL   torchtanh)r   r(   gor\   z_sumjr   r   r   r,   
  s$   




zHifiganGenerator.forwardc                 C   s4   | | jjj}tjj|| j| jfd}| 	|S )z
        Args:
            x (Tensor): conditioning input tensor.

        Returns:
            Tensor: output waveform.

        Shapes:
            x: [B, C, T]
            Tensor: [B, 1, T]
        	replicate)
torF   r.   devicera   r   r   padrA   r,   )r   r<   r   r   r   	inference-  s   
zHifiganGenerator.inferencec                 C   sN   t d | jD ]}t|d q| jD ]}|  qt| jd t| jd d S )NzRemoving weight norm...r.   )printrG   r   rJ   r1   rF   rL   r/   r   r   r   r1   >  s   


z#HifiganGenerator.remove_weight_normc                 C   sH   t j|t dd}| |d  |r"|   | jrJ |   d S d S )Ncpumap_locationmodel)ra   loadri   load_state_dictevaltrainingr1   )r   configcheckpoint_pathrs   cachestater   r   r   load_checkpointG  s   
z HifiganGenerator.load_checkpoint)r   r   TTTFr;   )FF)r3   r4   r5   r   r,   ra   no_gradrk   r1   ry   r7   r   r   r"   r   r=      s    
R#

r=   c                       &   e Zd Zd fdd	Zdd Z  ZS )SELayer   c                    sT   t t|   td| _tt||| tjddt|| |t	 | _
d S )Nr   Tinplace)r   r|   r   r   AdaptiveAvgPool2davg_pool
SequentialLinearReLUSigmoidfc)r   channel	reductionr"   r   r   r   S  s   

zSELayer.__init__c                 C   s@   |  \}}}}| |||}| |||dd}|| S )Nr   )sizer   viewr   )r   r(   br<   r_   yr   r   r   r,   ]  s   zSELayer.forward)r}   r3   r4   r5   r   r,   r7   r   r   r"   r   r|   R  s    
r|   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )	SEBasicBlockr   Nr}   c                    s~   t t|   tj||d|ddd| _t|| _tj||dddd| _t|| _	tj
dd| _t||| _|| _|| _d S )Nr   r   F)r!   strider   r@   )r!   r   r@   Tr~   )r   r   r   r   Conv2dconv1BatchNorm2dbn1conv2bn2r   relur|   se
downsampler   )r   inplanesplanesr   r   r   r"   r   r   r   g  s   
zSEBasicBlock.__init__c                 C   sj   |}|  |}| |}| |}| |}| |}| |}| jd ur*| |}||7 }| |}|S r;   )r   r   r   r   r   r   r   )r   r(   residualoutr   r   r   r,   r  s   








zSEBasicBlock.forward)r   Nr}   )r3   r4   r5   	expansionr   r,   r7   r   r   r"   r   r   d  s    r   c                    s   |  D ]\}}| vrtd| q fdd|  D } fdd|  D }|drE|jd urE|jD ]fdd|  D }q7 | tdt|t   S )Nz. | > Layer missing in the model definition: {}c                    s   i | ]\}}| v r||qS r   r   .0r   v
model_dictr   r   
<dictcomp>      z!set_init_dict.<locals>.<dictcomp>c                    s*   i | ]\}}|   |   kr||qS r   )numelr   r   r   r   r     s   * reinit_layersc                    s   i | ]\}} |vr||qS r   r   r   )reinit_layer_namer   r   r     r   z! | > {} / {} layers are restored.)itemsrl   formathasr   updaterB   )r   checkpoint_stater<   r   r   pretrained_dictr   )r   r   r   set_init_dict  s   

r   c                       r{   )PreEmphasis
ףp=
?c                    s:   t    || _| dt| j dgdd d S )Nfilterg      ?r   )r   r   coefficientregister_bufferra   FloatTensor	unsqueeze)r   r   r"   r   r   r     s   
*zPreEmphasis.__init__c                 C   sD   t | dks
J tjj|ddd}tjj|| j	dS )Nr	   r   )r   r   reflect)
rB   r   ra   r   r   rj   r   conv1dr   squeeze)r   r(   r   r   r   r,     s   zPreEmphasis.forward)r   r   r   r   r"   r   r     s    r   c                       s~   e Zd ZdZddg dg dddddf fd	d
	Zdd ZdddZdd ZdddZ				dde	de
de
fddZ  ZS )ResNetSpeakerEncoderu?   This is copied from 🐸TTS to remove it from the dependencies.@      )r         r   )    r         ASPFNc	                    s  t t|   || _|| _|| _|| _|| _|| _t	j
d|d dddd| _t	jdd| _t	|d | _|d | _| t|d |d | _| jt|d |d dd| _| jt|d	 |d	 dd| _| jt|d |d dd| _t	|| _| jrtj	t|d
 tjj|d |d |d |d tj|d d| _ nd | _ t!| jd }	t	t	j"|d |	 dddt	 t	#dt	j"d|d |	 ddt	j$d	d| _%| jdkr|d |	 }
n| jdkr|d |	 d	 }
nt&dt	'|
|| _(| )  d S )Nr   r   r   )r!   r   r   Tr~   )r	   r	   )r   r	   preemphasissample_ratefft_size
win_length
hop_lengthnum_mels)r   n_fftr   r   	window_fnn_melsr}   r   )r!   dimSAPr   zUndefined encoder)*r   r   r   encoder_type	input_dim	log_inputuse_torch_specaudio_configproj_dimr   r   r   r   r   r   r   r   create_layerr   layer1layer2layer3layer4InstanceNorm1dinstancenormra   r   r   
torchaudio
transformsMelSpectrogramhamming_window
torch_specr
   r   BatchNorm1dSoftmax	attention
ValueErrorr   r   _init_layers)r   r   r   layersnum_filtersr   r   r   r   outmap_sizeout_dimr"   r   r   r     sX   




zResNetSpeakerEncoder.__init__c                 C   s`   |   D ])}t|tjrtjj|jddd qt|tjr-tj|jd tj|j	d qd S )Nfan_outr   )modenonlinearityr   r   )
modules
isinstancer   r   initkaiming_normal_r.   r   	constant_r@   )r   mr   r   r   r     s   z!ResNetSpeakerEncoder._init_layersr   c              	   C   s   d }|dks| j ||j kr&ttj| j ||j d|ddt||j }g }||| j ||| ||j | _ td|D ]}||| j | q>tj| S )Nr   F)r!   r   r@   )r   r   r   r   r   r   rI   rK   )r   blockr   blocksr   r   r   r_   r   r   r   r     s   
z!ResNetSpeakerEncoder.create_layerc                 G   s    t tj| }t j| |S r;   )r   	Parameterra   r   r   xavier_normal_)r   r   r   r   r   r   new_parameter  s   z"ResNetSpeakerEncoder.new_parameterc                 C   sZ  | d | jr| |}| jr|d  }| |d}| |}| |}| 	|}| 
|}| |}| |}| |}|| d d| d }| |}| jdkrdtj|| dd}n,| jdkrtj|| dd}ttj|d | dd|d  jd	d
}t||fd}|| d d}| |}|rtjjj|ddd}|S )a{  Forward pass of the model.

        Args:
            x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
                to compute the spectrogram on-the-fly.
            l2_norm (bool): Whether to L2-normalize the outputs.

        Shapes:
            - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
        r   gư>r   r   r	   r   r   gh㈵>)min)pr   )squeeze_r   r   r   logr   r   r   r   r   r   r   r   r   reshaper   r   r   ra   sumsqrtclampcatr   r   r   r   	normalize)r   r(   l2_normwmusgr   r   r   r,     s4   











,
zResNetSpeakerEncoder.forwardrv   rs   use_cudac           	   
   C   s&  t |td|d}z| |d  td W n- ttfyD } z|r%|td |  }t||d }| | ~W Y d }~nd }~ww |d urpd|v rpz	||d  W n ttfyo } ztd| W Y d }~nd }~ww |r~| 	  |d ur~|	 }|r| 
  | jrJ |s||d fS |S )	Nrm   )ro   rw   rp   z > Model fully restored. z  > Partial model initialization.	criterionz% > Criterion load ignored because of:step)r   ra   ri   rr   rl   KeyErrorRuntimeError
state_dictr   cudars   rt   )	r   rv   rs   r  r  rw   rx   errorr   r   r   r   ry   ;  s>   

z$ResNetSpeakerEncoder.load_checkpoint)r   F)FFNF)r3   r4   r5   r6   r   r   r   r   r,   strboolry   r7   r   r   r"   r   r     s4    E

2r   c                       s   e Zd Zddddddg dg dg dgg dg dd	g d
d	dd	ddddddf fdd	Zedd ZdddZe dd Z	dddZ
  ZS ) HifiDecoderi"V  i]  r   i   r?   r   )r   r>      )r}   r}   r	   r	   r   )   r  r   r   Ti     i>  r   r   )r   r   r   r   r   r   c                    sf   t    || _|| _|| _|| _|| _t|d|||||
|	d|ddd|d| _t	dddd|d| _
d S )	Nr   r   F)rA   rW   rX   rY   rZ   rE   r   r   T)r   r   r   r   r   )r   r   input_sample_rateoutput_sample_rateoutput_hop_lengthar_mel_length_compressionspeaker_encoder_audio_configr=   waveform_decoderr   speaker_encoder)r   r  r  r  r  decoder_input_dimresblock_type_decoderresblock_dilation_sizes_decoderresblock_kernel_sizes_decoderupsample_rates_decoder upsample_initial_channel_decoderupsample_kernel_sizes_decoderd_vector_dim&cond_d_vector_in_each_upsampling_layerr  r"   r   r   r   h  s:   
zHifiDecoder.__init__c                 C   s   t |  jS r;   )next
parametersri   )r   r   r   r   ri     s   zHifiDecoder.deviceNc                 C   sn   t jjj|dd| j| j gddd}| j| j	kr.t jjj|| j| j	 gddd}| j
||d}|S )  
        Args:
            x (Tensor): feature input tensor (GPT latent).
            g (Tensor): global conditioning input tensor.

        Returns:
            Tensor: output waveform.

        Shapes:
            x: [B, C, T]
            Tensor: [B, 1, T]
        r   r	   linear)scale_factorr   r   rc   )ra   r   r   interpolate	transposer  r  r   r  r  r  )r   latentsrc   zrd   r   r   r   r,     s"   
zHifiDecoder.forwardc                 C   s   | j ||dS )r'  r*  )r,   )r   r<   rc   r   r   r   rk     s   zHifiDecoder.inferenceFc                 C   sx   t |tdd}|d }t| }|D ]}d|vr"d|vr"||= q| | |r:|   | jr3J | j	  d S d S )Nrm   rn   rp   zwaveform_decoder.zspeaker_encoder.)
r   ra   ri   listkeysrr   rs   rt   r  r1   )r   rv   rs   rx   states_keyskeyr   r   r   ry     s   

zHifiDecoder.load_checkpointr;   r  )r3   r4   r5   r   propertyri   r,   ra   rz   rk   ry   r7   r   r   r"   r   r  g  s8    6


r  )ra   r   r   torch.nnr   r   r   r%   torch.nn.utils.parametrizationsr   torch.nn.utils.parametrizer   TTS.utils.ior   r'   r   Moduler   r8   r=   r|   r   r   r   r   r  r   r   r   r   <module>   s*    o6 ! C