这是由国际计算机科学研究院的 Jaeyoung Choi 和加州大学伯克利分校的 Kevin Li 所著的一篇访客文章。本项目演示学术研究人员如何利用我们的 AWS Cloud Credits for Research Program 实现科学突破。
当您拍摄照片时,现代移动设备可以自动向图像分配地理坐标。不过,网络上的大多数图像仍缺少该位置元数据。图像定位是估计图像位置并应用位置标签的过程。根据您的数据集大小以及提出问题的方式,分配的位置标签可以是建筑物或地标名称或实际地理坐标 (纬度、经度)。
在本文中,我们会展示如何使用通过 Apache MXNet 创建的预训练模型对图像进行地理分类。我们使用的数据集包含拍摄于全球各地的数百万张 Flickr 图像。我们还会展示如何将结果制成地图以直观地显示结果。
我们的方法
图像定位方法可以分为两类:图像检索搜索法和分类法。(该博文将对这两个类别中最先进的方法进行比较。)
Weyand 等人近期的作品提出图像定位是一个分类问题。在这种方法中,作者将地球表面细分为数千个地理单元格,并利用带地理标记的图像训练了深层神经网路。有关他们的试验更通俗的描述,请参阅该文章。
由于作者没有公开他们的训练数据或训练模型 (即 PlaNet),因此我们决定训练我们自己的图像定位器。我们训练模型的场景灵感来自于 Weyand 等人描述的方法,但是我们对几个设置作了改动。
我们在单个 p2.16xlarge 实例上使用 MXNet 来训练我们的模型 LocationNet,该实例包含来自 AWS Multimedia Commons 数据集的带有地理标记的图像。
我们将训练、验证和测试图像分离,以便同一人上传的图像不会出现在多个集合中。我们使用 Google 的 S2 Geometry Library 通过训练数据创建类。该模型经过 12 个训练周期后收敛,完成 p2.16xlarge 实例训练大约花了 9 天时间。GitHub 上提供了采用 Jupyter Notebook 的完整教程。
下表对用于训练和测试 LocationNet 和 PlaNet 的设置进行了比较。
数据集来源 | Multimedia Commons | 从网络抓取的图像
训练集 | 3390 万 | 9100 万
验证 | 180 万 | 3400 万
S2 单元分区 | t1=5000, t2=500
→ 15,527 个单元格 | t1=10,000, t2=50
→ 26,263 个单元格
模型 | ResNet-101 | GoogleNet
优化 | 使用动量和 LR 计划的 SGD | Adagrad
训练时间 | 采用 16 个 NVIDIA K80 GPU (p2.16xlarge EC2 实例) 时为 9 天
12 个训练周期 | 采用 200 个 CPU 内核时为两个半月
框架 | MXNet | DistBelief
测试集 | Placing Task 2016 测试集 (150 万张 Flickr 图像) | 230 万张有地理标记的 Flickr 图像
在推理时,LocationNet 会输出地理单元格间的概率分布。单元格中概率最高的图像的质心地理坐标会被分配为查询图像的地理坐标。
LocationNet 会在 MXNet Model Zoo 中公开分享。
下载 LocationNet
现在下载 LocationNet 预训练模型。LocationNet 已使用 AWS Multimedia Commons 数据集中带地理标记的图像子集进行了训练。Multimedia Commons 数据集包含 3900 多万张图像和 15000 个地理单元格 (类)。
LocationNet 包括两部分:一个包含模型定义的 JSON 文件和一个包含参数的二进制文件。我们从 S3 加载必要的软件包并下载文件。
Java
import os
import urllib
import mxnet as mx
import logging
import numpy as np
from skimage import io, transform
from collections import namedtuple
from math import radians, sin, cos, sqrt, asin
path = 'https://s3.amazonaws.com/mmcommons-tutorial/models/'
model_path = 'models/'
if not os.path.exists(model_path):
os.mkdir(model_path)
urllib.urlretrieve(path+'RN101-5k500-symbol.json', model_path+'RN101-5k500-symbol.json')
urllib.urlretrieve(path+'RN101-5k500-0012.params', model_path+'RN101-5k500-0012.params')
复制代码
然后,加载下载的模型。如果您没有可用 GPU,请将 mx.gpu() 替换为 mx.cpu():
Java
# Load the pre-trained model
prefix = "models/RN101-5k500"
load_epoch = 12
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, load_epoch)
mod = mx.mod.Module(symbol=sym, context=mx.gpu())
mod.bind([('data', (1,3,224,224))], for_training=False)
mod.set_params(arg_params, aux_params, allow_missing=True)
复制代码
grids.txt 文件包含用于训练模型的地理单元格。
第 i 行是第 i 个类,列分别代表:S2 单元格标记、纬度和经度。我们将标签加载到名为 grids 的列表中。
Java
# Download and load grids file
urllib.urlretrieve('https://raw.githubusercontent.com/multimedia-berkeley/tutorials/master/grids.txt','grids.txt')
# Load labels.
grids = []
with open('grids.txt', 'r') as f:
for line in f:
line = line.strip().split('\t')
lat = float(line[1])
lng = float(line[2])
grids.append((lat, lng))
复制代码
该模型使用半径公式来测量点 p1 和 p2 之间的大圆弧距离,以千米为单位:
Java
def distance(p1, p2):
R = 6371 # Earth radius in km
lat1, lng1, lat2, lng2 = map(radians, (p1[0], p1[1], p2[0], p2[1]))
dlat = lat2 - lat1
dlng = lng2 - lng1
a = sin(dlat * 0.5) ** 2 + cos(lat1) * cos(lat2) * (sin(dlng * 0.5) ** 2)
复制代码
Java
return 2 * R * asin(sqrt(a))
复制代码
在将图像提供给深度学习网络之前,该模型会通过裁剪以及减去均值来预处理图像:
Java
# mean image for preprocessing
mean_rgb = np.array([123.68, 116.779, 103.939])
mean_rgb = mean_rgb.reshape((3, 1, 1))
def PreprocessImage(path, show_img=False):
# load image.
img = io.imread(path)
# We crop image from center to get size 224x224.
short_side = min(img.shape[:2])
yy = int((img.shape[0] - short_side) / 2)
xx = int((img.shape[1] - short_side) / 2)
crop_img = img[yy : yy + short_side, xx : xx + short_side]
resized_img = transform.resize(crop_img, (224,224))
if show_img:
io.imshow(resized_img)
# convert to numpy.ndarray
sample = np.asarray(resized_img) * 256
# swap axes to make image from (224, 224, 3) to (3, 224, 224)
sample = np.swapaxes(sample, 0, 2)
sample = np.swapaxes(sample, 1, 2)
# sub mean
normed_img = sample - mean_rgb
normed_img = normed_img.reshape((1, 3, 224, 224))
return [mx.nd.array(normed_img)]
复制代码
评估并比较模型
为了进行评估,我们使用两个数据集:IM2GPS 数据集和 Flickr 图像测试数据集,后者用于 MediaEval Placing 2016 基准测试。
IM2GPS 测试集结果
以下值表示 IM2GPS 测试集中正确位于与实际位置的每个距离内的图像的百分比。
Flickr 图像结果
由于 PlaNet 中使用的测试集图像尚未公开发布,因此不能直接比较这些结果。这些值表示测试集中正确位于与实际位置的每个距离内的图像的百分比。
通过目测检查定位图像,我们可以看到该模型不仅在地标位置方面表现出色,而且也能准确定位非标志性场景。
使用 URL 估算图像的地理位置
现在我们试着用 URL 对网页上的图像进行定位。
Java
Batch = namedtuple('Batch', ['data'])
def predict(imgurl, prefix='images/'):
download_url(imgurl, prefix)
imgname = imgurl.split('/')[-1]
batch = PreprocessImage(prefix + imgname, True)
#predict and show top 5 results
mod.forward(Batch(batch), is_train=False)
prob = mod.get_outputs()[0].asnumpy()[0]
pred = np.argsort(prob)[::-1]
result = list()
for i in range(5):
pred_loc = grids[int(pred[i])]
res = (i+1, prob[pred[i]], pred_loc)
print('rank=%d, prob=%f, lat=%s, lng=%s' \
% (i+1, prob[pred[i]], pred_loc[0], pred_loc[1]))
result.append(res[2])
return result
def download_url(imgurl, img_directory):
if not os.path.exists(img_directory):
os.mkdir(img_directory)
imgname = imgurl.split('/')[-1]
filepath = os.path.join(img_directory, imgname)
if not os.path.exists(filepath):
filepath, _ = urllib.urlretrieve(imgurl, filepath)
statinfo = os.stat(filepath)
print('Succesfully downloaded', imgname, statinfo.st_size, 'bytes.')
复制代码
Java
来看看我们的模型如何处理东京塔图片。以下代码从 URL 下载图像,并输出模型的位置预测。
Java
#download and predict geo-location of an image of Tokyo Tower
url = 'https://farm5.staticflickr.com/4275/34103081894_f7c9bfa86c_k_d.jpg'
复制代码
Java
结果列出了置信度分数 (概率) 排在前五位的输出以及地理坐标:
Java
rank=1, prob=0.139923, lat=35.6599344486, lng=139.728919109
rank=2, prob=0.095210, lat=35.6546613641, lng=139.745685815
rank=3, prob=0.042224, lat=35.7098435803, lng=139.810458528
rank=4, prob=0.032602, lat=35.6641725688, lng=139.746648114
rank=5, prob=0.023119, lat=35.6901996892, lng=139.692857396
复制代码
仅通过原始纬度和经度值,很难判断地理位置输出的质量。我们可以通过将输出制成地图来直观地显示结果。
在 Jupyter Notebook 上使用 Google Maps 直观显示结果
为了直观地显示预测结果,我们可以在 Jupyter Notebook 中使用 Google Maps。它让您能够看到预测是否有意义。我们使用一个名为 gmaps 的插件,它允许我们在 Jupyter Notebook 中使用 Google Maps。要安装 gmaps,请按照 gmaps GitHub 页面上的安装说明操作。
使用 gmaps 直观显示结果只需几行代码。请在您的 Notebook 输入以下内容:
Java
import gmaps
gmaps.configure(api_key="") # Fill in with your API key
fig = gmaps.figure()
for i in range(len(result)):
marker = gmaps.marker_layer([result[i]], label=str(i+1))
fig.add_layer(marker)
fig
复制代码
事实上,排在第一位的定位估算结果就是东京塔所在的位置。
现在,试着对您选择的图像进行定位吧!
鸣谢
在 AWS 上训练 LocationNet 的工作得到了 AWS 研究与教育计划的大力支持。我们还要感谢 AWS 公共数据集计划托管 Multimedia Commons 数据集以供公众使用。我们的工作也得到了劳伦斯·利弗莫尔国家实验室领导的合作 LDRD 的部分支持 (美国能源部合同 DE-AC52-07NA27344)。
本文转载自 AWS 技术博客。
原文链接:
https://amazonaws-china.com/cn/blogs/china/estimating-image-locations-using-the-apache-mxnet-and-multimedia-commons-datasets-on-aws-ec2/
评论