标签:weight copy txt prot stat child net 类型 name
#由于系统没有同时安装caffe和pytorch,一个在系统下,一个在conda中,应该是隔离的python环境,一般不能用。
#因而只能用numpy当做中间媒介,下面代码是numpy存储的caffe网络,将之转成pytorch
#我没有自动化那个prototxt的转换,没没必要,自己写的一摸一样的pytorch网络
def net_from_caffe(n,re): #n是pytorch的model, re 是numpy存储的caffemodel i=-1 for name, l1 in n.named_children(): try: l2 = getattr(n, name) l2.weight # skip ReLU / Dropout except Exception: continue i+=1 while len(re[i][‘weights‘])==0 and i<len(re): #在numpy中非conv和全连接层是没有weights的,只对齐这两个layer就行了 i+=1 w=torch.from_numpy(re[i][‘weights‘][0])# b=torch.from_numpy(re[i][‘weights‘][1]) assert w.size() == l2.weight.size() assert b.size() == l2.bias.size() l2.weight.data.copy_(w) l2.bias.data.copy_(b)
坑点:
1.pil在打开图片时,默认rgb,默认0-1范围。要搞成0-255的自己去乘
2.有个注意的点,pytorch在第一次con到全联接的时候,要做一个展开操作,直接h=h.view(h.size(0),-1)就可以和caffe的一一对应
3.rgb转bgr:im=im[[2,0,1],...]
torch.load的两种方式:
1.直接存model
但是这样子model的数据类型是固定的,你必须让这个数据类型在调用出可见才能打开
2.存state_dict
比较灵活,直接对参数赋值,没有外面包裹的数据类型,就是多了点麻烦
caffe的model和prototxt转pytorch的model
标签:weight copy txt prot stat child net 类型 name
原文地址:https://www.cnblogs.com/waldenlake/p/9838377.html