import logging
from dataclasses import dataclass, field

from rag import embedder, store
from rag.config import DEFAULT_COLLECTION, DEFAULT_TOP_K

logger = logging.getLogger(__name__)


@dataclass
class Chunk:
    text: str
    score: float
    source: str
    chunk_index: int
    metadata: dict = field(default_factory=dict)


@dataclass
class QueryResult:
    chunks: list[Chunk]
    collection: str
    top_k: int


def query(
    text: str,
    collection: str = DEFAULT_COLLECTION,
    top_k: int = DEFAULT_TOP_K,
) -> QueryResult:
    empty = QueryResult(chunks=[], collection=collection, top_k=top_k)

    if not text or not text.strip():
        return empty

    try:
        vector = embedder.embed_one(text.strip())
        hits = store.similarity_search(collection, vector, top_k)
        chunks = [
            Chunk(
                text=h["text"],
                score=h["score"],
                source=h["source"],
                chunk_index=h["chunk_index"],
                metadata=h["metadata"],
            )
            for h in hits
        ]
        logger.info(
            f"RAG query en '{collection}': {len(chunks)} chunk(s) recuperados (top_k={top_k})."
        )
        return QueryResult(chunks=chunks, collection=collection, top_k=top_k)

    except Exception as exc:
        logger.error("Error en RAG query: %s", exc, exc_info=True)
        return empty
