train_corrector.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import pandas as pd
  2. import joblib
  3. from apscheduler.schedulers.blocking import BlockingScheduler
  4. from sklearn.ensemble import RandomForestRegressor
  5. from sklearn.model_selection import train_test_split
  6. # === 配置路径 ===
  7. csv_path = 'C:\\Users\\Administrator\\Desktop\\defrost\\feedback_data.csv' # 你的csv
  8. model_save_path = "defrost_time_corrector.pkl" # 模型保存路径
  9. # === 特征列定义 ===
  10. feature_columns = [
  11. "w", "rho_coal", "rho_ice", "C_coal", "C_ice", "L", "k_coal", "k_ice", "h",
  12. "T_air", "T_initial", "T_m", "a", "b", "c"
  13. ]
  14. # 定义定时任务的训练函数
  15. def train_and_save_model():
  16. print("🔄 定时任务开始:重新训练模型...")
  17. # === 1. 读取CSV并预处理 ===
  18. try:
  19. df = pd.read_csv(csv_path, parse_dates=["t_formula", "t_real"], encoding='utf-8')
  20. print(f"✅ 成功读取CSV文件,共{len(df)}条数据")
  21. except Exception as e:
  22. print(f"❌ 读取CSV失败: {e}")
  23. return
  24. # 强制转换为 datetime 类型
  25. df["t_real"] = pd.to_datetime(df.get("t_real"), errors="coerce")
  26. df["t_formula"] = pd.to_datetime(df.get("t_formula"), errors="coerce")
  27. # 删除无法转换时间的行
  28. before_drop = len(df)
  29. df = df.dropna(subset=["t_real", "t_formula"])
  30. after_drop = len(df)
  31. if before_drop != after_drop:
  32. print(f"⚠️ 删除了 {before_drop - after_drop} 行非法时间数据")
  33. # 计算真实解冻时长(小时)
  34. df["t_real_hours"] = (df["t_real"] - df["t_formula"]).dt.total_seconds() / 3600
  35. # 再次检查并去除无效目标值
  36. if df["t_real_hours"].isna().any():
  37. invalid_count = df["t_real_hours"].isna().sum()
  38. df = df.dropna(subset=["t_real_hours"])
  39. print(f"⚠️ 删除了 {invalid_count} 行无效的解冻时长")
  40. # 确保字段类型正确(如果这两列存在)
  41. for col in ["material_name", "manufactured_goods"]:
  42. if col in df.columns:
  43. df[col] = df[col].astype(str)
  44. # 检查有没有缺失特征
  45. missing_features = [col for col in feature_columns if col not in df.columns]
  46. if missing_features:
  47. print(f"❌ 缺少必要特征列: {missing_features}")
  48. return
  49. # === 2. 智能训练模型 ===
  50. X = df[feature_columns]
  51. y = df["t_real_hours"]
  52. if len(X) >= 10:
  53. # 数据够多,做train_test_split
  54. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  55. print(f"📚 数据量 {len(X)},已划分训练集和测试集")
  56. else:
  57. # 数据少,直接全量训练
  58. X_train, y_train = X, y
  59. X_test, y_test = None, None
  60. print(f"⚠️ 数据量太少({len(X)}条),直接全量训练")
  61. # 建立随机森林回归模型
  62. model = RandomForestRegressor(n_estimators=100, random_state=42)
  63. model.fit(X_train, y_train)
  64. # 保存模型
  65. joblib.dump(model, model_save_path)
  66. print(f"✅ 模型训练完成,已保存为 {model_save_path}")
  67. # === 启动定时任务调度器 ===
  68. if __name__ == '__main__':
  69. # 先执行一次
  70. train_and_save_model()
  71. # 然后设置每小时定时执行
  72. scheduler = BlockingScheduler()
  73. scheduler.add_job(train_and_save_model, 'interval', hours=1)
  74. print("⏰ 启动定时任务调度器:每小时自动训练模型...")
  75. scheduler.start()