做 Wannier 紧束缚计算的同学都懂:先用 Wannier90 拟合出 wannier90_hr.dat,再用 WannierTools 算拓扑性质。但这两个程序的输入文件格式差异较大,手动编写复杂繁琐.

本文代码则自动通过识别wannier90.win 从而一键生成wannier_tools的输入文件wt.in

使用方法

一键生成全部 19 种任务的 wt.in 文件

python win2wt.py wannier90.win --all

只生成能带和 DOS

python win2wt.py wannier90.win -t bands,dos

交互式选择(有菜单)

python win2wt.py wannier90.win

列出所有可用任务

python win2wt.py --list

指定输出目录

python win2wt.py wannier90.win --all -o ./my_calc


### 自旋极化体系脚本会自动检测当前目录下是否有 `wannier90.up_hr.dat` 或 `wannier90.dn_hr.dat`:```bash
# 自动检测:有 up/dn 就生成两个独立目录
python win2wt.py wannier90.win --all
# → wt-up/  (Hrfile=wannier90.up_hr.dat, SOC=0)
# → wt-dn/  (Hrfile=wannier90.dn_hr.dat, SOC=0)# 手动指定通道
python win2wt.py wannier90.win --all --up      # 只要 up
python win2wt.py wannier90.win --all --dn      # 只要 dn
python win2wt.py wannier90.win --all --nospin  # 强制标准模式

配套批量测试脚本:

bash auto_test.sh -t 50 # 逐个测试,超时 50 秒自动跳过

支持的计算任务

图片
电子结构

bands 体态能带结构

bands_plane k 平面能带(可视化 Dirac 锥)

dos 态密度

fs 三维费米面

fs_plane k 平面费米面切片

拓扑性质

berry Berry 曲率分布

wcc Wannier 电荷中心(Wilson loop)

chirality Weyl 点手性(需手动填入坐标)

findnodes 自动搜索 Weyl 点 / 节点线

mirror_chern 镜面 Chern 数

表面/输运

slab_band 表面态能带

slab_ss 表面态自旋织构

ahc 反常霍尔电导率

ane 反常能斯特效应

shc 自旋霍尔电导率

ohe 轨道霍尔效应

landau Hofstadter 蝴蝶图谱

unfold 能带反折叠(超胞 → 原胞)

valley 谷自由度投影

注:本脚本仅辅助文件生成,任何计算任务执行前请自行详细检查输入文件确保文件路径及设置准确。

``#!/usr/bin/env python3

-- coding: utf-8 --

"""
win2wt: Wannier90 (.win) → WannierTools (wt.in) 自动转换脚本

输入:  wannier90.win
输出:  wt.in-{task_name}  (每个任务一个独立文件)
依赖:  Python 3.6+, numpy
用法:
   python win2wt.py wannier90.win          # 交互式选择任务
   python win2wt.py wannier90.win --all    # 生成所有任务的 wt.in 文件
   python win2wt.py wannier90.win --list   # 列出所有可用任务
   python win2wt.py wannier90.win -t bands,dos,ahc  # 生成指定任务
.win → wt.in 参数映射关系:
   .win begin unit_cell_cart  →  wt.in LATTICE
   .win begin atoms_cart      →  wt.in ATOM_POSITIONS (Cartesian → Direct)
   .win begin projections     →  wt.in PROJECTORS
   .win fermi_energy          →  wt.in E_FERMI
   .win spinors               →  wt.in SOC
   .win num_wann              →  wt.in NumOccupied
   .win begin kpoint_path     →  wt.in KPATH_BULK
   .win berry / kslice        →  wt.in AHC/BerryCurvature/KPLANE 参数
   .win mp_grid               →  默认 k 网格参考
   .win write_hr              →  wt.in Hrfile = 'wannier90_hr.dat'
"
""
import os
import sys
import re
import argparse
import numpy as np
from copy import deepcopy
import textwrap

============================================================================

WannierTools 中属于 &SYSTEM 的参数(不在 &PARAMETERS 中!)

这些参数如果放到 &PARAMETERS 会导致 "Invalid line in namelist PARAMETERS" 错误

============================================================================

SYSTEM_PARAMS = {
   "NSLAB", "NSLAB1", "NSLAB2", "NP",
   "Bmagnitude", "Btheta", "Bphi", "Bx", "By", "Bz",
   "surf_onsite",
}

============================================================================

计算任务定义

每个任务包含: 名称、描述、默认 CONTROL 开关、默认 PARAMETERS、是否需要特殊模块

============================================================================

TASK_DEFINITIONS = {
   "bands": {
       "name": "BulkBand - 体态能带结构",
       "description": "沿高对称 k-path 计算体态能带结构。最基本的计算任务。",
       "controls": {
           "BulkBand_calc": True,
           "BulkBand_plane_calc": False,
           "BulkBand_points_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "BerryPhase_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {"Nk1": 61},
       "needs": ["KPATH_BULK"],
   },
   "bands_plane": {
       "name": "BulkBand_plane - k 平面能带",
       "description": "在 k 空间平面内计算能带,用于可视化 Dirac 锥等色散特征。",
       "controls": {
           "BulkBand_calc": False,
           "BulkBand_plane_calc": True,
           "BulkBand_points_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "BerryPhase_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {"Nk1": 201, "Nk2": 201},
       "needs": ["KPLANE_BULK"],
   },
   "dos": {
       "name": "DOS - 态密度",
       "description": "计算体态密度(Density of States),用于分析能隙、范霍夫奇点等。",
       "controls": {
           "DOS_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {
           "Eta_Arc": 0.05,
           "OmegaNum": 2001,
           "OmegaMin": -10.0,
           "OmegaMax": 10.0,
           "Nk1": 101,
           "Nk2": 101,
           "Nk3": 1,
       },
       "needs": ["KCUBE_BULK"],
   },
   "fs": {
       "name": "BulkFS - 3D 费米面",
       "description": "计算三维费米面,用于分析金属/半金属的费米面拓扑。",
       "controls": {
           "BulkFS_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {"Nk1": 101, "Nk2": 101, "Nk3": 41},
       "needs": ["KCUBE_BULK"],
   },
   "fs_plane": {
       "name": "BulkFS_plane - 2D 费米面截面",
       "description": "计算 k 空间平面内的费米面截面(等高线图)。",
       "controls": {
           "BulkFS_plane_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {"Eta_Arc": 0.05, "E_arc": 0.0, "Nk1": 101, "Nk2": 101},
       "needs": ["KPLANE_BULK"],
   },
   "findnodes": {
       "name": "FindNodes - Weyl/Dirac 点搜索",
       "description": "在 3D 布里渊区中搜索能带交叉点(Weyl 点或 Dirac 点)。",
       "controls": {
           "FindNodes_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {"Nk1": 6, "Nk2": 6, "Nk3": 6, "Gap_threshold": 0.0001},
       "needs": ["KCUBE_BULK"],
   },
   "chirality": {
       "name": "WeylChirality - Weyl 点手性计算",
       "description": "计算每个 Weyl 点的手性(Chern 数 ±1),需要先运行 FindNodes。",
       "controls": {
           "WeylChirality_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {"Nk1": 60, "Nk2": 60},
       "needs": ["WEYL_CHIRALITY_PLACEHOLDER"],
   },
   "slab_band": {
       "name": "SlabBand - 表面能带结构",
       "description": "计算半无限 slab 体系的表面能带结构。",
       "controls": {
           "SlabBand_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {"NSLAB": 40, "NP": 1, "Nk1": 101},
       "needs": ["KPATH_SLAB"],
   },
   "slab_ss": {
       "name": "SlabSS - 表面态谱函数",
       "description": "计算半无限 slab 的表面态谱函数,可视化拓扑表面态和费米弧。",
       "controls": {
           "SlabSS_calc": True,
           "SlabArc_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {
           "Eta_Arc": 0.001,
           "E_arc": 0.0,
           "OmegaNum": 400,
           "OmegaMin": -1.0,
           "OmegaMax": 1.0,
           "Nk1": 201,
           "Nk2": 201,
           "NP": 2,
       },
       "needs": ["KPLANE_SLAB", "KPATH_SLAB"],
   },
   "wcc": {
       "name": "Wanniercenter - Wilson Loop / WCC",
       "description": "计算 Wannier 电荷中心演化(Wilson loop),用于获取 Chern 数或 Z2 不变量。",
       "controls": {
           "Wanniercenter_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {"Nk1": 60, "Nk2": 60},
       "needs": ["KPLANE_BULK"],
   },
   "berry": {
       "name": "BerryCurvature - Berry 曲率分布",
       "description": "在 k 平面内计算 Berry 曲率分布,用于分析动量空间的拓扑性质。",
       "controls": {
           "BerryCurvature_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {"Nk1": 101, "Nk2": 101},
       "needs": ["KPLANE_BULK"],
   },
   "ahc": {
       "name": "AHC - 反常霍尔电导率",
       "description": "计算反常霍尔电导率 σ_xy(E),需要高 k 点密度。仅适用于磁性/SOC 体系。",
       "controls": {
           "AHC_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {
           "Eta_Arc": 0.01,
           "OmegaNum": 401,
           "OmegaMin": -1.0,
           "OmegaMax": 1.0,
           "Nk1": 101,
           "Nk2": 101,
           "Nk3": 101,
       },
       "needs": ["KCUBE_BULK"],
   },
   "ane": {
       "name": "ANE - 反常能斯特效应",
       "description": "计算反常能斯特电导率随温度的变化,需要高 k 点密度。",
       "controls": {
           "ANE_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {
           "Eta_Arc": 0.01,
           "OmegaNum": 401,
           "OmegaMin": -1.0,
           "OmegaMax": 1.0,
           "Nk1": 101,
           "Nk2": 101,
           "Nk3": 101,
           "Tmin": 10,
           "Tmax": 310,
           "NumT": 31,
           "Bmagnitude": 1.0,
           "Btheta": 0.0,
           "Bphi": 0.0,
       },
       "needs": ["KCUBE_BULK"],
   },
   "shc": {
       "name": "SHC - 自旋霍尔电导率",
       "description": "计算自旋霍尔电导率,需要 SOC 和高 k 点密度。",
       "controls": {
           "SHC_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {
           "Eta_Arc": 0.05,
           "OmegaNum": 1001,
           "OmegaMin": -10.0,
           "OmegaMax": 10.0,
           "Nk1": 101,
           "Nk2": 101,
           "Nk3": 101,
       },
       "needs": ["KCUBE_BULK"],
   },
   "ohe": {
       "name": "Boltz_OHE - 轨道霍尔效应",
       "description": "计算轨道霍尔效应和磁阻,支持不同磁场方向。",
       "controls": {
           "Boltz_OHE_calc": True,
           "Symmetry_Import_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {
           "OmegaNum": 3,
           "OmegaMin": -0.01,
           "OmegaMax": 0.01,
           "EF_broadening": 0.06,
           "Nk1": 41,
           "Nk2": 41,
           "Nk3": 41,
           "BTauNum": 100,
           "BTauMax": 40.0,
           "Tmin": 30,
           "Tmax": 330,
           "NumT": 11,
           "Nslice_BTau_Max": 20000,
       },
       "needs": ["KCUBE_BULK", "SELECTEDBANDS"],
   },
   "landau": {
       "name": "LandauLevel - 朗道能级 / Hofstadter 蝴蝶",
       "description": "在磁场下计算朗道能级谱(Hofstadter 蝴蝶),需要定义磁超胞。",
       "controls": {
           "LandauLevel_B_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {
           "NSLAB": 400,
           "Eta_Arc": 0.1,
           "OmegaNum": 1001,
           "OmegaMin": -8.0,
           "OmegaMax": 12.0,
           "Nk1": 11,
           "Magp": 100,
           "NumRandomConfs": 10,
           "Bmagnitude": 10.0,
           "Btheta": 0.0,
           "Bphi": 0.0,
       },
       "needs": ["KPATH_BULK"],
   },
   "mirror_chern": {
       "name": "MirrorChern - 镜面 Chern 数",
       "description": "计算镜面对称保护的 Chern 数,用于镜面拓扑绝缘体。",
       "controls": {
           "MirrorChern_calc": True,
           "BulkBand_calc": True,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {"Nk1": 81, "Nk2": 201},
       "needs": ["KPATH_BULK", "KPLANE_BULK"],
   },
   "unfold": {
       "name": "BulkBand_Unfold - 能带展开",
       "description": "将超胞能带展开回原胞布里渊区,需要定义 LATTICE_UNFOLD 等模块。",
       "controls": {
           "BulkBand_Unfold_line_calc": True,
           "BulkBand_calc": False,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_plane_calc": False,
           "valley_projection_calc": False,
       },
       "params": {
           "Eta_Arc": 0.01,
           "OmegaNum_unfold": 600,
           "OmegaMin": -10.0,
           "OmegaMax": 10.0,
           "Nk1": 101,
       },
       "needs": ["KPATH_BULK", "UNFOLD_BLOCKS_PLACEHOLDER"],
   },
   "valley": {
       "name": "Valley - 谷投影能带",
       "description": "计算谷投影能带,区分 K/K' 谷的贡献。适用于六方 2D 材料。",
       "controls": {
           "valley_projection_calc": True,
           "BulkBand_calc": True,
           "BulkBand_plane_calc": False,
           "DOS_calc": False,
           "BulkFS_calc": False,
           "BulkFS_plane_calc": False,
           "FindNodes_calc": False,
           "SlabBand_calc": False,
           "SlabSS_calc": False,
           "SlabArc_calc": False,
           "Wanniercenter_calc": False,
           "BerryCurvature_calc": False,
           "AHC_calc": False,
           "ANE_calc": False,
           "SHC_calc": False,
           "Boltz_OHE_calc": False,
           "LandauLevel_B_calc": False,
           "LandauLevel_k_calc": False,
           "MirrorChern_calc": False,
           "WeylChirality_calc": False,
           "BulkBand_Unfold_line_calc": False,
           "BulkBand_Unfold_plane_calc": False,
       },
       "params": {"Nk1": 101},
       "needs": ["KPATH_BULK"],
   },
}

任务分组(用于菜单显示)

TASK_GROUPS = {
   "能带与态密度": ["bands", "bands_plane", "dos", "valley"],
   "费米面": ["fs", "fs_plane"],
   "拓扑不变量": ["wcc", "berry", "mirror_chern"],
   "Weyl 半金属": ["findnodes", "chirality"],
   "表面态": ["slab_band", "slab_ss"],
   "输运性质": ["ahc", "ane", "shc", "ohe"],
   "磁场效应": ["landau"],
   "超胞分析": ["unfold"],
}

============================================================================

.win 文件解析器

============================================================================

class WinParser:
   """解析 Wannier90 (.win) 输入文件"""
   def init(self, filepath):
       self.filepath = filepath
       self.raw_content = ""
       self.data = {}
       self._parse()
   def _parse(self):
       """主解析入口"""
       with open(self.filepath, "r") as f:
           self.raw_content = f.read()
       lines = self.raw_content.split("\n")
       # 解析单行键值对
       self._parse_simple_vars(lines)
       # 解析 begin/end 块
       self._parse_blocks(self.raw_content)
       # 后处理
       self._post_process()
   def _parse_simple_vars(self, lines):
       """解析简单变量赋值(非 begin/end 块)"""
       simple_vars = {
           "num_wann": int,
           "num_bands": int,
           "fermi_energy": float,
           "spinors": lambda x: x.lower().strip() in [".true.", "true", "t"],
           "write_hr": lambda x: x.lower().strip() in [".true.", "true", "t"],
           "mp_grid": str,
           "berry": lambda x: x.lower().strip() in [".true.", "true", "t"],
           "berry_task": str,
           "berry_kmesh": str,
           "kpath": lambda x: x.lower().strip() in [".true.", "true", "t"],
           "kpath_task": str,
           "kpath_num_points": int,
           "kslice": lambda x: x.lower().strip() in [".true.", "true", "t"],
           "kslice_task": str,
           "kslice_2dkmesh": int,
           "kslice_corner": str,
           "kslice_b1": str,
           "kslice_b2": str,
       }
       in_block = False
       for line in lines:
           stripped = line.split("!")[0].split("#")[0].strip()
           if not stripped:
               continue
           if stripped.lower().startswith("begin"):
               in_block = True
           if stripped.lower().startswith("end"):
               in_block = False
               continue
           if in_block:
               continue
           # 尝试匹配 key = value
           for key, converter in simple_vars.items():
               pattern = rf"^\s{key}\s[=:]\s(.+)$"
               match = re.match(pattern, stripped, re.IGNORECASE)
               if match:
                   val_str = match.group(1).strip()
                   try:
                       self.data[key] = converter(val_str)
                   except (ValueError, TypeError):
                       self.data[key] = val_str
                   break
   def _parse_blocks(self, content):
       """解析 begin/end 块"""
       blocks = {
           "unit_cell_cart": "lattice_vectors",
           "atoms_cart": "atom_positions",
           "projections": "projections",
           "kpoint_path": "kpoint_path",
           "kpoints": "kpoints",
       }
       for block_name, data_key in blocks.items():
           pattern = rf"begin\s+{block_name}\s
\n(.?)\n\send\s+{block_name}"
           match = re.search(pattern, content, re.IGNORECASE | re.DOTALL)
           if match:
               self.data[data_key] = match.group(1).strip()
   def _post_process(self):
       """后处理:补全缺失值,计算衍生参数"""
       # 默认值
       self.data.setdefault("spinors", False)
       self.data.setdefault("write_hr", True)
       self.data.setdefault("fermi_energy", 0.0)
       self.data.setdefault("num_wann", 1)
       self.data.setdefault("num_bands", 2)
   def get_lattice_vectors(self):
       """获取晶格矢量(Å),返回 3×3 矩阵"""
       raw = self.data.get("lattice_vectors", "")
       if not raw:
           return None
       # 处理 Bohr 单位
       lines = raw.strip().split("\n")
       unit = "Angstrom"
       vec_lines = []
       for line in lines:
           s = line.split("!")[0].split("#")[0].strip().lower()
           if s == "bohr":
               unit = "Bohr"
           elif s == "angstrom" or s == "ang":
               unit = "Angstrom"
           elif s:
               vec_lines.append(line)
       vectors = []
       for line in vec_lines[:3]:
           nums = re.findall(r"[-+]?\d.?\d+(?:[eEdD][-+]?\d+)?", line)
           if len(nums) >= 3:
               v = [float(n.replace("d", "e").replace("D", "e")) for n in nums[:3]]
               vectors.append(v)
       if len(vectors) != 3:
           return None
       # Bohr → Angstrom 转换
       if unit.lower() == "bohr":
           vectors = [[x * 0.529177210903 for x in v] for v in vectors]
       return vectors
   def get_atom_positions_cart(self):
       """获取 Cartesian 原子坐标,返回 [(symbol, x, y, z), ...]"""
       raw = self.data.get("atom_positions", "")
       if not raw:
           return []
       lines = raw.strip().split("\n")
       unit = "Angstrom"
       atom_lines = []
       for line in lines:
           s = line.split("!")[0].split("#")[0].strip().lower()
           if s == "bohr":
               unit = "Bohr"
           elif s == "angstrom" or s == "ang":
               unit = "Angstrom"
           elif s:
               atom_lines.append(line)
       atoms = []
       for line in atom_lines:
           parts = line.split()
           if len(parts) >= 4:
               # 格式: Symbol x y z
               symbol = parts[0]
               nums = re.findall(r"[-+]?\d
.?\d+(?:[eEdD][-+]?\d+)?", line)
               if len(nums) >= 3:
                   x = float(nums[0].replace("d", "e").replace("D", "e"))
                   y = float(nums[1].replace("d", "e").replace("D", "e"))
                   z = float(nums[2].replace("d", "e").replace("D", "e"))
                   if unit.lower() == "bohr":
                       x, y, z = x * 0.529177210903, y * 0.529177210903, z * 0.529177210903
                   atoms.append((symbol, x, y, z))
       return atoms
   def get_projections(self):
       """解析投影,返回投影列表。
       对于中心投影(c=),自动匹配最近原子坐标确定元素符号。
       "
""
       raw = self.data.get("projections", "")
       if not raw:
           return []
       lines = raw.strip().split("\n")
       projections = []
       atoms_cart = self.get_atom_positions_cart()  # [(symbol, x, y, z), ...]
       for line in lines:
           s = line.split("!")[0].split("#")[0].strip()
           comment = ""
           # 提取注释,用于识别原子类型(如 "! V1", "! Nb3", "! S9")
           if "!" in line:
               comment = line.split("!", 1)[1].strip()
           elif "#" in line:
               comment = line.split("#", 1)[1].strip()
           if not s:
               continue
           # 格式1: "c=x,y,z:l=0;l=2" (中心投影)
           # 格式2: "Symbol : orbital1; orbital2" (原子投影)
           if s.startswith("c="):
               # 中心投影:解析坐标和轨道
               # 格式: c=x,y,z:l=0;l=2
               c_match = re.match(r"c\s=\s([-+]?\d.?\d+(?:[eEdD][-+]?\d+)?)\s,\s([-+]?\d.?\d+(?:[eEdD][-+]?\d+)?)\s,\s([-+]?\d.?\d+(?:[eEdD][-+]?\d+)?)\s:\s(.+)", s)
               if not c_match:
                   # 尝试无空格版本: c=x,y,z : l=0;l=2
                   c_match = re.match(r"c\s
=\s(.+?)\s:\s*(.+)", s)
                   if c_match:
                       coords_str = c_match.group(1)
                       orb_str = c_match.group(2)
                       coord_parts = [c.strip() for c in coords_str.split(",")]
                       if len(coord_parts) != 3:
                           projections.append({"type": "center", "raw": s, "orbitals": [], "symbol": "?"})
                           continue
                       cx = float(coord_parts[0].replace("d", "e").replace("D", "e"))
                       cy = float(coord_parts[1].replace("d", "e").replace("D", "e"))
                       cz = float(coord_parts[2].replace("d", "e").replace("D", "e"))
                   else:
                       projections.append({"type": "center", "raw": s, "orbitals": [], "symbol": "?"})
                       continue
               else:
                   cx = float(c_match.group(1).replace("d", "e").replace("D", "e"))
                   cy = float(c_match.group(2).replace("d", "e").replace("D", "e"))
                   cz = float(c_match.group(3).replace("d", "e").replace("D", "e"))
                   orb_str = c_match.group(4)
               # 解析轨道
               orbs = []
               if "l=" in orb_str.lower():
                   l_vals = re.findall(r"l=(\d+)", orb_str.lower())
                   for lv in l_vals:
                       orbs.extend(self._l_to_orbitals(int(lv)))
               else:
                   raw_orbs = [o.strip() for o in orb_str.split(";")]
                   for ro in raw_orbs:
                       if ro:
                           orbs.extend(self._expand_orbital(ro.lower()))
               # 匹配最近原子坐标确定元素符号
               symbol = "?"
               if atoms_cart:
                   min_dist = float("inf")
                   for atom_sym, ax, ay, az in atoms_cart:
                       d = (cx - ax)2 + (cy - ay)2 + (cz - az)**2
                       if d < min_dist:
                           min_dist = d
                           symbol = atom_sym
                   # 如果距离 > 0.5 Å 则可能匹配失败,保留 ?
                   if min_dist > 0.25:
                       symbol = "?"
               # 如果注释中包含元素符号,优先使用注释
               if comment:
                   # 注释格式如 "V1", "Nb3", "S9" — 提取字母部分
                   atom_symbol_from_comment = re.match(r"([A-Za-z]+)", comment)
                   if atom_symbol_from_comment:
                       symbol = atom_symbol_from_comment.group(1)
               projections.append({
                   "type": "center",
                   "raw": s,
                   "orbitals": orbs,
                   "symbol": symbol,
                   "comment": comment,
               })
           else:
               # 原子投影
               parts = s.split("😊
               if len(parts) >= 2:
                   symbol = parts[0].strip()
                   orb_str = parts[1].strip()
                   orbs = []
                   if "l=" in orb_str.lower():
                       l_vals = re.findall(r"l=(\d+)", orb_str.lower())
                       for lv in l_vals:
                           orbs.extend(self._l_to_orbitals(int(lv)))
                   else:
                       raw_orbs = [o.strip() for o in orb_str.split(";")]
                       for ro in raw_orbs:
                           if ro:
                               orbs.extend(self._expand_orbital(ro.lower()))
                   projections.append({"type": "atom", "symbol": symbol, "orbitals": orbs})
       return projections
   def _l_to_orbitals(self, l_val):
       """将角量子数 l 转换为轨道名称"""
       l_map = {
           0: ["s"],
           1: ["pz", "px", "py"],
           2: ["dz2", "dxz", "dyz", "dx2-y2", "dxy"],
           3: ["fz3", "fxz2", "fyz2", "fxyz", "fz(x2-y2)", "fx(x2-3y2)", "fy(3x2-y2)"],
       }
       return l_map.get(l_val, [])
   def _expand_orbital(self, orb_name):
       """展开简写轨道名"""
       if orb_name == "s":
           return ["s"]
       elif orb_name == "p":
           return ["pz", "px", "py"]
       elif orb_name == "d":
           return ["dz2", "dxz", "dyz", "dx2-y2", "dxy"]
       elif orb_name == "f":
           return ["fz3", "fxz2", "fyz2", "fxyz", "fz(x2-y2)", "fx(x2-3y2)", "fy(3x2-y2)"]
       elif orb_name == "sp":
           return ["s", "pz", "px", "py"]
       elif orb_name == "spd":
           return ["s", "pz", "px", "py", "dz2", "dxz", "dyz", "dx2-y2", "dxy"]
       else:
           return [orb_name]
   def get_kpoint_path(self):
       """获取 k-path 定义,返回 [(label1, k1, label2, k2), ...]"""
       raw = self.data.get("kpoint_path", "")
       if not raw:
           return []
       lines = raw.strip().split("\n")
       paths = []
       for line in lines:
           s = line.split("!")[0].split("#")[0].strip()
           if not s:
               continue
           # 格式: "Label1 kx1 ky1 kz1  Label2 kx2 ky2 kz2"
           parts = s.split()
           if len(parts) >= 8:
               label1 = parts[0]
               k1 = [float(x) for x in parts[1:4]]
               label2 = parts[4]
               k2 = [float(x) for x in parts[5:8]]
               paths.append((label1, k1, label2, k2))
       return paths
   def get_mp_grid(self):
       """获取 Monkhorst-Pack k 网格"""
       raw = self.data.get("mp_grid", "")
       if not raw:
           return (1, 1, 1)
       parts = raw.split()
       if len(parts) >= 3:
           return (int(parts[0]), int(parts[1]), int(parts[2]))
       return (1, 1, 1)
   def get_fermi_energy(self):
       """获取费米能级"""
       return self.data.get("fermi_energy", 0.0)
   def get_soc(self):
       """获取 SOC 标志"""
       return 1 if self.data.get("spinors", False) else 0
   def get_num_wann(self):
       """获取 Wannier 函数数量"""
       return self.data.get("num_wann", 1)
   def get_berry_info(self):
       """获取 Berry 曲率相关参数"""
       return {
           "berry": self.data.get("berry", False),
           "berry_task": self.data.get("berry_task", ""),
           "berry_kmesh": self.data.get("berry_kmesh", ""),
       }
   def get_kslice_info(self):
       """获取 kslice 参数"""
       return {
           "kslice": self.data.get("kslice", False),
           "kslice_task": self.data.get("kslice_task", ""),
           "kslice_2dkmesh": self.data.get("kslice_2dkmesh", 50),
           "kslice_corner": self.data.get("kslice_corner", "0.0 0.0 0.0"),
           "kslice_b1": self.data.get("kslice_b1", "1.0 0.0 0.0"),
           "kslice_b2": self.data.get("kslice_b2", "0.0 1.0 0.0"),
       }

============================================================================

wt.in 生成器

============================================================================

class WtInGenerator:
   """根据解析的 .win 数据生成 wt.in 文件
   支持自旋极化体系:若检测到 wannier90.up/dn_hr.dat,自动设置
   spin_channel='up' 或 'dn',并调整 Hrfile、SOC、NumOccupied。
   "
""
   def init(self, win_parser, spin_channel=None, hrfile_path=None):
       """
       Parameters
       ----------
       win_parser : WinParser
           已解析的 .win 文件解析器
       spin_channel : str or None
           None:  非自旋极化(默认),SOC 从 .win 读取
           'up':  自旋上通道,SOC = 0,NumOccupied = num_wann
           'dn':  自旋下通道,SOC = 0,NumOccupied = num_wann
       hrfile_path : str or None
           Hrfile 的完整路径。若不提供,自动推导为 wannier90.{spin_channel}_hr.dat
           或 wannier90_hr.dat。提供路径可确保 wt.in 从任意工作目录都能找到 hr.dat。
       "
""
       self.win = win_parser
       self.spin_channel = spin_channel
       self._hrfile_path = hrfile_path  # 保存原始路径,用于 _gen_tb_file
       self.lattice = win_parser.get_lattice_vectors()
       self.atoms_cart = win_parser.get_atom_positions_cart()
       self.projections = win_parser.get_projections()
       self.kpath = win_parser.get_kpoint_path()
       self.fermi = win_parser.get_fermi_energy()
       self.num_wann = win_parser.get_num_wann()
       self.mp_grid = win_parser.get_mp_grid()
       self.berry = win_parser.get_berry_info()
       self.kslice = win_parser.get_kslice_info()
       # SOC 检测:自旋极化体系(up/dn 分开)SOC=0,否则从 .win 读取
       if spin_channel:
           self.soc = 0
       else:
           self.soc = win_parser.get_soc()
       # 计算原子 Direct 坐标
       self.atoms_direct = self._cart_to_direct()
       # 计算默认 NumOccupied
       # 自旋极化体系(up/dn):每个自旋通道独立,NumOccupied = num_wann
       # SOC 体系(spinors=true):num_wann = 2 * 轨道数,NumOccupied = num_wann // 2
       # 非 SOC 非极化体系:NumOccupied = num_wann
       self.num_occupied = self._estimate_num_occupied()
   def _cart_to_direct(self):
       """将 Cartesian 坐标转换为 Direct(分数)坐标"""
       if not self.lattice or not self.atoms_cart:
           return []
       lat = np.array(self.lattice, dtype=float)
       try:
           inv_lat = np.linalg.inv(lat.T)  # 注意:晶格矢量是行向量
       except np.linalg.LinAlgError:
           return []
       result = []
       for symbol, x, y, z in self.atoms_cart:
           cart = np.array([x, y, z], dtype=float)
           direct = inv_lat @ cart
           result.append((symbol, direct[0], direct[1], direct[2]))
       return result
   def _estimate_num_occupied(self):
       """估算占据能带数。
       物理依据:
       - 自旋极化体系(spin_channel='up'/'dn'): 每个自旋通道的 Wannier
         函数全部对应该通道的能带,NumOccupied = num_wann(全部占据或
         由 .win 中的 num_bands 和 dis_froz_max 决定)
       - SOC 体系(spinors=true): Wannier 函数的自旋自由度已编码在轨道中,
         num_wann = 2 * 轨道数,占据数约为 num_wann // 2
       - 非 SOC 非极化体系: NumOccupied = num_wann
       注意:此估算对含半金属/磁性金属可能是近似值,建议用户根据 DFT
       价带数手动核对。
       "
""
       if self.spin_channel:
           # 自旋极化体系:所有 num_wann 个 Wannier 函数都对应自旋通道的能带
           return self.num_wann
       elif self.soc == 1:
           # SOC 情况下,num_wann 包含自旋,占据数约为 num_wann/2
           return self.num_wann // 2
       else:
           return self.num_wann
   def _get_default_surface(self):
       """获取默认 SURFACE 矩阵"""
       return " 1  0  0\n 0  1  0\n 0  0  1"
   def _get_default_kplane(self):
       """获取默认 KPLANE_BULK"""
       kslice = self.win.get_kslice_info()
       if kslice["kslice"]:
           corner = kslice["kslice_corner"]
           b1 = kslice["kslice_b1"]
           b2 = kslice["kslice_b2"]
           return (
               f"  {corner}   ! Original point for 3D k plane\n"
               f"  {b1}   ! The first vector to define 3d k space plane\n"
               f"  {b2}   ! The second vector to define 3d k space plane"
           )
       return (
           " 0.00  0.00  0.00   ! Original point for 3D k plane\n"
           " 1.00  0.00  0.00   ! The first vector to define 3d k space plane\n"
           " 0.00  1.00  0.00   ! The second vector to define 3d k space plane"
       )
   def _get_default_kcube(self):
       """获取默认 KCUBE_BULK"""
       return (
           " 0.00  0.00  0.00   ! Original point for 3D k plane\n"
           " 1.00  0.00  0.00   ! The first vector to define 3d k space plane\n"
           " 0.00  1.00  0.00   ! The second vector to define 3d k space plane\n"
           " 0.00  0.00  1.00   ! The third vector to define 3d k cube"
       )
   def generate(self, task_key, output_path=None, overrides=None):
       """生成指定任务的 wt.in 文件。
       自动将 NSLAB, NP, Bmagnitude 等属于 &SYSTEM 的参数
       分离到 &SYSTEM namelist,避免 "
Invalid line in namelist PARAMETERS
" 错误。
       "
""
       if task_key not in TASK_DEFINITIONS:
           raise ValueError(f"未知任务: {task_key}。可用任务: {list(TASK_DEFINITIONS.keys())}")
       # 记住输出目录,供 _gen_tb_file 计算 Hrfile 相对路径
       if output_path:
           self._output_dir = os.path.dirname(os.path.abspath(output_path)) or "."
       task = TASK_DEFINITIONS[task_key]
       controls = deepcopy(task["controls"])
       all_params = deepcopy(task["params"])
       # 应用覆盖参数
       if overrides:
           all_params.update(overrides)
       # 分离系统级参数和数值参数
       # 系统级参数属于 &SYSTEM(NSLAB, NP, Bmagnitude 等)
       # 数值参数属于 &PARAMETERS(Nk1, Nk2, Eta_Arc 等)
       system_params = {}
       params = {}
       for key, val in all_params.items():
           if key in SYSTEM_PARAMS:
               system_params[key] = val
           else:
               params[key] = val
       lines = []
       lines.append(self._gen_tb_file())
       lines.append("")
       lines.append(self._gen_control(controls))
       lines.append("")
       lines.append(self._gen_system(system_params))
       lines.append("")
       lines.append(self._gen_parameters(params))
       lines.append("")
       lines.append(self._gen_lattice())
       lines.append("")
       lines.append(self._gen_atom_positions())
       lines.append("")
       lines.append(self._gen_projectors())
       lines.append("")
       lines.append(self._gen_surface())
       lines.append("")
       # 按需添加模块
       for need in task.get("needs", []):
           module = self._gen_module(need)
           if module:
               lines.append(module)
               lines.append("")
       content = "\n".join(lines)
       if output_path:
           with open(output_path, "w") as f:
               f.write(content)
           print(f"  ✓ 已生成: {output_path}")
       return content
   def _gen_tb_file(self):
       """生成 &TB_FILE namelist。
       Hrfile 使用从输出目录到 hr.dat 的相对路径(如 '../wannier90_hr.dat'),
       避免硬编码绝对路径导致换机器运行时报 "
no HmnR input
"。
       "
""
       if self._hrfile_path:
           hr_abs = os.path.abspath(self._hrfile_path)
           out_dir = getattr(self, '_output_dir', os.getcwd())
           hrfile = os.path.relpath(hr_abs, out_dir)
       elif self.spin_channel:
           hrfile = f"wannier90.{self.spin_channel}_hr.dat"
       else:
           hrfile = "wannier90_hr.dat"
       return f"&TB_FILE\nHrfile = '{hrfile}'\n/"
   def _gen_control(self, controls):
       lines = ["&CONTROL"]
       for key, val in controls.items():
           type_char = "T" if val else "F"
           lines.append(f"  {key:<30s} = {type_char}")
       lines.append("/")
       return "\n".join(lines)
   def _gen_system(self, system_params=None):
       """生成 &SYSTEM namelist。
       SOC, E_FERMI, NumOccupied 为必填项。
       system_params 包含属于 &SYSTEM 的可选参数(NSLAB, NP, Bmagnitude 等)。
       "
""
       lines = ["&SYSTEM"]
       lines.append(f"  SOC = {self.soc}")
       lines.append(f"  E_FERMI = {self.fermi}")
       lines.append(f"  NumOccupied = {self.num_occupied}")
       if system_params:
           for key in sorted(system_params.keys()):
               val = system_params[key]
               if isinstance(val, float):
                   lines.append(f"  {key} = {val}")
               elif isinstance(val, bool):
                   lines.append(f"  {key} = {'T' if val else 'F'}")
               else:
                   lines.append(f"  {key} = {val}")
       lines.append("/")
       return "\n".join(lines)
   def _gen_parameters(self, params):
       lines = ["&PARAMETERS"]
       for key, val in params.items():
           if isinstance(val, float):
               lines.append(f"  {key} = {val}")
           elif isinstance(val, bool):
               lines.append(f"  {key} = {'T' if val else 'F'}")
           else:
               lines.append(f"  {key} = {val}")
       lines.append("/")
       return "\n".join(lines)
   def _gen_lattice(self):
       if not self.lattice:
           return "LATTICE\nAngstrom\n# [请手动填写晶格矢量]\n1.0 0.0 0.0\n0.0 1.0 0.0\n0.0 0.0 1.0"
       lines = ["LATTICE", "Angstrom"]
       for v in self.lattice:
           lines.append(f"  {v[0]:12.6f} {v[1]:12.6f} {v[2]:12.6f}")
       return "\n".join(lines)
   def _gen_atom_positions(self):
       if not self.atoms_direct:
           if not self.atoms_cart:
               return "ATOM_POSITIONS\n1\nDirect\n# [请手动填写原子坐标]\nX 0.0 0.0 0.0"
           # 有 Cartesian 但没有 Direct(转换失败),使用 Cartesian
           lines = ["ATOM_POSITIONS", f"{len(self.atoms_cart)}", "Cartesian"]
           for symbol, x, y, z in self.atoms_cart:
               lines.append(f"  {symbol:<4s} {x:12.6f} {y:12.6f} {z:12.6f}  0.0  0.0  0.0")
           return "\n".join(lines)
       lines = ["ATOM_POSITIONS", f"{len(self.atoms_direct)}", "Direct"]
       for symbol, x, y, z in self.atoms_direct:
           lines.append(f"  {symbol:<4s} {x:12.6f} {y:12.6f} {z:12.6f}  0.0  0.0  0.0")
       return "\n".join(lines)
   def _gen_projectors(self):
       """生成 PROJECTORS 块。
       自动处理中心投影(c=)和原子投影(Symbol:orbitals)两种格式。
       中心投影:每个 c= 已对应一个原子,直接使用。
       原子投影:按符号类型定义,需展开为每个原子的投影。
       "
""
       if not self.projections:
           return "PROJECTORS\n1  ! number of projectors\nX s"
       proj_counts = []
       proj_lines = []
       # 判断是否为原子类型投影(需要展开到每个原子)
       is_atom_type = all(p.get("type") == "atom" for p in self.projections)
       if is_atom_type:
           # 原子类型投影:symbol → orbitals 映射
           proj_map = {}
           for p in self.projections:
               proj_map[p["symbol"]] = p.get("orbitals", [])
           # 遍历原子列表,为每个原子匹配投影
           for symbol, x, y, z in self.atoms_cart:
               if symbol in proj_map:
                   orbs = proj_map[symbol]
               else:
                   # 尝试大小写不敏感匹配
                   matched = False
                   for key in proj_map:
                       if key.lower() == symbol.lower():
                           orbs = proj_map[key]
                           matched = True
                           break
                   if not matched:
                       orbs = []
               norbs = len(orbs)
               proj_counts.append(norbs)
               if norbs > 0:
                   proj_lines.append(f"  {symbol:<4s} {' '.join(orbs)}")
               else:
                   proj_lines.append(f"# {symbol}  ! [未找到投影定义]")
       else:
           # 中心投影:每个 c= 已对应一个原子
           for p in self.projections:
               norbs = len(p.get("orbitals", []))
               proj_counts.append(norbs)
               symbol = p.get("symbol", "?")
               if norbs > 0:
                   proj_lines.append(f"  {symbol:<4s} {' '.join(p['orbitals'])}")
               else:
                   proj_lines.append(f"# {p.get('raw', '?')}  ! [未解析到轨道]")
       count_str = " ".join(str(c) for c in proj_counts)
       lines = [f"PROJECTORS", f"  {count_str}  ! number of projectors per atom"]
       lines.extend(proj_lines)
       return "\n".join(lines)
   def _gen_surface(self):
       return "SURFACE            ! See doc for details\n" + self._get_default_surface()
   def _gen_module(self, need):
       """生成特定模块"""
       if need == "KPATH_BULK":
           return self._gen_kpath_bulk()
       elif need == "KPLANE_BULK":
           return self._gen_kplane_bulk()
       elif need == "KCUBE_BULK":
           return self._gen_kcube_bulk()
       elif need == "KPATH_SLAB":
           return self._gen_kpath_slab()
       elif need == "KPLANE_SLAB":
           return self._gen_kplane_slab()
       elif need == "SELECTEDBANDS":
           return self._gen_selectedbands()
       elif need == "WEYL_CHIRALITY_PLACEHOLDER":
           return self._gen_weyl_chirality_placeholder()
       elif need == "UNFOLD_BLOCKS_PLACEHOLDER":
           return self._gen_unfold_placeholder()
       return ""
   def _gen_kpath_bulk(self):
       """从 .win 的 kpoint_path 生成 KPATH_BULK"""
       if not self.kpath:
           return (
               "KPATH_BULK            ! k point path\n"
               "1              ! number of k line\n"
               "  G 0.0 0.0 0.0  X 0.5 0.0 0.0  ! [自动生成,请手动检查]"
           )
       lines = ["KPATH_BULK            ! k point path"]
       lines.append(f"  {len(self.kpath)}              ! number of k line only for bulk band")
       for label1, k1, label2, k2 in self.kpath:
           lines.append(f"  {label1:<4s} {k1[0]:9.5f} {k1[1]:9.5f} {k1[2]:9.5f}  "
                        f"{label2:<4s} {k2[0]:9.5f} {k2[1]:9.5f} {k2[2]:9.5f}")
       return "\n".join(lines)
   def _gen_kplane_bulk(self):
       return "KPLANE_BULK\n" + self._get_default_kplane()
   def _gen_kcube_bulk(self):
       return "KCUBE_BULK\n" + self._get_default_kcube()
   def _gen_kpath_slab(self):
       """从 bulk kpath 推导 slab kpath(投影到 2D)"""
       if self.kpath:
           # 从 bulk kpath 提取 2D 投影(取前两个分量)
           lines = ["KPATH_SLAB"]
           lines.append(f"  {len(self.kpath)}        ! number of k line for 2D case")
           for label1, k1, label2, k2 in self.kpath:
               lines.append(f"  {label1:<4s} {k1[0]:9.5f} {k1[1]:9.5f}  "
                            f"{label2:<4s} {k2[0]:9.5f} {k2[1]:9.5f}")
           return "\n".join(lines)
       return (
           "KPATH_SLAB\n"
           "2        ! number of k line for 2D case\n"
           "  X -0.5 0.0  G 0.0 0.0  ! k path for 2D case\n"
           "  G 0.0 0.0  X 0.5 0.0   ! [自动生成,请手动检查]"
       )
   def _gen_kplane_slab(self):
       """从 kslice 或 bulk kpath 推导 slab k-plane"""
       # 优先使用 kslice 信息
       kslice = self.win.get_kslice_info()
       if kslice["kslice"]:
           corner = kslice["kslice_corner"]
           b1 = kslice["kslice_b1"]
           b2 = kslice["kslice_b2"]
           # 取前两个分量作为 2D plane
           corner_parts = corner.split()
           b1_parts = b1.split()
           b2_parts = b2.split()
           if len(corner_parts) >= 2 and len(b1_parts) >= 2 and len(b2_parts) >= 2:
               return (
                   "KPLANE_SLAB\n"
                   f"  {corner_parts[0]:>6s} {corner_parts[1]:>6s}      ! Original point for 2D k plane\n"
                   f"  {b1_parts[0]:>6s} {b1_parts[1]:>6s}      ! The first vector to define 2D k plane\n"
                   f"  {b2_parts[0]:>6s} {b2_parts[1]:>6s}      ! The second vector to define 2D k plane"
               )
       return (
           "KPLANE_SLAB\n"
           "  -0.5 -0.5      ! Original point for 2D k plane\n"
           "   1.0  0.0      ! The first vector to define 2D k plane\n"
           "   0.0  1.0      ! The second vector to define 2D k plane"
       )
   def _gen_selectedbands(self):
       return (
           "SELECTEDBANDS\n"
           f"  1\n"
           f"  {self.num_occupied}"
       )
   def _gen_weyl_chirality_placeholder(self):
       return (
           "WEYL_CHIRALITY\n"
           "0            ! Num_Weyls (请从 FindNodes 输出中获取并修改)\n"
           "Cartesian    ! Direct or Cartesian coordinate\n"
           "0.004        ! Radius of the ball surround a Weyl point\n"
           "# [提示] 先运行 findnodes 获取 Weyl 点坐标,然后修改 Num_Weyls\n"
           "# 并在此处逐个添加 Weyl 点坐标行,格式:    "
       )
   def _gen_unfold_placeholder(self):
       return (
           "# ============================================================\n"
           "# [能带展开] 需要手动定义以下模块(替换注释为实际数据):\n"
           "# ============================================================\n"
           "# LATTICE_UNFOLD         - 展开目标晶格(原胞晶格矢量)\n"
           "# ATOM_POSITIONS_UNFOLD  - 展开目标原子坐标\n"
           "# PROJECTORS_UNFOLD      - 展开目标投影轨道\n"
           "# SELECTED_ATOMS         - 超胞中选中的原子索引\n"
           "# ============================================================\n"
           "# 示例(请根据实际超胞修改):\n"
           "#\n"
           "# LATTICE_UNFOLD\n"
           "# Angstrom\n"
           "#   \n"
           "#   \n"
           "#   \n"
           "#\n"
           "# ATOM_POSITIONS_UNFOLD\n"
           "#   <num_atoms>      ! 原胞原子数\n"
           "# Direct\n"
           "#   \n"
           "#   ...\n"
           "#\n"
           "# PROJECTORS_UNFOLD\n"
           "#   <num_proj> <num_proj> ...  ! 每个原子的投影轨道数\n"
           "#   ...\n"
           "#   ...\n"
           "#\n"
           "# SELECTED_ATOMS\n"
           "#   1               ! 组数\n"
           "#   <num_selected>  ! 选中的原子数\n"
           "#   ...\n"
           "# ============================================================"
       )

============================================================================

自旋极化检测

============================================================================

def _prompt_hr_path(hr_type, search_dirs, expect_name):
   """当自动搜索找不到 hr.dat 时,提示用户手动输入路径。
   Parameters
   ----------
   hr_type : str
       描述性文字,如 "
spin-up
", "
spin-down
", "
标准(无自旋)
"
   expect_name : str
       期望文件名,如 "
wannier90.up_hr.dat
"
   search_dirs : list
       已搜索过的目录列表
   Returns
   -------
   str or None : 用户指定的有效路径,或 None(用户放弃)
   "
""
   print(f"\n  {hr_type} 的 hr.dat 文件未自动找到。")
   print(f"  期望文件名: {expect_name}")
   print(f"  已搜索目录:")
   for d in search_dirs:
       print(f"    {d}")
   print(f"\n  请手动输入 {expect_name} 的完整路径,")
   print(f"  或输入 'q' / 回车跳过。")
   for _ in range(5):
       try:
           path = input(f"  请输入路径: ").strip()
       except (EOFError, KeyboardInterrupt):
           return None
       if not path or path.lower() == 'q':
           return None
       if os.path.isfile(path):
           return os.path.abspath(path)
       print(f"  文件不存在: {path},请重新输入。")
   print(f"  达到最大尝试次数,已跳过。")
   return None
def _find_required_hr(win_dir, cwd, channel):
   """搜索 hr.dat 文件并返回绝对路径。未找到返回 None。
   Parameters
   ----------
   win_dir : str
       .win 文件所在目录
   cwd : str
       当前工作目录
   channel : str
       "
up
" → wannier90.up_hr.dat, "
dn
" → wannier90.dn_hr.dat, "
" → wannier90_hr.dat
   "
""
   fname = f"wannier90.{channel}_hr.dat" if channel else "wannier90_hr.dat"
   # 去重搜索路径(win_dir 优先,cwd 其次)
   search_dirs = []
   for d in [win_dir, cwd]:
       d_abs = os.path.abspath(d)
       if d_abs not in search_dirs:
           search_dirs.append(d_abs)
   for d in search_dirs:
       fpath = os.path.join(d, fname)
       if os.path.isfile(fpath):
           return os.path.abspath(fpath)
   return None
def detect_spin_channel(win_dir, cwd=None):
   """检测是否存在自旋极化 hr.dat 文件,返回实际文件路径。
   搜索顺序: win_dir (与 .win 同目录) → cwd (当前工作目录)
   物理背景:
   - 共线磁性 DFT 计算中 Wannier90 独立输出 wannier90.up_hr.dat / .dn_hr.dat
   - 两个自旋通道的紧束缚哈密顿量不完全相同(交换劈裂),需分别计算
   - 通道内 SOC=0,Hrfile 指向对应文件,NumOccupied=num_wann
   Parameters
   ----------
   win_dir : str
       .win 文件所在目录(hr.dat 通常在同一目录)
   cwd : str or None
       当前工作目录(备选搜索路径)
   Returns
   -------
   dict: {
       "
up
": "
/path/to/wannier90.up_hr.dat
" 或 None,
       "
dn
": "
/path/to/wannier90.dn_hr.dat
" 或 None,
       "
has_spin
": bool,
       "
available
": ["
up
", "
dn
"]  # 实际找到的通道列表
   }
   "
""
   result = {"up": None, "dn": None, "has_spin": False, "available": []}
   # 搜索目录列表:优先 .win 所在目录,其次当前工作目录
   search_dirs = [os.path.abspath(win_dir)]
   if cwd:
       cwd_abs = os.path.abspath(cwd)
       if cwd_abs not in search_dirs:
           search_dirs.append(cwd_abs)
   for search_dir in search_dirs:
       for ch in ["up", "dn"]:
           if result[ch] is not None:
               continue  # 已找到,不覆盖(win_dir 优先)
           fname = f"wannier90.{ch}_hr.dat"
           fpath = os.path.join(search_dir, fname)
           if os.path.isfile(fpath):
               result[ch] = fpath
               result[ch + "found_dir"] = search_dir
               result["available"].append(ch)
   # 也检测 Wannier90 其他命名: *
_hr.dat (排除标准 wannier90_hr.dat)
   if not result["available"]:
       for search_dir in search_dirs:
           try:
               for f in os.listdir(search_dir):
                   for suffix, ch in [("_up_hr.dat", "up"), ("_down_hr.dat", "dn")]:
                       if f.endswith(suffix) and f != "wannier90_hr.dat":
                           if result[ch] is None:
                               fpath = os.path.join(search_dir, f)
                               result[ch] = fpath
                               result[ch + "_found_dir"] = search_dir
                               if ch not in result["available"]:
                                   result["available"].append(ch)
           except (FileNotFoundError, PermissionError):
               continue
   result["has_spin"] = len(result["available"]) > 0
   # 检测标准(非自旋)hr.dat
   result["non_spin_hr"] = None
   for search_dir in search_dirs:
       std_path = os.path.join(search_dir, "wannier90_hr.dat")
       if os.path.isfile(std_path):
           result["non_spin_hr"] = std_path
           result["non_spin_found_dir"] = search_dir
           break
   return result
def get_spin_selection(spin_info, mode="interactive"):
   """获取用户自旋通道选择。
   Parameters
   ----------
   spin_info : dict
       detect_spin_channel() 的返回值
   mode : str
       "
interactive
": 交互式选择
       "
auto
": 自动选择(有 up/dn 就全选,包括 None(标准通道)如果无 hr 文件)
   Returns
   -------
   list of (str, str) tuples: [("
up
", hrfile_path), ("
dn
", hrfile_path)] 或
                             [(None, hrfile_path)] 或 []
   "
""
   available = spin_info.get("available", [])
   if not available:
       # 无自旋极化,检查标准 wannier90_hr.dat 是否存在
       std_hr = spin_info.get("non_spin_hr")
       if std_hr:
           return [(None, std_hr)]
       else:
           return []  # 无任何 hr.dat,交由调用方报错
   if mode == "auto":
       # 自动模式:有 up 就选 up,有 dn 就选 dn
       result = []
       for ch in available:
           hr_path = spin_info.get(ch)
           if hr_path:
               result.append((ch, hr_path))
       return result
   # 交互式选择
   print(f"\n{'=' * 72}")
   print("  检测到自旋极化计算!")
   for ch in available:
       print(f"  发现文件: {spin_info.get(ch)}")
   print(f"{'=' * 72}")
   print(f"\n  请选择要生成的自旋通道:")
   print(f"    [1] spin-up   (Hrfile -> wannier90.up_hr.dat, SOC = 0, NumOccupied = num_wann)")
   if "dn" in available:
       print(f"    [2] spin-down (Hrfile -> wannier90.dn_hr.dat, SOC = 0, NumOccupied = num_wann)")
       print(f"    [3] 两个通道都生成 (独立 wt-up/ 和 wt-dn/ 目录)")
       print(f"    [4] 跳过(不生成自旋极化文件)")
   else:
       print(f"    [2] 跳过(不生成)")
   while True:
       try:
           choice = input("\n  请选择: ").strip()
           if choice == "1":
               return [(ch, spin_info[ch]) for ch in available if spin_info[ch]][:1]
           elif choice == "2" and "dn" in available:
               return [(ch, spin_info[ch]) for ch in ["dn"] if spin_info[ch]]
           elif choice == "3" and "dn" in available:
               return [(ch, spin_info[ch]) for ch in available if spin_info[ch]]
           elif choice in ("2", "4"):
               return []
           else:
               print("  无效选择,请重新输入")
       except (EOFError, KeyboardInterrupt):
           return []

============================================================================

交互式菜单

============================================================================

def print_banner():
   """打印横幅"""
   print("=" * 72)
   print("   win2wt: Wannier90 (.win) → WannierTools (wt.in) 转换工具")
   print(f"   版本 {version} ({date})")
   print("=" * 72)
def print_menu():
   """打印任务选择菜单"""
   print("\n可用计算任务(输入编号或任务代码选择):\n")
   task_list = list(TASK_DEFINITIONS.keys())
   idx = 1
   task_index = {}
   for group_name, group_keys in TASK_GROUPS.items():
       print(f"  ┌─ {group_name} ──────────────────────────────────────────────┐")
       for key in group_keys:
           task = TASK_DEFINITIONS[key]
           print(f"  │ [{idx:2d}] {key:<18s} {task['name']:<42s} │")
           task_index[idx] = key
           idx += 1
       print(f"  └{'─' * 62}┘")
   print(f"\n  [{idx:2d}] 生成全部 19 个任务")
   task_index[idx] = "ALL"
   idx += 1
   print(f"  [{idx:2d}] 退出")
   return task_index
def get_user_selection(task_index):
   """获取用户选择"""
   while True:
       try:
           choice = input("\n请选择 [输入编号或任务代码,多个用逗号分隔]: ").strip()
           if not choice:
               continue
           selected = set()
           # 支持逗号分隔的多个选择
           parts = [p.strip() for p in choice.split(",")]
           for part in parts:
               # 尝试数字
               try:
                   num = int(part)
                   if num in task_index:
                       val = task_index[num]
                       if val == "ALL":
                           return list(TASK_DEFINITIONS.keys())
                       selected.add(val)
                   else:
                       print(f"  无效编号: {num}")
                       continue
               except ValueError:
                   # 尝试任务代码
                   if part.lower() in TASK_DEFINITIONS:
                       selected.add(part.lower())
                   elif part.lower() == "all":
                       return list(TASK_DEFINITIONS.keys())
                   elif part.lower() in ["q", "quit", "exit"]:
                       return []
                   else:
                       print(f"  无效任务代码: {part}")
                       continue
           if selected:
               return list(selected)
           else:
               print("  未选择有效任务,请重新输入。")
       except (EOFError, KeyboardInterrupt):
           print("\n")
           return []
def interactive_mode(win_path, output_dir=None, spin_channel=None):
   """交互式模式。
   自动检测自旋极化体系(搜索 .win 目录 + 当前工作目录),
   提示用户选择自旋通道和计算任务,生成对应的 wt.in 文件。
   "
""
   print_banner()
   # 解析 .win 文件
   print(f"\n正在解析: {win_path}")
   try:
       parser = WinParser(win_path)
   except Exception as e:
       print(f"  ✗ 解析失败: {e}")
       sys.exit(1)
   # 检测自旋极化(搜索 win_dir + cwd)
   win_dir = os.path.dirname(os.path.abspath(win_path)) or "."
   cwd = os.getcwd()
   spin_info = detect_spin_channel(win_dir, cwd)
   # 显示解析摘要
   print(f"  ✓ 晶格: {parser.get_lattice_vectors() is not None}")
   print(f"  ✓ 原子数: {len(parser.get_atom_positions_cart())}")
   print(f"  ✓ 投影数: {len(parser.get_projections())}")
   print(f"  ✓ k-path: {len(parser.get_kpoint_path())} 条")
   print(f"  ✓ 费米能级: {parser.get_fermi_energy()} eV")
   print(f"  ✓ SOC: {'是' if parser.get_soc() else '否'}")
   print(f"  ✓ Wannier 函数数: {parser.get_num_wann()}")
   num_occ = parser.get_num_wann() // 2 if parser.get_soc() else parser.get_num_wann()
   print(f"  ✓ 估算占据数: {num_occ}")
   if spin_info["has_spin"]:
       print(f"  ✓ 检测到自旋极化 HR 文件:")
       for ch in spin_info["available"]:
           print(f"      {ch.upper()}: {spin_info[ch]}")
   elif spin_info.get("non_spin_hr"):
       print(f"  ✓ 标准 hr.dat: {spin_info['non_spin_hr']}")
   else:
       print(f"  ⚠ 未检测到 hr.dat 文件(将使用默认路径 wannier90_hr.dat)")
   # 获取自旋通道选择
   if spin_channel is not None:
       # 用户显式指定了 --spin up|dn|both
       if spin_channel == "both":
           spin_channels = [(ch, spin_info[ch]) for ch in spin_info["available"] if spin_info[ch]]
       else:
           hr_path = os.path.join(win_dir, f"wannier90.{spin_channel}_hr.dat")
           if not os.path.isfile(hr_path):
               hr_path = f"wannier90.{spin_channel}_hr.dat"
           spin_channels = [(spin_channel, hr_path)]
   else:
       spin_channels = get_spin_selection(spin_info, mode="interactive")
   if not spin_channels:
       print("退出。")
       return
   # 显示菜单
   task_index = print_menu()
   # 获取选择
   selected = get_user_selection(task_index)
   if not selected:
       print("退出。")
       return
   # 生成文件
   if output_dir is None:
       output_dir = win_dir
   os.makedirs(output_dir, exist_ok=True)
   print(f"\n正在生成 wt.in 文件到: {output_dir}/")
   print("-" * 72)
   for ch, hr_path in spin_channels:
       generator = WtInGenerator(parser, spin_channel=ch, hrfile_path=hr_path)
       if ch:
           spin_dir = os.path.join(output_dir, f"wt-{ch}")
           os.makedirs(spin_dir, exist_ok=True)
           print(f"\n  [自旋通道: {ch.upper()}] Hrfile: {hr_path}")
           print(f"  [输出到: {spin_dir}/]")
       else:
           spin_dir = output_dir
       for task_key in selected:
           task = TASK_DEFINITIONS[task_key]
           output_path = os.path.join(spin_dir, f"wt.in-{task_key}")
           try:
               generator.generate(task_key, output_path)
           except Exception as e:
               print(f"  ✗ {task_key} 生成失败: {e}")
   print("-" * 72)
   total = len(selected) * len(spin_channels)
   print(f"\n完成! 共生成 {total} 个 wt.in 文件。")
   print(f"使用方法: cp wt.in- wt.in && mpirun -np 4 wt.x")
def batch_mode_forced(win_path, tasks, output_dir=None):
   """强制标准模式 — 不检测 up/dn,直接使用 wannier90_hr.dat。
   用于用户显式指定 --nospin 的场景。
   "
""
   parser = WinParser(win_path)
   if output_dir is None:
       output_dir = os.path.dirname(os.path.abspath(win_path)) or "."
   os.makedirs(output_dir, exist_ok=True)
   win_dir = os.path.dirname(os.path.abspath(win_path)) or "."
   cwd = os.getcwd()
   search_dirs = [win_dir, cwd]
   hr_path = _find_required_hr(win_dir, cwd, "")
   if hr_path is None:
       hr_path = _prompt_hr_path("标准(无自旋)", search_dirs, "wannier90_hr.dat")
       if hr_path is None:
           print("  已跳过,未生成任何文件。")
           return
   generator = WtInGenerator(parser, spin_channel=None, hrfile_path=hr_path)
   print(f"\n  模式: 强制标准")
   print(f"  Hrfile: {hr_path}")
   print(f"  输出到: {output_dir}/")
   for task_key in tasks:
       if task_key not in TASK_DEFINITIONS:
           print(f"  ✗ 未知任务: {task_key},跳过。")
           continue
       output_path = os.path.join(output_dir, f"wt.in-{task_key}")
       generator.generate(task_key, output_path)
def batch_mode(win_path, tasks, output_dir=None, spin_channel=None):
   """批量生成模式。
   自动检测当前文件夹和 .win 目录的 up/dn hr.dat 文件,
   有则自动分别生成对应通道的 wt.in 文件。
   "
""
   parser = WinParser(win_path)
   if output_dir is None:
       output_dir = os.path.dirname(os.path.abspath(win_path)) or "."
   win_dir = os.path.dirname(os.path.abspath(win_path)) or "."
   cwd = os.getcwd()
   if spin_channel is not None:
       # 用户显式指定了 --up / --dn → 必须找到对应文件
       ch = spin_channel
       hr_path = _find_required_hr(win_dir, cwd, ch)
       if hr_path is None:
           hr_path = _prompt_hr_path(
               f"spin-{ch}", [win_dir, cwd], f"wannier90.{ch}_hr.dat"
           )
           if hr_path is None:
               print("  已跳过,未生成任何文件。")
               return False
       spin_channels = [(ch, hr_path)]
   else:
       # 自动检测:有 up/dn 就都用,没有就回退到标准 hr.dat
       spin_info = detect_spin_channel(win_dir, cwd)
       spin_channels = get_spin_selection(spin_info, mode="auto")
   if not spin_channels:
       same_dirs = []
       for d in [win_dir, cwd]:
           d_abs = os.path.abspath(d)
           if d_abs not in same_dirs:
               same_dirs.append(d_abs)
       hr_path = _prompt_hr_path("通用", same_dirs, "wannier90_hr.dat")
       if hr_path is None:
           print("  未检测到任何 hr.dat 文件,退出。")
           print("  期望文件: wannier90_hr.dat, wannier90.up_hr.dat 或 wannier90.dn_hr.dat")
           return False
       spin_channels = [(None, hr_path)]
   for ch, hr_path in spin_channels:
       generator = WtInGenerator(parser, spin_channel=ch, hrfile_path=hr_path)
       if ch:
           spin_dir = os.path.join(output_dir, f"wt-{ch}")
           os.makedirs(spin_dir, exist_ok=True)
           print(f"\n{'=' * 60}")
           print(f"  自旋通道: {ch.upper()}")
           print(f"  Hrfile:   {hr_path}")
           print(f"  SOC:      0 (自旋通道内无 SOC)")
           print(f"  NumOccupied: {generator.num_occupied} (= num_wann)")
           print(f"  输出目录: {spin_dir}/")
           print(f"{'=' * 60}")
       else:
           spin_dir = output_dir
           os.makedirs(spin_dir, exist_ok=True)
           print(f"\n  标准模式(无自旋极化)")
           print(f"  Hrfile: {hr_path}")
           print(f"  输出到: {spin_dir}/")
       for task_key in tasks:
           if task_key not in TASK_DEFINITIONS:
               print(f"  ✗ 未知任务: {task_key},跳过。")
               continue
           output_path = os.path.join(spin_dir, f"wt.in-{task_key}")
           generator.generate(task_key, output_path)
   return True

============================================================================

主入口

============================================================================

def main():
   parser = argparse.ArgumentParser(
       description="win2wt: Wannier90 (.win) → WannierTools (wt.in) 自动转换工具",
       formatter_class=argparse.RawDescriptionHelpFormatter,
       epilog=textwrap.dedent("""
           示例:
             %(prog)s wannier90.win                  # 交互式选择任务和自旋通道
             %(prog)s wannier90.win --all            # 生成所有任务(自动检测 up/dn)
             %(prog)s wannier90.win -t bands,dos     # 生成指定任务
             %(prog)s wannier90.win --all --up       # 所有任务, 只生成 spin-up
             %(prog)s wannier90.win --all --dn       # 所有任务, 只生成 spin-down
             %(prog)s wannier90.win --all -o ./out   # 指定输出目录
             %(prog)s wannier90.win --list           # 列出任务
           提示: --all 模式下默认自动检测 up/dn 并同时生成,
                 无需手动指定 --spin。
       "
""
),
   )
   parser.add_argument("win_file", nargs="?", help="Wannier90 .win 输入文件路径")
   parser.add_argument("--all", "-a", action="store_true",
                       help="生成所有 19 个任务的 wt.in 文件")
   parser.add_argument("--list", "-l", action="store_true",
                       help="列出所有可用任务")
   parser.add_argument("--tasks", "-t", type=str,
                       help="逗号分隔的任务代码列表 (如 bands,dos,ahc)")
   parser.add_argument("--output", "-o", type=str, default=None,
                       help="输出目录(默认与 .win 同目录)")
   parser.add_argument("--up", action="store_true",
                       help="只生成 spin-up 通道(跳过交互选择)")
   parser.add_argument("--dn", action="store_true",
                       help="只生成 spin-down 通道(跳过交互选择)")
   parser.add_argument("--nospin", action="store_true",
                       help="强制使用标准模式(不检测 up/dn,即使用 wannier90_hr.dat)")
   args = parser.parse_args()
   # --list 模式
   if args.list:
       print_banner()
       print("\n可用任务列表:\n")
       for group_name, group_keys in TASK_GROUPS.items():
           print(f"  [{group_name}]")
           for key in group_keys:
               task = TASK_DEFINITIONS[key]
               print(f"    {key:<18s} - {task['name']}")
           print()
       return
   # 检查互斥选项
   if args.up and args.dn:
       print("错误: --up 和 --dn 不能同时使用。要生成两个通道,直接不加即可(默认自动检测)。")
       sys.exit(1)
   if args.up and args.nospin:
       print("错误: --up 和 --nospin 不能同时使用。")
       sys.exit(1)
   if args.dn and args.nospin:
       print("错误: --dn 和 --nospin 不能同时使用。")
       sys.exit(1)
   # 需要 .win 文件
   if not args.win_file:
       parser.print_help()
       sys.exit(1)
   if not os.path.exists(args.win_file):
       print(f"错误: 文件不存在: {args.win_file}")
       sys.exit(1)
   # 确定输出目录
   output_dir = args.output
   if output_dir is None:
       output_dir = os.path.dirname(os.path.abspath(args.win_file)) or "."
   # 确定自旋通道
   if args.nospin:
       spin_ch = "nospin"
   elif args.up:
       spin_ch = "up"
   elif args.dn:
       spin_ch = "dn"
   else:
       spin_ch = None  # auto-detect
   if args.all:
       print_banner()
       print(f"\n正在生成所有 19 个任务的 wt.in 文件...")
       if spin_ch == "nospin":
           batch_mode_forced(args.win_file, list(TASK_DEFINITIONS.keys()), output_dir)
           print(f"\n完成! 文件输出到: {output_dir}/")
       else:
           ok = batch_mode(args.win_file, list(TASK_DEFINITIONS.keys()), output_dir, spin_ch)
           if ok:
               print(f"\n完成! 文件输出到: {output_dir}/")
   elif args.tasks:
       tasks = [t.strip() for t in args.tasks.split(",")]
       if spin_ch == "nospin":
           batch_mode_forced(args.win_file, tasks, output_dir)
           print(f"\n完成! 文件输出到: {output_dir}/")
       else:
           ok = batch_mode(args.win_file, tasks, output_dir, spin_ch)
           if ok:
               print(f"\n完成! 文件输出到: {output_dir}/")
   else:
       interactive_mode(args.win_file, output_dir, spin_ch)
if name == "main":
   main()``
auto_test.sh 测试脚本
`#!/bin/bash

============================================================================

auto_test.sh — WannierTools wt.in 批量自动测试脚本

版本: 1.0 | 日期: 2026-06-06

功能:

1. 批量测试所有 wt.in-* 文件(cp 为 wt.in → wt.x → 保存 WT.out)

2. 超时控制(防止某任务卡死)

3. 自动检测 hr.dat 文件(支持 up/dn 自旋极化)

4. 汇总报告(通过/超时/失败/错误统计)

5. 支持并行模式(利用 GNU parallel 或背景进程)

用法:

./auto_test.sh                             # 在当前目录测试所有 wt.in-*

./auto_test.sh -d /path/to/wt-files         # 指定目录

./auto_test.sh -t 120                       # 超时 120 秒

./auto_test.sh -p 4                         # 4 核并行

./auto_test.sh -s "bands dos ahc"           # 只测试指定任务

./auto_test.sh --only-check                 # 只检查 in 文件,不运行

============================================================================

set -euo pipefail

--- 默认参数 ---

WT_EXEC="${WT_EXEC:-wt.x}"
TIMEOUT=50
PARALLEL=1
WORK_DIR="."
SELECTED_TASKS=""
ONLY_CHECK=false
VERBOSE=false

--- 颜色 ---

RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
CYAN='\033[0;36m'
NC='\033[0m' # No Color

--- 帮助 ---

usage() {
   cat <<EOF
用法: $0 [选项]
选项:
 -d DIR       工作目录 (默认: .)
 -t SECONDS   单任务超时秒数 (默认: 50)
 -p N         并行核数 (默认: 1, 串行)
 -s "task1 task2 ..."  只测试指定任务
 --only-check 只检查 in 文件语法,不运行 wt.x
 -v           详细输出
 -h           显示帮助
示例:
 $0                          # 测试当前目录所有 wt.in-*
 $0 -d ./wt-up -t 120        # 测试 wt-up 目录,120s 超时
 $0 -s "bands dos ahc"       # 只测试 bands, dos, ahc
 $0 -p 4                     # 4 核并行测试
EOF
   exit 0
}

--- 解析参数 ---

while [[ $# -gt 0 ]]; do
   case "$1" in
       -d) WORK_DIR="$2"; shift 2 ;;
       -t) TIMEOUT="$2"; shift 2 ;;
       -p) PARALLEL="$2"; shift 2 ;;
       -s) SELECTED_TASKS="$2"; shift 2 ;;
       --only-check) ONLY_CHECK=true; shift ;;
       -v) VERBOSE=true; shift ;;
       -h|--help) usage ;;
       *) echo "未知选项: $1"; usage ;;
   esac
done
cd "$WORK_DIR" || { echo -e "${RED}错误: 无法进入目录 $WORK_DIR${NC}"; exit 1; }

--- 检测 hr.dat ---

detect_hr() {
   if [ -f "wannier90.up_hr.dat" ]; then
       echo "up"
   elif [ -f "wannier90.dn_hr.dat" ]; then
       echo "dn"
   elif [ -f "wannier90_hr.dat" ]; then
       echo "standard"
   else
       echo ""
   fi
}

--- 检查单个 wt.in 是否有致命参数错误 ---

check_wt_in() {
   local f="$1"
   local errors=0
   # 检查 NSLAB 是否在 &PARAMETERS(应在 &SYSTEM)
   if grep -q "^&PARAMETERS" "$f" 2>/dev/null; then
       if awk '/&PARAMETERS/,///' "$f" | grep -q "NSLAB|NSLAB1|NSLAB2|[1]NP[[:space:]]=" 2>/dev/null; then
           echo -e "  ${RED}✗ NSLAB/NP 在 &PARAMETERS 中(应在 &SYSTEM)${NC}"
           errors=$((errors + 1))
       fi
       if awk '/&PARAMETERS/,///' "$f" | grep -q "Bmagnitude|Btheta|Bphi|[[:space:]]*Bx|[[:space:]]By|[2]Bz" 2>/dev/null; then
           echo -e "  ${RED}✗ Bmagnitude/Btheta/Bphi 在 &PARAMETERS 中(应在 &SYSTEM)${NC}"
           errors=$((errors + 1))
       fi
   fi
   # 检查 Hrfile 是否与目录中的 hr.dat 匹配
   local hr=$(detect_hr)
   if [ -n "$hr" ] && [ "$hr" != "standard" ]; then
       if ! grep -q "wannier90.${hr}_hr.dat" "$f" 2>/dev/null; then
           echo -e "  ${YELLOW}⚠ Hrfile 可能不匹配(期望 wannier90.${hr}_hr.dat)${NC}"
       fi
   fi
   return $errors
}

--- 运行单个测试 ---

run_one() {
   local f="$1"
   local task="${f#wt.in-}"
   local out_file="WT-${f}.out"
   local start_time=$(date +%s)
   if $VERBOSE; then
       printf "${BLUE}[测试]${NC} %-18s ... " "$task"
   fi
   # 清理旧文件
   rm -f WT.out "$out_file"
   # 复制为 wt.in(WannierTools 只读 ./wt.in)
   cp "$f" wt.in
   # 运行(超时控制)
   if timeout "$TIMEOUT" "$WT_EXEC" > /dev/null 2>&1; then
       local elapsed=$(($(date +%s) - start_time))
       if [ -f WT.out ]; then
           cp WT.out "$out_file"
       fi
       if $VERBOSE; then
           echo -e "${GREEN}OK${NC} (${elapsed}s)"
       fi
       echo "PASS ${task} ${elapsed}" >> "$SUMMARY_FILE"
   else
       local rc=$?
       if [ -f WT.out ]; then
           cp WT.out "$out_file"
       fi
       if [ $rc -eq 124 ]; then
           local elapsed="$TIMEOUT"
           if $VERBOSE; then
               echo -e "${YELLOW}TIMEOUT${NC} (>${TIMEOUT}s)"
           fi
           echo "TIMEOUT ${task} ${elapsed}" >> "$SUMMARY_FILE"
       else
           local nerr=$(grep -c "Error|ERROR" "$out_file" 2>/dev/null || echo 0)
           if $VERBOSE; then
               echo -e "${RED}FAIL${NC} (exit=$rc, errors=$nerr)"
           fi
           echo "FAIL ${task} ${rc} ${nerr}" >> "$SUMMARY_FILE"
       fi
   fi
   # 清理临时 wt.in
   rm -f wt.in
}

--- 主流程 ---

main() {
   echo -e "${CYAN}${NC}"
   echo -e "${CYAN}  WannierTools 批量测试工具${NC}"
   echo -e "${CYAN}
${NC}"
   echo "  工作目录: $(pwd)"
   echo "  超时设置: ${TIMEOUT}s"
   echo "  并行核数: ${PARALLEL}"
   echo "  wt.x路径: ${WT_EXEC}"
   echo "  时间: $(date '+%Y-%m-%d %H:%M:%S')"
   # 检测 hr.dat
   local hr=$(detect_hr)
   if [ -z "$hr" ]; then
       echo -e "${RED}错误: 未找到 hr.dat 文件!${NC}"
       echo "  期望: wannier90_hr.dat, wannier90.up_hr.dat 或 wannier90.dn_hr.dat"
       exit 1
   fi
   echo -e "  hr.dat:  ${GREEN}$hr${NC}"
   # 收集任务列表
   local files=()
   if [ -n "$SELECTED_TASKS" ]; then
       for task in $SELECTED_TASKS; do
           if [ -f "wt.in-${task}" ]; then
               files+=("wt.in-${task}")
           else
               echo -e "  ${YELLOW}警告: wt.in-${task} 不存在,跳过${NC}"
           fi
       done
   else
       for f in wt.in-; do
           [ -f "$f" ] && files+=("$f")
       done
   fi
   local total=${#files[@]}
   if [ $total -eq 0 ]; then
       echo -e "${RED}错误: 未找到任何 wt.in-
文件!${NC}"
       exit 1
   fi
   echo -e "  任务数:  ${total}"
   # 语法检查模式
   if $ONLY_CHECK; then
       echo ""
       echo -e "${CYAN}--- 语法检查 ---${NC}"
       local check_errors=0
       for f in "${files[@]}"; do
           local task="${f#wt.in-}"
           echo -e "  ${BLUE}$task${NC}"
           if ! check_wt_in "$f"; then
               check_errors=$((check_errors + 1))
           fi
       done
       echo ""
       if [ $check_errors -eq 0 ]; then
           echo -e "${GREEN}全部通过!${NC}"
       else
           echo -e "${RED}$check_errors 个文件存在问题${NC}"
       fi
       exit $check_errors
   fi
   # 运行测试
   echo ""
   echo -e "${CYAN}--- 运行测试 ---${NC}"
   echo ""
   SUMMARY_FILE=$(mktemp)
   if [ "$PARALLEL" -gt 1 ]; then
       echo "  并行模式: $PARALLEL 核"
       export TIMEOUT WT_EXEC SUMMARY_FILE VERBOSE RED GREEN YELLOW BLUE CYAN NC
       export -f run_one detect_hr
       printf '%s\n' "${files[@]}" | xargs -P "$PARALLEL" -I {} bash -c 'run_one "$@"' _ {}
       wait
   else
       local count=0
       for f in "${files[@]}"; do
           count=$((count + 1))
           printf "[%2d/%2d] " $count $total
           run_one "$f"
       done
   fi
   # 汇总
   echo ""
   echo -e "${CYAN}${NC}"
   echo -e "${CYAN}  测试汇总${NC}"
   echo -e "${CYAN}
${NC}"
   local n_pass=$(grep -c "^PASS" "$SUMMARY_FILE" 2>/dev/null || echo 0)
   local n_timeout=$(grep -c "^TIMEOUT" "$SUMMARY_FILE" 2>/dev/null || echo 0)
   local n_fail=$(grep -c "^FAIL" "$SUMMARY_FILE" 2>/dev/null || echo 0)
   echo ""
   echo -e "  ${GREEN}通过:   $n_pass${NC}"
   echo -e "  ${YELLOW}超时:   $n_timeout${NC}"
   echo -e "  ${RED}失败:   $n_fail${NC}"
   echo -e "  ${CYAN}总计:   $total${NC}"
   echo ""
   # 列出失败详情
   if [ "$n_fail" -gt 0 ]; then
       echo -e "${RED}失败详情:${NC}"
       grep "^FAIL" "$SUMMARY_FILE" | while read -r line; do
           echo "  $line"
       done
   fi
   # 列出超时详情
   if [ "$n_timeout" -gt 0 ]; then
       echo -e "${YELLOW}超时任务 (仍在运行)😒{NC}"
       grep "^TIMEOUT" "$SUMMARY_FILE" | while read -r line; do
           echo "  $line"
       done
   fi
   # 输出文件位置
   echo ""
   echo "WT.out 文件已保存至: $(pwd)/WT-wt.in-*.out"
   rm -f "$SUMMARY_FILE"
   # 返回码
   if [ "$n_fail" -gt 0 ]; then
       exit 1
   else
       exit 0
   fi
}
main`


  1. [:space:] ↩︎

  2. [:space:] ↩︎