#!/usr/bin/env python3
import sys
import re
from collections import defaultdict

# -------------------------
# Functions
# -------------------------
def expand_config(config_str):
    """
    Expands a configuration such as '221010' into a list of spin occupations [(nα,nβ),...].
    Assumed convention: '0' -> (0,0), '1' -> (1,0), '2' -> (1,1).
    """
    mapping = {"0": (0,0), "1": (1,0), "2": (1,1)}
    try:
        return [mapping[ch] for ch in config_str.strip()]
    except KeyError as e:
        raise ValueError(f"Unrecognized character in configuration: {e}")

def diff_squared(conf, ref):
    total = 0
    for (na, nb), (ra, rb) in zip(conf, ref):
        total += (na - ra)**2 + (nb - rb)**2
    return total

def get_excitation_order(conf, ref):
    """
    Counts how many electrons GAIN occupation with respect to the reference:
    k = sum( max(0, na-ra) + max(0, nb-rb) ) over all orbitals.
    """
    k_plus = 0
    for (na, nb), (ra, rb) in zip(conf, ref):
        if na > ra:
            k_plus += na - ra
        if nb > rb:
            k_plus += nb - rb
    return k_plus

float_re = re.compile(r"[+\-]?\d*\.?\d+(?:[Ee][+\-]?\d+)?")

def parse_block(text_block):
    """
    Extracts (c2, config) from a text block with lines of the type:
      0.28210 [   4]: 220200
    """
    configs = []
    for line in text_block.strip().splitlines():
        m = re.search(rf"\s*({float_re.pattern})\s+\[\s*\d+\]\s*:\s*(\d+)\s*$", line)
        if m:
            c2 = float(m.group(1))
            config = m.group(2).strip()
            configs.append((c2, config))
    return configs

def parse_conf_file(filename):
    with open(filename, "r") as f:
        text = f.read()
    if re.search(r"\bROOT\b", text):
        blocks = re.split(r"ROOT\s+\d+:", text)[1:]
        parsed = [parse_block(b) for b in blocks]
    else:
        parsed = [parse_block(text)]
    return parsed

def calc_Nex_for_roots(roots):
    if not roots or not roots[0]:
        raise ValueError("No data found in ROOT 0 to determine the reference.")

    ref_config = max(roots[0], key=lambda x: x[0])[1]
    ref = expand_config(ref_config)

    results = []
    fractions_list = []
    counts_list = []
    weights_list = []

    for root in roots:
        Nex_total = 0.0
        k_contrib = defaultdict(float)
        k_counts = defaultdict(int)

        # NEW: explicit normalization of printed CI coefficients
        total_c2 = sum(c2 for c2, _ in root)
        if total_c2 <= 0.0:
            raise ValueError("Total CI weight is zero or negative; cannot normalize.")

        for c2, config in root:
            c2_norm = c2 / total_c2
            conf = expand_config(config)
            diff = diff_squared(conf, ref)
            k = get_excitation_order(conf, ref)
            contribution = 0.5 * c2_norm * diff
            Nex_total += contribution
            k_contrib[k] += contribution
            k_counts[k] += 1

        fractions = {}
        for k, contrib in sorted(k_contrib.items()):
            fractions[k] = contrib / Nex_total if Nex_total > 0 else 0.0

        results.append(Nex_total)
        fractions_list.append(fractions)
        counts_list.append(dict(k_counts))
        weights_list.append(total_c2)

    return ref_config, results, fractions_list, counts_list, weights_list

# -------------------------
# Main
# -------------------------
def main():
    if len(sys.argv) < 2:
        print("Usage: python Nex.py file.conf")
        sys.exit(1)

    filename = sys.argv[1]
    roots = parse_conf_file(filename)
    ref_config, Nex_list, fractions_list, counts_list, weights_list = calc_Nex_for_roots(roots)

    print(f"Reference configuration (ROOT 0, largest c2): {ref_config}")
    print("="*60)

    for i, (Nex, fracs, counts, w) in enumerate(zip(Nex_list, fractions_list, counts_list, weights_list)):
        print(f"\nROOT {i}: N_ex = {Nex:.6f}")
        print(f"Printed CI weight before normalization: {w:.6f}")
        print("-"*40)
        total_fraction = 0.0
        for k in sorted(fracs.keys()):
            f_ex_k = fracs[k]
            total_fraction += f_ex_k
            print(f"  k = {k}: f_ex^({k}) = {f_ex_k:.6f} ({f_ex_k*100:.2f}%), count = {counts.get(k,0)}")
        print(f"  Total sum of fractions: {total_fraction:.12f}")
        if abs(total_fraction - 1.0) < 1e-8:
            print("  ✓ The sum of fractions is correct (≈ 1.0)")
        else:
            print(f"  ⚠ The sum of fractions differs from 1.0: {total_fraction:.6f}")

if __name__ == "__main__":
    main()
