Mapillary Vistas数据集实战:用Python快速加载并可视化66类街景语义分割标签
Mapillary Vistas数据集实战:用Python快速加载并可视化66类街景语义分割标签
第一次接触Mapillary Vistas数据集时,我被它丰富的标注类别和真实场景多样性所震撼。这个包含25,000张高分辨率街景图像的数据集,不仅覆盖了常规的道路、建筑等静态元素,还细致标注了不同天气条件下的动态物体,如骑行者和特殊车辆。对于从事自动驾驶或街景分析的研究者来说,掌握如何高效使用这个数据集是快速开展实验的基础。本文将带你从零开始,用Python实现数据加载、标签解析和可视化全流程,重点解决实际项目中常见的三个痛点:如何正确处理JSON格式的复杂标签定义、如何高效映射66类颜色编码、如何生成可用于论文发表的专业可视化效果。
1. 环境准备与数据获取
在开始处理Mapillary Vistas数据集前,需要配置合适的Python环境。推荐使用conda创建独立环境以避免依赖冲突:
conda create -n mapillary python=3.8 conda activate mapillary pip install numpy pillow opencv-python matplotlib jsonlines数据集可通过官网注册后下载,建议选择v2.0版本,其文件结构如下:
mapillary_vistas/ ├── training/ │ ├── images/ # 18,000张训练图像 │ └── labels/ # 对应的PNG标注文件 ├── validation/ │ ├── images/ # 2,000张验证图像 │ └── labels/ └── config_v2.0.json # 标签定义文件注意:下载后的压缩包约65GB,解压需要确保磁盘有足够空间。可使用
tar -xvf命令分卷解压。
标签配置文件config_v2.0.json包含三个关键信息:
labels: 每个类别的名称、颜色值和评估标志mapping: 标注与评估的映射关系folder_structure: 数据集存储路径模板
2. 标签解析与颜色映射
Mapillary的标注文件采用PNG格式存储,每个像素点的RGB值对应特定类别。我们需要先将JSON配置转换为可操作的色彩查找表:
import json import numpy as np def load_label_config(config_path): with open(config_path) as f: config = json.load(f) color_map = {} class_names = [] for label in config['labels']: rgb = tuple(label['color']) color_map[rgb] = label['name'] class_names.append(label['readable']) return color_map, class_names # 使用示例 color_map, class_names = load_label_config('config_v2.0.json') print(f"加载完成 {len(class_names)} 个类别")为提升后续处理效率,我们预生成颜色映射矩阵:
def build_color_matrix(color_map): max_color = max(max(rgb) for rgb in color_map.keys()) size = max_color + 1 matrix = np.zeros((size, size, size), dtype=np.int32) - 1 for idx, rgb in enumerate(color_map.keys()): r, g, b = rgb matrix[r, g, b] = idx return matrix color_matrix = build_color_matrix(color_map)这个三维矩阵让我们可以用O(1)时间复杂度查询任意RGB值对应的类别ID,比字典查询快3-5倍。
3. 数据加载与可视化
使用OpenCV和Matplotlib实现高效的图像和标注加载:
import cv2 import matplotlib.pyplot as plt def visualize_sample(image_path, label_path, color_map): image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) label = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB) # 创建带透明度的叠加效果 overlay = image.copy() alpha = 0.6 for rgb, name in color_map.items(): mask = np.all(label == np.array(rgb), axis=-1) overlay[mask] = rgb result = cv2.addWeighted(overlay, alpha, image, 1-alpha, 0) # 绘制图例 plt.figure(figsize=(16, 9)) plt.imshow(result) plt.axis('off') # 随机选择10个类别显示图例 import random sampled = random.sample(list(color_map.items()), 10) patches = [plt.Rectangle((0,0),1,1, fc=np.array(rgb)/255) for rgb, name in sampled] plt.legend(patches, [name for rgb,name in sampled], loc='lower right', bbox_to_anchor=(1.3, 0)) plt.tight_layout() return result典型调用方式:
sample_img = 'training/images/0.jpg' sample_label = 'training/labels/0.png' result = visualize_sample(sample_img, sample_label, color_map) plt.savefig('visualization.jpg', dpi=300, bbox_inches='tight')4. 高级可视化技巧
4.1 类别过滤显示
实际分析时往往只需要关注特定类别,修改可视化函数:
def visualize_selected_classes(image_path, label_path, target_classes): """ target_classes: 如 ['road', 'car', 'person'] """ image = cv2.imread(image_path) label_rgb = cv2.imread(label_path) # 转换为类别ID矩阵 h, w = label_rgb.shape[:2] label_ids = np.zeros((h, w), dtype=np.int32) for rgb, name in color_map.items(): if name in target_classes: mask = np.all(label_rgb == np.array(rgb), axis=-1) label_ids[mask] = list(color_map.keys()).index(rgb) + 1 plt.imshow(image) plt.imshow(label_ids, alpha=0.5, cmap='jet') plt.colorbar(label='Class IDs')4.2 批量生成可视化结果
使用多进程加速批量处理:
from multiprocessing import Pool def batch_visualize(args): img_path, label_path, output_dir = args try: result = visualize_sample(img_path, label_path, color_map) output_path = f"{output_dir}/{Path(img_path).stem}.jpg" plt.savefig(output_path) plt.close() except Exception as e: print(f"Error processing {img_path}: {str(e)}") # 使用示例 image_paths = [...] # 列表形式存储所有图像路径 label_paths = [...] # 对应的标注路径 with Pool(8) as p: # 8个进程并行 p.map(batch_visualize, zip(image_paths, label_paths, ['output']*len(image_paths)))5. 实际应用中的问题解决
5.1 内存优化技巧
处理高分辨率图像时(如4000×6000像素),可采用分块加载:
def load_image_tiles(path, tile_size=1024): img = Image.open(path) width, height = img.size for i in range(0, width, tile_size): for j in range(0, height, tile_size): box = (i, j, min(i+tile_size, width), min(j+tile_size, height)) yield img.crop(box), box5.2 标签不一致处理
部分早期版本存在标注错误,可通过后处理修正:
def validate_labels(label_path, color_map): label = cv2.imread(label_path) unique_colors = np.unique(label.reshape(-1, 3), axis=0) invalid = [] for color in unique_colors: if tuple(color) not in color_map: invalid.append(color) if invalid: print(f"发现{len(invalid)}种无效颜色值") # 自动替换为最近的合法颜色 for bad_color in invalid: distances = [np.linalg.norm(np.array(bad_color)-np.array(c)) for c in color_map] closest = list(color_map.keys())[np.argmin(distances)] label[np.all(label == bad_color, axis=-1)] = closest return label5.3 与常用框架集成
将数据集转换为COCO格式以便使用MMSegmentation等框架:
def convert_to_coco(output_path): coco = { "info": {...}, "licenses": [...], "categories": [ {"id": idx, "name": name, "supercategory": name.split('--')[0]} for idx, name in enumerate(class_names) ], "images": [], "annotations": [] } # 需要实现图像遍历和标注转换逻辑 ... with open(output_path, 'w') as f: json.dump(coco, f)