o
    
j~                     @   s   d dl Z d dlZd dlZd dlZd dlmZmZmZ d dlZ	d dl
Z
d dlZd dlmZ d dlmZmZmZ d dlmZ d dlmZ e
jd dd	 Zd
d Zdd ZG dd deZG dd deZG dd dZG dd dZdS )    N)DictListUnion)Dataset)prepare_dataprepare_stop_targetprepare_tensor)AudioProcessor)compute_energyfile_systemc                 C   sl   d }d }t | dkr| \}}}}}nt | dkr| \}}}}nt | dkr+| \}}}ntd|||||fS )N         z% [!] Dataset cannot parse the sample.)len
ValueError)itemlanguage_name	attn_filetextwav_filespeaker_name r   K/home/kuhnn/.local/lib/python3.10/site-packages/TTS/tts/datasets/dataset.py_parse_sample   s   r   c                 C   s   | dt jj| j   S )Ng       ?)nprandomrandshape)wavr   r   r   noise_augment_audio#   s   r   c                 C   s   t | ddd}|S )Nzutf-8ignore)base64urlsafe_b64encodeencodedecode)stringfilenamer   r   r   string2filename'   s   r'   c                /       s  e Zd Zddddddddddddeddedddddddddfdededed	ee d
ddedede	de	dedededededede	dededededededef. fddZ
ed d! Zed"d# Zejd$d# Zd%d& Zd'd( ZdNd)ed*dfd+d,Zd-d. Zd/d0 Zd1d2 Zd3d4 Zed5d6 Zd7d8 Zd9d: Zed;d< Zed=ee d>ed?efd@dAZed	ee fdBdCZedefdDdEZedFdG ZdHdI ZedJdK Z dLdM Z!  Z"S )O
TTSDataset   FNr   infoutputs_per_stepcompute_linear_specapsamples	tokenizerTTSTokenizer
compute_f0r
   f0_cache_pathenergy_cache_path
return_wavbatch_group_sizemin_text_lenmax_text_lenmin_audio_lenmax_audio_lenphoneme_cache_pathprecompute_num_workersspeaker_id_mappingd_vector_mappinglanguage_id_mappinguse_noise_augmentstart_by_longestverbosec                    s  t    || _|| _|| _|| _|
| _|| _|| _|| _	|	| _
|| _|| _|| _|| _|| _|| _|| _|| _|| _|| _|| _|| _d| _d| _|| _| jjr\t| j| j||d| _|rit| j| j||d| _|rvt | j| j|	|d| _!| jr| "  dS dS )u8  Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.

        If you need something different, you can subclass and override.

        Args:
            outputs_per_step (int): Number of time frames predicted per step.

            compute_linear_spec (bool): compute linear spectrogram if True.

            ap (TTS.tts.utils.AudioProcessor): Audio processor object.

            samples (list): List of dataset samples.

            tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
                use the given. Defaults to None.

            compute_f0 (bool): compute f0 if True. Defaults to False.

            compute_energy (bool): compute energy if True. Defaults to False.

            f0_cache_path (str): Path to store f0 cache. Defaults to None.

            energy_cache_path (str): Path to store energy cache. Defaults to None.

            return_wav (bool): Return the waveform of the sample. Defaults to False.

            batch_group_size (int): Range of batch randomization after sorting
                sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
                batch. Set 0 to disable. Defaults to 0.

            min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored.
                Defaults to 0.

            max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored.
                Defaults to float("inf").

            min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored.
                Defaults to 0.

            max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored.
                The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to
                this value if you encounter an OOM error in training. Defaults to float("inf").

            phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
                separate file. Defaults to None.

            precompute_num_workers (int): Number of workers to precompute features. Defaults to 0.

            speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
                embedding layer. Defaults to None.

            d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None.

            use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.

            start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.

            verbose (bool): Print diagnostic information. Defaults to false.
        r)   F)r;   )
cache_pathr;   N)#super__init__r5   _samplesr+   r,   r4   r1   r
   r2   r3   r8   r9   r6   r7   r-   r:   r<   r=   r>   r?   r@   rA   rescue_item_idxpitch_computedr/   use_phonemesPhonemeDatasetr.   phoneme_dataset	F0Dataset
f0_datasetEnergyDatasetenergy_dataset
print_logs)selfr+   r,   r-   r.   r/   r1   r
   r2   r3   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   r?   r@   rA   	__class__r   r   rD   .   sP   
UzTTSDataset.__init__c                 C   s@   g }| j D ]}t|^}}}tj|d d }|| q|S )N      )r.   r   ospathgetsizeappend)rP   lensr   _r   	audio_lenr   r   r   lengths   s   
zTTSDataset.lengthsc                 C      | j S N)rE   rP   r   r   r   r.      s   zTTSDataset.samplesc                 C   sD   || _ t| dr|| j_t| dr|| j_t| dr || j_d S d S )NrL   rN   rJ   )rE   hasattrrL   r.   rN   rJ   )rP   new_samplesr   r   r   r.      s   


c                 C   
   t | jS r^   r   r.   r_   r   r   r   __len__      
zTTSDataset.__len__c                 C   s
   |  |S r^   )	load_data)rP   idxr   r   r   __getitem__   re   zTTSDataset.__getitem__levelreturnc                 C   X   d| }t d t | d t | d | j|d  t | dt| j  d S )N	
z> DataLoader initialization| > Tokenizer:r)   | > Number of instances : printr/   rO   r   r.   rP   ri   indentr   r   r   rO         zTTSDataset.print_logsc                 C   s   | j |}|jdksJ |S Nr   )r-   load_wavsize)rP   r&   waveformr   r   r   rv      s   zTTSDataset.load_wavc                 C   sD   | j | }||d ksJ | d|d  t|d dks J |S )Nr   z != 	token_idsr   )rJ   r   )rP   rg   r   out_dictr   r   r   get_phonemes   s   
"zTTSDataset.get_phonemesc                 C   ,   | j | }| j| }|d |d ksJ |S Naudio_unique_name)rL   r.   rP   rg   rz   r   r   r   r   get_f0      

zTTSDataset.get_f0c                 C   r|   r}   )rN   r.   r   r   r   r   
get_energy   r   zTTSDataset.get_energyc                 C   s
   t | S r^   )r   load)r   r   r   r   get_attn_mask   s   
zTTSDataset.get_attn_maskc                 C   s6   | j jr| ||d }n| j |}tj|tjdS )Nry   dtype)r/   rH   r{   text_to_idsr   arrayint32)rP   rg   r   ry   r   r   r   get_token_ids   s   zTTSDataset.get_token_idsc           
      C   s  | j | }|d }tj| |d tjd}| jrt|}| ||d }d }d|v r2| |d }t	|| j
ks@t	|| jk rM|  jd7  _| | jS d }| jrY| |d }d }| jre| |d }|||||||d |d |d	 tj|d |d
 d}	|	S )Nr   
audio_filer   alignment_filer)   f0energyr   languager~   )raw_textry   r   pitchr   attnitem_idxr   r   wav_file_namer~   )r.   r   asarrayrv   float32r?   r   r   r   r   r7   r8   rF   rf   r1   r   r
   r   rU   rV   basename)
rP   rg   r   r   r   ry   r   r   r   sampler   r   r   rf      s>   
zTTSDataset.load_datac                 C   sP   g }| D ]!}t j|d d d }t|d }||d< ||d< ||g7 }q|S )Nr   rS   rT   r   audio_lengthtext_length)rU   rV   rW   r   )r.   ra   r   r   text_lenghtr   r   r   _compute_lengths&  s   zTTSDataset._compute_lengthsr\   min_lenmax_lenc                 C   sR   t | }g }g }|D ]}| | }||k s||kr|| q|| q||fS r^   )r   argsortrX   )r\   r   r   idxs
ignore_idxkeep_idxrg   lengthr   r   r   filter_by_length1  s   
zTTSDataset.filter_by_lengthc                 C   s   dd | D }t |}|S )Nc                 S      g | ]}|d  qS r   r   .0sr   r   r   
<listcomp>@      z-TTSDataset.sort_by_length.<locals>.<listcomp>)r   r   )r.   audio_lengthsr   r   r   r   sort_by_length>  s   
zTTSDataset.sort_by_lengthc                 C   sX   |dksJ t t| | D ]}|| }|| }| || }t| || ||< q| S ru   )ranger   r   shuffle)r.   r5   ioffset
end_offset
temp_itemsr   r   r   create_bucketsD  s   
zTTSDataset.create_bucketsc                 C   s    g }| D ]	}| ||  q|S r^   )rX   )r   r.   samples_newrg   r   r   r   _select_samples_by_idxO  s   z!TTSDataset._select_samples_by_idxc                 C   s  |  | j}dd |D }dd |D }| || j| j\}}| || j| j\}}tt|t|@ }tt|t|B }	| 	||}| 
|}
| jrZ|
d }|
d |
d< ||
d< | 	|
|}t|dkrjtd| jdkrv| || j}dd |D }dd |D }|| _| jrtd	 td
t| tdt| tdt| td tdt| tdt| tdt| tdt|	  td| j dS dS )zSort `items` based on text length or audio length in ascending order. Filter out samples out or the length
        range.
        c                 S   r   r   r   r   r   r   r   r   r   ]  r   z1TTSDataset.preprocess_samples.<locals>.<listcomp>c                 S   r   r   r   r   r   r   r   r   ^  r   r   z [!] No samples leftc                 S   r   r   r   r   r   r   r   r   y  r   c                 S   r   r   r   r   r   r   r   r   z  r   z | > Preprocessing samplesz | > Max text length: {}z | > Min text length: {}z | > Avg text length: {}z | z | > Max audio length: {}z | > Min audio length: {}z | > Avg audio length: {}z' | > Num. instances discarded samples: z | > Batch group size: {}.N)r   r.   r   r6   r7   r8   r9   listsetr   r   r@   r   RuntimeErrorr5   r   rA   rq   formatr   maxminmean)rP   r.   text_lengthsr   text_ignore_idxtext_keep_idxaudio_ignore_idxaudio_keep_idxr   r   sorted_idxslongest_idxsr   r   r   preprocess_samplesV  sB   

zTTSDataset.preprocess_samplesc                    s6   t jt |ddd\}} fdd|D   ||fS )zSort the batch by the input text length for RNN efficiency.

        Args:
            batch (Dict): Batch returned by `__getitem__`.
            text_lengths (List[int]): Lengths of the input character sequences.
        r   T)dim
descendingc                    s   g | ]} | qS r   r   r   rg   batchr   r   r     r   z*TTSDataset._sort_batch.<locals>.<listcomp>)torchsort
LongTensor)r   r   ids_sorted_decreasingr   r   r   _sort_batch  s   
zTTSDataset._sort_batchc                    s
  t  d tjjrztdd  D } |\ }} fdd d D  jdur9fdd d D }nd}jdurPt	 d	 }fd
d|D }nd}j
rafdd d D }nd}fdd d D }dd |D }	fdd|D }
dd |	D }t|j}t d tj}t|j}|ddd}t|}t|}t| }t|	}	t|}|durt|}|durt|}|durt|}d}jr
fdd d D }t|j}|ddd}|jd |jd ksJ t| }d}jrqdd  d D }t|
jj }t|}tt d d|}t d D ]3\}}|
| }tj|djjj fdd}|d|jj  }t |||ddd|jd f< q7|!dd j"rt d }|jd |jd ksJ d|j d|j t|dddddf  }nd}j#rt d }|jd |jd ksJ d|j d|j t|dddddf  }nd}d} d d dur= fdd|D }t|D ]=\}}|jd |jd  }|jd |jd  }|dkr|dksJ d| d | t|d|gd|gg}|||< qt|j}t|$d}i d!|d"|d# d d$|d%|d&|	d'|d( d) d*|d+|d,|d-|d. d. d|d|d/|d0 d	 S t%d1&t' d )2z
        Perform preprocessing and create a final data batch:
        1. Sort batch instances by text-length
        2. Convert Audio signal to features.
        3. PAD sequences wrt r.
        4. Load to Torch.
        r   c                 S      g | ]}t |d  qS ry   r   )r   dr   r   r   r         z)TTSDataset.collate_fn.<locals>.<listcomp>c                    s    i | ]   fd dD qS )c                    s   g | ]}|  qS r   r   )r   dickr   r   r     r   z4TTSDataset.collate_fn.<locals>.<dictcomp>.<listcomp>r   )r   r   r   r   
<dictcomp>  s     z)TTSDataset.collate_fn.<locals>.<dictcomp>Nc                       g | ]} j | qS r   )r>   )r   lnr_   r   r   r         r   r~   c                    s   g | ]	} j | d  qS )	embedding)r=   r   wr_   r   r   r         c                    r   r   )r<   )r   snr_   r   r   r     r   r   c                       g | ]} j |d qS r   )r-   melspectrogramastyper   r_   r   r   r         r   c                 S      g | ]}|j d  qS r)   r   r   mr   r   r   r     r   c                    sF   g | ]}|j d   j r|j d   j|j d   j   n|j d  qS r   )r   r+   r   r_   r   r   r     s    "c                 S   s&   g | ]}t d g|d  dg qS )        r)   g      ?)r   r   )r   mel_lenr   r   r   r     s   & ry      r)   c                    r   r   )r-   spectrogramr   r   r_   r   r   r     r   c                 S   r   r   r   r   r   r   r   r     r   edge)moder   z[!] z vs r   r   c                    s   g | ]	} d  | j qS )r   )Tr   r   r   r   r     r   z[!] Negative padding - z and token_idtoken_id_lengthsspeaker_nameslinearmelmel_lengthsstop_targets	item_idxsr   	d_vectorsspeaker_idsattnsrx   r   language_idsaudio_unique_nameszUbatch must contain tensors, numbers, dicts or lists;                         found {})(
isinstancecollectionsabcMappingr   r   r   r>   r=   r   r<   r   r+   r   r   r   r   	transposer   r   FloatTensor
contiguousr,   r   r4   r   r-   
hop_lengthzerosr   	enumeratepad
from_numpy
transpose_r1   r
   	unsqueeze	TypeErrorr   type)rP   r   token_ids_lengthsr   r  embedding_keysr   r   r   r   mel_lengths_adjustedr   ry   r   
wav_paddedwav_lengthsmax_wav_lenr   r   
mel_lengthr   r   r   rg   r   pad2pad1r   )r   rP   r   
collate_fn  s   











&."."(	

zTTSDataset.collate_fnr   )#__name__
__module____qualname__floatintboolr	   r   r   strrD   propertyr\   r.   setterrd   rh   rO   rv   r{   r   r   staticmethodr   r   rf   r   r   r   r   r   r   r   r  __classcell__r   r   rQ   r   r(   -   s    	
 


	
1



3
r(   c                   @   s   e Zd ZdZ	ddeee ee f dddefddZd	d
 Z	dd Z
dd Zdd ZdddZdd ZddeddfddZdS )rI   a  Phoneme Dataset for converting input text to phonemes and then token IDs

    At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
    loading latency. If `cache_path` is already present, it skips the pre-computation.

    Args:
        samples (Union[List[List], List[Dict]]):
            List of samples. Each sample is a list or a dict.

        tokenizer (TTSTokenizer):
            Tokenizer to convert input text to phonemes.

        cache_path (str):
            Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation.

        precompute_num_workers (int):
            Number of workers used for pre-computing the phonemes. Defaults to 0.
    r   r.   r/   r0   rB   c                 C   sF   || _ || _|| _|d urtj|s!t| | | d S d S d S r^   )r.   r/   rB   rU   rV   existsmakedirs
precompute)rP   r.   r/   rB   r;   r   r   r   rD   J  s   
zPhonemeDataset.__init__c                 C   sJ   | j | }| t|d |d |d }| j|}|d ||t|dS )Nr~   r   r   )r   ph_hatry   token_ids_len)r.   compute_or_loadr'   r/   ids_to_textr   )rP   indexr   idsr+  r   r   r   rh   X  s   
zPhonemeDataset.__getitem__c                 C   rb   r^   rc   r_   r   r   r   rd   ^  re   zPhonemeDataset.__len__c                 C   sZ   d}t j| j|| }zt|}W |S  ty,   | jj||d}t	|| Y |S w )zpCompute phonemes for the given text.

        If the phonemes are already cached, load them from cache.
        z_phoneme.npy)r   )
rU   rV   joinrB   r   r   FileNotFoundErrorr/   r   save)rP   	file_namer   r   file_extrB   r0  r   r   r   r-  a  s   zPhonemeDataset.compute_or_loadc                 C   s   | j jS )z%Get pad token ID for sequence padding)r/   pad_idr_   r   r   r   
get_pad_ido  s   zPhonemeDataset.get_pad_idr)   c                 C   s|   t d tjt| d(}|dkr|nd}tjjj|| d|| jd}|D ]}|| q$W d   dS 1 s7w   Y  dS )zePrecompute phonemes for all samples.

        We use pytorch dataloader because we are lazy.
        z[*] Pre-computing phonemes...totalr   r)   F
batch_sizedatasetr   num_workersr  N)	rq   tqdmr   r   utilsdata
DataLoaderr  update)rP   r=  pbarr;  	dataloderrZ   r   r   r   r*  s  s   "zPhonemeDataset.precomputec           
      C   s   dd |D }dd |D }dd |D }dd |D }t |}tt|||  }t|D ]\}}	t|| ||d |	f< q1|||dS )Nc                 S   r   r   r   r   r   r   r   r   r     r   z-PhonemeDataset.collate_fn.<locals>.<listcomp>c                 S   r   )r,  r   rE  r   r   r   r     r   c                 S   r   )r   r   rE  r   r   r   r     r   c                 S   r   )r+  r   rE  r   r   r   r     r   )r   r+  ry   r   r   r   r   fill_r7  r  )
rP   r   r0  ids_lenstexts	texts_hatids_lens_max	ids_torchr   ids_lenr   r   r   r    s   zPhonemeDataset.collate_fnri   rj   Nc                 C   rk   )Nrl   rm   z> PhonemeDataset rn   r)   ro   rp   rr   r   r   r   rO     rt   zPhonemeDataset.print_logsr   r   )r  r  r  __doc__r   r   r   r#  rD   rh   rd   r-  r7  r*  r  r!  rO   r   r   r   r   rI   6  s"    

rI   c                   @   s   e Zd ZdZ					d(deee ee f ddd	efd
dZdd Z	dd Z
d)ddZdd Zedd Zed*ddZedd Zdd Zdd Zdd Zd d! Zd"d# Zd)d$ed%dfd&d'ZdS )+rK   aT  F0 Dataset for computing F0 from wav files in CPU

    Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
    also computes the mean and std of F0 values if `normalize_f0` is True.

    Args:
        samples (Union[List[List], List[Dict]]):
            List of samples. Each sample is a list or a dict.

        ap (AudioProcessor):
            AudioProcessor to compute F0 from wav files.

        cache_path (str):
            Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation.
            Defaults to None.

        precompute_num_workers (int):
            Number of workers used for pre-computing the F0 values. Defaults to 0.

        normalize_f0 (bool):
            Whether to normalize F0 values by mean and std. Defaults to True.
    NFr   Tr.   r-   r	   rB   c                 C   sn   || _ || _|| _|| _|| _d| _d | _d | _|d ur,tj	
|s,t| | | |r5| | d S d S Nr   )r.   r-   rA   rB   normalize_f0r6  r   stdrU   rV   r(  r)  r*  
load_stats)rP   r.   r-   audio_configrA   rB   r;   rP  r   r   r   rD     s   


zF0Dataset.__init__c                 C   \   | j | }| |d t|d }| jr'| jd ur| jd us"J d| |}|d |dS )Nr   r~   " [!] Mean and STD is not available)r~   r   )r.   r-  r'   rP  r   rQ  	normalize)rP   rg   r   r   r   r   r   rh        

zF0Dataset.__getitem__c                 C   rb   r^   rc   r_   r   r   r   rd     re   zF0Dataset.__len__c                 C      t d tjt| d@}|dkr|nd}| j}d| _tjjj|| d|| jd}g }|D ]}|d }|	dd	 |D  |
| q,|| _W d    n1 sOw   Y  | jrzd
d |D }| |\}	}
|	|
d}tjtj| jd|dd d S d S )Nz[*] Pre-computing F0s...r8  r   r)   Fr:  r   c                 s       | ]}|V  qd S r^   r   )r   fr   r   r   	<genexpr>      z'F0Dataset.precompute.<locals>.<genexpr>c                 S      g | ]	}|D ]}|qqS r   r   r   r   tensorr   r   r   r     r   z(F0Dataset.precompute.<locals>.<listcomp>r   rQ  pitch_statsTallow_pickle)rq   r>  r   rP  r   r?  r@  rA  r  rX   rB  compute_pitch_statsr   r3  rU   rV   r1  rB   )rP   r=  rC  r;  rP  rD  computed_datar   r   
pitch_mean	pitch_stdra  r   r   r   r*    *   
 zF0Dataset.precomputec                 C   r]   r^   r6  r_   r   r   r   r7       zF0Dataset.get_pad_idc                 C   s   t j|| d }|S )Nz
_pitch.npy)rU   rV   r1  )r4  rB   
pitch_filer   r   r   create_pitch_file_path  s   z F0Dataset.create_pitch_file_pathc                 C   s(   |  |}| |}|rt|| |S r^   )rv   r1   r   r3  )r-   r   rk  r   r   r   r   r   _compute_and_save_pitch  s
   

z!F0Dataset._compute_and_save_pitchc                 C   2   t dd | D }t |t |}}||fS )Nc                 S   "   g | ]}|t |d kd  qS r   r   r   wherer   vr   r   r   r        " z1F0Dataset.compute_pitch_stats.<locals>.<listcomp>r   concatenater   rQ  )
pitch_vecsnonzerosr   rQ  r   r   r   rd       zF0Dataset.compute_pitch_statsc                 C   H   t j|d}tj|dd }|d tj| _|d tj| _	d S )Nzpitch_stats.npyTrb  r   rQ  
rU   rV   r1  r   r   r   r   r   r   rQ  rP   rB   
stats_pathstatsr   r   r   rR       zF0Dataset.load_statsc                 C   2   t |dkd }|| j }|| j }d||< |S Nr   r   r   rr  r   rQ  rP   r   	zero_idxsr   r   r   rV    
   

zF0Dataset.normalizec                 C   2   t |dkd }|| j9 }|| j7 }d||< |S r  r   rr  rQ  r   r  r   r   r   denormalize
  r  zF0Dataset.denormalizec                 C   B   |  || j}tj|s| | j||}nt|}|	tj
S )zH
        compute pitch and return a numpy array of pitch values
        )rl  rB   rU   rV   r(  rm  r-   r   r   r   r   )rP   r   r~   rk  r   r   r   r   r-    
   
zF0Dataset.compute_or_loadc           	      C      dd |D }dd |D }dd |D }t |}tt|||  }t|D ]\}}t|| ||d |f< q*|||dS )Nc                 S   r   r~   r   rE  r   r   r   r     r   z(F0Dataset.collate_fn.<locals>.<listcomp>c                 S   r   r   r   rE  r   r   r   r     r   c                 S   r   r  r   rE  r   r   r   r     r   )r~   r   f0_lensrF  )	rP   r   r~   f0sr  f0_lens_max	f0s_torchr   f0_lenr   r   r   r       zF0Dataset.collate_fnri   rj   c                 C   :   d| }t d t | d t | dt| j  d S )Nrl   rm   z> F0Dataset ro   rq   r   r.   rr   r   r   r   rO   &     zF0Dataset.print_logs)NFNr   Tr   r^   )r  r  r  rN  r   r   r   r#  rD   rh   rd   r*  r7  r&  rl  rm  rd  rR  rV  r  r-  r  r!  rO   r   r   r   r   rK     s<    




rK   c                   @   s   e Zd ZdZ				d(deee ee f ddd	efd
dZdd Z	dd Z
d)ddZdd Zedd Zed*ddZedd Zdd Zdd Zdd Zd d! Zd"d# Zd)d$ed%dfd&d'ZdS )+rM   a|  Energy Dataset for computing Energy from wav files in CPU

    Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It
    also computes the mean and std of Energy values if `normalize_Energy` is True.

    Args:
        samples (Union[List[List], List[Dict]]):
            List of samples. Each sample is a list or a dict.

        ap (AudioProcessor):
            AudioProcessor to compute Energy from wav files.

        cache_path (str):
            Path to cache Energy values. If `cache_path` is already present or None, it skips the pre-computation.
            Defaults to None.

        precompute_num_workers (int):
            Number of workers used for pre-computing the Energy values. Defaults to 0.

        normalize_Energy (bool):
            Whether to normalize Energy values by mean and std. Defaults to True.
    FNr   Tr.   r-   r	   rB   c                 C   sn   || _ || _|| _|| _|| _d| _d | _d | _|d ur,tj	
|s,t| | | |r5| | d S d S rO  )r.   r-   rA   rB   normalize_energyr6  r   rQ  rU   rV   r(  r)  r*  rR  )rP   r.   r-   rA   rB   r;   r  r   r   r   rD   E  s   	

zEnergyDataset.__init__c                 C   rT  )Nr   r~   rU  )r~   r   )r.   r-  r'   r  r   rQ  rV  )rP   rg   r   r   r   r   r   rh   \  rW  zEnergyDataset.__getitem__c                 C   rb   r^   rc   r_   r   r   r   rd   d  re   zEnergyDataset.__len__c                 C   rX  )Nz[*] Pre-computing energys...r8  r   r)   Fr:  r   c                 s   rY  r^   r   )r   er   r   r   r[  t  r\  z+EnergyDataset.precompute.<locals>.<genexpr>c                 S   r]  r   r   r^  r   r   r   r   y  r   z,EnergyDataset.precompute.<locals>.<listcomp>r`  energy_statsTrb  )rq   r>  r   r  r   r?  r@  rA  r  rX   rB  compute_energy_statsr   r3  rU   rV   r1  rB   )rP   r=  rC  r;  r  rD  re  r   r   energy_mean
energy_stdr  r   r   r   r*  g  rh  zEnergyDataset.precomputec                 C   r]   r^   ri  r_   r   r   r   r7  ~  rj  zEnergyDataset.get_pad_idc                 C   s.   t jt j| d }t j||d }|S )Nr   z_energy.npy)rU   rV   splitextr   r1  )r   rB   r4  energy_filer   r   r   create_energy_file_path  s   z%EnergyDataset.create_energy_file_pathc                 C   s4   |  |}t|| j| j| jd}|rt|| |S )N)fft_sizer
  
win_length)rv   calculate_energyr  r
  r  r   r3  )r-   r   r  r   r   r   r   r   _compute_and_save_energy  s
   
z&EnergyDataset._compute_and_save_energyc                 C   rn  )Nc                 S   ro  rp  rq  rs  r   r   r   r     ru  z6EnergyDataset.compute_energy_stats.<locals>.<listcomp>rv  )energy_vecsry  r   rQ  r   r   r   r    rz  z"EnergyDataset.compute_energy_statsc                 C   r{  )Nzenergy_stats.npyTrb  r   rQ  r|  r}  r   r   r   rR    r  zEnergyDataset.load_statsc                 C   r  r  r  rP   r   r  r   r   r   rV    r  zEnergyDataset.normalizec                 C   r  r  r  r  r   r   r   r    r  zEnergyDataset.denormalizec                 C   r  )zJ
        compute energy and return a numpy array of energy values
        )r  rB   rU   rV   r(  r  r-   r   r   r   r   )rP   r   r~   r  r   r   r   r   r-    r  zEnergyDataset.compute_or_loadc           	      C   r  )Nc                 S   r   r  r   rE  r   r   r   r     r   z,EnergyDataset.collate_fn.<locals>.<listcomp>c                 S   r   r   r   rE  r   r   r   r     r   c                 S   r   r  r   rE  r   r   r   r     r   )r~   r   energy_lensrF  )	rP   r   r~   energysr  energy_lens_maxenergys_torchr   
energy_lenr   r   r   r    r  zEnergyDataset.collate_fnri   rj   c                 C   r  )Nrl   rm   z> energyDataset ro   r  rr   r   r   r   rO     r  zEnergyDataset.print_logs)FNr   Tr   r^   )r  r  r  rN  r   r   r   r#  rD   rh   rd   r*  r7  r&  r  r  r  rR  rV  r  r-  r  r!  rO   r   r   r   r   rM   -  s:    




rM   ) r!   r  rU   r   typingr   r   r   numpyr   r   r>  torch.utils.datar   TTS.tts.utils.datar   r   r   TTS.utils.audior	    TTS.utils.audio.numpy_transformsr
   r  multiprocessingset_sharing_strategyr   r   r'   r(   rI   rK   rM   r   r   r   r   <module>   s2        _ 