from __future__ import annotations

from datetime import datetime, timedelta
from typing import List, Optional

from sqlalchemy import (
    create_engine,
    Column,
    Integer,
    String,
    Text,
    DateTime,
    Boolean,
    ForeignKey,
    UniqueConstraint,
    select,
)
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
from sqlalchemy import update, delete, func

# SQLite DB file in project root
engine = create_engine(
    'sqlite:///bot.db',
    future=True,
    echo=False,
    connect_args={"check_same_thread": False},
)
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)

# --- MySQL example (commented) ---
# To use MySQL on server, first install a driver, e.g.:
#   pip install pymysql
# Then replace the engine/session with something like below:
# from sqlalchemy.engine.url import URL
# mysql_url = URL.create(
#     drivername='mysql+pymysql',
#     username=os.environ.get('MYSQL_USER', 'your_user'),
#     password=os.environ.get('MYSQL_PASSWORD', 'your_password'),
#     host=os.environ.get('MYSQL_HOST', '127.0.0.1'),
#     port=int(os.environ.get('MYSQL_PORT', '3306')),
#     database=os.environ.get('MYSQL_DB', 'your_db'),
#     query={'charset': 'utf8mb4'}  # optional
# )
# engine = create_engine(
#     mysql_url,
#     future=True,
#     echo=False,
#     pool_pre_ping=True,          # helps with stale connections
# )
# SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)

Base = declarative_base()


class User(Base):
    __tablename__ = 'users'
    id = Column(Integer, primary_key=True)
    chat_id = Column(Integer, unique=True, index=True, nullable=False)
    created_at = Column(DateTime, default=datetime.utcnow, nullable=False)

    resumes = relationship('Resume', back_populates='user')
    jobs = relationship('EmployerJob', back_populates='user')


class Resume(Base):
    __tablename__ = 'resumes'
    id = Column(Integer, primary_key=True)
    user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False)
    county = Column(String(100), nullable=False)
    category = Column(String(100), nullable=False)
    text = Column(Text, nullable=False)
    created_at = Column(DateTime, default=datetime.utcnow, nullable=False)

    user = relationship('User', back_populates='resumes')


class EmployerJob(Base):
    __tablename__ = 'employer_jobs'
    id = Column(Integer, primary_key=True)
    user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False)
    county = Column(String(100), nullable=False)
    category = Column(String(100), nullable=False)
    text = Column(Text, nullable=False)
    created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
    active = Column(Boolean, default=True, nullable=False)
    views = Column(Integer, default=0, nullable=False)  # for future stats

    user = relationship('User', back_populates='jobs')


class Admin(Base):
    __tablename__ = 'admins'
    id = Column(Integer, primary_key=True)
    chat_id = Column(Integer, unique=True, index=True, nullable=False)
    created_at = Column(DateTime, default=datetime.utcnow, nullable=False)


class AdminUsername(Base):
    __tablename__ = 'admins_by_username'
    id = Column(Integer, primary_key=True)
    username = Column(String(64), unique=True, index=True, nullable=False)
    created_at = Column(DateTime, default=datetime.utcnow, nullable=False)


# Simple key-value settings table (for night mode and future settings)
class Setting(Base):
    __tablename__ = 'settings'
    id = Column(Integer, primary_key=True)
    key = Column(String(64), unique=True, index=True, nullable=False)
    value = Column(String(256), nullable=False)
    created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)


class DailyView(Base):
    __tablename__ = 'daily_views'
    id = Column(Integer, primary_key=True)
    user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False)
    date = Column(String(10), nullable=False)  # YYYY-MM-DD
    count = Column(Integer, default=0, nullable=False)

    __table_args__ = (
        UniqueConstraint('user_id', 'date', name='uq_daily_view_user_date'),
    )


class SeenItem(Base):
    __tablename__ = 'seen_items'
    id = Column(Integer, primary_key=True)
    user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False)
    item_type = Column(String(16), nullable=False)  # 'job' | 'resume'
    item_id = Column(Integer, nullable=False)
    seen_at = Column(DateTime, default=datetime.utcnow, nullable=False)

    __table_args__ = (
        UniqueConstraint('user_id', 'item_type', 'item_id', name='uq_seen_user_item'),
    )


class PromoPost(Base):
    __tablename__ = 'promo_posts'
    id = Column(Integer, primary_key=True)
    user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False)
    media_type = Column(String(10), nullable=False)  # 'photo' | 'video' | 'text'
    file_id = Column(String(256), nullable=True)     # Telegram file_id for photo/video
    caption = Column(Text, nullable=True)
    text = Column(Text, nullable=True)
    views = Column(Integer, default=0, nullable=False)
    created_at = Column(DateTime, default=datetime.utcnow, nullable=False)

    user = relationship('User')


# Create tables
Base.metadata.create_all(bind=engine)

# Lightweight migration: ensure 'views' column exists on promo_posts (for SQLite)
try:
    with engine.begin() as conn:
        info = conn.exec_driver_sql("PRAGMA table_info('promo_posts')").fetchall()
        cols = {row[1] for row in info}  # row[1] is column name
        if 'views' not in cols:
            conn.exec_driver_sql("ALTER TABLE promo_posts ADD COLUMN views INTEGER NOT NULL DEFAULT 0")
except Exception:
    # Best-effort; ignore if not applicable
    pass


# Helpers

def _today_key() -> str:
    return datetime.now().strftime('%Y-%m-%d')


def set_setting(key: str, value: str) -> None:
    """Create or update a string setting by key."""
    if not key:
        return
    with SessionLocal() as session:
        row = session.execute(select(Setting).where(Setting.key == key)).scalar_one_or_none()
        if row:
            row.value = value
        else:
            row = Setting(key=key, value=value)
            session.add(row)
        session.commit()


def get_setting(key: str, default: Optional[str] = None) -> Optional[str]:
    if not key:
        return default
    with SessionLocal() as session:
        row = session.execute(select(Setting.value).where(Setting.key == key)).scalar_one_or_none()
        return row if row is not None else default


def get_night_mode_config() -> dict:
    """Returns dict: {enabled: bool, start: str|None, end: str|None}"""
    enabled = get_setting('night_enabled', '0') == '1'
    start = get_setting('night_start', None)
    end = get_setting('night_end', None)
    return {"enabled": enabled, "start": start, "end": end}


def get_or_create_user(chat_id: int) -> User:
    with SessionLocal() as session:
        user = session.execute(select(User).where(User.chat_id == chat_id)).scalar_one_or_none()
        if user:
            return user
        user = User(chat_id=chat_id)
        session.add(user)
        session.commit()
        session.refresh(user)
        return user


# Admin helpers

# Deprecated: id-based admin helpers (kept for backward compatibility)
def is_admin(chat_id: int) -> bool:
    with SessionLocal() as session:
        return session.execute(select(Admin).where(Admin.chat_id == chat_id)).scalar_one_or_none() is not None


def add_admin(chat_id: int) -> bool:
    with SessionLocal() as session:
        existing = session.execute(select(Admin).where(Admin.chat_id == chat_id)).scalar_one_or_none()
        if existing:
            return False
        session.add(Admin(chat_id=chat_id))
        session.commit()
        return True


def remove_admin(chat_id: int) -> bool:
    with SessionLocal() as session:
        rowcount = session.execute(delete(Admin).where(Admin.chat_id == chat_id)).rowcount or 0
        session.commit()
        return rowcount > 0


def list_admins() -> list[int]:
    with SessionLocal() as session:
        rows = session.execute(select(Admin.chat_id)).all()
        return [r[0] for r in rows]


# Username-based admin helpers

def is_admin_username(username: str) -> bool:
    if not username:
        return False
    with SessionLocal() as session:
        return session.execute(select(AdminUsername).where(AdminUsername.username == username.lower())).scalar_one_or_none() is not None


def add_admin_username(username: str) -> bool:
    if not username:
        return False
    uname = username.lower()
    with SessionLocal() as session:
        existing = session.execute(select(AdminUsername).where(AdminUsername.username == uname)).scalar_one_or_none()
        if existing:
            return False
        session.add(AdminUsername(username=uname))
        session.commit()
        return True


def remove_admin_username(username: str) -> bool:
    if not username:
        return False
    uname = username.lower()
    with SessionLocal() as session:
        rowcount = session.execute(delete(AdminUsername).where(AdminUsername.username == uname)).rowcount or 0
        session.commit()
        return rowcount > 0


def list_admin_usernames() -> list[str]:
    with SessionLocal() as session:
        rows = session.execute(select(AdminUsername.username)).all()
        return [r[0] for r in rows]


# Resumes

def save_resume(chat_id: int, county: str, category: str, text: str) -> int:
    user = get_or_create_user(chat_id)
    with SessionLocal() as session:
        # attach the user to this session
        user = session.merge(user)
        resume = Resume(user_id=user.id, county=county, category=category, text=text)
        session.add(resume)
        session.commit()
        return resume.id


def get_active_resumes_by_county(county: str) -> List[dict]:
    cutoff = datetime.utcnow() - timedelta(days=7)
    with SessionLocal() as session:
        q = session.execute(
            select(Resume, User.chat_id)
            .join(User, User.id == Resume.user_id)
            .where(Resume.county == county, Resume.created_at >= cutoff)
            .order_by(Resume.created_at.desc())
        )
        results = []
        for resume, chat_id in q.all():
            results.append({
                'id': resume.id,
                'chat_id': chat_id,
                'county': resume.county,
                'category': resume.category,
                'text': resume.text,
                'created_at': resume.created_at.strftime('%Y-%m-%d %H:%M:%S'),
            })
        return results


def get_unseen_resumes_by_county(chat_id: int, county: str) -> List[dict]:
    """Active resumes for county that this user has not seen yet (7-day window)."""
    cutoff = datetime.utcnow() - timedelta(days=7)
    user = get_or_create_user(chat_id)
    with SessionLocal() as session:
        user = session.merge(user)
        # Fetch active resumes
        rows = session.execute(
            select(Resume, User.chat_id)
            .join(User, User.id == Resume.user_id)
            .where(Resume.county == county, Resume.created_at >= cutoff)
            .order_by(Resume.created_at.desc())
        ).all()
        # Seen set
        seen_ids = {
            r[0] for r in session.execute(
                select(SeenItem.item_id).where(
                    SeenItem.user_id == user.id,
                    SeenItem.item_type == 'resume'
                )
            ).all()
        }
        items: List[dict] = []
        for resume, r_chat in rows:
            if resume.id in seen_ids:
                continue
            items.append({
                'id': resume.id,
                'chat_id': r_chat,
                'county': resume.county,
                'category': resume.category,
                'text': resume.text,
                'created_at': resume.created_at.strftime('%Y-%m-%d %H:%M:%S'),
            })
        return items


def mark_resumes_seen(chat_id: int, resume_ids: list[int]) -> None:
    if not resume_ids:
        return
    user = get_or_create_user(chat_id)
    now = datetime.utcnow()
    with SessionLocal() as session:
        user = session.merge(user)
        # Insert if not exists (SQLite UPSERT via ignore duplicates by UniqueConstraint)
        for rid in set(resume_ids):
            try:
                session.add(SeenItem(user_id=user.id, item_type='resume', item_id=rid, seen_at=now))
                session.flush()
            except Exception:
                session.rollback()
                # may conflict on unique; ignore
                pass
        session.commit()


# Employer job posts

def save_job_post(chat_id: int, county: str, category: str, text: str) -> int:
    user = get_or_create_user(chat_id)
    with SessionLocal() as session:
        user = session.merge(user)
        job = EmployerJob(user_id=user.id, county=county, category=category, text=text, active=True)
        session.add(job)
        session.commit()
        return job.id


# Promo posts

def save_promo_post(chat_id: int, media_type: str, file_id: str | None, caption: str | None, text: str | None) -> int:
    """
    Save a promo post created by an admin.
    media_type: 'photo' | 'video' | 'text'
    file_id: Telegram file_id for media (photo/video) or None
    caption: optional caption for photo/video
    text: body text for text-only posts or optional extra text
    """
    user = get_or_create_user(chat_id)
    with SessionLocal() as session:
        user = session.merge(user)
        pp = PromoPost(user_id=user.id, media_type=media_type, file_id=file_id, caption=caption, text=text)
        session.add(pp)
        session.commit()
        return pp.id


def get_random_promo_post() -> dict | None:
    """Return a random promo post as dict or None if not exists."""
    with SessionLocal() as session:
        row = session.execute(select(PromoPost).order_by(func.random()).limit(1)).scalar_one_or_none()
        if not row:
            return None
        return {
            'id': row.id,
            'media_type': row.media_type,
            'file_id': row.file_id,
            'caption': row.caption,
            'text': row.text,
            'views': row.views,
            'created_at': row.created_at.strftime('%Y-%m-%d %H:%M:%S'),
        }


def list_promo_posts(limit: int = 50) -> list[dict]:
    """Return recent promo posts ordered by newest first."""
    with SessionLocal() as session:
        q = session.execute(
            select(PromoPost).order_by(PromoPost.created_at.desc()).limit(limit)
        )
        rows = q.scalars().all()
        items: list[dict] = []
        for row in rows:
            items.append({
                'id': row.id,
                'media_type': row.media_type,
                'file_id': row.file_id,
                'caption': row.caption,
                'text': row.text,
                'views': row.views,
                'created_at': row.created_at.strftime('%Y-%m-%d %H:%M:%S'),
            })
        return items


def delete_promo_post(promo_id: int) -> bool:
    """Delete a promo post by id. Returns True if deleted."""
    with SessionLocal() as session:
        rc = session.execute(
            delete(PromoPost).where(PromoPost.id == promo_id)
        ).rowcount or 0
        session.commit()
        return rc > 0


def increment_promo_views(promo_id: int, inc: int = 1) -> None:
    if not promo_id:
        return
    with SessionLocal() as session:
        session.execute(
            update(PromoPost).where(PromoPost.id == promo_id).values(views=PromoPost.views + inc)
        )
        session.commit()


def get_active_jobs_by_county(county: str, limit: int | None = None) -> list[dict]:
    cutoff = datetime.utcnow() - timedelta(days=7)
    with SessionLocal() as session:
        q = session.execute(
            select(EmployerJob, User.chat_id)
            .join(User, User.id == EmployerJob.user_id)
            .where(
                EmployerJob.county == county,
                EmployerJob.active == True,
                EmployerJob.created_at >= cutoff,
            )
            .order_by(EmployerJob.created_at.desc())
        )
        rows = q.all()
        items = []
        for job, chat_id in rows:
            items.append({
                'id': job.id,
                'chat_id': chat_id,
                'county': job.county,
                'category': job.category,
                'text': job.text,
                'created_at': job.created_at.strftime('%Y-%m-%d %H:%M:%S'),
                'active': job.active,
                'views': job.views,
            })
        return items[:limit] if limit else items


def get_unseen_jobs_by_county(chat_id: int, county: str, limit: int | None = None) -> list[dict]:
    """Active jobs in county that this user has not seen yet (7-day window)."""
    cutoff = datetime.utcnow() - timedelta(days=7)
    user = get_or_create_user(chat_id)
    with SessionLocal() as session:
        user = session.merge(user)
        rows = session.execute(
            select(EmployerJob, User.chat_id)
            .join(User, User.id == EmployerJob.user_id)
            .where(
                EmployerJob.county == county,
                EmployerJob.active == True,
                EmployerJob.created_at >= cutoff,
            )
            .order_by(EmployerJob.created_at.desc())
        ).all()
        seen_ids = {
            r[0] for r in session.execute(
                select(SeenItem.item_id).where(
                    SeenItem.user_id == user.id,
                    SeenItem.item_type == 'job'
                )
            ).all()
        }
        items: list[dict] = []
        for job, j_chat in rows:
            if job.id in seen_ids:
                continue
            items.append({
                'id': job.id,
                'chat_id': j_chat,
                'county': job.county,
                'category': job.category,
                'text': job.text,
                'created_at': job.created_at.strftime('%Y-%m-%d %H:%M:%S'),
                'active': job.active,
                'views': job.views,
            })
        return items[:limit] if limit else items


def mark_jobs_seen(chat_id: int, job_ids: list[int]) -> None:
    if not job_ids:
        return
    user = get_or_create_user(chat_id)
    now = datetime.utcnow()
    with SessionLocal() as session:
        user = session.merge(user)
        for jid in set(job_ids):
            try:
                session.add(SeenItem(user_id=user.id, item_type='job', item_id=jid, seen_at=now))
                session.flush()
            except Exception:
                session.rollback()
                pass
        session.commit()


def increment_job_views(job_ids: list[int], inc: int = 1) -> None:
    if not job_ids:
        return
    with SessionLocal() as session:
        # increment per id; small lists keep it simple
        for jid in job_ids:
            session.execute(
                update(EmployerJob).where(EmployerJob.id == jid).values(views=EmployerJob.views + inc)
            )
        session.commit()


# Daily views (shared for both job/resume viewing)

def get_remaining_views(chat_id: int, daily_limit: int = 3) -> int:
    user = get_or_create_user(chat_id)
    today = _today_key()
    with SessionLocal() as session:
        user = session.merge(user)
        dv = session.execute(
            select(DailyView).where(DailyView.user_id == user.id, DailyView.date == today)
        ).scalar_one_or_none()
        if not dv:
            return daily_limit
        return max(0, daily_limit - dv.count)


def add_views(chat_id: int, n: int, daily_limit: int = 3) -> None:
    user = get_or_create_user(chat_id)
    today = _today_key()
    with SessionLocal() as session:
        user = session.merge(user)
        dv = session.execute(
            select(DailyView).where(DailyView.user_id == user.id, DailyView.date == today)
        ).scalar_one_or_none()
        if not dv:
            dv = DailyView(user_id=user.id, date=today, count=0)
            session.add(dv)
        dv.count = min(daily_limit, (dv.count or 0) + n)
        session.commit()


def get_my_resumes(chat_id: int) -> list[dict]:
    with SessionLocal() as session:
        q = session.execute(
            select(Resume)
            .join(User, User.id == Resume.user_id)
            .where(User.chat_id == chat_id)
            .order_by(Resume.created_at.desc())
        )
        items = []
        for resume in q.scalars().all():
            items.append({
                'id': resume.id,
                'county': resume.county,
                'category': resume.category,
                'text': resume.text,
                'created_at': resume.created_at.strftime('%Y-%m-%d %H:%M:%S'),
            })
        return items


def get_my_jobs(chat_id: int) -> list[dict]:
    with SessionLocal() as session:
        q = session.execute(
            select(EmployerJob)
            .join(User, User.id == EmployerJob.user_id)
            .where(User.chat_id == chat_id)
            .order_by(EmployerJob.created_at.desc())
        )
        items = []
        for job in q.scalars().all():
            items.append({
                'id': job.id,
                'county': job.county,
                'category': job.category,
                'text': job.text,
                'created_at': job.created_at.strftime('%Y-%m-%d %H:%M:%S'),
                'active': job.active,
                'views': job.views,
            })
        return items


def delete_my_resume(chat_id: int, resume_id: int) -> bool:
    """Delete a resume owned by chat_id. Returns True if deleted."""
    with SessionLocal() as session:
        user = session.execute(select(User).where(User.chat_id == chat_id)).scalar_one_or_none()
        if not user:
            return False
        rowcount = session.execute(
            delete(Resume).where(Resume.id == resume_id, Resume.user_id == user.id)
        ).rowcount or 0
        session.commit()
        return rowcount > 0


def delete_my_job(chat_id: int, job_id: int) -> bool:
    """Delete a job post owned by chat_id. Returns True if deleted."""
    with SessionLocal() as session:
        user = session.execute(select(User).where(User.chat_id == chat_id)).scalar_one_or_none()
        if not user:
            return False
        rowcount = session.execute(
            delete(EmployerJob).where(EmployerJob.id == job_id, EmployerJob.user_id == user.id)
        ).rowcount or 0
        session.commit()
        return rowcount > 0


def cleanup_expired(older_than_days: int = 7) -> dict:
    """Delete resumes and jobs older than N days. Returns counts.
    """
    cutoff = datetime.utcnow() - timedelta(days=older_than_days)
    with SessionLocal() as session:
        res_del = session.execute(
            delete(Resume).where(Resume.created_at < cutoff)
        ).rowcount or 0
        job_del = session.execute(
            delete(EmployerJob).where(EmployerJob.created_at < cutoff)
        ).rowcount or 0
        session.commit()
        return {"resumes_deleted": res_del, "jobs_deleted": job_del}