train_corrector.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import pandas as pd
  2. import joblib
  3. from sklearn.ensemble import RandomForestRegressor
  4. # === 配置路径 ===
  5. csv_path = 'C:\\Users\\Administrator\\Desktop\\defrost\\feedback_data.csv' # 你的csv
  6. model_save_path = "defrost_time_corrector.pkl" # 模型保存路径
  7. # === 特征列定义 ===
  8. feature_columns = [
  9. "w", "rho_coal", "rho_ice", "C_coal", "C_ice", "L", "k_coal", "k_ice", "h",
  10. "T_air", "T_initial", "T_m", "a", "b", "c"
  11. ]
  12. # === 1. 读取CSV并预处理 ===
  13. try:
  14. df = pd.read_csv(csv_path, parse_dates=["t_formula", "t_real"], encoding='utf-8')
  15. print(f"✅ 成功读取CSV文件,共{len(df)}条数据")
  16. except Exception as e:
  17. print(f"❌ 读取CSV失败: {e}")
  18. exit(1)
  19. # 确保字段类型正确(如果这两列存在)
  20. for col in ["material_name", "manufactured_goods"]:
  21. if col in df.columns:
  22. df[col] = df[col].astype(str)
  23. # 计算真实解冻时长(小时)
  24. df["t_real_hours"] = (df["t_real"] - df["t_formula"]).dt.total_seconds() / 3600
  25. # 检查有没有缺失特征
  26. missing_features = [col for col in feature_columns if col not in df.columns]
  27. if missing_features:
  28. print(f"❌ 缺少必要特征列: {missing_features}")
  29. exit(1)
  30. # === 2. 智能训练模型 ===
  31. X = df[feature_columns]
  32. y = df["t_real_hours"]
  33. if len(X) >= 10:
  34. # 数据够多,做train_test_split
  35. from sklearn.model_selection import train_test_split
  36. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  37. print(f"📚 数据量 {len(X)},已划分训练集和测试集")
  38. else:
  39. # 数据少,直接全量训练
  40. X_train, y_train = X, y
  41. X_test, y_test = None, None
  42. print(f"⚠️ 数据量太少({len(X)}条),直接全量训练")
  43. # 建立随机森林回归模型
  44. model = RandomForestRegressor(n_estimators=100, random_state=42)
  45. model.fit(X_train, y_train)
  46. # 保存模型
  47. joblib.dump(model, model_save_path)
  48. print(f"✅ 模型训练完成,已保存为 {model_save_path}")
  49. # === 3. 预测最新一条数据 ===
  50. new_sample = df.tail(1) # 取最后一行
  51. X_new = new_sample[feature_columns]
  52. predicted_time = model.predict(X_new)[0]
  53. # 把预测值写回DataFrame
  54. df.loc[new_sample.index, "predicted_t_real_hours"] = predicted_time
  55. # === 4. 保存带预测值的CSV ===
  56. try:
  57. df.to_csv(csv_path, encoding='utf-8', index=False)
  58. print(f"✅ 最新数据预测完成,已更新到 {csv_path}")
  59. except Exception as e:
  60. print(f"❌ 保存CSV失败: {e}")
  61. # === 5. 打印最终预测结果 ===
  62. print(f"\n📊 预测最后一条数据真实解冻时间为:{predicted_time:.2f} 小时")