import numpy as np
from entropy_numba import calculate_entropy_similarity, apply_weight_to_intensity
import tools
from gnps_index import compute_all_pairs_spectra_entropy, create_index, find_bin_range  # Replace with your actual script/module name

from numba.typed import List
import numba

# ========== Parse spectrum strings ==========
def parse_mgf_block(mgf_str):
    mzs = []
    intensities = []
    for line in mgf_str.strip().splitlines():
        line = line.strip()
        if not line or line.startswith("BEGIN") or line.startswith("END") or "=" in line:
            continue
        try:
            mz, inten = map(float, line.split())
            mzs.append(mz)
            intensities.append(inten)
        except ValueError:
            continue
    return np.column_stack((np.array(mzs, dtype=np.float32), np.array(intensities, dtype=np.float32)))

query_mgf ="""BEGIN IONS
SCANS=6144
PEPMASS=430.91354
CHARGE=1
COLLISION_ENERGY=0.0
55.580627	3206.192871
55.815353	3218.787842
56.970882	3178.656982
77.562996	3133.969238
90.971001	6860.199707
90.977066	93055.726562
158.963852	46303.84375
221.166138	3846.468262
222.263962	8268.759766
226.951141	21141.042969
281.258545	4045.204834
379.04541	3317.030029
END IONS"""
reference_mgf = """BEGIN IONS
SCANS=6145  
PEPMASS=226.95129
CHARGE=1
COLLISION_ENERGY=0.0
54.248283	2962.395264
61.309139	3229.900146
73.028709	4186.556152
74.181908	2997.041504
79.456779	2975.758545
90.977036	758335.5
90.990501	6890.162109
91.305176	3304.454102
103.127068	3252.181152
103.34082	3011.128174
120.010719	3149.802734
155.081284	4428.489746
158.963776	32663.013672
160.440475	3207.70166
217.417847	3289.767334
225.068069	5826.655273
227.10408	5153.543457
242.53569	3254.464111
247.280106	3951.192139
END IONS"""

# ========== Set Parameters ==========
tolerance = 0.01
threshold = 0.0  # show all matches

# ========== Clean and weight ==========
specs_raw = [parse_mgf_block(query_mgf), parse_mgf_block(reference_mgf)]

numba_spectra = List()
for idx, raw_spec in enumerate(specs_raw):
    cleaned = tools.clean_spectrum(raw_spec,
        min_ms2_difference_in_da=2 * tolerance,
        noise_threshold=0.01,
        normalize_intensity=True
    )
    weighted = apply_weight_to_intensity(cleaned)
    numba_spectra.append((
        weighted[:, 0],
        weighted[:, 1],
        np.float32(223.06 if idx == 0 else 381.24),  # approximate precursor mz
        np.int32(1)  # assume charge = 1
    ))

# ========== Run reference calculation ==========
score_ref = calculate_entropy_similarity(specs_raw[0], specs_raw[1],
    ms2_tolerance_in_da=tolerance,
    clean_spectra=True
)
print(f"[Ref] Entropy similarity: {score_ref:.6f}")

# ========== Run index-based version ==========
shared_idx = create_index(numba_spectra, is_shifted=False, tolerance=tolerance, shifted_offset=6000)
results = compute_all_pairs_spectra_entropy(
    spectra=numba_spectra,
    shared_entries=shared_idx,
    tolerance=tolerance,
    threshold=threshold,
    query_start=0,
    query_end=0  # only test index 0
)

score_test = results[0][1][0][1] if results[0][1] else 0.0
print(f"[Index] Entropy similarity: {score_test:.6f}")

# ========== Compare ==========
diff = abs(score_ref - score_test)
print(f"Difference: {diff:.8f}")
if diff < 1e-6:
    print("✅ PASS: Scores match!")
else:
    print("❌ FAIL: Scores differ!")
