#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""Speech-friendly Chinese briefing text cleaner.

Goals:
- Reduce OOV for offline Chinese TTS (sherpa-onnx Matcha/VITS, Piper)
- Make text more conversational and easier to read aloud
- Normalize years/dates/percentages and punctuation/breathing

Usage:
  brief_text_clean.py --in in.txt --out out.txt
"""

from __future__ import annotations

import argparse
import json
import os
import re

DIGIT_ZH = {
    "0": "零",
    "1": "一",
    "2": "二",
    "3": "三",
    "4": "四",
    "5": "五",
    "6": "六",
    "7": "七",
    "8": "八",
    "9": "九",
}

LETTER_ZH = {
    "A": "诶",
    "B": "比",
    "C": "西",
    "D": "迪",
    "E": "伊",
    "F": "艾弗",
    "G": "吉",
    "H": "艾尺",
    "I": "艾",
    "J": "杰",
    "K": "开",
    "L": "艾勒",
    "M": "艾姆",
    "N": "恩",
    "O": "哦",
    "P": "屁",
    "Q": "苦",
    "R": "阿尔",
    "S": "艾丝",
    "T": "提",
    "U": "优",
    "V": "维",
    "W": "豆贝流",
    "X": "艾克斯",
    "Y": "歪",
    "Z": "贼",
}

def load_replacements(path: str | None) -> tuple[dict[str, str], dict[str, str]]:
    """Load user-editable replacement dictionaries.

    JSON format:
      {"fixed": {"BBC": "英国广播公司"}, "names": {"Emmanuel Macron": "马克龙"}}

    - fixed: acronym/short tokens replacements
    - names: multi-word name replacements
    """
    if not path:
        return {}, {}

    try:
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        fixed = data.get("fixed", {}) or {}
        names = data.get("names", {}) or {}
        # ensure str->str
        fixed = {str(k): str(v) for k, v in fixed.items()}
        names = {str(k): str(v) for k, v in names.items()}
        return fixed, names
    except Exception:
        return {}, {}


def year_to_zh(y: str) -> str:
    return "".join(DIGIT_ZH.get(ch, ch) for ch in y)


def int_to_zh(n: int) -> str:
    """0-99 only (good enough for temps, percentages in briefings)."""
    if n < 0:
        return "负" + int_to_zh(-n)
    if n < 10:
        return DIGIT_ZH[str(n)]
    if n < 20:
        return "十" + (DIGIT_ZH[str(n % 10)] if n % 10 else "")
    if n < 100:
        tens, ones = divmod(n, 10)
        return DIGIT_ZH[str(tens)] + "十" + (DIGIT_ZH[str(ones)] if ones else "")
    return str(n)


def acronym_to_zh(s: str) -> str:
    return " ".join(LETTER_ZH.get(ch, ch) for ch in s)


def normalize(text: str, repl_fixed: dict[str, str] | None = None, repl_names: dict[str, str] | None = None) -> str:
    repl_fixed = repl_fixed or {}
    repl_names = repl_names or {}

    t = text.strip()

    # Title brackets to quotes
    t = t.replace("《", "“").replace("》", "”")

    # Normalize dashes/bullets
    t = t.replace("—", "，").replace("–", "，").replace("•", "")
    t = t.replace(";", "。")
    t = t.replace("；", "。")

    # Names (user dictionary first)
    for k, v in repl_names.items():
        if not k:
            continue
        t = re.sub(re.escape(k), v, t)

    # Fixed replacements (robust boundaries; avoid \b because Unicode)
    for k, v in repl_fixed.items():
        if not k:
            continue
        t = re.sub(rf"(?<![A-Za-z]){re.escape(k)}(?![A-Za-z])", v, t)

    # Backstop for a common pattern
    t = re.sub(r"Emmanuel\s+Macron", "马克龙", t)

    # Years 1900-2099: 2026 -> 二零二六 (standalone or before 年)
    t = re.sub(
        r"(?<!\d)((?:19|20)\d{2})(?=\s*年|[^\d])",
        lambda m: year_to_zh(m.group(1)),
        t,
    )

    # Month/day: 2月10日 -> 二月十日
    def _md(m: re.Match) -> str:
        n = int(m.group(1))
        return int_to_zh(n) + m.group(2)

    t = re.sub(r"(?<!\d)(\d{1,2})\s*(月)", _md, t)
    t = re.sub(r"(?<!\d)(\d{1,2})\s*(日|号)", _md, t)

    # Percentages: 85% / 85％ -> 百分之八十五
    def _pct(m: re.Match) -> str:
        n = int(m.group(1))
        if 0 <= n <= 99:
            return "百分之" + int_to_zh(n)
        return "百分之" + str(n)

    t = re.sub(r"(\d{1,3})\s*[%％]", _pct, t)

    # Temperatures: 9度 / 9°C
    def _deg(m: re.Match) -> str:
        n = int(m.group(1))
        return int_to_zh(n) + "度"

    t = re.sub(r"(?<!\d)(-?\d{1,2})\s*(?:°\s*C|℃|度)\b", _deg, t)

    # Common small numbers before units -> Chinese (kept conservative)
    def _unit_num(m: re.Match) -> str:
        n = int(m.group(1))
        unit = m.group(2)
        if -99 <= n <= 99:
            return int_to_zh(n) + unit
        return m.group(0)

    t = re.sub(r"(?<!\d)(-?\d{1,2})\s*(英里|公里|小时|分钟)\b", _unit_num, t)

    # Acronyms (ALL CAPS 2-6) -> letter names in Chinese
    t = re.sub(r"(?<![A-Za-z])([A-Z]{2,6})(?![A-Za-z])", lambda m: acronym_to_zh(m.group(1)), t)

    # Remaining Latin words (names/tickers) -> remove to avoid OOV
    t = re.sub(r"[A-Za-z]{2,}", "", t)

    # Tokenization quirks observed: avoid rare pinyin token for 谁
    t = t.replace("谁", "哪位")

    # Encourage breathing points
    t = re.sub(r"(第一条|第二条|第三条|第四条|第五条|第六条)\s*", r"\1\n", t)
    t = re.sub(r"([。！？.!?])\s*", r"\1\n", t)

    # Clean spaces/newlines
    t = re.sub(r"[ \t]{2,}", " ", t)
    t = re.sub(r"\n{3,}", "\n\n", t)
    return t.strip() + "\n"


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--in", dest="in_path", required=True)
    ap.add_argument("--out", dest="out_path", required=True)
    ap.add_argument(
        "--repl",
        dest="repl_path",
        default=os.path.join(os.path.dirname(__file__), "brief_replacements.json"),
        help="Path to replacement dictionary JSON (default: bin/brief_replacements.json)",
    )
    args = ap.parse_args()

    fixed, names = load_replacements(args.repl_path)
    raw = open(args.in_path, "r", encoding="utf-8", errors="ignore").read()
    out = normalize(raw, repl_fixed=fixed, repl_names=names)
    with open(args.out_path, "w", encoding="utf-8") as f:
        f.write(out)


if __name__ == "__main__":
    main()
