码迷,mamicode.com
首页 > Web开发 > 详细

基于 ONNX 在 ML.NET 中使用 Pytorch 训练的垃圾分类模型

时间:2020-06-22 14:46:22      阅读:99      评论:0      收藏:0      [点我收藏+]

标签:catch   its   init   机器学习   out   param   res   path   oba   

ML.NET 在经典机器学习范畴内,对分类、回归、异常检测等问题开发模型已经有非常棒的表现了,我之前的文章都有过介绍。当然我们希望在更高层次的领域加以使用,例如计算机视觉、自然语言处理和信号处理等等领域。

图像识别是计算机视觉的一类分支,AI研发者们较为熟悉的是使用TensorFlow、Pytorch、Keras、MXNET等框架来训练深度神经网络模型,其中会涉及到CNN(卷积神经网络)、DNN(深度神经网络)的相关算法。

ML.NET 在较早期的版本是无法支持这类研究的,可喜的是最新的版本不但能很好地集成 TensorFlow 的模型做迁移学习,还可以直接导入 DNN 常见预编译模型:AlexNet、ResNet18、ResNet50、ResNet101 实现对图像的分类、识别等。

技术图片

 

 

 

我特别想推荐的是,ML.NET 最新版本对 ONNX 的支持也是非常强劲,通过 ONNX 可以把众多其他优秀深度学习框架的模型引入到 .NET Core 运行时中,极大地扩充了 .NET 应用在智能认知服务的丰富程度。在 Microsoft Docs 中已经提供了一个基于 ONNX 使用 Tiny YOLOv2 做对象检测的例子。为了展现 ML.NET 在其他框架上的通用性,本文将介绍使用 Pytorch 训练的垃圾分类的模型,基于 ONNX 导入到 ML.NET 中完成预测。

在2019年9月华为云举办了一次人工智能大赛·垃圾分类挑战杯,首次将AI与环保主题结合,展现人工智能技术在生活中的运用。有幸我看到了本次大赛亚军方案的分享,并且在 github 上找到了开源代码,按照 README 说明,我用 Pytorch 训练出了一个模型,并保存为garbage.pt 文件。

 

生成 ONNX 模型

首先,我使用以下 Pytorch 代码来生成一个garbage.pt 对应的文件,命名为 garbage.onnx

torch_model = torch.load("garbage.pt") # pytorch模型加载
    batch_size = 1  #批处理大小
    input_shape = (3,224,224)   #输入数据

    # # set the model to inference mode
    torch_model.eval()

    x = torch.randn(batch_size, *input_shape, device=cuda)        # 生成张量
    export_onnx_file = "garbage.onnx"                    # 目的ONNX文件名
 
    
    torch.onnx.export(torch_model.module,
                        x,
                        export_onnx_file,
                        input_names=["input"],        # 输入名
                        output_names=["output"]    # 输出名
                        )

 

准备 ML.NET 项目

创建一个 .NET Core 控制台应用,按如下结构创建好合适的目录。assets 目录下的 images 子目录将放置待预测的图片,而 Model 子目录放入前一个步骤生成的 garbage.onnx 模型文件。

技术图片

 

 

ImageNetData 和 ImageNetPrediction 类定义了输入和输出的数据结构。

using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;

namespace GarbageDemo.DataStructures
{
    public class ImageNetData
    {
        [LoadColumn(0)]
        public string ImagePath;

        [LoadColumn(1)]
        public string Label;

        public static IEnumerable<ImageNetData> ReadFromFile(string imageFolder)
        {
            return Directory
               .GetFiles(imageFolder)
               .Where(filePath => Path.GetExtension(filePath) == ".jpg")
               .Select(filePath => new ImageNetData { ImagePath = filePath, Label = Path.GetFileName(filePath) });

        }
    }

    public class ImageNetPrediction : ImageNetData
    {
        public float[] Score;

        public string PredictedLabelValue;
    }
}

 

OnnxModelScorer 类定义了 ONNX 模型加载、打分预测的过程。注意 ImageNetModelSettings 的属性和第一步中指定的输入输出字段名要一致。
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms.Onnx;
using Microsoft.ML.Transforms.Image;
using GarbageDemo.DataStructures;

namespace GarbageDemo
{
    class OnnxModelScorer
    {
        private readonly string imagesFolder;
        private readonly string modelLocation;
        private readonly MLContext mlContext;


        public OnnxModelScorer(string imagesFolder, string modelLocation, MLContext mlContext)
        {
            this.imagesFolder = imagesFolder;
            this.modelLocation = modelLocation;
            this.mlContext = mlContext;
        }

        public struct ImageNetSettings
        {
            public const int imageHeight = 224;
            public const int imageWidth = 224;    
            public const float Mean = 127;
            public const float Scale = 1;
            public const bool ChannelsLast = false;
        } 
        
        public struct ImageNetModelSettings
        {
            // input tensor name
            public const string ModelInput = "input";

            // output tensor name
            public const string ModelOutput = "output";
        }

        private ITransformer LoadModel(string modelLocation)
        {
            Console.WriteLine("Read model");
            Console.WriteLine($"Model location: {modelLocation}");
            Console.WriteLine($"Default parameters: image size=({ImageNetSettings.imageWidth},{ImageNetSettings.imageHeight})");

            // Create IDataView from empty list to obtain input data schema
            var data = mlContext.Data.LoadFromEnumerable(new List<ImageNetData>());

            // Define scoring pipeline
            var pipeline = mlContext.Transforms.LoadImages(outputColumnName: ImageNetModelSettings.ModelInput, imageFolder: "", inputColumnName: nameof(ImageNetData.ImagePath))                           
                            .Append(mlContext.Transforms.ResizeImages(outputColumnName: ImageNetModelSettings.ModelInput, 
                                                                        imageWidth: ImageNetSettings.imageWidth, 
                                                                        imageHeight: ImageNetSettings.imageHeight, 
                                                                        inputColumnName: ImageNetModelSettings.ModelInput,
                                                                        resizing: ImageResizingEstimator.ResizingKind.IsoCrop,
                                                                        cropAnchor: ImageResizingEstimator.Anchor.Center
                                                                        ))
                            .Append(mlContext.Transforms.ExtractPixels(outputColumnName: ImageNetModelSettings.ModelInput, interleavePixelColors: ImageNetSettings.ChannelsLast))
                            .Append(mlContext.Transforms.NormalizeGlobalContrast(outputColumnName: ImageNetModelSettings.ModelInput, 
                                                                                 inputColumnName: ImageNetModelSettings.ModelInput, 
                                                                                 ensureZeroMean : true, 
                                                                                 ensureUnitStandardDeviation: true, 
                                                                                 scale: ImageNetSettings.Scale))
                            .Append(mlContext.Transforms.ApplyOnnxModel(modelFile: modelLocation, outputColumnNames: new[] { ImageNetModelSettings.ModelOutput }, inputColumnNames: new[] { ImageNetModelSettings.ModelInput }));

            // Fit scoring pipeline
            var model = pipeline.Fit(data);

            return model;
        }

        private IEnumerable<float[]> PredictDataUsingModel(IDataView testData, ITransformer model)
        {
            Console.WriteLine($"Images location: {imagesFolder}");
            Console.WriteLine("");
            Console.WriteLine("=====Identify the objects in the images=====");
            Console.WriteLine("");

            IDataView scoredData = model.Transform(testData);

            IEnumerable<float[]> probabilities = scoredData.GetColumn<float[]>(ImageNetModelSettings.ModelOutput);

            return probabilities;
        }

        public IEnumerable<float[]> Score(IDataView data)
        {
            var model = LoadModel(modelLocation);

            return PredictDataUsingModel(data, model);
        }
    }
}

 

Program 类中定义了调用过程,完成预测结果呈现。

using GarbageDemo.DataStructures;
using Microsoft.ML;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;

namespace GarbageDemo
{
    class Program
    {
        static void Main(string[] args)
        {
            var assetsRelativePath = @"../../../assets";
            string assetsPath = GetAbsolutePath(assetsRelativePath);
            var modelFilePath = Path.Combine(assetsPath, "Model", "garbage.onnx");
            var imagesFolder = Path.Combine(assetsPath, "images");// Initialize MLContext
            MLContext mlContext = new MLContext();

            try
            {
                // Load Data
                IEnumerable<ImageNetData> images = ImageNetData.ReadFromFile(imagesFolder);
                IDataView imageDataView = mlContext.Data.LoadFromEnumerable(images);

                // Create instance of model scorer
                var modelScorer = new OnnxModelScorer(imagesFolder, modelFilePath, mlContext);

                // Use model to score data
                IEnumerable<float[]> probabilities = modelScorer.Score(imageDataView);

                int index = 0;
                foreach (var probable in probabilities)
                {
                    var scores = Softmax(probable);

                    var (topResultIndex, topResultScore) = scores.Select((predictedClass, index) => (Index: index, Value: predictedClass))
                        .OrderByDescending(result => result.Value)
                        .First();
                    Console.WriteLine("图片:{3} \r\n 分类{2} {0}:{1}", labels[topResultIndex], topResultScore, topResultIndex, images.ElementAt(index).ImagePath);
                    Console.WriteLine("=============================");
                    index++;
                }

            }
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
            }

            Console.WriteLine("========= End of Process..Hit any Key ========");
            Console.ReadLine();
        }

        public static string GetAbsolutePath(string relativePath)
        {
            FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
            string assemblyFolderPath = _dataRoot.Directory.FullName;

            string fullPath = Path.Combine(assemblyFolderPath, relativePath);

            return fullPath;
        }

        private static float[] Softmax(float[] values)
        {
            var maxVal = values.Max();
            var exp = values.Select(v => Math.Exp(v - maxVal));
            var sumExp = exp.Sum();

            return exp.Select(v => (float)(v / sumExp)).ToArray();
        }

        private static string[] labels = new string[]
        {
            "其他垃圾/一次性快餐盒",
            "其他垃圾/污损塑料",
            "其他垃圾/烟蒂",
            "其他垃圾/牙签",
            "其他垃圾/破碎花盆及碟碗",
            "其他垃圾/竹筷",
            "厨余垃圾/剩饭剩菜",
            "厨余垃圾/大骨头",
            "厨余垃圾/水果果皮",
            "厨余垃圾/水果果肉",
            "厨余垃圾/茶叶渣",
            "厨余垃圾/菜叶菜根",
            "厨余垃圾/蛋壳",
            "厨余垃圾/鱼骨",
            "可回收物/充电宝",
            "可回收物/包",
            "可回收物/化妆品瓶",
            "可回收物/塑料玩具",
            "可回收物/塑料碗盆",
            "可回收物/塑料衣架",
            "可回收物/快递纸袋",
            "可回收物/插头电线",
            "可回收物/旧衣服",
            "可回收物/易拉罐",
            "可回收物/枕头",
            "可回收物/毛绒玩具",
            "可回收物/洗发水瓶",
            "可回收物/玻璃杯",
            "可回收物/皮鞋",
            "可回收物/砧板",
            "可回收物/纸板箱",
            "可回收物/调料瓶",
            "可回收物/酒瓶",
            "可回收物/金属食品罐",
            "可回收物/锅",
            "可回收物/食用油桶",
            "可回收物/饮料瓶",
            "有害垃圾/干电池",
            "有害垃圾/软膏",
            "有害垃圾/过期药物",
            "可回收物/毛巾",
            "可回收物/饮料盒",
            "可回收物/纸袋"
        };

 

选择一张图片放到 images 目录中,运行结果如下:

技术图片

 

有 0.88 的得分说明照片中的物品属于污损塑料,让我们看一下图片真相。

技术图片

 

果然是相当准确 ,并且把周边的附属物都过滤掉了。

 

对于 ML.NET 训练深度神经网络模型支持更复杂的场景是不是更有信心了!

基于 ONNX 在 ML.NET 中使用 Pytorch 训练的垃圾分类模型

标签:catch   its   init   机器学习   out   param   res   path   oba   

原文地址:https://www.cnblogs.com/BeanHsiang/p/13176454.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!