o
    
jm8                     @   sL  d dl mZ d dlmZ d dlmZmZ d dlZd dlm	Z	 d dl
m	  mZ d dlm	  m  mZ dZG dd dej	jZG dd	 d	ej	jZG d
d de	jZeG dd dZG dd deZedkre ZedddZedddZeej eeeZeej ejeg dksJ e dd e! D Z"ee" dS dS )    )	dataclass)Enum)CallableOptionalNg      @c                       sB   e Zd ZdZdddddddif fdd		Zd
d Zdd Z  ZS )KernelPredictorz7Kernel predictor for the location-variable convolutions   @           	LeakyReLUnegative_slopeg?c                    sN  t    || _|| _|| _|| _|| | | }|| }ttjj	
tj||ddddtt|	di |
| _t | _|d d }tdD ]<}| jtt|tjj	
tj||||ddtt|	di |
tjj	
tj||||ddtt|	di |
 qHtjj	
tj||||dd| _tjj	
tj||||dd| _dS )	a7  
        Args:
            cond_channels (int): number of channel for the conditioning sequence,
            conv_in_channels (int): number of channel for the input sequence,
            conv_out_channels (int): number of channel for the output sequence,
            conv_layers (int): number of layers
              T)paddingbias   r   N )super__init__conv_in_channelsconv_out_channelsconv_kernel_sizeconv_layersnn
Sequentialutilsparametrizationsweight_normConv1dgetattr
input_conv
ModuleListresidual_convsrangeappendDropoutkernel_conv	bias_conv)selfcond_channelsr   r   r   r   kpnet_hidden_channelskpnet_conv_sizekpnet_dropoutkpnet_nonlinear_activation!kpnet_nonlinear_activation_paramskpnet_kernel_channelskpnet_bias_channelsr   _	__class__r   R/home/kuhnn/.local/lib/python3.10/site-packages/TTS/tts/layers/tortoise/vocoder.pyr      sv   

			
zKernelPredictor.__init__c           
      C   s   |j \}}}| |}| jD ]}||j ||| }q| |}| |}| || j	| j
| j| j|}| || j	| j|}	||	fS )zm
        Args:
            c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
        )shaper   r!   todevicer%   r&   
contiguousviewr   r   r   r   )
r'   cbatchr0   cond_lengthresidual_convkbkernelsr   r   r   r3   forwardd   s,   



zKernelPredictor.forwardc                 C   s\   t | jd d t | jd t | j | jD ]}t |d d t |d d qd S )Nr   weightr   r   )parametrizeremove_parametrizationsr   r%   r&   r!   r'   blockr   r   r3   remove_weight_norm   s   
z"KernelPredictor.remove_weight_norm)__name__
__module____qualname____doc__r   r@   rF   __classcell__r   r   r1   r3   r      s    Tr   c                       sN   e Zd ZdZg dddddddf fdd		Zd
d ZdddZdd Z  ZS )LVCBlockz"the location-variable convolutionsr   r   	      皙?r      r   r	   c                    s   t    || _t|| _|| _t||d| t||||	|
d|id	| _t	t
|tjjtj||d| ||d |d  |d d| _t | _|D ]&}| jt	t
|tjjtj|||||d  d |dt
| qMd S )Nr   r   )	r(   r   r   r   r   r)   r*   r+   r-   )strider   output_paddingr   )r   dilation)r   r   cond_hop_lengthlenr   r   r   kernel_predictorr   r   r
   r   r   r   ConvTranspose1d	convt_prer    conv_blocksr#   r   )r'   in_channelsr(   rR   	dilationslReLU_sloper   rU   r)   r*   r+   rT   r1   r   r3   r      s\   


	zLVCBlock.__init__c              
   C   s   |j \}}}| |}| |\}}t| jD ]V\}}||}	|dd|ddddddddf }
|dd|ddddf }| j|	|
|| jd}	|t|	ddd|ddf t	|	dd|dddf   }q|S )aL  forward propagation of the location-variable convolutions.
        Args:
            x (Tensor): the input sequence (batch, in_channels, in_length)
            c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)

        Returns:
            Tensor: the output sequence (batch, in_channels, in_length)
        N)hop_size)
r4   rY   rW   	enumeraterZ   location_variable_convolutionrU   torchsigmoidtanh)r'   xr9   r0   r[   r?   r   iconvoutputr=   r>   r   r   r3   r@      s   	
(
$
zLVCBlock.forwardr   c                 C   s,  |j \}}}|j \}}}	}
}||| ksJ d|t|
d d  }t|||fdd}|d|d|  |}||k rEt|d|fdd}|d||}|ddddddddd|f }|dd}|d|
d}td	||}|jtj	d
}|
d
djtj	d
}|| }| ||	d}|S )u  perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
        Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
        Args:
            x (Tensor): the input sequence (batch, in_channels, in_length).
            kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
            bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
            dilation (int): the dilation of convolution.
            hop_size (int): the hop_size of the conditioning sequence.
        Returns:
            (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
        z$length of (x, kernel) is not matchedr   r   constantr   r   N   zbildsk,biokl->bolsd)memory_format)r4   intFpadunfold	transposera   einsumr5   channels_last_3d	unsqueezer7   r8   )r'   rd   kernelr   rT   r^   r:   r0   	in_lengthout_channelskernel_sizekernel_lengthr   or   r   r3   r`      s(   &z&LVCBlock.location_variable_convolutionc                 C   s<   | j   t| jd d | jD ]
}t|d d qd S )Nr   rA   )rW   rF   rB   rC   rY   rZ   rD   r   r   r3   rF     s
   

zLVCBlock.remove_weight_norm)r   rQ   )	rG   rH   rI   rJ   r   r@   r`   rF   rK   r   r   r1   r3   rL      s    ?
%rL   c                       sb   e Zd ZdZddg dg ddddd	f fd
d	Zdd Zd fdd	Zdd ZdddZ  Z	S )UnivNetGeneratorzw
    UnivNet Generator

    Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py.
    r       rM   )   r|   ri   rP   r   rQ   d   c	           
         s   t t|   || _|| _|| _|}|}t | _d}|D ]}	|	| }| j	t
|||	||||d qtjjtj||dddd| _tt|tjjtj|dddddt | _d S )Nr   )rR   r\   r]   rU   r*      r   reflect)r   padding_mode)r   rz   r   mel_channel	noise_dim
hop_lengthr   r    	res_stackr#   rL   r   r   r   r   conv_prer   r
   Tanh	conv_post)
r'   r   channel_sizer\   stridesr]   r*   r   n_mel_channelsrR   r1   r   r3   r     s:   

zUnivNetGenerator.__init__c                 C   s:   |  |}| jD ]}||j |||}q| |}|S )z
        Args:
            c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
            z (Tensor): the noise sequence (batch, noise_dim, in_length)

        )r   r   r5   r6   r   )r'   r9   z	res_blockr   r   r3   r@   I  s   


zUnivNetGenerator.forwardFc                    s"   t t|   |r|   d S d S N)r   rz   evalrF   )r'   	inferencer1   r   r3   r   Z  s   zUnivNetGenerator.evalc                 C   sN   t | jd | jD ]}t| dkrt |d q
| jD ]}|  qd S )NrA   r   )rB   rC   r   r   rV   
state_dictr   rF   )r'   layerr   r   r   r3   rF   `  s   


z#UnivNetGenerator.remove_weight_normNc                 C   s   t |jd | jdfd|j}t j||fdd}|d u r0t |jd | j|	d|j}| 
||}|d d d d d | jd  f }|jddd}|S )	Nr   
   g<,Ԛ'r   )dimrk   r   )minmax)ra   fullr4   r   r5   r6   catrandnr   sizer@   r   clamp)r'   r9   r   zeromelaudior   r   r3   r   j  s   "$"zUnivNetGenerator.inference)Fr   )
rG   rH   rI   rJ   r   r@   r   rF   r   rK   r   r   r1   r3   rz     s    -
rz   c                   @   s@   e Zd ZU eg ejf ed< eed< dZe	e ed< dd Z
dS )VocTypeconstructor
model_pathNsubkeyc                 C   s   | j d ur
|| j  S |S r   )r   )r'   
model_dictr   r   r3   optionally_index  s   

zVocType.optionally_index)rG   rH   rI   r   r   Module__annotations__strr   r   r   r   r   r   r3   r   y  s
   
 r   c                   @   s   e Zd ZeeddZdS )VocConfzvocoder.pthmodel_gN)rG   rH   rI   r   rz   Univnetr   r   r   r3   r     s    r   __main__r   r}   r   r   )r   r   i 
  c                 c   s    | ]
}|j r| V  qd S r   )requires_gradnumel).0pr   r   r3   	<genexpr>  s    r   )#dataclassesr   enumr   typingr   r   ra   torch.nnr   torch.nn.functional
functionalrm   torch.nn.utils.parametrizer   rB   MAX_WAV_VALUEr   r   rL   rz   r   r   rG   modelr   r9   r   printr4   ySizesum
parameterspytorch_total_paramsr   r   r   r3   <module>   s4    } d


