From 0f806dc3ca8a483a52bdd90fb168cd0f6dd2cb17 Mon Sep 17 00:00:00 2001 From: Lynn Date: Mon, 20 Apr 2026 11:40:01 +0800 Subject: [PATCH] Feat: mysql sync (#14200) ### What problem does this PR solve? Add a script to sync db schema with peewee_migrate. ### Type of change - [x] Other (please describe): tool script --- tools/scripts/README.md | 138 ++++- tools/scripts/db_schema_sync.py | 951 ++++++++++++++++++++++++++++++++ 2 files changed, 1088 insertions(+), 1 deletion(-) create mode 100644 tools/scripts/db_schema_sync.py diff --git a/tools/scripts/README.md b/tools/scripts/README.md index cd7ebf075b..f310d9083c 100644 --- a/tools/scripts/README.md +++ b/tools/scripts/README.md @@ -1,4 +1,13 @@ -# MySQL Data Migration Script +# Database Scripts + +This directory contains database-related utility scripts for RAGFlow. + +- **mysql_migration.py**: Data migration between tables with stage-based execution +- **db_schema_sync.py**: Database schema synchronization using peewee-migrate + +--- + +# mysql_migration.py A flexible MySQL data migration tool for migrating data between tables with stage-based execution. @@ -208,3 +217,130 @@ Stage Details: | `[DRY RUN] Target table does not exist` | Target table missing, use `--execute` or `--create-table-only`to create | | `Dependency table does not exist` | Required table from previous stage missing | | `Inserted batch X: Y records` | Successfully inserted batch of records | + +--- + +# db_schema_sync.py + +A database schema synchronization tool that uses peewee-migrate to detect and manage schema changes. + +## Overview + +This script: +1. Reads model definitions from `api/db/db_models.py` +2. Compares with existing database tables specified via command line +3. Generates migration files in `tools/migrate/{version}/` + +### Detected Change Types + +| Change Type | Description | Auto-included? | +|-------------|-------------|----------------| +| New table | Model class with no corresponding DB table | Yes | +| New field | Model field not present in DB table | Yes | +| Field type change | Model field type differs from DB column type | Yes | +| Removed field | DB column not present in model definition | No (requires `--drop`) | + +> **Warning**: Removed fields are **not** included in migrations by default. You must explicitly use `--drop` to generate `DROP COLUMN` statements, as this operation permanently deletes data. + +## Prerequisites + +Install peewee-migrate: +```bash +pip install peewee-migrate +``` + +## Usage + +### Command Line Arguments + +``` +python db_schema_sync.py [OPTIONS] +``` + +| Option | Short | Description | +|--------|-------|-------------| +| `--host` | - | MySQL host (required) | +| `--port` | - | MySQL port (default: 3306) | +| `--user` | - | MySQL user (required) | +| `--password` | - | MySQL password (required) | +| `--database` | - | MySQL database name (required) | +| `--version` | `-v` | Version number in format `vxx.xx.xx` (required) | +| `--list` | `-l` | List all migrations | +| `--create` | - | Create a new migration (auto-detect changes) | +| `--migrate` | `-m` | Run pending migrations | +| `--diff` | `-d` | Show schema differences | +| `--name` | `-n` | Migration name (default: auto) | +| `--drop` | - | Include `DROP COLUMN` for fields removed from models (destructive - permanently deletes data!) | + +### Version Format + +Version must be in format `vxx.xx.xx` where `xx` are digits: +- Valid: `v0.24.0`, `v1.0.0`, `v10.20.30` +- Invalid: `0.24.0`, `v0.24`, `v0.24.0.1` + +### Migration File Location + +Migration files are stored in: +``` +tools/migrate/{version_dir}/ +``` + +Where `{version_dir}` is the version with `.` replaced by `_`. + +Example: Version `v0.24.0` → Directory `tools/migrate/v0_24_0/` + +### Examples + +```bash +# List all migrations +python db_schema_sync.py --list \ + --host localhost --port 3306 --user root --password xxx --database rag_flow \ + --version v0.24.0 + +# Create a new auto-detected migration (new tables, new fields, type changes only) +python db_schema_sync.py --create \ + --host localhost --port 3306 --user root --password xxx --database rag_flow \ + --version v0.24.0 + +# Create a migration including dropped fields (destructive!) +python db_schema_sync.py --create --drop \ + --host localhost --port 3306 --user root --password xxx --database rag_flow \ + --version v0.24.0 + +# Create a named migration +python db_schema_sync.py --create --name add_user_table \ + --host localhost --port 3306 --user root --password xxx --database rag_flow \ + --version v0.24.0 + +# Run all pending migrations +python db_schema_sync.py --migrate \ + --host localhost --port 3306 --user root --password xxx --database rag_flow \ + --version v0.24.0 + +# Show schema differences (including removed fields) +python db_schema_sync.py --diff \ + --host localhost --port 3306 --user root --password xxx --database rag_flow \ + --version v0.24.0 +``` + +## How It Works + +1. **Load Models**: Imports all model classes from `api/db/db_models.py` +2. **Connect Database**: Creates MySQL connection from command line arguments +3. **Detect Changes**: Compares model definitions with actual database schema: + - New tables → `create_model` + - New fields → `ALTER TABLE ADD COLUMN` + - Field type changes → `ALTER TABLE MODIFY COLUMN` + - Removed fields → `ALTER TABLE DROP COLUMN` (only with `--drop`) +4. **Generate Migration**: Creates Python migration file with `migrate()` and `rollback()` functions + +### Rollback Behavior + +| Forward Operation | Rollback Operation | +|-------------------|--------------------| +| `CREATE TABLE` | `remove_model` | +| `ADD COLUMN` | `DROP COLUMN` | +| `MODIFY COLUMN` | `MODIFY COLUMN` (restore original type) | +| `DROP COLUMN` | `ADD COLUMN` (restore column definition; **data is lost**) | + +> **Note**: Rolling back a `DROP COLUMN` will re-add the column structure, but the data that was in it cannot be recovered. diff --git a/tools/scripts/db_schema_sync.py b/tools/scripts/db_schema_sync.py new file mode 100644 index 0000000000..85d1dade57 --- /dev/null +++ b/tools/scripts/db_schema_sync.py @@ -0,0 +1,951 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Database Schema Sync Script + +This script synchronizes database models defined in api/db/db_models.py +with the actual database schema using peewee-migrate. + +Features: +1. Reads model definitions from api/db/db_models.py +2. Compares with existing database tables specified via command line +3. Generates migration files in tools/migrate/{version}/ +""" + +import argparse +import importlib.util +import inspect +import logging +import os +import re +import sys + +from peewee import MySQLDatabase, Model, Field +from peewee_migrate import Router + +# Add project root to path for imports +PROJECT_BASE = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.insert(0, PROJECT_BASE) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def validate_version(version: str) -> bool: + """Validate version format: vxx.xx.xx where xx are digits""" + pattern = r'^v\d+\.\d+\.\d+$' + return bool(re.match(pattern, version)) + + +def version_to_dirname(version: str) -> str: + """Convert version string to valid directory name (e.g., 'v0.24.0' -> 'v0_24_0')""" + return version.replace('.', '_') + + +def load_db_models(): + """Load database models from api/db/db_models.py""" + models_path = os.path.join(PROJECT_BASE, 'api', 'db', 'db_models.py') + + if not os.path.exists(models_path): + raise FileNotFoundError(f"db_models.py not found at {models_path}") + + # Import the module + spec = importlib.util.spec_from_file_location("db_models", models_path) + db_models = importlib.util.module_from_spec(spec) + spec.loader.exec_module(db_models) + + # Get all Model subclasses + models = [] + for name, obj in inspect.getmembers(db_models): + if inspect.isclass(obj) and issubclass(obj, Model) and obj is not Model: + # Skip base model classes + if obj.__name__ in ['BaseModel', 'DataBaseModel']: + continue + # Check if it has a database attribute (is a proper model) + if hasattr(obj._meta, 'database'): + models.append(obj) + + return models, db_models + + +def create_database_connection(host: str, port: int, user: str, password: str, database: str): + """Create MySQL database connection from command line arguments""" + db = MySQLDatabase( + database, + host=host, + port=port, + user=user, + password=password, + charset='utf8mb4' + ) + return db + + +# MySQL type to Peewee field type mapping +MYSQL_TO_PEEWEE_TYPE = { + 'varchar': 'CharField', + 'char': 'CharField', + 'text': 'TextField', + 'longtext': 'TextField', + 'mediumtext': 'TextField', + 'int': 'IntegerField', + 'integer': 'IntegerField', + 'bigint': 'BigIntegerField', + 'float': 'FloatField', + 'double': 'FloatField', + 'decimal': 'FloatField', + 'datetime': 'DateTimeField', + 'timestamp': 'DateTimeField', + 'tinyint(1)': 'BooleanField', + 'tinyint': 'IntegerField', + 'smallint': 'IntegerField', + 'mediumint': 'IntegerField', +} + +PEEWEE_TO_MYSQL_TYPE = { + 'CharField': 'varchar', + 'TextField': 'text', + 'IntegerField': 'int', + 'BigIntegerField': 'bigint', + 'FloatField': 'float', + 'BooleanField': 'tinyint', + 'DateTimeField': 'datetime', +} + + +def get_table_columns(db, table_name: str) -> dict: + """Get column information from database table + + Returns: + dict: {column_name: {type, nullable, default, ...}} + """ + cursor = db.execute_sql(""" + SELECT + column_name, + data_type, + column_type, + is_nullable, + column_default, + column_key, + extra + FROM information_schema.columns + WHERE table_schema = %s AND table_name = %s + ORDER BY ordinal_position + """, (db.database, table_name)) + + columns = {} + for row in cursor.fetchall(): + col_name = row[0] + data_type = row[1].lower() + column_type = row[2].lower() + is_nullable = row[3] == 'YES' + column_default = row[4] + column_key = row[5] + extra = row[6] or '' + + # Determine peewee type + if column_type.startswith('tinyint(1)'): + peewee_type = 'BooleanField' + else: + peewee_type = MYSQL_TO_PEEWEE_TYPE.get(data_type, 'TextField') + + columns[col_name] = { + 'data_type': data_type, + 'column_type': column_type, + 'peewee_type': peewee_type, + 'nullable': is_nullable, + 'default': column_default, + 'is_primary': column_key == 'PRI', + 'extra': extra, + } + + return columns + + +def get_peewee_field_type(field: Field) -> str: + """Get peewee field type name""" + field_class = field.__class__.__name__ + return field_class + + +def get_base_field_type(field: Field) -> str: + """Get base peewee field type by walking the MRO chain. + + Custom field types (like DateTimeTzField, JSONField) inherit from standard types. + This function returns the underlying standard type for comparison. + """ + # Standard peewee field types we consider as "base" types + STANDARD_TYPES = { + 'CharField', 'TextField', 'IntegerField', 'BigIntegerField', + 'FloatField', 'BooleanField', 'DateTimeField', 'DateField', + 'TimeField', 'DecimalField', 'ForeignKeyField', 'ManyToManyField', + 'PrimaryKeyField', 'AutoField' + } + + # Walk through the MRO (Method Resolution Order) to find standard type + for cls in field.__class__.__mro__: + class_name = cls.__name__ + if class_name in STANDARD_TYPES: + return class_name + + # Fallback to TextField if no standard type found + return 'TextField' + + +def normalize_field_type(field: Field) -> str: + """Normalize field type for comparison using base type""" + return get_base_field_type(field) + + +def compare_fields(model_fields: dict, db_columns: dict) -> dict: + """Compare model fields with database columns + + Returns: + dict: { + 'added': {field_name: field_obj}, # New fields not in DB + 'changed': {field_name: (old_info, new_field)}, # Type changed + 'removed': {field_name: col_info}, # Fields in DB but not in model + } + """ + result = { + 'added': {}, + 'changed': {}, + 'removed': {}, + } + + # Skip auto-generated fields like id, create_time, etc. + skip_fields = {'id'} + + for field_name, field in model_fields.items(): + if field_name in skip_fields: + continue + + # Check if field exists in database + if field_name not in db_columns: + result['added'][field_name] = field + logger.info(f" New field detected: {field_name} ({field.__class__.__name__})") + else: + # Check if type changed + db_col = db_columns[field_name] + model_base_type = normalize_field_type(field) + db_type = db_col['peewee_type'] + + # Type mismatch + if model_base_type != db_type: + result['changed'][field_name] = (db_col, field) + logger.info(f" Field type changed: {field_name} ({db_type} -> {model_base_type}, actual: {field.__class__.__name__})") + + # Detect removed fields: columns in DB but not in model + for col_name, col_info in db_columns.items(): + if col_name in skip_fields: + continue + if col_name not in model_fields: + result['removed'][col_name] = col_info + logger.info(f" Removed field detected: {col_name} ({col_info['column_type']})") + + return result + + +def generate_field_code(field: Field, field_name: str) -> str: + """Generate peewee field definition code""" + field_class = field.__class__.__name__ + + # Map custom field types to standard peewee types for migration + # These custom types will be stored as their underlying standard type + custom_to_standard = { + 'LongTextField': 'TextField', + 'JSONField': 'TextField', + 'ListField': 'TextField', + 'SerializedField': 'TextField', + 'DateTimeTzField': 'CharField', + } + + # Use standard type for custom fields + pw_field_class = custom_to_standard.get(field_class, field_class) + + # Build field arguments + args = [] + + # max_length for CharField + if pw_field_class == 'CharField' and hasattr(field, 'max_length') and field.max_length is not None: + args.append(f"max_length={field.max_length}") + + # null + if field.null: + args.append("null=True") + + # default + if field.default is not None: + default_val = field.default + if isinstance(default_val, str): + # Escape quotes in string + escaped = default_val.replace("'", "\\'") + args.append(f"default='{escaped}'") + elif isinstance(default_val, bool): + args.append(f"default={'True' if default_val else 'False'}") + elif isinstance(default_val, (int, float)): + args.append(f"default={default_val}") + elif isinstance(default_val, dict): + args.append(f"default={default_val}") + elif isinstance(default_val, list): + args.append(f"default={default_val}") + + # index + if getattr(field, 'index', False): + args.append("index=True") + + # unique + if getattr(field, 'unique', False): + args.append("unique=True") + + args_str = ', '.join(args) + return f"pw.{pw_field_class}({args_str})" + + +def generate_add_field_sql(table_name: str, field: Field, field_name: str) -> str: + """Generate raw SQL for adding a field to MySQL table. + + This is used for existing tables where migrator.add_fields doesn't work + because the model is not registered in migrator.orm. + """ + field_class = field.__class__.__name__ + + # Determine MySQL column type + mysql_type_map = { + 'CharField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)', + 'TextField': 'LONGTEXT', + 'LongTextField': 'LONGTEXT', + 'JSONField': 'LONGTEXT', + 'ListField': 'LONGTEXT', + 'SerializedField': 'LONGTEXT', + 'IntegerField': 'INT', + 'BigIntegerField': 'BIGINT', + 'FloatField': 'DOUBLE', + 'BooleanField': 'TINYINT(1)', + 'DateTimeField': 'DATETIME', + 'DateTimeTzField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)', + } + + mysql_type = mysql_type_map.get(field_class, 'LONGTEXT') + + # Build column definition + parts = [f'`{field_name}`', mysql_type] + + # NULL/NOT NULL + if field.null: + parts.append('NULL') + else: + parts.append('NOT NULL') + + # DEFAULT + if field.default is not None: + default_val = field.default + if isinstance(default_val, str): + escaped = default_val.replace("'", "''") + parts.append(f"DEFAULT '{escaped}'") + elif isinstance(default_val, bool): + parts.append(f"DEFAULT {1 if default_val else 0}") + elif isinstance(default_val, (int, float)): + parts.append(f"DEFAULT {default_val}") + elif isinstance(default_val, dict) or isinstance(default_val, list): + import json + escaped = json.dumps(default_val).replace("'", "''") + parts.append(f"DEFAULT '{escaped}'") + + # COMMENT + if hasattr(field, 'help_text') and field.help_text: + escaped = field.help_text.replace("'", "''") + parts.append(f"COMMENT '{escaped}'") + + sql = f"ALTER TABLE `{table_name}` ADD COLUMN {' '.join(parts)}" + + # Add index if needed + index_sql = None + if getattr(field, 'index', False): + index_sql = f"CREATE INDEX `idx_{table_name}_{field_name}` ON `{table_name}` (`{field_name}`)" + + return sql, index_sql + + +def generate_drop_field_sql(table_name: str, field_name: str) -> str: + """Generate SQL for dropping a field from a table.""" + return f"ALTER TABLE `{table_name}` DROP COLUMN `{field_name}`" + + +def generate_rollback_field_sql(table_name: str, field_name: str) -> str: + """Generate SQL for removing a field.""" + return f"ALTER TABLE `{table_name}` DROP COLUMN `{field_name}`" + + +def generate_rollback_add_field_sql(table_name: str, col_info: dict, field_name: str) -> str: + """Generate SQL for rolling back a dropped field (re-adding it). + + This reconstructs the ADD COLUMN statement from the column info + that was captured before the field was dropped. + """ + mysql_type = col_info.get('column_type', 'LONGTEXT') + + parts = [f'`{field_name}`', mysql_type] + + # NULL/NOT NULL + if col_info.get('nullable', True): + parts.append('NULL') + else: + parts.append('NOT NULL') + + # DEFAULT + default_val = col_info.get('default') + if default_val is not None: + if isinstance(default_val, str): + escaped = default_val.replace("'", "''") + parts.append(f"DEFAULT '{escaped}'") + elif isinstance(default_val, bool): + parts.append(f"DEFAULT {1 if default_val else 0}") + elif isinstance(default_val, (int, float)): + parts.append(f"DEFAULT {default_val}") + + sql = f"ALTER TABLE `{table_name}` ADD COLUMN {' '.join(parts)}" + + # Re-add index if it was a non-primary key + index_sql = None + if col_info.get('column_key') == 'MUL': + index_sql = f"CREATE INDEX `idx_{table_name}_{field_name}` ON `{table_name}` (`{field_name}`)" + + return sql, index_sql + + +def generate_rollback_modify_sql(table_name: str, old_info: dict, field_name: str) -> str: + """Generate SQL for rolling back a field type change. + + Note: This restores the column type, but data values may need manual handling + if the type conversion caused data loss or transformation. + """ + # Reconstruct MySQL type from old_info + mysql_type = old_info.get('column_type', 'LONGTEXT') + + # Build column definition + parts = [f'`{field_name}`', mysql_type] + + # NULL/NOT NULL + if old_info.get('nullable', True): + parts.append('NULL') + else: + parts.append('NOT NULL') + + # DEFAULT (if available) + if old_info.get('default') is not None: + default_val = old_info['default'] + if isinstance(default_val, str): + escaped = default_val.replace("'", "''") + parts.append(f"DEFAULT '{escaped}'") + elif isinstance(default_val, bool): + parts.append(f"DEFAULT {1 if default_val else 0}") + elif isinstance(default_val, (int, float)): + parts.append(f"DEFAULT {default_val}") + + return f"ALTER TABLE `{table_name}` MODIFY COLUMN {' '.join(parts)}" + + +def generate_modify_field_sql(table_name: str, field: Field, field_name: str) -> str: + """Generate SQL for modifying a field in MySQL table.""" + field_class = field.__class__.__name__ + + # Determine MySQL column type + mysql_type_map = { + 'CharField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)', + 'TextField': 'LONGTEXT', + 'LongTextField': 'LONGTEXT', + 'JSONField': 'LONGTEXT', + 'ListField': 'LONGTEXT', + 'SerializedField': 'LONGTEXT', + 'IntegerField': 'INT', + 'BigIntegerField': 'BIGINT', + 'FloatField': 'DOUBLE', + 'BooleanField': 'TINYINT(1)', + 'DateTimeField': 'DATETIME', + 'DateTimeTzField': f'VARCHAR({field.max_length})' if hasattr(field, 'max_length') and field.max_length else 'VARCHAR(255)', + } + + mysql_type = mysql_type_map.get(field_class, 'LONGTEXT') + + # Build column definition + parts = [f'`{field_name}`', mysql_type] + + # NULL/NOT NULL + if field.null: + parts.append('NULL') + else: + parts.append('NOT NULL') + + # DEFAULT + if field.default is not None: + default_val = field.default + if isinstance(default_val, str): + escaped = default_val.replace("'", "''") + parts.append(f"DEFAULT '{escaped}'") + elif isinstance(default_val, bool): + parts.append(f"DEFAULT {1 if default_val else 0}") + elif isinstance(default_val, (int, float)): + parts.append(f"DEFAULT {default_val}") + elif isinstance(default_val, dict) or isinstance(default_val, list): + import json + escaped = json.dumps(default_val).replace("'", "''") + parts.append(f"DEFAULT '{escaped}'") + + # COMMENT + if hasattr(field, 'help_text') and field.help_text: + escaped = field.help_text.replace("'", "''") + parts.append(f"COMMENT '{escaped}'") + + return f"ALTER TABLE `{table_name}` MODIFY COLUMN {' '.join(parts)}" + + +def generate_migration_content(new_tables: list, field_changes: dict, migrate_dir: str, migration_name: str, drop_fields: bool = False) -> str: + """Generate migration file content""" + lines = [ + '"""Peewee migrations."""', + '', + 'from contextlib import suppress', + '', + 'import peewee as pw', + 'from peewee_migrate import Migrator', + '', + '', + 'with suppress(ImportError):', + ' import playhouse.postgres_ext as pw_pext', + '', + '', + 'def migrate(migrator: Migrator, database: pw.Database, *, fake=False):', + ' """Write your migrations here."""', + '', + ] + + # Generate create_model for new tables + for model in new_tables: + table_name = model._meta.table_name + model_name = model.__name__ + + lines.append(' @migrator.create_model') + lines.append(f' class {model_name}(pw.Model):') + + # Get all fields + fields = model._meta.fields + for field_name, field in fields.items(): + field_code = generate_field_code(field, field_name) + lines.append(f' {field_name} = {field_code}') + + lines.append('') + lines.append(' class Meta:') + lines.append(f' table_name = "{table_name}"') + + # Add indexes if defined + indexes = getattr(model._meta, 'indexes', None) + if indexes: + lines.append(f' indexes = {indexes}') + + lines.append('') + + # Generate SQL for adding new fields to existing tables + for table_name, changes in field_changes.items(): + if changes.get('added'): + for field_name, field in changes['added'].items(): + sql, index_sql = generate_add_field_sql(table_name, field, field_name) + lines.append(f' migrator.sql("{sql}")') + if index_sql: + lines.append(f' migrator.sql("{index_sql}")') + lines.append('') + + # Generate SQL for modifying fields in existing tables + for table_name, changes in field_changes.items(): + if changes.get('changed'): + for field_name, (old_info, field) in changes['changed'].items(): + modify_sql = generate_modify_field_sql(table_name, field, field_name) + lines.append(f' migrator.sql("{modify_sql}")') + lines.append('') + + # Generate SQL for dropping removed fields from existing tables + if drop_fields: + for table_name, changes in field_changes.items(): + if changes.get('removed'): + for field_name, col_info in changes['removed'].items(): + drop_sql = generate_drop_field_sql(table_name, field_name) + lines.append(f' # WARNING: Dropping column `{field_name}` from `{table_name}` - this will permanently delete data!') + lines.append(f' migrator.sql("{drop_sql}")') + lines.append('') + + # Generate rollback + lines.append('') + lines.append('def rollback(migrator: Migrator, database: pw.Database, *, fake=False):') + lines.append(' """Write your rollback migrations here."""') + lines.append('') + + # Rollback: re-add dropped fields (before other rollbacks, since they may depend on these fields) + if drop_fields: + for table_name, changes in field_changes.items(): + if changes.get('removed'): + for field_name, col_info in changes['removed'].items(): + add_sql, index_sql = generate_rollback_add_field_sql(table_name, col_info, field_name) + lines.append(f' # Re-add dropped column `{field_name}` to `{table_name}` (data is lost)') + lines.append(f' migrator.sql("{add_sql}")') + if index_sql: + lines.append(f' migrator.sql("{index_sql}")') + + # Rollback: reverse field type changes first (before removing added fields) + for table_name, changes in field_changes.items(): + if changes.get('changed'): + for field_name, (old_info, field) in changes['changed'].items(): + rollback_modify_sql = generate_rollback_modify_sql(table_name, old_info, field_name) + lines.append(' # Note: Data values may need manual handling if type conversion caused data loss') + lines.append(f' migrator.sql("{rollback_modify_sql}")') + + # Rollback: remove added fields using SQL + for table_name, changes in field_changes.items(): + if changes.get('added'): + for field_name in changes['added'].keys(): + rollback_sql = generate_rollback_field_sql(table_name, field_name) + lines.append(f' migrator.sql("{rollback_sql}")') + + # Rollback: remove tables (in reverse order) + for model in reversed(new_tables): + table_name = model._meta.table_name + lines.append(f' migrator.remove_model("{table_name}")') + + lines.append('') + + return '\n'.join(lines) + + +def create_migration(router: Router, models: list, db, name: str = "auto", drop_fields: bool = False): + """Create a new migration by auto-detecting model changes + + Detects: + 1. New tables -> generate create_model + 2. New fields in existing tables -> generate add_fields + 3. Field type changes -> generate change_fields + 4. Removed fields (only when --drop is specified) -> generate drop_fields + + Args: + router: peewee-migrate Router instance + models: List of model classes to compare against database + db: Database connection + name: Migration name + drop_fields: Whether to include DROP COLUMN for removed fields + """ + try: + # Get existing tables from database + cursor = db.execute_sql( + "SELECT table_name FROM information_schema.tables WHERE table_schema = %s", + (db.database,) + ) + existing_tables = {row[0] for row in cursor.fetchall()} + + new_tables = [] + field_changes = {} + + for model in models: + table_name = model._meta.table_name + + if table_name not in existing_tables: + # New table + new_tables.append(model) + logger.info(f"New table detected: {table_name}") + else: + # Existing table - check for field changes + logger.info(f"Checking existing table: {table_name}") + + # Get model fields (exclude auto-generated) + model_fields = {} + for field_name, field in model._meta.fields.items(): + # Skip id and base model fields + if field_name in ('id', 'create_time', 'create_date', 'update_time', 'update_date'): + continue + if hasattr(field, '_auto_created') and field._auto_created: + continue + model_fields[field_name] = field + + # Get database columns + db_columns = get_table_columns(db, table_name) + + # Compare + changes = compare_fields(model_fields, db_columns) + + if changes['added'] or changes['changed'] or changes['removed']: + field_changes[table_name] = changes + + # Check if any changes detected + has_removed = any(changes.get('removed') for changes in field_changes.values()) + if not drop_fields and has_removed: + removed_details = [] + for table_name, changes in field_changes.items(): + if changes.get('removed'): + for col_name in changes['removed']: + removed_details.append(f"{table_name}.{col_name}") + logger.warning(f"Removed fields detected (not included in migration, use --drop to include): {', '.join(removed_details)}") + # Remove 'removed' from changes since we're not acting on them + for table_name in field_changes: + field_changes[table_name]['removed'] = {} + + if not new_tables and not any(changes['added'] or changes['changed'] for changes in field_changes.values()): + if not (drop_fields and has_removed): + logger.info("No schema changes detected, migration not created") + return None + + # Generate migration file content + migration_content = generate_migration_content(new_tables, field_changes, router.migrate_dir, name, drop_fields=drop_fields) + + # Get next migration number (count existing migration files) + existing_migrations = [f for f in os.listdir(router.migrate_dir) if f.endswith('.py') and not f.startswith('_')] + migration_num = len(existing_migrations) + 1 + migration_file = os.path.join(router.migrate_dir, f'{migration_num:03d}_{name}.py') + + with open(migration_file, 'w') as f: + f.write(migration_content) + + logger.info(f"Created migration: {migration_file}") + return migration_file + + except Exception as e: + logger.error(f"Failed to create migration: {e}") + raise + + +def run_migrations(router: Router): + """Run all pending migrations""" + try: + diff = router.diff + if not diff: + logger.info("No pending migrations to run") + return + + router.run() + logger.info("Migrations completed successfully") + except Exception as e: + logger.error(f"Failed to run migrations: {e}") + raise + + +def list_migrations(router: Router): + """List all migrations""" + todo = router.todo + if not todo: + logger.info("No migration files found") + return + + logger.info("Available migrations:") + done = set(router.done) + for migration in todo: + status = "applied" if migration in done else "pending" + logger.info(f" [{status}] {migration}") + + +def diff_schema(models: list, db): + """Show schema differences between models and database""" + logger.info("Checking schema differences...") + + # Tables to ignore (managed by peewee-migrate) + IGNORE_TABLES = {'migratehistory'} + + # Get all model table names + model_tables = set() + for model in models: + table_name = model._meta.table_name + model_tables.add(table_name) + + logger.info(f"Found {len(model_tables)} model tables") + + # Get existing tables from database + cursor = db.execute_sql( + "SELECT table_name FROM information_schema.tables WHERE table_schema = %s", + (db.database,) + ) + existing_tables = {row[0] for row in cursor.fetchall() if row[0] not in IGNORE_TABLES} + + # Find tables that exist in models but not in database + missing_tables = model_tables - existing_tables + if missing_tables: + logger.warning(f"Tables not in database ({len(missing_tables)}): {', '.join(sorted(missing_tables))}") + + # Find tables that exist in database but not in models + extra_tables = existing_tables - model_tables + if extra_tables: + logger.info(f"Tables in database but not in models: {', '.join(sorted(extra_tables))}") + + # Check field differences for existing tables + common_tables = model_tables & existing_tables + if common_tables: + logger.info(f"\nChecking field differences for {len(common_tables)} existing tables...") + + total_added = 0 + total_changed = 0 + total_removed = 0 + + for model in models: + table_name = model._meta.table_name + if table_name not in common_tables: + continue + + # Get model fields + model_fields = {} + for field_name, field in model._meta.fields.items(): + if field_name in ('id', 'create_time', 'create_date', 'update_time', 'update_date'): + continue + model_fields[field_name] = field + + # Get database columns + db_columns = get_table_columns(db, table_name) + + # Compare + changes = compare_fields(model_fields, db_columns) + + if changes['added']: + total_added += len(changes['added']) + field_details = [f"{k}:{v.__class__.__name__}" for k, v in changes['added'].items()] + logger.info(f" {table_name}: {len(changes['added'])} new field(s) - {field_details}") + + if changes['changed']: + total_changed += len(changes['changed']) + field_details = [f"{k}:{v[1].__class__.__name__}" for k, v in changes['changed'].items()] + logger.info(f" {table_name}: {len(changes['changed'])} changed field(s) - {field_details}") + + if changes['removed']: + total_removed += len(changes['removed']) + field_details = [f"{k}:{v['column_type']}" for k, v in changes['removed'].items()] + logger.warning(f" {table_name}: {len(changes['removed'])} removed field(s) - {field_details}") + + logger.info(f"\nSummary: {total_added} new fields, {total_changed} changed fields, {total_removed} removed fields") + + +def main(): + parser = argparse.ArgumentParser( + description='Database Schema Synchronization Tool using peewee-migrate', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # List all migrations + python db_schema_sync.py --list --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.24.0 + + # Create migration from model changes + python db_schema_sync.py --create --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.24.0 + + # Create migration including dropped fields (destructive!) + python db_schema_sync.py --create --drop --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.24.0 + + # Run all pending migrations + python db_schema_sync.py --migrate --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.24.0 + + # Show schema differences + python db_schema_sync.py --diff --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.24.0 +""" + ) + + # Database connection options + parser.add_argument('--host', type=str, required=True, help='MySQL host') + parser.add_argument('--port', type=int, default=3306, help='MySQL port (default: 3306)') + parser.add_argument('--user', type=str, required=True, help='MySQL user') + parser.add_argument('--password', type=str, required=True, help='MySQL password') + parser.add_argument('--database', type=str, required=True, help='MySQL database name') + + # Version option + parser.add_argument('--version', '-v', type=str, required=True, + help='Version number in format vxx.xx.xx (e.g., v0.24.0)') + + # Action options + parser.add_argument('--list', '-l', action='store_true', help='List all migrations') + parser.add_argument('--create', '-c', action='store_true', + help='Create migration from model changes (auto-detect)') + parser.add_argument('--migrate', '-m', action='store_true', help='Run pending migrations') + parser.add_argument('--diff', '-d', action='store_true', help='Show schema differences') + + # Migration options + parser.add_argument('--name', '-n', type=str, default='auto', help='Migration name') + parser.add_argument('--drop', action='store_true', + help='Include DROP COLUMN for fields removed from models (destructive - will permanently delete data!)') + + args = parser.parse_args() + + # Validate version format + if not validate_version(args.version): + logger.error(f"Invalid version format: {args.version}. Expected format: vxx.xx.xx (e.g., v0.24.0)") + sys.exit(1) + + # Validate at least one action is specified + if not any([args.list, args.create, args.migrate, args.diff]): + parser.print_help() + logger.error("Please specify at least one action: --list, --create, --migrate, or --diff") + sys.exit(1) + + # Convert version to directory name + version_dir = version_to_dirname(args.version) + migrate_dir = os.path.join(PROJECT_BASE, 'tools', 'migrate', version_dir) + + logger.info(f"Version: {args.version}") + logger.info(f"Migration directory: {migrate_dir}") + + # Create migration directory if it doesn't exist + os.makedirs(migrate_dir, exist_ok=True) + + # Load database models + logger.info("Loading database models from api/db/db_models.py...") + models, _ = load_db_models() + logger.info(f"Found {len(models)} model classes") + + # Create database connection + db = create_database_connection( + host=args.host, + port=args.port, + user=args.user, + password=args.password, + database=args.database + ) + + try: + db.connect() + logger.info(f"Connected to database: {args.database}") + + # Create router + router = Router( + db, + migrate_dir, + ignore=['basemodel', 'base_model', 'migratehistory'] + ) + + # Execute requested actions + if args.list: + list_migrations(router) + + if args.create: + create_migration(router, models, db, args.name, drop_fields=args.drop) + + if args.migrate: + run_migrations(router) + + if args.diff: + diff_schema(models, db) + + finally: + if not db.is_closed(): + db.close() + logger.info("Database connection closed") + + logger.info("Done.") + + +if __name__ == '__main__': + main() \ No newline at end of file