Files
ragflow/tools/scripts/db_schema_sync.py

938 lines
34 KiB
Python

#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Database Schema Sync Script
This script synchronizes database models defined in api/db/db_models.py
with the actual database schema using peewee-migrate.
Features:
1. Reads model definitions from api/db/db_models.py
2. Compares with existing database tables specified via command line
3. Generates migration files in tools/migrate/{version}/
"""
import argparse
import importlib.util
import inspect
import logging
import os
import re
import sys
from peewee import MySQLDatabase, Model, Field
from peewee_migrate import Router
# Add project root to path for imports
PROJECT_BASE = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, PROJECT_BASE)
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def validate_version(version: str) -> bool:
"""Validate version format: vxx.xx.xx where xx are digits"""
pattern = r"^v\d+\.\d+\.\d+$"
return bool(re.match(pattern, version))
def version_to_dirname(version: str) -> str:
"""Convert version string to valid directory name (e.g., 'v0.26.3' -> 'v0_26_3')"""
return version.replace(".", "_")
def load_db_models():
"""Load database models from api/db/db_models.py"""
models_path = os.path.join(PROJECT_BASE, "api", "db", "db_models.py")
if not os.path.exists(models_path):
raise FileNotFoundError(f"db_models.py not found at {models_path}")
# Import the module
spec = importlib.util.spec_from_file_location("db_models", models_path)
db_models = importlib.util.module_from_spec(spec)
spec.loader.exec_module(db_models)
# Get all Model subclasses
models = []
for name, obj in inspect.getmembers(db_models):
if inspect.isclass(obj) and issubclass(obj, Model) and obj is not Model:
# Skip base model classes
if obj.__name__ in ["BaseModel", "DataBaseModel"]:
continue
# Check if it has a database attribute (is a proper model)
if hasattr(obj._meta, "database"):
models.append(obj)
return models, db_models
def create_database_connection(host: str, port: int, user: str, password: str, database: str):
"""Create MySQL database connection from command line arguments"""
db = MySQLDatabase(database, host=host, port=port, user=user, password=password, charset="utf8mb4")
return db
# MySQL type to Peewee field type mapping
MYSQL_TO_PEEWEE_TYPE = {
"varchar": "CharField",
"char": "CharField",
"text": "TextField",
"longtext": "TextField",
"mediumtext": "TextField",
"int": "IntegerField",
"integer": "IntegerField",
"bigint": "BigIntegerField",
"float": "FloatField",
"double": "FloatField",
"decimal": "FloatField",
"datetime": "DateTimeField",
"timestamp": "DateTimeField",
"tinyint(1)": "BooleanField",
"tinyint": "IntegerField",
"smallint": "IntegerField",
"mediumint": "IntegerField",
}
PEEWEE_TO_MYSQL_TYPE = {
"CharField": "varchar",
"TextField": "text",
"IntegerField": "int",
"BigIntegerField": "bigint",
"FloatField": "float",
"BooleanField": "tinyint",
"DateTimeField": "datetime",
}
def get_table_columns(db, table_name: str) -> dict:
"""Get column information from database table
Returns:
dict: {column_name: {type, nullable, default, ...}}
"""
cursor = db.execute_sql(
"""
SELECT
column_name,
data_type,
column_type,
is_nullable,
column_default,
column_key,
extra
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
""",
(db.database, table_name),
)
columns = {}
for row in cursor.fetchall():
col_name = row[0]
data_type = row[1].lower()
column_type = row[2].lower()
is_nullable = row[3] == "YES"
column_default = row[4]
column_key = row[5]
extra = row[6] or ""
# Determine peewee type
if column_type.startswith("tinyint(1)"):
peewee_type = "BooleanField"
else:
peewee_type = MYSQL_TO_PEEWEE_TYPE.get(data_type, "TextField")
columns[col_name] = {
"data_type": data_type,
"column_type": column_type,
"peewee_type": peewee_type,
"nullable": is_nullable,
"default": column_default,
"is_primary": column_key == "PRI",
"extra": extra,
}
return columns
def get_peewee_field_type(field: Field) -> str:
"""Get peewee field type name"""
field_class = field.__class__.__name__
return field_class
def get_base_field_type(field: Field) -> str:
"""Get base peewee field type by walking the MRO chain.
Custom field types (like DateTimeTzField, JSONField) inherit from standard types.
This function returns the underlying standard type for comparison.
"""
# Standard peewee field types we consider as "base" types
STANDARD_TYPES = {
"CharField",
"TextField",
"IntegerField",
"BigIntegerField",
"FloatField",
"BooleanField",
"DateTimeField",
"DateField",
"TimeField",
"DecimalField",
"ForeignKeyField",
"ManyToManyField",
"PrimaryKeyField",
"AutoField",
}
# Walk through the MRO (Method Resolution Order) to find standard type
for cls in field.__class__.__mro__:
class_name = cls.__name__
if class_name in STANDARD_TYPES:
return class_name
# Fallback to TextField if no standard type found
return "TextField"
def normalize_field_type(field: Field) -> str:
"""Normalize field type for comparison using base type"""
return get_base_field_type(field)
def compare_fields(model_fields: dict, db_columns: dict) -> dict:
"""Compare model fields with database columns
Returns:
dict: {
'added': {field_name: field_obj}, # New fields not in DB
'changed': {field_name: (old_info, new_field)}, # Type changed
'removed': {field_name: col_info}, # Fields in DB but not in model
}
"""
result = {
"added": {},
"changed": {},
"removed": {},
}
# Skip auto-generated fields like id, create_time, etc.
skip_fields = {"id"}
for field_name, field in model_fields.items():
if field_name in skip_fields:
continue
# Check if field exists in database
if field_name not in db_columns:
result["added"][field_name] = field
logger.info(f" New field detected: {field_name} ({field.__class__.__name__})")
else:
# Check if type changed
db_col = db_columns[field_name]
model_base_type = normalize_field_type(field)
db_type = db_col["peewee_type"]
# Type mismatch
if model_base_type != db_type:
result["changed"][field_name] = (db_col, field)
logger.info(f" Field type changed: {field_name} ({db_type} -> {model_base_type}, actual: {field.__class__.__name__})")
# Detect removed fields: columns in DB but not in model
for col_name, col_info in db_columns.items():
if col_name in skip_fields:
continue
if col_name not in model_fields:
result["removed"][col_name] = col_info
logger.info(f" Removed field detected: {col_name} ({col_info['column_type']})")
return result
def generate_field_code(field: Field, field_name: str) -> str:
"""Generate peewee field definition code"""
field_class = field.__class__.__name__
# Map custom field types to standard peewee types for migration
# These custom types will be stored as their underlying standard type
custom_to_standard = {
"LongTextField": "TextField",
"JSONField": "TextField",
"ListField": "TextField",
"SerializedField": "TextField",
"DateTimeTzField": "CharField",
}
# Use standard type for custom fields
pw_field_class = custom_to_standard.get(field_class, field_class)
# Build field arguments
args = []
# max_length for CharField
if pw_field_class == "CharField" and hasattr(field, "max_length") and field.max_length is not None:
args.append(f"max_length={field.max_length}")
# null
if field.null:
args.append("null=True")
# default
if field.default is not None:
default_val = field.default
if isinstance(default_val, str):
# Escape quotes in string
escaped = default_val.replace("'", "\\'")
args.append(f"default='{escaped}'")
elif isinstance(default_val, bool):
args.append(f"default={'True' if default_val else 'False'}")
elif isinstance(default_val, (int, float)):
args.append(f"default={default_val}")
elif isinstance(default_val, dict):
args.append(f"default={default_val}")
elif isinstance(default_val, list):
args.append(f"default={default_val}")
# index
if getattr(field, "index", False):
args.append("index=True")
# unique
if getattr(field, "unique", False):
args.append("unique=True")
args_str = ", ".join(args)
return f"pw.{pw_field_class}({args_str})"
def generate_add_field_sql(table_name: str, field: Field, field_name: str) -> str:
"""Generate raw SQL for adding a field to MySQL table.
This is used for existing tables where migrator.add_fields doesn't work
because the model is not registered in migrator.orm.
"""
field_class = field.__class__.__name__
# Determine MySQL column type
mysql_type_map = {
"CharField": f"VARCHAR({field.max_length})" if hasattr(field, "max_length") and field.max_length else "VARCHAR(255)",
"TextField": "LONGTEXT",
"LongTextField": "LONGTEXT",
"JSONField": "LONGTEXT",
"ListField": "LONGTEXT",
"SerializedField": "LONGTEXT",
"IntegerField": "INT",
"BigIntegerField": "BIGINT",
"FloatField": "DOUBLE",
"BooleanField": "TINYINT(1)",
"DateTimeField": "DATETIME",
"DateTimeTzField": f"VARCHAR({field.max_length})" if hasattr(field, "max_length") and field.max_length else "VARCHAR(255)",
}
mysql_type = mysql_type_map.get(field_class, "LONGTEXT")
# Build column definition
parts = [f"`{field_name}`", mysql_type]
# NULL/NOT NULL
if field.null:
parts.append("NULL")
else:
parts.append("NOT NULL")
# DEFAULT
if field.default is not None:
default_val = field.default
if isinstance(default_val, str):
escaped = default_val.replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
elif isinstance(default_val, bool):
parts.append(f"DEFAULT {1 if default_val else 0}")
elif isinstance(default_val, (int, float)):
parts.append(f"DEFAULT {default_val}")
elif isinstance(default_val, dict) or isinstance(default_val, list):
import json
escaped = json.dumps(default_val).replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
# COMMENT
if hasattr(field, "help_text") and field.help_text:
escaped = field.help_text.replace("'", "''")
parts.append(f"COMMENT '{escaped}'")
sql = f"ALTER TABLE `{table_name}` ADD COLUMN {' '.join(parts)}"
# Add index if needed
index_sql = None
if getattr(field, "index", False):
index_sql = f"CREATE INDEX `idx_{table_name}_{field_name}` ON `{table_name}` (`{field_name}`)"
return sql, index_sql
def generate_drop_field_sql(table_name: str, field_name: str) -> str:
"""Generate SQL for dropping a field from a table."""
return f"ALTER TABLE `{table_name}` DROP COLUMN `{field_name}`"
def generate_rollback_field_sql(table_name: str, field_name: str) -> str:
"""Generate SQL for removing a field."""
return f"ALTER TABLE `{table_name}` DROP COLUMN `{field_name}`"
def generate_rollback_add_field_sql(table_name: str, col_info: dict, field_name: str) -> str:
"""Generate SQL for rolling back a dropped field (re-adding it).
This reconstructs the ADD COLUMN statement from the column info
that was captured before the field was dropped.
"""
mysql_type = col_info.get("column_type", "LONGTEXT")
parts = [f"`{field_name}`", mysql_type]
# NULL/NOT NULL
if col_info.get("nullable", True):
parts.append("NULL")
else:
parts.append("NOT NULL")
# DEFAULT
default_val = col_info.get("default")
if default_val is not None:
if isinstance(default_val, str):
escaped = default_val.replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
elif isinstance(default_val, bool):
parts.append(f"DEFAULT {1 if default_val else 0}")
elif isinstance(default_val, (int, float)):
parts.append(f"DEFAULT {default_val}")
sql = f"ALTER TABLE `{table_name}` ADD COLUMN {' '.join(parts)}"
# Re-add index if it was a non-primary key
index_sql = None
if col_info.get("column_key") == "MUL":
index_sql = f"CREATE INDEX `idx_{table_name}_{field_name}` ON `{table_name}` (`{field_name}`)"
return sql, index_sql
def generate_rollback_modify_sql(table_name: str, old_info: dict, field_name: str) -> str:
"""Generate SQL for rolling back a field type change.
Note: This restores the column type, but data values may need manual handling
if the type conversion caused data loss or transformation.
"""
# Reconstruct MySQL type from old_info
mysql_type = old_info.get("column_type", "LONGTEXT")
# Build column definition
parts = [f"`{field_name}`", mysql_type]
# NULL/NOT NULL
if old_info.get("nullable", True):
parts.append("NULL")
else:
parts.append("NOT NULL")
# DEFAULT (if available)
if old_info.get("default") is not None:
default_val = old_info["default"]
if isinstance(default_val, str):
escaped = default_val.replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
elif isinstance(default_val, bool):
parts.append(f"DEFAULT {1 if default_val else 0}")
elif isinstance(default_val, (int, float)):
parts.append(f"DEFAULT {default_val}")
return f"ALTER TABLE `{table_name}` MODIFY COLUMN {' '.join(parts)}"
def generate_modify_field_sql(table_name: str, field: Field, field_name: str) -> str:
"""Generate SQL for modifying a field in MySQL table."""
field_class = field.__class__.__name__
# Determine MySQL column type
mysql_type_map = {
"CharField": f"VARCHAR({field.max_length})" if hasattr(field, "max_length") and field.max_length else "VARCHAR(255)",
"TextField": "LONGTEXT",
"LongTextField": "LONGTEXT",
"JSONField": "LONGTEXT",
"ListField": "LONGTEXT",
"SerializedField": "LONGTEXT",
"IntegerField": "INT",
"BigIntegerField": "BIGINT",
"FloatField": "DOUBLE",
"BooleanField": "TINYINT(1)",
"DateTimeField": "DATETIME",
"DateTimeTzField": f"VARCHAR({field.max_length})" if hasattr(field, "max_length") and field.max_length else "VARCHAR(255)",
}
mysql_type = mysql_type_map.get(field_class, "LONGTEXT")
# Build column definition
parts = [f"`{field_name}`", mysql_type]
# NULL/NOT NULL
if field.null:
parts.append("NULL")
else:
parts.append("NOT NULL")
# DEFAULT
if field.default is not None:
default_val = field.default
if isinstance(default_val, str):
escaped = default_val.replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
elif isinstance(default_val, bool):
parts.append(f"DEFAULT {1 if default_val else 0}")
elif isinstance(default_val, (int, float)):
parts.append(f"DEFAULT {default_val}")
elif isinstance(default_val, dict) or isinstance(default_val, list):
import json
escaped = json.dumps(default_val).replace("'", "''")
parts.append(f"DEFAULT '{escaped}'")
# COMMENT
if hasattr(field, "help_text") and field.help_text:
escaped = field.help_text.replace("'", "''")
parts.append(f"COMMENT '{escaped}'")
return f"ALTER TABLE `{table_name}` MODIFY COLUMN {' '.join(parts)}"
def generate_migration_content(new_tables: list, field_changes: dict, migrate_dir: str, migration_name: str, drop_fields: bool = False) -> str:
"""Generate migration file content"""
lines = [
'"""Peewee migrations."""',
"",
"from contextlib import suppress",
"",
"import peewee as pw",
"from peewee_migrate import Migrator",
"",
"",
"with suppress(ImportError):",
" import playhouse.postgres_ext as pw_pext",
"",
"",
"def migrate(migrator: Migrator, database: pw.Database, *, fake=False):",
' """Write your migrations here."""',
"",
]
# Generate create_model for new tables
for model in new_tables:
table_name = model._meta.table_name
model_name = model.__name__
lines.append(" @migrator.create_model")
lines.append(f" class {model_name}(pw.Model):")
# Get all fields
fields = model._meta.fields
for field_name, field in fields.items():
field_code = generate_field_code(field, field_name)
lines.append(f" {field_name} = {field_code}")
lines.append("")
lines.append(" class Meta:")
lines.append(f' table_name = "{table_name}"')
# Add indexes if defined
indexes = getattr(model._meta, "indexes", None)
if indexes:
lines.append(f" indexes = {indexes}")
lines.append("")
# Generate SQL for adding new fields to existing tables
for table_name, changes in field_changes.items():
if changes.get("added"):
for field_name, field in changes["added"].items():
sql, index_sql = generate_add_field_sql(table_name, field, field_name)
lines.append(f' migrator.sql("{sql}")')
if index_sql:
lines.append(f' migrator.sql("{index_sql}")')
lines.append("")
# Generate SQL for modifying fields in existing tables
for table_name, changes in field_changes.items():
if changes.get("changed"):
for field_name, (old_info, field) in changes["changed"].items():
modify_sql = generate_modify_field_sql(table_name, field, field_name)
lines.append(f' migrator.sql("{modify_sql}")')
lines.append("")
# Generate SQL for dropping removed fields from existing tables
if drop_fields:
for table_name, changes in field_changes.items():
if changes.get("removed"):
for field_name, col_info in changes["removed"].items():
drop_sql = generate_drop_field_sql(table_name, field_name)
lines.append(f" # WARNING: Dropping column `{field_name}` from `{table_name}` - this will permanently delete data!")
lines.append(f' migrator.sql("{drop_sql}")')
lines.append("")
# Generate rollback
lines.append("")
lines.append("def rollback(migrator: Migrator, database: pw.Database, *, fake=False):")
lines.append(' """Write your rollback migrations here."""')
lines.append("")
# Rollback: re-add dropped fields (before other rollbacks, since they may depend on these fields)
if drop_fields:
for table_name, changes in field_changes.items():
if changes.get("removed"):
for field_name, col_info in changes["removed"].items():
add_sql, index_sql = generate_rollback_add_field_sql(table_name, col_info, field_name)
lines.append(f" # Re-add dropped column `{field_name}` to `{table_name}` (data is lost)")
lines.append(f' migrator.sql("{add_sql}")')
if index_sql:
lines.append(f' migrator.sql("{index_sql}")')
# Rollback: reverse field type changes first (before removing added fields)
for table_name, changes in field_changes.items():
if changes.get("changed"):
for field_name, (old_info, field) in changes["changed"].items():
rollback_modify_sql = generate_rollback_modify_sql(table_name, old_info, field_name)
lines.append(" # Note: Data values may need manual handling if type conversion caused data loss")
lines.append(f' migrator.sql("{rollback_modify_sql}")')
# Rollback: remove added fields using SQL
for table_name, changes in field_changes.items():
if changes.get("added"):
for field_name in changes["added"].keys():
rollback_sql = generate_rollback_field_sql(table_name, field_name)
lines.append(f' migrator.sql("{rollback_sql}")')
# Rollback: remove tables (in reverse order)
for model in reversed(new_tables):
table_name = model._meta.table_name
lines.append(f' migrator.remove_model("{table_name}")')
lines.append("")
return "\n".join(lines)
def create_migration(router: Router, models: list, db, name: str = "auto", drop_fields: bool = False):
"""Create a new migration by auto-detecting model changes
Detects:
1. New tables -> generate create_model
2. New fields in existing tables -> generate add_fields
3. Field type changes -> generate change_fields
4. Removed fields (only when --drop is specified) -> generate drop_fields
Args:
router: peewee-migrate Router instance
models: List of model classes to compare against database
db: Database connection
name: Migration name
drop_fields: Whether to include DROP COLUMN for removed fields
"""
try:
# Get existing tables from database
cursor = db.execute_sql("SELECT table_name FROM information_schema.tables WHERE table_schema = %s", (db.database,))
existing_tables = {row[0] for row in cursor.fetchall()}
new_tables = []
field_changes = {}
for model in models:
table_name = model._meta.table_name
if table_name not in existing_tables:
# New table
new_tables.append(model)
logger.info(f"New table detected: {table_name}")
else:
# Existing table - check for field changes
logger.info(f"Checking existing table: {table_name}")
# Get model fields (exclude auto-generated)
model_fields = {}
for field_name, field in model._meta.fields.items():
# Skip id and base model fields
if field_name in ("id", "create_time", "create_date", "update_time", "update_date"):
continue
if hasattr(field, "_auto_created") and field._auto_created:
continue
model_fields[field_name] = field
# Get database columns
db_columns = get_table_columns(db, table_name)
# Compare
changes = compare_fields(model_fields, db_columns)
if changes["added"] or changes["changed"] or changes["removed"]:
field_changes[table_name] = changes
# Check if any changes detected
has_removed = any(changes.get("removed") for changes in field_changes.values())
if not drop_fields and has_removed:
removed_details = []
for table_name, changes in field_changes.items():
if changes.get("removed"):
for col_name in changes["removed"]:
removed_details.append(f"{table_name}.{col_name}")
logger.warning(f"Removed fields detected (not included in migration, use --drop to include): {', '.join(removed_details)}")
# Remove 'removed' from changes since we're not acting on them
for table_name in field_changes:
field_changes[table_name]["removed"] = {}
if not new_tables and not any(changes["added"] or changes["changed"] for changes in field_changes.values()):
if not (drop_fields and has_removed):
logger.info("No schema changes detected, migration not created")
return None
# Generate migration file content
migration_content = generate_migration_content(new_tables, field_changes, router.migrate_dir, name, drop_fields=drop_fields)
# Get next migration number (count existing migration files)
existing_migrations = [f for f in os.listdir(router.migrate_dir) if f.endswith(".py") and not f.startswith("_")]
migration_num = len(existing_migrations) + 1
migration_file = os.path.join(router.migrate_dir, f"{migration_num:03d}_{name}.py")
with open(migration_file, "w") as f:
f.write(migration_content)
logger.info(f"Created migration: {migration_file}")
return migration_file
except Exception as e:
logger.error(f"Failed to create migration: {e}")
raise
def run_migrations(router: Router):
"""Run all pending migrations"""
try:
diff = router.diff
if not diff:
logger.info("No pending migrations to run")
return
router.run()
logger.info("Migrations completed successfully")
except Exception as e:
logger.error(f"Failed to run migrations: {e}")
raise
def list_migrations(router: Router):
"""List all migrations"""
todo = router.todo
if not todo:
logger.info("No migration files found")
return
logger.info("Available migrations:")
done = set(router.done)
for migration in todo:
status = "applied" if migration in done else "pending"
logger.info(f" [{status}] {migration}")
def diff_schema(models: list, db):
"""Show schema differences between models and database"""
logger.info("Checking schema differences...")
# Tables to ignore (managed by peewee-migrate)
IGNORE_TABLES = {"migratehistory"}
# Get all model table names
model_tables = set()
for model in models:
table_name = model._meta.table_name
model_tables.add(table_name)
logger.info(f"Found {len(model_tables)} model tables")
# Get existing tables from database
cursor = db.execute_sql("SELECT table_name FROM information_schema.tables WHERE table_schema = %s", (db.database,))
existing_tables = {row[0] for row in cursor.fetchall() if row[0] not in IGNORE_TABLES}
# Find tables that exist in models but not in database
missing_tables = model_tables - existing_tables
if missing_tables:
logger.warning(f"Tables not in database ({len(missing_tables)}): {', '.join(sorted(missing_tables))}")
# Find tables that exist in database but not in models
extra_tables = existing_tables - model_tables
if extra_tables:
logger.info(f"Tables in database but not in models: {', '.join(sorted(extra_tables))}")
# Check field differences for existing tables
common_tables = model_tables & existing_tables
if common_tables:
logger.info(f"\nChecking field differences for {len(common_tables)} existing tables...")
total_added = 0
total_changed = 0
total_removed = 0
for model in models:
table_name = model._meta.table_name
if table_name not in common_tables:
continue
# Get model fields
model_fields = {}
for field_name, field in model._meta.fields.items():
if field_name in ("id", "create_time", "create_date", "update_time", "update_date"):
continue
model_fields[field_name] = field
# Get database columns
db_columns = get_table_columns(db, table_name)
# Compare
changes = compare_fields(model_fields, db_columns)
if changes["added"]:
total_added += len(changes["added"])
field_details = [f"{k}:{v.__class__.__name__}" for k, v in changes["added"].items()]
logger.info(f" {table_name}: {len(changes['added'])} new field(s) - {field_details}")
if changes["changed"]:
total_changed += len(changes["changed"])
field_details = [f"{k}:{v[1].__class__.__name__}" for k, v in changes["changed"].items()]
logger.info(f" {table_name}: {len(changes['changed'])} changed field(s) - {field_details}")
if changes["removed"]:
total_removed += len(changes["removed"])
field_details = [f"{k}:{v['column_type']}" for k, v in changes["removed"].items()]
logger.warning(f" {table_name}: {len(changes['removed'])} removed field(s) - {field_details}")
logger.info(f"\nSummary: {total_added} new fields, {total_changed} changed fields, {total_removed} removed fields")
def main():
parser = argparse.ArgumentParser(
description="Database Schema Synchronization Tool using peewee-migrate",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# List all migrations
python db_schema_sync.py --list --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.26.3
# Create migration from model changes
python db_schema_sync.py --create --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.26.3
# Create migration including dropped fields (destructive!)
python db_schema_sync.py --create --drop --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.26.3
# Run all pending migrations
python db_schema_sync.py --migrate --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.26.3
# Show schema differences
python db_schema_sync.py --diff --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.26.3
""",
)
# Database connection options
parser.add_argument("--host", type=str, required=True, help="MySQL host")
parser.add_argument("--port", type=int, default=3306, help="MySQL port (default: 3306)")
parser.add_argument("--user", type=str, required=True, help="MySQL user")
parser.add_argument("--password", type=str, required=True, help="MySQL password")
parser.add_argument("--database", type=str, required=True, help="MySQL database name")
# Version option
parser.add_argument("--version", "-v", type=str, required=True, help="Version number in format vxx.xx.xx (e.g., v0.26.3)")
# Action options
parser.add_argument("--list", "-l", action="store_true", help="List all migrations")
parser.add_argument("--create", "-c", action="store_true", help="Create migration from model changes (auto-detect)")
parser.add_argument("--migrate", "-m", action="store_true", help="Run pending migrations")
parser.add_argument("--diff", "-d", action="store_true", help="Show schema differences")
# Migration options
parser.add_argument("--name", "-n", type=str, default="auto", help="Migration name")
parser.add_argument("--drop", action="store_true", help="Include DROP COLUMN for fields removed from models (destructive - will permanently delete data!)")
args = parser.parse_args()
# Validate version format
if not validate_version(args.version):
logger.error(f"Invalid version format: {args.version}. Expected format: vxx.xx.xx (e.g., v0.26.3)")
sys.exit(1)
# Validate at least one action is specified
if not any([args.list, args.create, args.migrate, args.diff]):
parser.print_help()
logger.error("Please specify at least one action: --list, --create, --migrate, or --diff")
sys.exit(1)
# Convert version to directory name
version_dir = version_to_dirname(args.version)
migrate_dir = os.path.join(PROJECT_BASE, "tools", "migrate", version_dir)
logger.info(f"Version: {args.version}")
logger.info(f"Migration directory: {migrate_dir}")
# Create migration directory if it doesn't exist
os.makedirs(migrate_dir, exist_ok=True)
# Load database models
logger.info("Loading database models from api/db/db_models.py...")
models, _ = load_db_models()
logger.info(f"Found {len(models)} model classes")
# Create database connection
db = create_database_connection(host=args.host, port=args.port, user=args.user, password=args.password, database=args.database)
try:
db.connect()
logger.info(f"Connected to database: {args.database}")
# Create router
router = Router(db, migrate_dir, ignore=["basemodel", "base_model", "migratehistory"])
# Execute requested actions
if args.list:
list_migrations(router)
if args.create:
create_migration(router, models, db, args.name, drop_fields=args.drop)
if args.migrate:
run_migrations(router)
if args.diff:
diff_schema(models, db)
finally:
if not db.is_closed():
db.close()
logger.info("Database connection closed")
logger.info("Done.")
if __name__ == "__main__":
main()