#!/usr/bin/env python3
"""
Migration script to add course_speciality_association table for shared/interdisciplinary courses.

This migration adds support for linking multiple specialities to courses,
which is essential for interdisciplinary courses and timetable generation.
"""

import sqlite3
import os
from datetime import datetime

def run_migration():
    """Run the migration to add course_speciality_association table"""
    
    # Database path
    db_path = 'test.db'  # Update this to your actual database path
    
    if not os.path.exists(db_path):
        print(f"Database file {db_path} not found!")
        return False
    
    try:
        # Connect to database
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        
        print("Starting course_speciality_association table migration...")
        
        # Check if table already exists
        cursor.execute("""
            SELECT name FROM sqlite_master 
            WHERE type='table' AND name='course_speciality_association'
        """)
        
        if cursor.fetchone():
            print("course_speciality_association table already exists. Skipping migration.")
            conn.close()
            return True
        
        # Create the course_speciality_association table
        cursor.execute("""
            CREATE TABLE course_speciality_association (
                course_id VARCHAR(150) NOT NULL,
                speciality_id VARCHAR(150) NOT NULL,
                is_primary_speciality BOOLEAN DEFAULT 0,
                assigned_date DATE DEFAULT CURRENT_DATE,
                assigned_by VARCHAR(150),
                notes TEXT,
                is_active BOOLEAN DEFAULT 1,
                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
                updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
                FOREIGN KEY (course_id) REFERENCES courses(id),
                FOREIGN KEY (speciality_id) REFERENCES specialities(id),
                FOREIGN KEY (assigned_by) REFERENCES supervisors(id),
                PRIMARY KEY (course_id, speciality_id)
            )
        """)
        
        # Create indexes for better performance
        cursor.execute("""
            CREATE INDEX idx_course_speciality_course_id 
            ON course_speciality_association(course_id)
        """)
        
        cursor.execute("""
            CREATE INDEX idx_course_speciality_speciality_id 
            ON course_speciality_association(speciality_id)
        """)
        
        cursor.execute("""
            CREATE INDEX idx_course_speciality_active 
            ON course_speciality_association(is_active)
        """)
        
        # Update existing courses to have proper shared course fields if they don't exist
        cursor.execute("""
            SELECT name FROM sqlite_master 
            WHERE type='table' AND name='courses'
        """)
        
        if cursor.fetchone():
            # Check if shared course columns exist
            cursor.execute("PRAGMA table_info(courses)")
            columns = [column[1] for column in cursor.fetchall()]
            
            if 'is_shared_course' not in columns:
                cursor.execute("ALTER TABLE courses ADD COLUMN is_shared_course BOOLEAN DEFAULT 0")
                print("Added is_shared_course column to courses table")
            
            if 'shared_course_type' not in columns:
                cursor.execute("ALTER TABLE courses ADD COLUMN shared_course_type VARCHAR(50) DEFAULT 'department_specific'")
                print("Added shared_course_type column to courses table")
            
            if 'sharing_level' not in columns:
                cursor.execute("ALTER TABLE courses ADD COLUMN sharing_level VARCHAR(50) DEFAULT 'single'")
                print("Added sharing_level column to courses table")
        
        # Commit changes
        conn.commit()
        
        print("✅ Successfully created course_speciality_association table")
        print("✅ Added indexes for better performance")
        print("✅ Updated courses table with shared course fields")
        
        # Verify the table was created
        cursor.execute("""
            SELECT COUNT(*) FROM course_speciality_association
        """)
        count = cursor.fetchone()[0]
        print(f"✅ Table created successfully. Current records: {count}")
        
        conn.close()
        return True
        
    except sqlite3.Error as e:
        print(f"❌ SQLite error during migration: {e}")
        if conn:
            conn.rollback()
            conn.close()
        return False
    except Exception as e:
        print(f"❌ Unexpected error during migration: {e}")
        if conn:
            conn.rollback()
            conn.close()
        return False

def rollback_migration():
    """Rollback the migration by dropping the table"""
    
    db_path = 'test.db'  # Update this to your actual database path
    
    if not os.path.exists(db_path):
        print(f"Database file {db_path} not found!")
        return False
    
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        
        print("Rolling back course_speciality_association table migration...")
        
        # Drop the table
        cursor.execute("DROP TABLE IF EXISTS course_speciality_association")
        
        # Remove the shared course columns from courses table (optional)
        # Note: SQLite doesn't support DROP COLUMN, so we'll leave them
        
        conn.commit()
        conn.close()
        
        print("✅ Successfully rolled back migration")
        return True
        
    except sqlite3.Error as e:
        print(f"❌ SQLite error during rollback: {e}")
        if conn:
            conn.rollback()
            conn.close()
        return False
    except Exception as e:
        print(f"❌ Unexpected error during rollback: {e}")
        if conn:
            conn.rollback()
            conn.close()
        return False

if __name__ == "__main__":
    import sys
    
    if len(sys.argv) > 1 and sys.argv[1] == "rollback":
        success = rollback_migration()
    else:
        success = run_migration()
    
    if success:
        print("\n🎉 Migration completed successfully!")
        sys.exit(0)
    else:
        print("\n💥 Migration failed!")
        sys.exit(1)
