码迷,mamicode.com
首页 > Web开发 > 详细

Mxnet 查看模型params的网络结构

时间:2018-06-12 17:09:26      阅读:602      评论:0      收藏:0      [点我收藏+]

标签:key   model   ==   check   int   import   span   dict   tee   

 

import mxnet as mx  
import pdb  
def load_checkpoint():  
    """ 
    Load model checkpoint from file. 
    :param prefix: Prefix of model name. 
    :param epoch: Epoch number of model we would like to load. 
    :return: (arg_params, aux_params) 
    arg_params : dict of str to NDArray 
        Model parameter, dict of name to NDArray of net‘s weights. 
    aux_params : dict of str to NDArray 
        Model parameter, dict of name to NDArray of net‘s auxiliary states. 
    """  
    save_dict = mx.nd.load(model-0000.params)  
    arg_params = {}  
    aux_params = {}  
    for k, v in save_dict.items():  
        tp, name = k.split(:, 1)  
        if tp == arg:  
            arg_params[name] = v  
        if tp == aux:  
            aux_params[name] = v  
    return arg_params, aux_params  
  
  
def convert_context(params, ctx):  
    """ 
    :param params: dict of str to NDArray 
    :param ctx: the context to convert to 
    :return: dict of str of NDArray with context ctx 
    """  
    new_params = dict()  
    for k, v in params.items():  
        new_params[k] = v.as_in_context(ctx)  
    #print new_params[0]  
    return new_params  
  
  
def load_param(convert=False, ctx=None):  
    """ 
    wrapper for load checkpoint 
    :param prefix: Prefix of model name. 
    :param epoch: Epoch number of model we would like to load. 
    :param convert: reference model should be converted to GPU NDArray first 
    :param ctx: if convert then ctx must be designated. 
    :return: (arg_params, aux_params) 
    """  
    arg_params, aux_params = load_checkpoint()  
    if convert:  
        if ctx is None:  
            ctx = mx.cpu()  
        arg_params = convert_context(arg_params, ctx)  
        aux_params = convert_context(aux_params, ctx)  
    return arg_params, aux_params  
  
  
if __name__==__main__:  
        result =  load_param();  
        #pdb.set_trace()  
        print result is  
        #print result
        for dic in result:
            for key in dic:
                print(key,dic[key].shape)
        # print ‘one of results is:‘  
        # print result[0][‘fc2_weight‘].asnumpy()  

 

python showmxmodel.py 2>&1 | tee log.txt
result is
(‘stage3_unit2_bn1_beta‘, (256L,))
(‘stage3_unit2_bn3_beta‘, (256L,))
(‘stage3_unit11_bn1_gamma‘, (256L,))
(‘stage3_unit5_bn3_gamma‘, (256L,))
(‘stage3_unit3_conv1_weight‘, (256L, 256L, 3L, 3L))
(‘stage2_unit1_bn3_gamma‘, (128L,))
(‘stage3_unit4_conv1_weight‘, (256L, 256L, 3L, 3L))
(‘stage3_unit12_bn3_beta‘, (256L,))
(‘stage2_unit2_bn3_beta‘, (128L,))
(‘conv0_weight‘, (64L, 3L, 3L, 3L))
(‘stage3_unit11_relu1_gamma‘, (256L,))
(‘stage4_unit1_conv1sc_weight‘, (512L, 256L, 1L, 1L))
(‘stage3_unit1_conv1sc_weight‘, (256L, 128L, 1L, 1L))
(‘bn1_beta‘, (512L,))
(‘stage1_unit2_bn2_beta‘, (64L,))
(‘stage3_unit2_conv2_weight‘, (256L, 256L, 3L, 3L))
(‘stage1_unit2_conv1_weight‘, (64L, 64L, 3L, 3L))
(‘stage3_unit14_bn2_beta‘, (256L,))
(‘stage4_unit2_bn3_beta‘, (512L,))
(‘stage3_unit8_bn1_gamma‘, (256L,))
(‘stage3_unit7_bn1_gamma‘, (256L,))
(‘stage2_unit3_bn1_beta‘, (128L,))
(‘stage2_unit4_conv1_weight‘, (128L, 128L, 3L, 3L))
(‘stage3_unit2_bn2_gamma‘, (256L,))
(‘stage1_unit1_conv1_weight‘, (64L, 64L, 3L, 3L))
(‘stage3_unit9_conv2_weight‘, (256L, 256L, 3L, 3L))
(‘stage3_unit13_conv1_weight‘, (256L, 256L, 3L, 3L))
(‘stage3_unit1_relu1_gamma‘, (256L,))
(‘stage4_unit1_bn3_beta‘, (512L,))
(‘stage2_unit1_bn2_beta‘, (128L,))
(‘stage3_unit14_conv1_weight‘, (256L, 256L, 3L, 3L))
(‘stage3_unit8_bn1_beta‘, (256L,))
(‘stage3_unit11_conv1_weight‘, (256L, 256L, 3L, 3L))
(‘stage1_unit1_bn3_gamma‘, (64L,))
(‘stage2_unit2_conv2_weight‘, (128L, 128L, 3L, 3L))
(‘stage4_unit2_bn1_gamma‘, (512L,))
(‘stage3_unit3_bn1_gamma‘, (256L,))
(‘stage1_unit3_bn2_gamma‘, (64L,))
(‘stage1_unit3_bn3_gamma‘, (64L,))
(‘stage4_unit2_relu1_gamma‘, (512L,))
(‘stage3_unit10_conv2_weight‘, (256L, 256L, 3L, 3L))
(‘stage3_unit12_conv1_weight‘, (256L, 256L, 3L, 3L))
(‘stage3_unit2_relu1_gamma‘, (256L,))
(‘stage3_unit10_bn2_beta‘, (256L,))
(‘stage2_unit3_bn3_gamma‘, (128L,))
(‘stage2_unit3_bn2_beta‘, (128L,))
(‘stage3_unit8_bn3_beta‘, (256L,))
(‘fc1_gamma‘, (512L,))
(‘stage3_unit14_bn3_gamma‘, (256L,))
(‘stage3_unit9_bn3_gamma‘, (256L,))
(‘stage2_unit3_bn3_beta‘, (128L,))
(‘stage3_unit1_sc_gamma‘, (256L,))
(‘stage3_unit7_bn1_beta‘, (256L,))
(‘stage1_unit2_bn3_beta‘, (64L,))
(‘stage3_unit14_relu1_gamma‘, (256L,))
(‘stage3_unit13_bn2_beta‘, (256L,))
(‘stage2_unit1_conv1sc_weight‘, (128L, 64L, 1L, 1L))
(‘bn0_beta‘, (64L,))
(‘stage3_unit12_bn1_gamma‘, (256L,))
(‘stage2_unit1_sc_gamma‘, (128L,))
(‘relu0_gamma‘, (64L,))
(‘stage2_unit2_bn2_gamma‘, (128L,))
(‘stage3_unit4_relu1_gamma‘, (256L,))

Mxnet 查看模型params的网络结构

标签:key   model   ==   check   int   import   span   dict   tee   

原文地址:https://www.cnblogs.com/adong7639/p/9173854.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!