train_corrector.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import pandas as pd
  2. import joblib
  3. from sklearn.ensemble import RandomForestRegressor
  4. from sklearn.model_selection import train_test_split
  5. # === 步骤1:读取CSV并预处理 ===
  6. csv_path = 'C:\\Users\\Administrator\\Desktop\\defrost\\feedback_data.csv'
  7. df = pd.read_csv(csv_path, parse_dates=["t_formula", "t_real"], encoding='gbk')
  8. # 确保类型一致
  9. df["material_name"] = df["material_name"].astype(str)
  10. df["manufactured_goods"] = df["manufactured_goods"].astype(str)
  11. # 计算真实解冻时长(单位:小时)
  12. df["t_real_hours"] = (df["t_real"] - df["t_formula"]).dt.total_seconds() / 3600
  13. # 特征列(不包括物料名称和产品名称)
  14. feature_columns = [
  15. "w", "rho_coal", "rho_ice", "C_coal", "C_ice", "L", "k_coal", "k_ice", "h",
  16. "T_air", "T_initial", "T_m", "a", "b", "c"
  17. ]
  18. # 模型输入和标签
  19. X = df[feature_columns].copy()
  20. y = df["t_real_hours"]
  21. # === 步骤2:训练模型 ===
  22. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  23. model = RandomForestRegressor(n_estimators=100, random_state=42)
  24. model.fit(X_train, y_train)
  25. # === 步骤3:保存模型 ===
  26. joblib.dump(model, "defrost_time_corrector.pkl")
  27. print("模型训练完成并已保存为 defrost_time_corrector.pkl")
  28. # === 步骤4:测试一个新样本并判断是否为相同样本类型 ===
  29. new_sample_info = {
  30. "material_name": "国产动力煤",
  31. "manufactured_goods": "龙家堡洗混煤-5206",
  32. "w": 12,
  33. "rho_coal": 3000,
  34. "rho_ice": 917,
  35. "C_coal": 800,
  36. "C_ice": 2100,
  37. "L": 334000,
  38. "k_coal": 20,
  39. "k_ice": 2.2,
  40. "h": 300,
  41. "T_air": 90,
  42. "T_initial": -20,
  43. "T_m": 0,
  44. "a": 13,
  45. "b": 2.72,
  46. "c": 1.6
  47. }
  48. # 构造 DataFrame
  49. new_sample = pd.DataFrame([new_sample_info])
  50. # 一致性判断
  51. is_known = ((df["material_name"] == new_sample_info["material_name"]) &
  52. (df["manufactured_goods"] == new_sample_info["manufactured_goods"])).any()
  53. # 只传入特征列用于模型预测
  54. X_new = new_sample[feature_columns]
  55. predicted_time = model.predict(X_new)[0]
  56. print(f"\n📊 预测真实解冻时间: {predicted_time:.2f} 小时")
  57. if is_known:
  58. print("该样本与历史数据中存在相同物料和制造品,可以认为是同一类样本。")
  59. else:
  60. print("该样本是新的物料或产品组合,可能存在偏差,请注意验证。")