import logging

import chromadb

from rag.config import CHROMA_PERSIST_DIR, DEFAULT_COLLECTION, MIN_SCORE

logger = logging.getLogger(__name__)

_client: chromadb.PersistentClient | None = None


def _get_client() -> chromadb.PersistentClient:
    global _client
    if _client is None:
        CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
        _client = chromadb.PersistentClient(path=str(CHROMA_PERSIST_DIR))
    return _client


def get_collection(name: str = DEFAULT_COLLECTION) -> chromadb.Collection:
    return _get_client().get_or_create_collection(
        name=name,
        metadata={"hnsw:space": "cosine"},
    )


def upsert(
    collection_name: str,
    ids: list[str],
    embeddings: list[list[float]],
    documents: list[str],
    metadatas: list[dict],
) -> tuple[int, int]:
    coll = get_collection(collection_name)
    existing = coll.get(ids=ids)
    existing_set = set(existing["ids"])
    added = sum(1 for i in ids if i not in existing_set)
    skipped = len(ids) - added
    coll.upsert(
        ids=ids,
        embeddings=embeddings,
        documents=documents,
        metadatas=metadatas,
    )
    return added, skipped


def similarity_search(
    collection_name: str,
    query_embedding: list[float],
    top_k: int,
) -> list[dict]:
    coll = get_collection(collection_name)
    count = coll.count()
    if count == 0:
        return []

    n = min(top_k, count)
    results = coll.query(
        query_embeddings=[query_embedding],
        n_results=n,
        include=["documents", "metadatas", "distances"],
    )

    hits = []
    for doc, meta, dist in zip(
        results["documents"][0],
        results["metadatas"][0],
        results["distances"][0],
    ):
        score = max(0.0, 1.0 - dist)
        if score < MIN_SCORE:
            continue
        hits.append(
            {
                "text": doc,
                "score": score,
                "source": meta.get("source", ""),
                "chunk_index": meta.get("chunk_index", 0),
                "metadata": meta,
            }
        )

    hits.sort(key=lambda h: h["score"], reverse=True)
    return hits


def delete_collection(name: str = DEFAULT_COLLECTION) -> None:
    try:
        _get_client().delete_collection(name)
        logger.info(f"Colección '{name}' eliminada.")
    except Exception as exc:
        logger.warning(f"No se pudo eliminar la colección '{name}': {exc}")


def reset(collection: str = DEFAULT_COLLECTION) -> None:
    delete_collection(collection)


def collection_exists(name: str) -> bool:
    try:
        _get_client().get_collection(name)
        return True
    except Exception:
        return False
