#!/usr/bin/env python3
"""
资治通鉴人名提取 v2
/no_think + 简洁prompt + 严格过滤
"""
import re
import json
import os
import time
from llama_cpp import Llama

# ========== 配置 ==========
MODEL_PATH = "/tmp/modelscope/qwen/Qwen3-0.6B-GGUF/Qwen3-0.6B-Q8_0.gguf"
TEXT_FILE = "/tmp/tongjian_full.txt"
OUTPUT_FILE = "/var/www/html/republic/audio/tongjian/names_extracted.json"
CHECKPOINT_FILE = "/var/www/html/republic/audio/tongjian/names_ckp.json"
CHUNK_SIZE = 180
BATCH = 5
MAX_TOKENS = 100
SAVE_EVERY = 30
# =========================

# 历史人名过滤
STOPWORDS = {
    '无', '之王', '之也', '之时', '之后', '之前', '之中', '之内', '之外', '之下', '之上',
    '于是', '然而', '何以', '莫不', '无不', '乃止', '而已', '若是', '如此', '何如',
    '无几', '后世', '先帝', '今世', '古者', '昔者', '人名', '答案', '提取', '如下',
    '这些', '哪些', '是否', '中文', '名字', '出现', '文本', '描述', '事件', '历史',
    '用于', '古代', '个文本', '个名', '出现在', '中的', '之后是', '人名出现',
    '人名提取', '些名字', '个地方', '这个地方', '人名的地方', '中文名字', '些人名',
    '人名是否', '名人', '这些人', '那些人', '有几人', '几个人', '人名列表', '著名人物',
    '哪些人', '哪些是', '以下人名', '以上人名', '人名有', '如下人名', '包括',
}

# 姓氏表（用于辅助判断）
SURNAMES = set('赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦尤许何吕施张孔曹严华金魏陶姜戚谢邹喻柏水窦章云苏潘葛奚范彭郎鲁韦昌马苗凤花方俞任袁柳酆鲍史步陶甘曲毋')
COMMON_TWO = set('无之也以于中为上下内后前岁时年月日侯王帝后子夫将军事守官治法令卒兵师国家代世代时中之之也以而于为')

def is_likely_name(word):
    if len(word) < 2 or len(word) > 5: return False
    if word in STOPWORDS: return False
    # 两字人名：首字必须是姓氏或常用名首字
    if len(word) == 2:
        if word[0] in COMMON_TWO and word[1] in COMMON_TWO: return False
        if word[0] not in SURNAMES and word[1] in COMMON_TWO: return False
    return True

print("Loading Qwen3...")
llm = Llama(model_path=MODEL_PATH, n_ctx=2048, n_threads=4, verbose=False)
print("Model loaded!")

with open(TEXT_FILE, encoding='utf-8') as f:
    text = f.read()
print(f"Text: {len(text)} chars")

# 分段
sentences = re.split(r'([。！？；])', text)
chunks = []
current = ""
for i in range(0, len(sentences)-1, 2):
    sent = sentences[i] + (sentences[i+1] if i+1 < len(sentences) else '')
    if len(current) + len(sent) <= CHUNK_SIZE:
        current += sent
    else:
        if current.strip(): chunks.append(current.strip())
        current = sent
if current.strip(): chunks.append(current.strip())
print(f"Chunks: {len(chunks)} | Batches: {len(chunks)//BATCH}")

# 断点
if os.path.exists(CHECKPOINT_FILE):
    with open(CHECKPOINT_FILE) as f:
        ckp = json.load(f)
    done = ckp.get("batches_done", 0)
    names_set = set(ckp.get("names", []))
else:
    done = 0
    names_set = set()

print(f"Resume from batch {done} ({len(names_set)} names)")

start = time.time()
total_batches = (len(chunks) + BATCH - 1) // BATCH

for b in range(done, total_batches):
    start_idx = b * BATCH
    batch_chunks = chunks[start_idx:start_idx + BATCH]
    joined = "。\n".join(batch_chunks) + "。"
    
    prompt = f"/no_think\n严格只输出人名：\n{joined}\n人名："
    
    try:
        resp = llm.create_chat_completion(
            messages=[{"role": "user", "content": prompt}],
            max_tokens=MAX_TOKENS, temperature=0.0,
        )
        result = resp['choices'][0]['message']['content'].strip()
        # 解析：支持顿号、逗号、空格分隔
        names = re.findall(r'[\u4e00-\u9fa5]{2,5}', result)
        for name in names:
            if is_likely_name(name):
                names_set.add(name)
    except Exception as e:
        pass
    
    if b % 10 == 0:
        elapsed = time.time() - start
        rate = (b - done + 1) / elapsed * 60
        remain = (total_batches - b) / rate if rate > 0 else 0
        print(f"[Batch {b}/{total_batches}] {len(names_set)} names | {rate:.1f}/min | ETA {remain:.0f}min")
    
    if b % SAVE_EVERY == 0:
        ckp = {"batches_done": b + 1, "names": list(names_set)}
        with open(CHECKPOINT_FILE, 'w') as f:
            json.dump(ckp, f, ensure_ascii=False)

# 最终
sorted_names = sorted(names_set, key=lambda x: x)
result = {"total_count": len(sorted_names), "names": sorted_names}
with open(OUTPUT_FILE, 'w') as f:
    json.dump(result, f, ensure_ascii=False, indent=2)

if os.path.exists(CHECKPOINT_FILE):
    os.remove(CHECKPOINT_FILE)

print(f"\n✅ Done! {len(sorted_names)} unique names")
print(f"First 50: {sorted_names[:50]}")
