MLflow实验管理全流程指南
MLflow是开源的机器学习生命周期管理平台,其实验管理功能通过Tracking(实验追踪)组件实现,覆盖实验记录、查询、对比及模型版本控制等环节,帮助团队高效管理机器学习实验,确保结果可复现。
实验管理的第一步是创建或指定实验,确保所有运行结果归属同一实验。通过mlflow.set_experiment()
函数设置实验名称,若实验不存在则自动创建,后续所有start_run()
调用都会关联到该实验。
import mlflow
mlflow.set_experiment("Iris_Classification") # 设置实验名称(如“Iris分类”)
这一步是实验管理的基石,避免实验分散在不同目录或工作区。
使用mlflow.start_run()
开启一个实验运行(Run),并在with
块内记录参数(超参数配置)、指标(模型性能)和产物(模型文件、可视化图表)。运行结束后自动关闭,确保数据完整性。
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
with mlflow.start_run(): # 开启运行
# 加载数据
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)
# 记录参数(超参数)
mlflow.log_param("n_estimators", 100) # 随机森林树的数量
mlflow.log_param("max_depth", 5) # 树的最大深度
# 训练模型
model = RandomForestClassifier(n_estimators=100, max_depth=5)
model.fit(X_train, y_train)
# 记录指标(性能评估)
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
mlflow.log_metric("accuracy", accuracy) # 准确率
# 记录产物(模型文件)
mlflow.sklearn.log_model(model, "random_forest_model") # 保存模型到“random_forest_model”目录
# 记录额外产物(如特征重要性)
with open("feature_importance.txt", "w") as f:
f.write(str(model.feature_importances_))
mlflow.log_artifact("feature_importance.txt") # 上传特征重要性文件
通过log_param
、log_metric
和log_artifact
,实验的所有关键信息都被结构化记录,便于后续查询。
实验完成后,可通过MLflow UI或代码查询、对比不同运行的结果:
mlflow ui
),默认访问http://localhost:5000
,界面会展示所有实验的运行列表。支持并排对比(选择多个运行)、过滤(如“accuracy > 0.9”)、可视化工件(如特征重要性图),快速识别最佳实验。mlflow.search_runs()
函数,通过实验ID、参数或指标筛选运行。例如,获取某实验的最新2次运行并按准确率降序排序:import mlflow
experiment_id = "0" # 替换为目标实验ID(可通过mlflow.search_experiments()获取)
runs = mlflow.search_runs(
experiment_ids=[experiment_id],
order_by=["metrics.accuracy DESC"],
max_results=2
)
print(runs[["params.n_estimators", "metrics.accuracy"]]) # 打印参数与指标
这一步帮助团队快速评估不同超参数组合的效果,避免重复实验。
实验完成后,可将模型注册到MLflow Model Registry(模型注册表),实现模型生命周期的集中管理。注册表支持以下阶段:
注册模型的步骤:
from mlflow.tracking import MlflowClient
# 注册模型(runs://格式指向某次运行的模型路径)
model_uri = f"runs:/{runs.iloc[0].info.run_id}/random_forest_model" # 使用最新运行的模型
client = MlflowClient()
registered_model = client.create_registered_model("IrisRandomForest") # 创建注册模型
# 创建模型版本
version = client.create_model_version(
name="IrisRandomForest",
source=model_uri,
run_id=runs.iloc[0].info.run_id
)
# 将模型转为生产状态
client.transition_model_version_stage(
name="IrisRandomForest",
version=version.version,
stage="Production"
)
模型注册表不仅解决了“模型版本混乱”的问题,还支持团队协作与审批流程,确保生产环境使用的是经过验证的模型。
mlflow server --backend-store-uri sqlite:///mlflow.db
),方便团队共享实验数据;MLflow Projects
(MLproject
文件)规范项目结构与依赖,确保实验可复现;