#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Speech-friendly Chinese briefing text cleaner (from OpenClaw).

Goals:
- Reduce OOV for offline Chinese TTS (sherpa-onnx Matcha/VITS, Piper)
- Make text more conversational and easier to read aloud
- Normalize years/dates/percentages and punctuation/breathing

Usage:
  brief_text_clean.py --in in.txt --out out.txt
"""

from __future__ import annotations
import argparse, json, os, re

DIGIT_ZH = {
    "0": "零", "1": "一", "2": "二", "3": "三", "4": "四",
    "5": "五", "6": "六", "7": "七", "8": "八", "9": "九",
}

LETTER_ZH = {
    "A": "诶", "B": "比", "C": "西", "D": "迪", "E": "伊",
    "F": "艾弗", "G": "吉", "H": "艾尺", "I": "艾", "J": "杰",
    "K": "开", "L": "艾勒", "M": "艾姆", "N": "恩", "O": "哦",
    "P": "屁", "Q": "苦", "R": "阿尔", "S": "艾丝", "T": "提",
    "U": "优", "V": "维", "W": "豆贝流", "X": "艾克斯", "Y": "歪", "Z": "贼",
}

def load_replacements(path):
    if not path: return {}, {}
    try:
        with open(path) as f:
            data = json.load(f)
        fixed = data.get("fixed", {}) or {}
        names = data.get("names", {}) or {}
        fixed = {str(k): str(v) for k, v in fixed.items()}
        names = {str(k): str(v) for k, v in names.items()}
        return fixed, names
    except Exception:
        return {}, {}

def year_to_zh(y):
    return "".join(DIGIT_ZH.get(ch, ch) for ch in y)

def int_to_zh(n):
    """0-99 only (good enough for temps, percentages in briefings)."""
    if n < 0: return "负" + int_to_zh(-n)
    if n < 10: return DIGIT_ZH[str(n)]
    if n < 20: return "十" + (DIGIT_ZH[str(n % 10)] if n % 10 else "")
    if n < 100:
        tens, ones = divmod(n, 10)
        return DIGIT_ZH[str(tens)] + "十" + (DIGIT_ZH[str(ones)] if ones else "")
    return str(n)

def acronym_to_zh(s):
    return " ".join(LETTER_ZH.get(ch, ch) for ch in s)

def normalize(text, repl_fixed=None, repl_names=None):
    repl_fixed = repl_fixed or {}
    repl_names = repl_names or {}
    t = text.strip()
    t = t.replace("《", """).replace("》", """)
    t = t.replace("—", "，").replace("–", "，").replace("•", "")
    t = t.replace(";", "。").replace("；", "。")
    for k, v in repl_names.items():
        if k: t = re.sub(re.escape(k), v, t)
    for k, v in repl_fixed.items():
        if k: t = re.sub(rf"(?<![A-Za-z]){re.escape(k)}(?![A-Za-z])", v, t)
    t = re.sub(r"Emmanuel\s+Macron", "马克龙", t)
    # Years 1900-2099
    t = re.sub(r"(?<!\d)((?:19|20)\d{2})(?=\s*年|[^\d])", lambda m: year_to_zh(m.group(1)), t)
    # Month/day
    def _md(m): return int_to_zh(int(m.group(1))) + m.group(2)
    t = re.sub(r"(?<!\d)(\d{1,2})\s*(月)", _md, t)
    t = re.sub(r"(?<!\d)(\d{1,2})\s*(日|号)", _md, t)
    # Percentages
    def _pct(m):
        n = int(m.group(1))
        return f"百分之{int_to_zh(n)}" if 0 <= n <= 99 else f"百分之{n}"
    t = re.sub(r"(\d{1,3})\s*[%％]", _pct, t)
    # Temperatures
    def _deg(m): return int_to_zh(int(m.group(1))) + "度"
    t = re.sub(r"(?<!\d)(-?\d{1,2})\s*(?:°\s*C|℃|度)\b", _deg, t)
    # Small numbers + units
    def _unit_num(m):
        n = int(m.group(1))
        unit = m.group(2)
        if -99 <= n <= 99: return int_to_zh(n) + unit
        return m.group(0)
    t = re.sub(r"(?<!\d)(-?\d{1,2})\s*(英里|公里|小时|分钟)\b", _unit_num, t)
    # Acronyms
    t = re.sub(r"(?<![A-Za-z])([A-Z]{2,6})(?![A-Za-z])", lambda m: acronym_to_zh(m.group(1)), t)
    t = re.sub(r"[A-Za-z]{2,}", "", t)
    t = t.replace("谁", "哪位")
    # Breathing points
    t = re.sub(r"(第一条|第二条|第三条|第四条|第五条|第六条)\s*", r"\1\n", t)
    t = re.sub(r"([。！？.!?])\s*", r"\1\n", t)
    t = re.sub(r"[ \t]{2,}", " ", t)
    t = re.sub(r"\n{3,}", "\n\n", t)
    return t.strip() + "\n"

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in", dest="in_path", required=True)
    ap.add_argument("--out", dest="out_path", required=True)
    ap.add_argument("--repl", dest="repl_path", default=os.path.join(os.path.dirname(__file__), "brief_replacements.json"))
    args = ap.parse_args()
    fixed, names = load_replacements(args.repl_path)
    raw = open(args.in_path, encoding="utf-8", errors="ignore").read()
    out = normalize(raw, fixed, names)
    with open(args.out_path, "w", encoding="utf-8") as f: f.write(out)

if __name__ == "__main__": main()
