Source code for zoo.pipeline.api.onnx.onnx_helper

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


[docs]class OnnxHelper:
[docs] @staticmethod def parse_attr(attr_proto): """Convert a list of AttributeProto to a dict, with names as keys.""" attrs = {} for a in attr_proto: for f in ['f', 'i', 's']: if a.HasField(f): attrs[a.name] = getattr(a, f) for f in ['floats', 'ints', 'strings']: if list(getattr(a, f)): assert a.name not in attrs, "Only one type of attr is allowed" attrs[a.name] = tuple(getattr(a, f)) for f in ['t', 'g']: if a.HasField(f): attrs[a.name] = getattr(a, f) for f in ['tensors', 'graphs']: if list(getattr(a, f)): raise NotImplementedError("Filed {} is not supported in mxnet.".format(f)) if a.name not in attrs: raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) return attrs
[docs] @staticmethod def to_numpy(tensor_proto): """Grab data in TensorProto and to_tensor to numpy array.""" try: from onnx.numpy_helper import to_array except ImportError as e: raise ImportError("Unable to import onnx which is required {}".format(e)) np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims)) return np_array
[docs] @staticmethod def get_shape_from_node(valueInfoProto): return [int(dim.dim_value) for dim in valueInfoProto.type.tensor_type.shape.dim]
[docs] @staticmethod def get_padds(onnx_attr): border_mode = None pads = None if "auto_pad" in onnx_attr.keys(): if onnx_attr['auto_pad'].decode() == 'SAME_UPPER': border_mode = 'same' elif onnx_attr['auto_pad'].decode() == 'VALID': border_mode = 'valid' elif onnx_attr['auto_pad'].decode() == 'NOTSET': assert "pads" in onnx_attr.keys(), "you should specify pads explicitly" else: raise NotImplementedError('Unknown auto_pad mode "%s", ' 'only SAME_UPPER and VALID are supported.' % onnx_attr['auto_pad']) # In ONNX, pads format is [x1_begin, x2_begin...x1_end, x2_end,...]. # While pads format we supported should be [x1_begin, x1_end, x2_begin, x2_end...] if "pads" in onnx_attr.keys(): pads = [int(i) for i in onnx_attr["pads"]] if len(pads) == 4: assert pads[0] == pads[2] assert pads[1] == pads[3] pads = [pads[0], pads[1]] elif len(pads) == 2: assert pads[0] == pads[1] pads = pads[0] return border_mode, pads