#!/usr/bin/env python3

import argparse
import csv
import os
import re
import subprocess
import tempfile
from collections import defaultdict

import pysam


def left_softclip(read):
    """Return left soft-clipped sequence if read starts with S in CIGAR."""
    if read.cigartuples is None or read.query_sequence is None:
        return None

    op, length = read.cigartuples[0]
    # 4 = soft clip, 5 = hard clip. Hard-clipped sequence is absent.
    if op == 4 and length > 0:
        return read.query_sequence[:length]
    return None


def overlaps_region(read, chrom, start0, end0):
    if read.is_unmapped:
        return False
    if read.reference_name != chrom:
        return False
    return read.reference_start < end0 and read.reference_end > start0


def write_fasta(records, path):
    with open(path, "w") as f:
        for name, seq in records:
            if seq and len(seq) > 15:
                f.write(f">{name}\n{seq}\n")


def run_minimap2(human_fasta, query_fasta, paf_out, preset="sr", threads=4):
    cmd = [
        "minimap2",
        "-x", preset,
        "-t", str(threads),
        human_fasta,
        query_fasta,
    ]

    with open(paf_out, "w") as out:
        subprocess.run(cmd, check=True, stdout=out)


def parse_paf(paf_path, human_regex):
    hits = defaultdict(list)
    pattern = re.compile(human_regex) if human_regex else None

    with open(paf_path) as f:
        for line in f:
            fields = line.rstrip("\n").split("\t")
            if len(fields) < 12:
                continue

            qname = fields[0]
            qlen = int(fields[1])
            qstart = int(fields[2])
            qend = int(fields[3])
            strand = fields[4]
            target = fields[5]
            tlen = int(fields[6])
            tstart = int(fields[7])
            tend = int(fields[8])
            nmatch = int(fields[9])
            aln_len = int(fields[10])
            mapq = int(fields[11])

            if pattern and not pattern.match(target):
                continue

            hits[qname].append({
                "human_chrom": target,
                "human_start_0based": tstart,
                "human_end_0based": tend,
                "human_start_1based": tstart + 1,
                "human_end_1based": tend,
                "strand": strand,
                "query_start": qstart,
                "query_end": qend,
                "query_length": qlen,
                "matches": nmatch,
                "aln_length": aln_len,
                "mapq": mapq,
            })

    return hits


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--bam", required=True)
    ap.add_argument("--human-fasta", required=True)
    ap.add_argument("--out", required=True)

    ap.add_argument("--plasmid", default="plasmid3")
    ap.add_argument("--start", type=int, default=5399, help="1-based inclusive")
    ap.add_argument("--end", type=int, default=5662, help="1-based inclusive")

    ap.add_argument("--min-softclip", type=int, default=20)
    ap.add_argument("--min-mapq", type=int, default=10)
    ap.add_argument("--threads", type=int, default=4)
    ap.add_argument(
        "--human-regex",
        default=r"^(chr)?([1-9]|1[0-9]|2[0-2]|X|Y|M|MT)$",
        help="Regex for human chromosome names in FASTA",
    )
    args = ap.parse_args()

    region_start0 = args.start - 1
    region_end0 = args.end

    bam = pysam.AlignmentFile(args.bam, "rb")

    selected_qnames = set()
    plasmid_records = {}

    for read in bam.fetch(args.plasmid, region_start0, region_end0):
        if overlaps_region(read, args.plasmid, region_start0, region_end0):
            selected_qnames.add(read.query_name)
            plasmid_records.setdefault(read.query_name, []).append(read.to_string())

    bam.close()

    # Second pass: collect all alignments with selected qnames, including mates.
    bam = pysam.AlignmentFile(args.bam, "rb")

    candidate_fasta = []
    candidate_meta = {}

    for read in bam.fetch(until_eof=True):
        if read.query_name not in selected_qnames:
            continue

        qn = read.query_name

        # Left soft-clipped sequence from any selected qname alignment.
        clip = left_softclip(read)
        if clip and len(clip) >= args.min_softclip:
            cid = f"{qn}|leftclip|flag={read.flag}|{read.reference_name}:{read.reference_start + 1}"
            candidate_fasta.append((cid, clip))
            candidate_meta[cid] = {
                "read_name": qn,
                "candidate_type": "left_clip",
                "source_ref": read.reference_name,
                "source_start_1based": read.reference_start + 1 if not read.is_unmapped else "",
                "source_end_1based": read.reference_end if not read.is_unmapped else "",
                "source_flag": read.flag,
                "candidate_length": len(clip),
            }

        # Mate sequence: opposite read in pair, if present in BAM.
        if read.is_paired and read.query_sequence:
            # We keep both R1/R2 records for selected qnames, but mark mate-like
            # sequences as those not overlapping the plasmid interval.
            is_plasmid_overlap = overlaps_region(read, args.plasmid, region_start0, region_end0)
            if not is_plasmid_overlap:
                cid = f"{qn}|mate|flag={read.flag}"
                candidate_fasta.append((cid, read.query_sequence))
                candidate_meta[cid] = {
                    "read_name": qn,
                    "candidate_type": "mate",
                    "source_ref": read.reference_name if not read.is_unmapped else "unmapped",
                    "source_start_1based": read.reference_start + 1 if not read.is_unmapped else "",
                    "source_end_1based": read.reference_end if not read.is_unmapped else "",
                    "source_flag": read.flag,
                    "candidate_length": len(read.query_sequence),
                }

    bam.close()

    with tempfile.TemporaryDirectory() as tmp:
        query_fa = os.path.join(tmp, "candidates.fa")
        paf = os.path.join(tmp, "candidates_vs_human.paf")

        write_fasta(candidate_fasta, query_fa)

        if not candidate_fasta:
            open(args.out, "w").write("")
            return

        run_minimap2(args.human_fasta, query_fa, paf, threads=args.threads)
        hits = parse_paf(paf, args.human_regex)

    columns = [
        "read_name",
        "candidate_type",
        "candidate_length",
        "source_ref",
        "source_start_1based",
        "source_end_1based",
        "source_flag",
        "human_chrom",
        "human_start_1based",
        "human_end_1based",
        "human_start_0based",
        "human_end_0based",
        "strand",
        "query_start",
        "query_end",
        "query_length",
        "matches",
        "aln_length",
        "mapq",
    ]

    with open(args.out, "w", newline="") as out:
        writer = csv.DictWriter(out, fieldnames=columns, delimiter="\t")
        writer.writeheader()

        for cid, meta in candidate_meta.items():
            for hit in hits.get(cid, []):
                if hit["mapq"] < args.min_mapq:
                    continue
                row = dict(meta)
                row.update(hit)
                writer.writerow(row)


if __name__ == "__main__":
    main()
