o
    
j                     @   s   d Z ddlZddlmZ ddlmZmZ ddlZddl	m
Z
 eejZdZdZdZed	Zh d
Zeeeg dZdddddddZedZG dd deZG dd dZdS )zPython implementation of libtashkeel.

See: https://github.com/mush42/libtashkeel

Ported with the help of ChatGPT 2025-05-01.
    N)Path)OptionalUnion)InferenceSessioni.  _#u   0123456789٠١٢٣٤٥٦٧٨٩>      ٌ   ٍ   َ   ُ   ِ   ّ   ْ)R  iQ  iN  iO  iP  iK  iL  iM  u   َّu   ًّu   ُّu   ٌّu   ِّu   ٍّ)u   َّu   ًّu   ُّu   ٌّu   ِّu   ٍّr   c                   @   s   e Zd ZdZdS )TashkeelErrorzError for tashkeel.N)__name__
__module____qualname____doc__ r   r   J/home/kuhnn/.local/lib/python3.10/site-packages/piper/tashkeel/__init__.pyr      s    r   c                   @   s|  e Zd ZdZefdeeef ddfddZd*dede	e
 defd	d
Zd*dedefddZdee dee dedeee ee
 f fddZdedee dee defddZdedee dee dee
 de
defddZdefddZ	d+dededeeee f fdd Zdedeeee f fd!d"Zdedee fd#d$Zdee dee fd%d&Zd'ee dee fd(d)ZdS ),TashkeelDiacritizerz0Add diacritics for Arabic text with libtashkeel.	model_dirreturnNc                    s   t |}t|d | _t|d ddd}t|| _W d   n1 s%w   Y  t|d ddd}t| dd	   D | _W d   n1 sMw   Y   fd
dt	fD | _
t|d ddd}t|| _W d   dS 1 sxw   Y  dS )zInitialize diacritizer.z
model.onnxzinput_id_map.jsonrzutf-8)encodingNztarget_id_map.jsonc                 S   s   i | ]\}}||qS r   r   ).0cir   r   r   
<dictcomp>0   s    
z0TashkeelDiacritizer.__init__.<locals>.<dictcomp>c                    s   h | ]} | qS r   r   r   r   target_id_mapr   r   	<setcomp>4   s    z/TashkeelDiacritizer.__init__.<locals>.<setcomp>zhint_id_map.json)r   r   sessionopenjsonloadinput_id_mapitemsid_target_mapPADtarget_id_meta_charshint_id_map)selfr   input_id_map_filetarget_id_map_filehint_id_map_filer   r!   r   __init__!   s0   




"zTashkeelDiacritizer.__init__texttaskeen_thresholdc                 C   s
   |  |S )!Add diacritics using libtashkeel.)
diacritize)r.   r3   r4   r   r   r   __call__;   s   
zTashkeelDiacritizer.__call__c                 C   s   |  }t|tkrtdt | |\}}| j|dd\}}| |}| |}t|}|dkr5|S | |||\}	}
| 	|	}|du rN| 
|||S | ||||
|S )r5   zText length cannot exceed T)normalize_diacriticsr   N)striplen
CHAR_LIMITr   _to_valid_chars_extract_chars_and_diacritics_input_to_ids_hint_to_ids_infer_target_to_diacritics_annotate_text_with_diacritics&_annotate_text_with_diacritics_taskeen)r.   r3   r4   
input_textremoved_chars
diacritics	input_idsdiac_ids
seq_length
target_idslogitsr   r   r   r6   ?   s&   




zTashkeelDiacritizer.diacritizerG   rH   rI   c                 C   s   t j|t jdd|}t j|t jdd|}t j|gt jdd}|||d}| jd|}|d  t j	 }	|d  t j
	 }
|	|
fS )zInfer target ids and logits.)dtype   )char_inputsdiac_inputsinput_lengthsNr   )nparrayint64reshaper$   runflattenastypeuint8tolistfloat32)r.   rG   rH   rI   input_ids_arrdiac_ids_arrinput_len_arrinputsoutputsrJ   rK   r   r   r   r@   \   s   zTashkeelDiacritizer._inferrD   rF   rE   c                 C   sZ   g }t |}|D ]}| |rq||v r|| q|| |t|d qd|S N )iter_is_diacritic_charappendnextjoin)r.   rD   rF   rE   output	diac_iterr   r   r   r   rB   s   s   


z2TashkeelDiacritizer._annotate_text_with_diacriticsrK   	thresholdc                 C   sx   g }t ||}|D ]-}| |rq	||v r|| q	|| t|d\}	}
|
|kr1|t q	||	 q	d|S )N)ra   g        ra   )ziprc   rd   re   SUKOONrf   )r.   rD   rF   rE   rK   ri   rg   rh   r   diaclogitr   r   r   rC      s   



z:TashkeelDiacritizer._annotate_text_with_diacritics_taskeenc                 C   s   |t v S N)ARABIC_DIACRITICS)r.   r   r   r   r   rc      s   z&TashkeelDiacritizer._is_diacritic_charTr8   c           	      C   s   | dt}g }g }d}t|dg D ]}| |r!||7 }q|| || d}q|r4|  |r;|d |rSt|D ]\}}|| jvrRt	
|d||< qAd||fS )Nra    r   )lstriprf   ro   listrc   rd   pop	enumerater-   NORMALIZED_DIAC_MAPget)	r.   r3   r8   clean_charsrF   pending_diacr   r   dr   r   r   r=      s(   





z1TashkeelDiacritizer._extract_chars_and_diacriticsc                 C   s^   g }t  }|D ] }|| jv s|tv r|| q|tv r"|t q|| qd||fS r`   )setr(   ro   rd   NUMERALSNUMERAL_SYMBOLaddrf   )r.   r3   validinvalidr   r   r   r   r<      s   z#TashkeelDiacritizer._to_valid_charsc                        fdd|D S )Nc                       g | ]} j | qS r   )r(   r    r.   r   r   
<listcomp>       z5TashkeelDiacritizer._input_to_ids.<locals>.<listcomp>r   )r.   r3   r   r   r   r>         z!TashkeelDiacritizer._input_to_idsc                    r   )Nc                    r   r   )r-   )r   ry   r   r   r   r      r   z4TashkeelDiacritizer._hint_to_ids.<locals>.<listcomp>r   )r.   rF   r   r   r   r?      r   z TashkeelDiacritizer._hint_to_idsrJ   c                    r   )Nc                    s    g | ]}| j vr j| qS r   )r,   r*   )r   r   r   r   r   r      s
    
z=TashkeelDiacritizer._target_to_diacritics.<locals>.<listcomp>r   )r.   rJ   r   r   r   rA      s   
z)TashkeelDiacritizer._target_to_diacriticsrn   )T)r   r   r   r   TASHKEEL_DIRr   strr   r2   r   floatr7   r6   rr   inttupler@   rz   rB   rC   boolrc   r=   r<   r>   r?   rA   r   r   r   r   r      sb    



r   )r   r&   pathlibr   typingr   r   numpyrQ   onnxruntimer   __file__parentr   r;   r+   r|   rz   r{   HARAKAT_CHARSmapchrro   ru   rk   	Exceptionr   r   r   r   r   r   <module>   s"    
