今天,我们非常高兴地宣布推出 Amazon SageMaker 模型监控器。这是 Amazon SageMaker 的一项新功能,可以自动监控生产中的机器学习 (ML) 模型,并在出现数据质量问题时向您发出警报。
在我从事数据处理之初时,我学会了一样东西,那便是再关注数据质量都不为过。不知道您是否有过这样的经历:您花费数小时排查问题,最后知道是意外的 NULL 值或不知怎么就到了您的一个数据库的外来字符编码导致的。
由于模型实际上是根据大量数据构建的,因此不难理解,为什么 ML 从业人员会花费大量时间来维护数据集。特别是,他们会确保训练集(用于训练模型)和验证集(用于测量模型的准确性)中的数据样本具有相同的统计属性。
现在还不是松懈的时候! 尽管您可以完全控制实验数据集,但对于模型将要接收的真实数据就不是那回事了。当然,这些数据将是未经清理的,但是更令人担忧的问题是“数据漂移”,即您所接收数据的统计性质发生渐变。最小值和最大值、平均值、中位数、方差等等:所有这些都是决定模型训练期间做出的假设和决策的关键属性。我们的直觉告诉我们,这些值的任何重大变化都会影响预测的准确性:设想一下,要是由于输入特征出现漂移甚至缺失,导致一个贷款应用程序预测的金额升高,那多可怕!
检测这些条件非常困难:您将需要捕获模型接收的数据,运行各种统计分析以将这些数据与训练集进行比较,定义规则以检测漂移,并在发生漂移时发出警报……并在每次更新模型时从头再来一遍。专家级 ML 从业人员当然知道如何构建这些复杂的工具,但是却要花费大量的时间和耗费大量的资源。这不就是眉毛胡子一把抓么……
为了帮助所有客户专注于创造价值,我们构建了 Amazon SageMaker 模型监控器。下面我来进行更多介绍。
Amazon SageMaker 模型监控器简介
典型的监控会话如下。首先,我们要从 SageMaker 终端节点开始,可以使用现有的终端节点,也可以专门为了监控目的而创建新的终端节点。您可以在任何终端节点上使用 SageMaker 模型监控器,无论模型是从内置算法、内置框架,还是从您自己的容器训练而来。
使用 SageMaker 开发工具包,您可以捕获发送到终端节点的部分数据(可配置),您也可以根据需要捕获预测,并将这些数据存储在您的 Amazon Simple Storage Service (S3) 存储桶中。捕获的数据会附加上元数据(内容类型、时间戳等),您可以像使用任何 S3 对象一样保护和访问它。
然后,从用于训练在终端节点上部署的模型的数据集建立基线。当然,您也可以选择使用已有的基线。这将启动 Amazon SageMaker 处理作业,其中 SageMaker 模型监控器将执行以下操作:
推断输入数据的架构,即有关每个特征的类型和完整性的信息。您应该对其进行检查,并在需要时进行更新。
(仅对于预构建的容器)使用 Deequ(基于由 Amazon 开发并在 Amazon 使用的 Apache Spark 的开放源代码工具)来计算特征统计信息(博客文章和研究论文)。这些统计信息包括 KLL 草图,这是一种用于在数据流上计算准确分位数的高级技术,这也是我们最近对 Deequ 做出的一项贡献。
使用这些构件,下一步是启动监控计划,以使 SageMaker 模型监控器检查收集的数据和预测质量。无论使用的是内置容器还是自定义容器,都需要应用许多内置规则,并且报告会定期推送到 S3。这些报告包含在上一个时间段内接收到的数据的统计和架构信息以及检测到的任何违规情况。
最后但并非最不重要的一点是, SageMaker 模型监控器会向 Amazon CloudWatch 发出与特征相对应的指标,可用于设置控制面板和警报。CloudWatch 的摘要指标也可以在 Amazon SageMaker Studio 中看到,当然所有统计数据、监控结果和收集的数据都可以在笔记本中查看和进一步分析。
有关更多信息以及有关如何通过 AWS CloudFormation 使用 SageMaker 模型监控器的示例,请参阅开发人员指南。
现在,让我们使用经过内置 XGBoost 算法训练的用户流失预测模型进行演示。
启用数据捕获
第一步是创建终端节点配置以启用数据捕获。在这里,我决定捕获 100% 的传入数据以及模型输出(即预测)。我还传递了 CSV 和 JSON 数据的内容类型。
Python
接下来,我使用常规的 CreateEndpoint
API 创建终端节点。
Python
对于已有的终端节点,我可以使用 UpdateEndpoint
API 来无缝更新终端节点配置。
反复调用终端节点后,我可以在 S3 中看到一些捕获的数据(为清晰起见,对输出进行了编辑)。
Bash
AllTrafficVariant/2019/11/22/08/24-40-519-9a9273ca-09c2-45d3-96ab-fc7be2402d43.jsonl
AllTrafficVariant/2019/11/22/08/25-42-243-3e1c653b-8809-4a6b-9d51-69ada40bc809.jsonl
“endpointInput”:{
“observedContentType”:“text/csv”,
“mode”:“INPUT”,
“data”:“132,25,113.2,96,269.9,107,229.1,87,7.1,7,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,1”,
“encoding”:“CSV”
},
“endpointOutput”:{
“observedContentType”:“text/csv; charset=utf-8”,
“mode”:“OUTPUT”,
“data”:“0.01076381653547287”,
“encoding”:“CSV”}
},
“eventMetadata”:{
“eventId”:“6ece5c74-7497-43f1-a263-4833557ffd63”,
“inferenceTime”:“2019-11-22T08:24:40Z”},
“eventVersion”:“0”}
from processingjob_wrapper import ProcessingJob
processing_job = ProcessingJob(sm_client, role).
create(job_name, baseline_data_uri, baseline_results_uri)
aws s3 ls s3://sagemaker-us-west-2-123456789012/sagemaker/DEMO-ModelMonitor/baselining/results/
constraints.json
statistics.json
{
“version” : 0.0,
“features” : [ {
“name” : “Churn”,
“inferred_type” : “Integral”,
“completeness” : 1.0
}, {
“name” : “Account Length”,
“inferred_type” : “Integral”,
“completeness” : 1.0
}, {
“name” : “VMail Message”,
“inferred_type” : “Integral”,
“completeness” : 1.0
}, {
“name” : “Day Mins”,
“inferred_type” : “Fractional”,
“completeness” : 1.0
}, {
“name” : “Day Calls”,
“inferred_type” : “Integral”,
“completeness” : 1.0
“monitoring_config” : {
“evaluate_constraints” : “Enabled”,
“emit_metrics” : “Enabled”,
“distribution_constraints” : {
“enable_comparisons” : true,
“min_domain_mass” : 1.0,
“comparison_threshold” : 1.0
}
}
“name” : “Day Mins”,
“inferred_type” : “Fractional”,
“numerical_statistics” : {
“common” : {
“num_present” : 2333,
“num_missing” : 0
},
“mean” : 180.22648949849963,
“sum” : 420468.3999999996,
“std_dev” : 53.987178959901556,
“min” : 0.0,
“max” : 350.8,
“distribution” : {
“kll” : {
“buckets” : [ {
“lower_bound” : 0.0,
“upper_bound” : 35.08,
“count” : 14.0
}, {
“lower_bound” : 35.08,
“upper_bound” : 70.16,
“count” : 48.0
}, {
“lower_bound” : 70.16,
“upper_bound” : 105.24000000000001,
“count” : 130.0
}, {
“lower_bound” : 105.24000000000001,
“upper_bound” : 140.32,
“count” : 318.0
}, {
“lower_bound” : 140.32,
“upper_bound” : 175.4,
“count” : 565.0
}, {
“lower_bound” : 175.4,
“upper_bound” : 210.48000000000002,
“count” : 587.0
}, {
“lower_bound” : 210.48000000000002,
“upper_bound” : 245.56,
“count” : 423.0
}, {
“lower_bound” : 245.56,
“upper_bound” : 280.64,
“count” : 180.0
}, {
“lower_bound” : 280.64,
“upper_bound” : 315.72,
“count” : 58.0
}, {
“lower_bound” : 315.72,
“upper_bound” : 350.8,
“count” : 10.0
} ],
“sketch” : {
“parameters” : {
“c” : 0.64,
“k” : 2048.0
},
“data” : [ [ 178.1, 160.3, 197.1, 105.2, 283.1, 113.6, 232.1, 212.7, 73.3, 176.9, 161.9, 128.6, 190.5, 223.2, 157.9, 173.1, 273.5, 275.8, 119.2, 174.6, 133.3, 145.0, 150.6, 220.2, 109.7, 155.4, 172.0, 235.6, 218.5, 92.7, 90.7, 162.3, 146.5, 210.1, 214.4, 194.4, 237.3, 255.9, 197.9, 200.2, 120, …
ms = MonitoringSchedule(sm_client, role)
schedule = ms.create(
mon_schedule_name,
endpoint_name,
s3_report_path,
record_preprocessor_source_uri=s3_code_preprocessor_uri,
post_analytics_source_uri=s3_code_postprocessor_uri,
baseline_statistics_uri=baseline_results_uri + ‘/statistics.json’,
baseline_constraints_uri=baseline_results_uri+ ‘/constraints.json’
)
mon_executions = sm_client.list_monitoring_executions(MonitoringScheduleName=mon_schedule_name, MaxResults=3)
for execution_summary in mon_executions[‘MonitoringExecutionSummaries’]:
print(“ProcessingJob: {}”.format(execution_summary[‘ProcessingJobArn’].split(’/’)[1]))
print(‘MonitoringExecutionStatus: {} \n’.format(execution_summary[‘MonitoringExecutionStatus’]))
ProcessingJob: model-monitoring-201911221050-df2c7fc4
MonitoringExecutionStatus: Completed
ProcessingJob: model-monitoring-201911221040-3a738dd7
MonitoringExecutionStatus: Completed
ProcessingJob: model-monitoring-201911221030-83f15fb9
MonitoringExecutionStatus: Completed
desc_analytics_job_result=sm_client.describe_processing_job(ProcessingJobName=job_name)
report_uri=desc_analytics_job_result[‘ProcessingOutputConfig’][‘Outputs’][0][‘S3Output’][‘S3Uri’]
print(‘Report Uri: {}’.format(report_uri))
Report Uri: s3://sagemaker-us-west-2-123456789012/sagemaker/DEMO-ModelMonitor/reports/2019112208-2019112209
aws s3 ls s3://sagemaker-us-west-2-123456789012/sagemaker/DEMO-ModelMonitor/reports/2019112208-2019112209/
constraint_violations.json
constraints.json
statistics.json
violations" : [ {
“feature_name” : “State_AL”,
“constraint_check_type” : “data_type_check”,
“description” : "Value: 0.8 does not meet the constraint requirement! "
}, {
“feature_name” : “Eve Mins”,
“constraint_check_type” : “baseline_drift_check”,
“description” : “Numerical distance: 0.2711598746081505 exceeds numerical threshold: 0”
}, {
“feature_name” : “CustServ Calls”,
“constraint_check_type” : “baseline_drift_check”,
“description” : “Numerical distance: 0.6470588235294117 exceeds numerical threshold: 0”
}
评论