Skip to content

Instantly share code, notes, and snippets.

@l-bat
Created July 14, 2020 05:53
Show Gist options
  • Select an option

  • Save l-bat/17ec566dce64c7bf4f3c6a31eb29ef06 to your computer and use it in GitHub Desktop.

Select an option

Save l-bat/17ec566dce64c7bf4f3c6a31eb29ef06 to your computer and use it in GitHub Desktop.
else if (layer_type == "Gather")
{
CV_Assert(node_proto.input_size() == 2);
Mat indexMat = getBlob(node_proto, constBlobs, 1);
CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
int index = indexMat.at<int>(0);
if ((constBlobs.find(node_proto.input(0)) != constBlobs.end()))
{
Mat input = getBlob(node_proto, constBlobs, 0);
Mat out;
if (layerParams.has("axis"))
{
int axis = layerParams.get<int>("axis");
std::vector<cv::Range> ranges(input.dims, Range::all());
ranges[axis] = Range(index, index + 1);
out = input(ranges);
}
else
{
CV_Assert(index < input.total());
const int dims = input.dims;
input = input.reshape(1, 1);
input.dims = 2;
out = input.reshape(1, 1).colRange(index, index + 1);
out.dims = dims;
}
addConstant(layerParams.name, out, constBlobs, outShapes);
continue;
}
else
{
shapeIt = outShapes.find(node_proto.input(0));
CV_Assert(shapeIt != outShapes.end());
MatShape inpShape = shapeIt->second;
if (layerParams.has("axis"))
{
LayerParams sliceLp;
sliceLp.type = "Slice";
sliceLp.name = layerParams.name + "/slice";
int axis = layerParams.get<int>("axis");
std::vector<int> begin(inpShape.size(), 0);
std::vector<int> end(inpShape.size(), -1);
begin[axis] = index;
end[axis] = index + 1;
cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt(begin.data(), begin.size());
cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt(end.data(), end.size());
sliceLp.set("begin", paramBegin);
sliceLp.set("end", paramEnd);
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(sliceLp.name);
addLayer(dstNet, sliceLp, proto, layer_id, outShapes);
std::vector<int> squeezeShape = inpShape;
squeezeShape.erase(squeezeShape.begin() + axis);
layerParams.type = "Reshape";
layerParams.set("dim", DictValue::arrayInt(&squeezeShape[0], squeezeShape.size()));
node_proto.set_input(0, sliceLp.name);
node_proto.set_output(0, layerParams.name);
}
else
{
// TODO: Support Gather without axis and add test
CV_Error(Error::StsNotImplemented, "Unsupported Gather op");
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment