2026-04-20 11:40:01 +08:00
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Database Schema Sync Script
This script synchronizes database models defined in api / db / db_models . py
with the actual database schema using peewee - migrate .
Features :
1. Reads model definitions from api / db / db_models . py
2. Compares with existing database tables specified via command line
3. Generates migration files in tools / migrate / { version } /
"""
import argparse
import importlib . util
import inspect
import logging
import os
import re
import sys
from peewee import MySQLDatabase , Model , Field
from peewee_migrate import Router
# Add project root to path for imports
PROJECT_BASE = os . path . dirname ( os . path . dirname ( os . path . dirname ( os . path . abspath ( __file__ ) ) ) )
sys . path . insert ( 0 , PROJECT_BASE )
# Configure logging
2026-07-03 12:53:39 +08:00
logging . basicConfig ( level = logging . INFO , format = " %(asctime)s - %(levelname)s - %(message)s " )
2026-04-20 11:40:01 +08:00
logger = logging . getLogger ( __name__ )
def validate_version ( version : str ) - > bool :
""" Validate version format: vxx.xx.xx where xx are digits """
2026-07-03 12:53:39 +08:00
pattern = r " ^v \ d+ \ . \ d+ \ . \ d+$ "
2026-04-20 11:40:01 +08:00
return bool ( re . match ( pattern , version ) )
def version_to_dirname ( version : str ) - > str :
2026-07-02 20:55:15 +08:00
""" Convert version string to valid directory name (e.g., ' v0.26.3 ' -> ' v0_26_3 ' ) """
2026-07-03 12:53:39 +08:00
return version . replace ( " . " , " _ " )
2026-04-20 11:40:01 +08:00
def load_db_models ( ) :
""" Load database models from api/db/db_models.py """
2026-07-03 12:53:39 +08:00
models_path = os . path . join ( PROJECT_BASE , " api " , " db " , " db_models.py " )
2026-04-20 11:40:01 +08:00
if not os . path . exists ( models_path ) :
raise FileNotFoundError ( f " db_models.py not found at { models_path } " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Import the module
spec = importlib . util . spec_from_file_location ( " db_models " , models_path )
db_models = importlib . util . module_from_spec ( spec )
spec . loader . exec_module ( db_models )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Get all Model subclasses
models = [ ]
for name , obj in inspect . getmembers ( db_models ) :
if inspect . isclass ( obj ) and issubclass ( obj , Model ) and obj is not Model :
# Skip base model classes
2026-07-03 12:53:39 +08:00
if obj . __name__ in [ " BaseModel " , " DataBaseModel " ] :
2026-04-20 11:40:01 +08:00
continue
# Check if it has a database attribute (is a proper model)
2026-07-03 12:53:39 +08:00
if hasattr ( obj . _meta , " database " ) :
2026-04-20 11:40:01 +08:00
models . append ( obj )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
return models , db_models
def create_database_connection ( host : str , port : int , user : str , password : str , database : str ) :
""" Create MySQL database connection from command line arguments """
2026-07-03 12:53:39 +08:00
db = MySQLDatabase ( database , host = host , port = port , user = user , password = password , charset = " utf8mb4 " )
2026-04-20 11:40:01 +08:00
return db
# MySQL type to Peewee field type mapping
MYSQL_TO_PEEWEE_TYPE = {
2026-07-03 12:53:39 +08:00
" varchar " : " CharField " ,
" char " : " CharField " ,
" text " : " TextField " ,
" longtext " : " TextField " ,
" mediumtext " : " TextField " ,
" int " : " IntegerField " ,
" integer " : " IntegerField " ,
" bigint " : " BigIntegerField " ,
" float " : " FloatField " ,
" double " : " FloatField " ,
" decimal " : " FloatField " ,
" datetime " : " DateTimeField " ,
" timestamp " : " DateTimeField " ,
" tinyint(1) " : " BooleanField " ,
" tinyint " : " IntegerField " ,
" smallint " : " IntegerField " ,
" mediumint " : " IntegerField " ,
2026-04-20 11:40:01 +08:00
}
PEEWEE_TO_MYSQL_TYPE = {
2026-07-03 12:53:39 +08:00
" CharField " : " varchar " ,
" TextField " : " text " ,
" IntegerField " : " int " ,
" BigIntegerField " : " bigint " ,
" FloatField " : " float " ,
" BooleanField " : " tinyint " ,
" DateTimeField " : " datetime " ,
2026-04-20 11:40:01 +08:00
}
def get_table_columns ( db , table_name : str ) - > dict :
""" Get column information from database table
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
Returns :
dict : { column_name : { type , nullable , default , . . . } }
"""
2026-07-03 12:53:39 +08:00
cursor = db . execute_sql (
"""
SELECT
2026-04-20 11:40:01 +08:00
column_name ,
data_type ,
column_type ,
is_nullable ,
column_default ,
column_key ,
extra
FROM information_schema . columns
WHERE table_schema = % s AND table_name = % s
ORDER BY ordinal_position
2026-07-03 12:53:39 +08:00
""" ,
( db . database , table_name ) ,
)
2026-04-20 11:40:01 +08:00
columns = { }
for row in cursor . fetchall ( ) :
col_name = row [ 0 ]
data_type = row [ 1 ] . lower ( )
column_type = row [ 2 ] . lower ( )
2026-07-03 12:53:39 +08:00
is_nullable = row [ 3 ] == " YES "
2026-04-20 11:40:01 +08:00
column_default = row [ 4 ]
column_key = row [ 5 ]
2026-07-03 12:53:39 +08:00
extra = row [ 6 ] or " "
2026-04-20 11:40:01 +08:00
# Determine peewee type
2026-07-03 12:53:39 +08:00
if column_type . startswith ( " tinyint(1) " ) :
peewee_type = " BooleanField "
2026-04-20 11:40:01 +08:00
else :
2026-07-03 12:53:39 +08:00
peewee_type = MYSQL_TO_PEEWEE_TYPE . get ( data_type , " TextField " )
2026-04-20 11:40:01 +08:00
columns [ col_name ] = {
2026-07-03 12:53:39 +08:00
" data_type " : data_type ,
" column_type " : column_type ,
" peewee_type " : peewee_type ,
" nullable " : is_nullable ,
" default " : column_default ,
" is_primary " : column_key == " PRI " ,
" extra " : extra ,
2026-04-20 11:40:01 +08:00
}
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
return columns
def get_peewee_field_type ( field : Field ) - > str :
""" Get peewee field type name """
field_class = field . __class__ . __name__
return field_class
def get_base_field_type ( field : Field ) - > str :
""" Get base peewee field type by walking the MRO chain.
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
Custom field types ( like DateTimeTzField , JSONField ) inherit from standard types .
This function returns the underlying standard type for comparison .
"""
# Standard peewee field types we consider as "base" types
STANDARD_TYPES = {
2026-07-03 12:53:39 +08:00
" CharField " ,
" TextField " ,
" IntegerField " ,
" BigIntegerField " ,
" FloatField " ,
" BooleanField " ,
" DateTimeField " ,
" DateField " ,
" TimeField " ,
" DecimalField " ,
" ForeignKeyField " ,
" ManyToManyField " ,
" PrimaryKeyField " ,
" AutoField " ,
2026-04-20 11:40:01 +08:00
}
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Walk through the MRO (Method Resolution Order) to find standard type
for cls in field . __class__ . __mro__ :
class_name = cls . __name__
if class_name in STANDARD_TYPES :
return class_name
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Fallback to TextField if no standard type found
2026-07-03 12:53:39 +08:00
return " TextField "
2026-04-20 11:40:01 +08:00
def normalize_field_type ( field : Field ) - > str :
""" Normalize field type for comparison using base type """
return get_base_field_type ( field )
def compare_fields ( model_fields : dict , db_columns : dict ) - > dict :
""" Compare model fields with database columns
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
Returns :
dict : {
' added ' : { field_name : field_obj } , # New fields not in DB
' changed ' : { field_name : ( old_info , new_field ) } , # Type changed
' removed ' : { field_name : col_info } , # Fields in DB but not in model
}
"""
result = {
2026-07-03 12:53:39 +08:00
" added " : { } ,
" changed " : { } ,
" removed " : { } ,
2026-04-20 11:40:01 +08:00
}
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Skip auto-generated fields like id, create_time, etc.
2026-07-03 12:53:39 +08:00
skip_fields = { " id " }
2026-04-20 11:40:01 +08:00
for field_name , field in model_fields . items ( ) :
if field_name in skip_fields :
continue
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Check if field exists in database
if field_name not in db_columns :
2026-07-03 12:53:39 +08:00
result [ " added " ] [ field_name ] = field
2026-04-20 11:40:01 +08:00
logger . info ( f " New field detected: { field_name } ( { field . __class__ . __name__ } ) " )
else :
# Check if type changed
db_col = db_columns [ field_name ]
model_base_type = normalize_field_type ( field )
2026-07-03 12:53:39 +08:00
db_type = db_col [ " peewee_type " ]
2026-04-20 11:40:01 +08:00
# Type mismatch
if model_base_type != db_type :
2026-07-03 12:53:39 +08:00
result [ " changed " ] [ field_name ] = ( db_col , field )
2026-04-20 11:40:01 +08:00
logger . info ( f " Field type changed: { field_name } ( { db_type } -> { model_base_type } , actual: { field . __class__ . __name__ } ) " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Detect removed fields: columns in DB but not in model
for col_name , col_info in db_columns . items ( ) :
if col_name in skip_fields :
continue
if col_name not in model_fields :
2026-07-03 12:53:39 +08:00
result [ " removed " ] [ col_name ] = col_info
2026-04-20 11:40:01 +08:00
logger . info ( f " Removed field detected: { col_name } ( { col_info [ ' column_type ' ] } ) " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
return result
def generate_field_code ( field : Field , field_name : str ) - > str :
""" Generate peewee field definition code """
field_class = field . __class__ . __name__
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Map custom field types to standard peewee types for migration
# These custom types will be stored as their underlying standard type
custom_to_standard = {
2026-07-03 12:53:39 +08:00
" LongTextField " : " TextField " ,
" JSONField " : " TextField " ,
" ListField " : " TextField " ,
" SerializedField " : " TextField " ,
" DateTimeTzField " : " CharField " ,
2026-04-20 11:40:01 +08:00
}
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Use standard type for custom fields
pw_field_class = custom_to_standard . get ( field_class , field_class )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Build field arguments
args = [ ]
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# max_length for CharField
2026-07-03 12:53:39 +08:00
if pw_field_class == " CharField " and hasattr ( field , " max_length " ) and field . max_length is not None :
2026-04-20 11:40:01 +08:00
args . append ( f " max_length= { field . max_length } " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# null
if field . null :
args . append ( " null=True " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# default
if field . default is not None :
default_val = field . default
if isinstance ( default_val , str ) :
# Escape quotes in string
escaped = default_val . replace ( " ' " , " \\ ' " )
args . append ( f " default= ' { escaped } ' " )
elif isinstance ( default_val , bool ) :
args . append ( f " default= { ' True ' if default_val else ' False ' } " )
elif isinstance ( default_val , ( int , float ) ) :
args . append ( f " default= { default_val } " )
elif isinstance ( default_val , dict ) :
args . append ( f " default= { default_val } " )
elif isinstance ( default_val , list ) :
args . append ( f " default= { default_val } " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# index
2026-07-03 12:53:39 +08:00
if getattr ( field , " index " , False ) :
2026-04-20 11:40:01 +08:00
args . append ( " index=True " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# unique
2026-07-03 12:53:39 +08:00
if getattr ( field , " unique " , False ) :
2026-04-20 11:40:01 +08:00
args . append ( " unique=True " )
2026-07-03 12:53:39 +08:00
args_str = " , " . join ( args )
2026-04-20 11:40:01 +08:00
return f " pw. { pw_field_class } ( { args_str } ) "
def generate_add_field_sql ( table_name : str , field : Field , field_name : str ) - > str :
""" Generate raw SQL for adding a field to MySQL table.
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
This is used for existing tables where migrator . add_fields doesn ' t work
because the model is not registered in migrator . orm .
"""
field_class = field . __class__ . __name__
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Determine MySQL column type
mysql_type_map = {
2026-07-03 12:53:39 +08:00
" CharField " : f " VARCHAR( { field . max_length } ) " if hasattr ( field , " max_length " ) and field . max_length else " VARCHAR(255) " ,
" TextField " : " LONGTEXT " ,
" LongTextField " : " LONGTEXT " ,
" JSONField " : " LONGTEXT " ,
" ListField " : " LONGTEXT " ,
" SerializedField " : " LONGTEXT " ,
" IntegerField " : " INT " ,
" BigIntegerField " : " BIGINT " ,
" FloatField " : " DOUBLE " ,
" BooleanField " : " TINYINT(1) " ,
" DateTimeField " : " DATETIME " ,
" DateTimeTzField " : f " VARCHAR( { field . max_length } ) " if hasattr ( field , " max_length " ) and field . max_length else " VARCHAR(255) " ,
2026-04-20 11:40:01 +08:00
}
2026-07-03 12:53:39 +08:00
mysql_type = mysql_type_map . get ( field_class , " LONGTEXT " )
2026-04-20 11:40:01 +08:00
# Build column definition
2026-07-03 12:53:39 +08:00
parts = [ f " ` { field_name } ` " , mysql_type ]
2026-04-20 11:40:01 +08:00
# NULL/NOT NULL
if field . null :
2026-07-03 12:53:39 +08:00
parts . append ( " NULL " )
2026-04-20 11:40:01 +08:00
else :
2026-07-03 12:53:39 +08:00
parts . append ( " NOT NULL " )
2026-04-20 11:40:01 +08:00
# DEFAULT
if field . default is not None :
default_val = field . default
if isinstance ( default_val , str ) :
escaped = default_val . replace ( " ' " , " ' ' " )
parts . append ( f " DEFAULT ' { escaped } ' " )
elif isinstance ( default_val , bool ) :
parts . append ( f " DEFAULT { 1 if default_val else 0 } " )
elif isinstance ( default_val , ( int , float ) ) :
parts . append ( f " DEFAULT { default_val } " )
elif isinstance ( default_val , dict ) or isinstance ( default_val , list ) :
import json
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
escaped = json . dumps ( default_val ) . replace ( " ' " , " ' ' " )
parts . append ( f " DEFAULT ' { escaped } ' " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# COMMENT
2026-07-03 12:53:39 +08:00
if hasattr ( field , " help_text " ) and field . help_text :
2026-04-20 11:40:01 +08:00
escaped = field . help_text . replace ( " ' " , " ' ' " )
parts . append ( f " COMMENT ' { escaped } ' " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
sql = f " ALTER TABLE ` { table_name } ` ADD COLUMN { ' ' . join ( parts ) } "
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Add index if needed
index_sql = None
2026-07-03 12:53:39 +08:00
if getattr ( field , " index " , False ) :
2026-04-20 11:40:01 +08:00
index_sql = f " CREATE INDEX `idx_ { table_name } _ { field_name } ` ON ` { table_name } ` (` { field_name } `) "
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
return sql , index_sql
def generate_drop_field_sql ( table_name : str , field_name : str ) - > str :
""" Generate SQL for dropping a field from a table. """
return f " ALTER TABLE ` { table_name } ` DROP COLUMN ` { field_name } ` "
def generate_rollback_field_sql ( table_name : str , field_name : str ) - > str :
""" Generate SQL for removing a field. """
return f " ALTER TABLE ` { table_name } ` DROP COLUMN ` { field_name } ` "
def generate_rollback_add_field_sql ( table_name : str , col_info : dict , field_name : str ) - > str :
""" Generate SQL for rolling back a dropped field (re-adding it).
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
This reconstructs the ADD COLUMN statement from the column info
that was captured before the field was dropped .
"""
2026-07-03 12:53:39 +08:00
mysql_type = col_info . get ( " column_type " , " LONGTEXT " )
parts = [ f " ` { field_name } ` " , mysql_type ]
2026-04-20 11:40:01 +08:00
# NULL/NOT NULL
2026-07-03 12:53:39 +08:00
if col_info . get ( " nullable " , True ) :
parts . append ( " NULL " )
2026-04-20 11:40:01 +08:00
else :
2026-07-03 12:53:39 +08:00
parts . append ( " NOT NULL " )
2026-04-20 11:40:01 +08:00
# DEFAULT
2026-07-03 12:53:39 +08:00
default_val = col_info . get ( " default " )
2026-04-20 11:40:01 +08:00
if default_val is not None :
if isinstance ( default_val , str ) :
escaped = default_val . replace ( " ' " , " ' ' " )
parts . append ( f " DEFAULT ' { escaped } ' " )
elif isinstance ( default_val , bool ) :
parts . append ( f " DEFAULT { 1 if default_val else 0 } " )
elif isinstance ( default_val , ( int , float ) ) :
parts . append ( f " DEFAULT { default_val } " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
sql = f " ALTER TABLE ` { table_name } ` ADD COLUMN { ' ' . join ( parts ) } "
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Re-add index if it was a non-primary key
index_sql = None
2026-07-03 12:53:39 +08:00
if col_info . get ( " column_key " ) == " MUL " :
2026-04-20 11:40:01 +08:00
index_sql = f " CREATE INDEX `idx_ { table_name } _ { field_name } ` ON ` { table_name } ` (` { field_name } `) "
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
return sql , index_sql
def generate_rollback_modify_sql ( table_name : str , old_info : dict , field_name : str ) - > str :
""" Generate SQL for rolling back a field type change.
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
Note : This restores the column type , but data values may need manual handling
if the type conversion caused data loss or transformation .
"""
# Reconstruct MySQL type from old_info
2026-07-03 12:53:39 +08:00
mysql_type = old_info . get ( " column_type " , " LONGTEXT " )
2026-04-20 11:40:01 +08:00
# Build column definition
2026-07-03 12:53:39 +08:00
parts = [ f " ` { field_name } ` " , mysql_type ]
2026-04-20 11:40:01 +08:00
# NULL/NOT NULL
2026-07-03 12:53:39 +08:00
if old_info . get ( " nullable " , True ) :
parts . append ( " NULL " )
2026-04-20 11:40:01 +08:00
else :
2026-07-03 12:53:39 +08:00
parts . append ( " NOT NULL " )
2026-04-20 11:40:01 +08:00
# DEFAULT (if available)
2026-07-03 12:53:39 +08:00
if old_info . get ( " default " ) is not None :
default_val = old_info [ " default " ]
2026-04-20 11:40:01 +08:00
if isinstance ( default_val , str ) :
escaped = default_val . replace ( " ' " , " ' ' " )
parts . append ( f " DEFAULT ' { escaped } ' " )
elif isinstance ( default_val , bool ) :
parts . append ( f " DEFAULT { 1 if default_val else 0 } " )
elif isinstance ( default_val , ( int , float ) ) :
parts . append ( f " DEFAULT { default_val } " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
return f " ALTER TABLE ` { table_name } ` MODIFY COLUMN { ' ' . join ( parts ) } "
def generate_modify_field_sql ( table_name : str , field : Field , field_name : str ) - > str :
""" Generate SQL for modifying a field in MySQL table. """
field_class = field . __class__ . __name__
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Determine MySQL column type
mysql_type_map = {
2026-07-03 12:53:39 +08:00
" CharField " : f " VARCHAR( { field . max_length } ) " if hasattr ( field , " max_length " ) and field . max_length else " VARCHAR(255) " ,
" TextField " : " LONGTEXT " ,
" LongTextField " : " LONGTEXT " ,
" JSONField " : " LONGTEXT " ,
" ListField " : " LONGTEXT " ,
" SerializedField " : " LONGTEXT " ,
" IntegerField " : " INT " ,
" BigIntegerField " : " BIGINT " ,
" FloatField " : " DOUBLE " ,
" BooleanField " : " TINYINT(1) " ,
" DateTimeField " : " DATETIME " ,
" DateTimeTzField " : f " VARCHAR( { field . max_length } ) " if hasattr ( field , " max_length " ) and field . max_length else " VARCHAR(255) " ,
2026-04-20 11:40:01 +08:00
}
2026-07-03 12:53:39 +08:00
mysql_type = mysql_type_map . get ( field_class , " LONGTEXT " )
2026-04-20 11:40:01 +08:00
# Build column definition
2026-07-03 12:53:39 +08:00
parts = [ f " ` { field_name } ` " , mysql_type ]
2026-04-20 11:40:01 +08:00
# NULL/NOT NULL
if field . null :
2026-07-03 12:53:39 +08:00
parts . append ( " NULL " )
2026-04-20 11:40:01 +08:00
else :
2026-07-03 12:53:39 +08:00
parts . append ( " NOT NULL " )
2026-04-20 11:40:01 +08:00
# DEFAULT
if field . default is not None :
default_val = field . default
if isinstance ( default_val , str ) :
escaped = default_val . replace ( " ' " , " ' ' " )
parts . append ( f " DEFAULT ' { escaped } ' " )
elif isinstance ( default_val , bool ) :
parts . append ( f " DEFAULT { 1 if default_val else 0 } " )
elif isinstance ( default_val , ( int , float ) ) :
parts . append ( f " DEFAULT { default_val } " )
elif isinstance ( default_val , dict ) or isinstance ( default_val , list ) :
import json
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
escaped = json . dumps ( default_val ) . replace ( " ' " , " ' ' " )
parts . append ( f " DEFAULT ' { escaped } ' " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# COMMENT
2026-07-03 12:53:39 +08:00
if hasattr ( field , " help_text " ) and field . help_text :
2026-04-20 11:40:01 +08:00
escaped = field . help_text . replace ( " ' " , " ' ' " )
parts . append ( f " COMMENT ' { escaped } ' " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
return f " ALTER TABLE ` { table_name } ` MODIFY COLUMN { ' ' . join ( parts ) } "
def generate_migration_content ( new_tables : list , field_changes : dict , migrate_dir : str , migration_name : str , drop_fields : bool = False ) - > str :
""" Generate migration file content """
lines = [
' " " " Peewee migrations. " " " ' ,
2026-07-03 12:53:39 +08:00
" " ,
" from contextlib import suppress " ,
" " ,
" import peewee as pw " ,
" from peewee_migrate import Migrator " ,
" " ,
" " ,
" with suppress(ImportError): " ,
" import playhouse.postgres_ext as pw_pext " ,
" " ,
" " ,
" def migrate(migrator: Migrator, database: pw.Database, *, fake=False): " ,
2026-04-20 11:40:01 +08:00
' " " " Write your migrations here. " " " ' ,
2026-07-03 12:53:39 +08:00
" " ,
2026-04-20 11:40:01 +08:00
]
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Generate create_model for new tables
for model in new_tables :
table_name = model . _meta . table_name
model_name = model . __name__
2026-07-03 12:53:39 +08:00
lines . append ( " @migrator.create_model " )
lines . append ( f " class { model_name } (pw.Model): " )
2026-04-20 11:40:01 +08:00
# Get all fields
fields = model . _meta . fields
for field_name , field in fields . items ( ) :
field_code = generate_field_code ( field , field_name )
2026-07-03 12:53:39 +08:00
lines . append ( f " { field_name } = { field_code } " )
lines . append ( " " )
lines . append ( " class Meta: " )
2026-04-20 11:40:01 +08:00
lines . append ( f ' table_name = " { table_name } " ' )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Add indexes if defined
2026-07-03 12:53:39 +08:00
indexes = getattr ( model . _meta , " indexes " , None )
2026-04-20 11:40:01 +08:00
if indexes :
2026-07-03 12:53:39 +08:00
lines . append ( f " indexes = { indexes } " )
lines . append ( " " )
2026-04-20 11:40:01 +08:00
# Generate SQL for adding new fields to existing tables
for table_name , changes in field_changes . items ( ) :
2026-07-03 12:53:39 +08:00
if changes . get ( " added " ) :
for field_name , field in changes [ " added " ] . items ( ) :
2026-04-20 11:40:01 +08:00
sql , index_sql = generate_add_field_sql ( table_name , field , field_name )
lines . append ( f ' migrator.sql( " { sql } " ) ' )
if index_sql :
lines . append ( f ' migrator.sql( " { index_sql } " ) ' )
2026-07-03 12:53:39 +08:00
lines . append ( " " )
2026-04-20 11:40:01 +08:00
# Generate SQL for modifying fields in existing tables
for table_name , changes in field_changes . items ( ) :
2026-07-03 12:53:39 +08:00
if changes . get ( " changed " ) :
for field_name , ( old_info , field ) in changes [ " changed " ] . items ( ) :
2026-04-20 11:40:01 +08:00
modify_sql = generate_modify_field_sql ( table_name , field , field_name )
lines . append ( f ' migrator.sql( " { modify_sql } " ) ' )
2026-07-03 12:53:39 +08:00
lines . append ( " " )
2026-04-20 11:40:01 +08:00
# Generate SQL for dropping removed fields from existing tables
if drop_fields :
for table_name , changes in field_changes . items ( ) :
2026-07-03 12:53:39 +08:00
if changes . get ( " removed " ) :
for field_name , col_info in changes [ " removed " ] . items ( ) :
2026-04-20 11:40:01 +08:00
drop_sql = generate_drop_field_sql ( table_name , field_name )
2026-07-03 12:53:39 +08:00
lines . append ( f " # WARNING: Dropping column ` { field_name } ` from ` { table_name } ` - this will permanently delete data! " )
2026-04-20 11:40:01 +08:00
lines . append ( f ' migrator.sql( " { drop_sql } " ) ' )
2026-07-03 12:53:39 +08:00
lines . append ( " " )
2026-04-20 11:40:01 +08:00
# Generate rollback
2026-07-03 12:53:39 +08:00
lines . append ( " " )
lines . append ( " def rollback(migrator: Migrator, database: pw.Database, *, fake=False): " )
2026-04-20 11:40:01 +08:00
lines . append ( ' " " " Write your rollback migrations here. " " " ' )
2026-07-03 12:53:39 +08:00
lines . append ( " " )
2026-04-20 11:40:01 +08:00
# Rollback: re-add dropped fields (before other rollbacks, since they may depend on these fields)
if drop_fields :
for table_name , changes in field_changes . items ( ) :
2026-07-03 12:53:39 +08:00
if changes . get ( " removed " ) :
for field_name , col_info in changes [ " removed " ] . items ( ) :
2026-04-20 11:40:01 +08:00
add_sql , index_sql = generate_rollback_add_field_sql ( table_name , col_info , field_name )
2026-07-03 12:53:39 +08:00
lines . append ( f " # Re-add dropped column ` { field_name } ` to ` { table_name } ` (data is lost) " )
2026-04-20 11:40:01 +08:00
lines . append ( f ' migrator.sql( " { add_sql } " ) ' )
if index_sql :
lines . append ( f ' migrator.sql( " { index_sql } " ) ' )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Rollback: reverse field type changes first (before removing added fields)
for table_name , changes in field_changes . items ( ) :
2026-07-03 12:53:39 +08:00
if changes . get ( " changed " ) :
for field_name , ( old_info , field ) in changes [ " changed " ] . items ( ) :
2026-04-20 11:40:01 +08:00
rollback_modify_sql = generate_rollback_modify_sql ( table_name , old_info , field_name )
2026-07-03 12:53:39 +08:00
lines . append ( " # Note: Data values may need manual handling if type conversion caused data loss " )
2026-04-20 11:40:01 +08:00
lines . append ( f ' migrator.sql( " { rollback_modify_sql } " ) ' )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Rollback: remove added fields using SQL
for table_name , changes in field_changes . items ( ) :
2026-07-03 12:53:39 +08:00
if changes . get ( " added " ) :
for field_name in changes [ " added " ] . keys ( ) :
2026-04-20 11:40:01 +08:00
rollback_sql = generate_rollback_field_sql ( table_name , field_name )
lines . append ( f ' migrator.sql( " { rollback_sql } " ) ' )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Rollback: remove tables (in reverse order)
for model in reversed ( new_tables ) :
table_name = model . _meta . table_name
lines . append ( f ' migrator.remove_model( " { table_name } " ) ' )
2026-07-03 12:53:39 +08:00
lines . append ( " " )
return " \n " . join ( lines )
2026-04-20 11:40:01 +08:00
def create_migration ( router : Router , models : list , db , name : str = " auto " , drop_fields : bool = False ) :
""" Create a new migration by auto-detecting model changes
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
Detects :
1. New tables - > generate create_model
2. New fields in existing tables - > generate add_fields
3. Field type changes - > generate change_fields
4. Removed fields ( only when - - drop is specified ) - > generate drop_fields
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
Args :
router : peewee - migrate Router instance
models : List of model classes to compare against database
db : Database connection
name : Migration name
drop_fields : Whether to include DROP COLUMN for removed fields
"""
try :
# Get existing tables from database
2026-07-03 12:53:39 +08:00
cursor = db . execute_sql ( " SELECT table_name FROM information_schema.tables WHERE table_schema = %s " , ( db . database , ) )
2026-04-20 11:40:01 +08:00
existing_tables = { row [ 0 ] for row in cursor . fetchall ( ) }
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
new_tables = [ ]
field_changes = { }
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
for model in models :
table_name = model . _meta . table_name
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
if table_name not in existing_tables :
# New table
new_tables . append ( model )
logger . info ( f " New table detected: { table_name } " )
else :
# Existing table - check for field changes
logger . info ( f " Checking existing table: { table_name } " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Get model fields (exclude auto-generated)
model_fields = { }
for field_name , field in model . _meta . fields . items ( ) :
# Skip id and base model fields
2026-07-03 12:53:39 +08:00
if field_name in ( " id " , " create_time " , " create_date " , " update_time " , " update_date " ) :
2026-04-20 11:40:01 +08:00
continue
2026-07-03 12:53:39 +08:00
if hasattr ( field , " _auto_created " ) and field . _auto_created :
2026-04-20 11:40:01 +08:00
continue
model_fields [ field_name ] = field
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Get database columns
db_columns = get_table_columns ( db , table_name )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Compare
changes = compare_fields ( model_fields , db_columns )
2026-07-03 12:53:39 +08:00
if changes [ " added " ] or changes [ " changed " ] or changes [ " removed " ] :
2026-04-20 11:40:01 +08:00
field_changes [ table_name ] = changes
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Check if any changes detected
2026-07-03 12:53:39 +08:00
has_removed = any ( changes . get ( " removed " ) for changes in field_changes . values ( ) )
2026-04-20 11:40:01 +08:00
if not drop_fields and has_removed :
removed_details = [ ]
for table_name , changes in field_changes . items ( ) :
2026-07-03 12:53:39 +08:00
if changes . get ( " removed " ) :
for col_name in changes [ " removed " ] :
2026-04-20 11:40:01 +08:00
removed_details . append ( f " { table_name } . { col_name } " )
logger . warning ( f " Removed fields detected (not included in migration, use --drop to include): { ' , ' . join ( removed_details ) } " )
# Remove 'removed' from changes since we're not acting on them
for table_name in field_changes :
2026-07-03 12:53:39 +08:00
field_changes [ table_name ] [ " removed " ] = { }
if not new_tables and not any ( changes [ " added " ] or changes [ " changed " ] for changes in field_changes . values ( ) ) :
2026-04-20 11:40:01 +08:00
if not ( drop_fields and has_removed ) :
logger . info ( " No schema changes detected, migration not created " )
return None
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Generate migration file content
migration_content = generate_migration_content ( new_tables , field_changes , router . migrate_dir , name , drop_fields = drop_fields )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Get next migration number (count existing migration files)
2026-07-03 12:53:39 +08:00
existing_migrations = [ f for f in os . listdir ( router . migrate_dir ) if f . endswith ( " .py " ) and not f . startswith ( " _ " ) ]
2026-04-20 11:40:01 +08:00
migration_num = len ( existing_migrations ) + 1
2026-07-03 12:53:39 +08:00
migration_file = os . path . join ( router . migrate_dir , f " { migration_num : 03d } _ { name } .py " )
with open ( migration_file , " w " ) as f :
2026-04-20 11:40:01 +08:00
f . write ( migration_content )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
logger . info ( f " Created migration: { migration_file } " )
return migration_file
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
except Exception as e :
logger . error ( f " Failed to create migration: { e } " )
raise
def run_migrations ( router : Router ) :
""" Run all pending migrations """
try :
diff = router . diff
if not diff :
logger . info ( " No pending migrations to run " )
return
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
router . run ( )
logger . info ( " Migrations completed successfully " )
except Exception as e :
logger . error ( f " Failed to run migrations: { e } " )
raise
def list_migrations ( router : Router ) :
""" List all migrations """
todo = router . todo
if not todo :
logger . info ( " No migration files found " )
return
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
logger . info ( " Available migrations: " )
done = set ( router . done )
for migration in todo :
status = " applied " if migration in done else " pending "
logger . info ( f " [ { status } ] { migration } " )
def diff_schema ( models : list , db ) :
""" Show schema differences between models and database """
logger . info ( " Checking schema differences... " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Tables to ignore (managed by peewee-migrate)
2026-07-03 12:53:39 +08:00
IGNORE_TABLES = { " migratehistory " }
2026-04-20 11:40:01 +08:00
# Get all model table names
model_tables = set ( )
for model in models :
table_name = model . _meta . table_name
model_tables . add ( table_name )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
logger . info ( f " Found { len ( model_tables ) } model tables " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Get existing tables from database
2026-07-03 12:53:39 +08:00
cursor = db . execute_sql ( " SELECT table_name FROM information_schema.tables WHERE table_schema = %s " , ( db . database , ) )
2026-04-20 11:40:01 +08:00
existing_tables = { row [ 0 ] for row in cursor . fetchall ( ) if row [ 0 ] not in IGNORE_TABLES }
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Find tables that exist in models but not in database
missing_tables = model_tables - existing_tables
if missing_tables :
logger . warning ( f " Tables not in database ( { len ( missing_tables ) } ): { ' , ' . join ( sorted ( missing_tables ) ) } " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Find tables that exist in database but not in models
extra_tables = existing_tables - model_tables
if extra_tables :
logger . info ( f " Tables in database but not in models: { ' , ' . join ( sorted ( extra_tables ) ) } " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Check field differences for existing tables
common_tables = model_tables & existing_tables
if common_tables :
logger . info ( f " \n Checking field differences for { len ( common_tables ) } existing tables... " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
total_added = 0
total_changed = 0
total_removed = 0
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
for model in models :
table_name = model . _meta . table_name
if table_name not in common_tables :
continue
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Get model fields
model_fields = { }
for field_name , field in model . _meta . fields . items ( ) :
2026-07-03 12:53:39 +08:00
if field_name in ( " id " , " create_time " , " create_date " , " update_time " , " update_date " ) :
2026-04-20 11:40:01 +08:00
continue
model_fields [ field_name ] = field
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Get database columns
db_columns = get_table_columns ( db , table_name )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Compare
changes = compare_fields ( model_fields , db_columns )
2026-07-03 12:53:39 +08:00
if changes [ " added " ] :
total_added + = len ( changes [ " added " ] )
field_details = [ f " { k } : { v . __class__ . __name__ } " for k , v in changes [ " added " ] . items ( ) ]
2026-04-20 11:40:01 +08:00
logger . info ( f " { table_name } : { len ( changes [ ' added ' ] ) } new field(s) - { field_details } " )
2026-07-03 12:53:39 +08:00
if changes [ " changed " ] :
total_changed + = len ( changes [ " changed " ] )
field_details = [ f " { k } : { v [ 1 ] . __class__ . __name__ } " for k , v in changes [ " changed " ] . items ( ) ]
2026-04-20 11:40:01 +08:00
logger . info ( f " { table_name } : { len ( changes [ ' changed ' ] ) } changed field(s) - { field_details } " )
2026-07-03 12:53:39 +08:00
if changes [ " removed " ] :
total_removed + = len ( changes [ " removed " ] )
field_details = [ f " { k } : { v [ ' column_type ' ] } " for k , v in changes [ " removed " ] . items ( ) ]
2026-04-20 11:40:01 +08:00
logger . warning ( f " { table_name } : { len ( changes [ ' removed ' ] ) } removed field(s) - { field_details } " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
logger . info ( f " \n Summary: { total_added } new fields, { total_changed } changed fields, { total_removed } removed fields " )
def main ( ) :
parser = argparse . ArgumentParser (
2026-07-03 12:53:39 +08:00
description = " Database Schema Synchronization Tool using peewee-migrate " ,
2026-04-20 11:40:01 +08:00
formatter_class = argparse . RawDescriptionHelpFormatter ,
epilog = """
Examples :
# List all migrations
2026-07-02 20:55:15 +08:00
python db_schema_sync . py - - list - - host localhost - - port 3306 - - user root - - password xxx - - database rag_flow - - version v0 .26 .3
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Create migration from model changes
2026-07-02 20:55:15 +08:00
python db_schema_sync . py - - create - - host localhost - - port 3306 - - user root - - password xxx - - database rag_flow - - version v0 .26 .3
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Create migration including dropped fields (destructive!)
2026-07-02 20:55:15 +08:00
python db_schema_sync . py - - create - - drop - - host localhost - - port 3306 - - user root - - password xxx - - database rag_flow - - version v0 .26 .3
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Run all pending migrations
2026-07-02 20:55:15 +08:00
python db_schema_sync . py - - migrate - - host localhost - - port 3306 - - user root - - password xxx - - database rag_flow - - version v0 .26 .3
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Show schema differences
2026-07-02 20:55:15 +08:00
python db_schema_sync . py - - diff - - host localhost - - port 3306 - - user root - - password xxx - - database rag_flow - - version v0 .26 .3
2026-07-03 12:53:39 +08:00
""" ,
2026-04-20 11:40:01 +08:00
)
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Database connection options
2026-07-03 12:53:39 +08:00
parser . add_argument ( " --host " , type = str , required = True , help = " MySQL host " )
parser . add_argument ( " --port " , type = int , default = 3306 , help = " MySQL port (default: 3306) " )
parser . add_argument ( " --user " , type = str , required = True , help = " MySQL user " )
parser . add_argument ( " --password " , type = str , required = True , help = " MySQL password " )
parser . add_argument ( " --database " , type = str , required = True , help = " MySQL database name " )
2026-04-20 11:40:01 +08:00
# Version option
2026-07-03 12:53:39 +08:00
parser . add_argument ( " --version " , " -v " , type = str , required = True , help = " Version number in format vxx.xx.xx (e.g., v0.26.3) " )
2026-04-20 11:40:01 +08:00
# Action options
2026-07-03 12:53:39 +08:00
parser . add_argument ( " --list " , " -l " , action = " store_true " , help = " List all migrations " )
parser . add_argument ( " --create " , " -c " , action = " store_true " , help = " Create migration from model changes (auto-detect) " )
parser . add_argument ( " --migrate " , " -m " , action = " store_true " , help = " Run pending migrations " )
parser . add_argument ( " --diff " , " -d " , action = " store_true " , help = " Show schema differences " )
2026-04-20 11:40:01 +08:00
# Migration options
2026-07-03 12:53:39 +08:00
parser . add_argument ( " --name " , " -n " , type = str , default = " auto " , help = " Migration name " )
parser . add_argument ( " --drop " , action = " store_true " , help = " Include DROP COLUMN for fields removed from models (destructive - will permanently delete data!) " )
2026-04-20 11:40:01 +08:00
args = parser . parse_args ( )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Validate version format
if not validate_version ( args . version ) :
2026-07-02 20:55:15 +08:00
logger . error ( f " Invalid version format: { args . version } . Expected format: vxx.xx.xx (e.g., v0.26.3) " )
2026-04-20 11:40:01 +08:00
sys . exit ( 1 )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Validate at least one action is specified
if not any ( [ args . list , args . create , args . migrate , args . diff ] ) :
parser . print_help ( )
logger . error ( " Please specify at least one action: --list, --create, --migrate, or --diff " )
sys . exit ( 1 )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Convert version to directory name
version_dir = version_to_dirname ( args . version )
2026-07-03 12:53:39 +08:00
migrate_dir = os . path . join ( PROJECT_BASE , " tools " , " migrate " , version_dir )
2026-04-20 11:40:01 +08:00
logger . info ( f " Version: { args . version } " )
logger . info ( f " Migration directory: { migrate_dir } " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Create migration directory if it doesn't exist
os . makedirs ( migrate_dir , exist_ok = True )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Load database models
logger . info ( " Loading database models from api/db/db_models.py... " )
models , _ = load_db_models ( )
logger . info ( f " Found { len ( models ) } model classes " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Create database connection
2026-07-03 12:53:39 +08:00
db = create_database_connection ( host = args . host , port = args . port , user = args . user , password = args . password , database = args . database )
2026-04-20 11:40:01 +08:00
try :
db . connect ( )
logger . info ( f " Connected to database: { args . database } " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
# Create router
2026-07-03 12:53:39 +08:00
router = Router ( db , migrate_dir , ignore = [ " basemodel " , " base_model " , " migratehistory " ] )
2026-04-20 11:40:01 +08:00
# Execute requested actions
if args . list :
list_migrations ( router )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
if args . create :
create_migration ( router , models , db , args . name , drop_fields = args . drop )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
if args . migrate :
run_migrations ( router )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
if args . diff :
diff_schema ( models , db )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
finally :
if not db . is_closed ( ) :
db . close ( )
logger . info ( " Database connection closed " )
2026-07-03 12:53:39 +08:00
2026-04-20 11:40:01 +08:00
logger . info ( " Done. " )
2026-07-03 12:53:39 +08:00
if __name__ == " __main__ " :
main ( )