Skip to content

Instantly share code, notes, and snippets.

@RussellLuo
Last active April 15, 2020 10:35
Show Gist options
  • Select an option

  • Save RussellLuo/9ee9585e3c2b0dbd0298574c241e1bcf to your computer and use it in GitHub Desktop.

Select an option

Save RussellLuo/9ee9585e3c2b0dbd0298574c241e1bcf to your computer and use it in GitHub Desktop.
gRPC client interface for Python: generation script and mocking class.
# -*- coding=utf-8 -*-
"""Generate a more pythonic interface based on the code generated by
`grpcio-tools`.
Example:
$ python grpc_pi.py --pb2-module-name='python.path.xx_pb2' --stub-class-name='XxStub'
"""
import argparse
import re
import sys
from collections import OrderedDict
from importlib import import_module
import grpc
class Generator(object):
writer = sys.stdout
def __init__(self, pb2_module_name, stub_class_name, core_method_name):
self.pb2_module_name = pb2_module_name
self.stub_class_name = stub_class_name
self.core_method_name = core_method_name
self.pb2_path, self.pb2_module = self.pb2_module_name.rsplit('.', 1)
@staticmethod
def underscore(word):
"""Make an underscored, lowercase form from the expression
in the string.
"""
word = re.sub(r"([A-Z]+)([A-Z][a-z])", r'\1_\2', word)
word = re.sub(r"([a-z\d])([A-Z])", r'\1_\2', word)
word = word.replace("-", "_")
return word.lower()
def write_module_header(self):
self.writer.write(
'# -*- coding=utf-8 -*-'
'\n\nfrom {} import {}'.format(self.pb2_path, self.pb2_module)
)
def write_class_header(self):
class_prefix = self.stub_class_name.rstrip('Stub')
self.writer.write(
'\n\n\nclass {}Interface(object):\n'
'\n timeout = 10\n'.format(class_prefix)
)
def write_stub_property(self):
self.writer.write(
'\n @property\n'
' def stub(self):\n'
' # channel = grpc.insecure_channel(...)\n'
' # return {pb2_module}.{stub_class_name}(channel)\n'.format(
pb2_module=self.pb2_module,
stub_class_name=self.stub_class_name
)
)
def write_core_method(self):
self.writer.write(
'\n def {core_method_name}(self, rpc_name, req_name, **kwargs):\n'
' req_class = getattr({pb2_module}, req_name)\n'
' req = req_class(**kwargs)\n\n'
' rpc = getattr(self.stub, rpc_name)\n'
' resp = rpc(req, self.timeout)\n'
' return resp\n'.format(
core_method_name=self.core_method_name,
pb2_module=self.pb2_module
)
)
def write_rpc_method(self, method_name, req_name, req_param_names):
indented_header = ' def {}('.format(method_name)
indent_len = len(indented_header)
indented_params = ',\n'.join(
(indent_len * ' ') + param_name
for param_name in req_param_names
)
indented_kwargs = ',\n'.join(
' {0}={0}'.format(param_name)
for param_name in req_param_names
)
indented_body = (
" resp = self.{core_method_name}(\n"
" '{method_name}',\n"
" '{req_name}',\n"
"{indented_kwargs}\n"
" )\n"
" return resp\n".format(
core_method_name=self.core_method_name,
method_name=method_name,
req_name=req_name,
indented_kwargs=indented_kwargs
)
)
self.writer.write(
'\n{indented_header}self,\n'
'{indented_params}):\n'
'{indented_body}'.format(
indented_header=indented_header,
indented_params=indented_params,
indented_body=indented_body
)
)
def write_rpc_methods(self):
pb2_module = import_module(self.pb2_module_name)
stub_class = getattr(pb2_module, self.stub_class_name)
channel = grpc.insecure_channel('localhost')
stub = stub_class(channel)
stub_method_names = [
attr
for attr in dir(stub)
if not attr.startswith('__')
]
stub_method_names.sort()
stub_methods = OrderedDict([
(stub_method_name, getattr(stub, stub_method_name))
for stub_method_name in stub_method_names
])
for stub_method_name, stub_method in stub_methods.iteritems():
req_class = stub_method._request_serializer.im_class
req_name = req_class.__name__
req_param_names = [
self.underscore(field.name)
for field in req_class.DESCRIPTOR.fields
]
method_name = self.underscore(stub_method_name)
self.write_rpc_method(method_name, req_name, req_param_names)
def generate(self):
self.write_module_header()
self.write_class_header()
self.write_stub_property()
self.write_core_method()
self.write_rpc_methods()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--pb2-module-name', required=True,
help='The name of the generated `xx_pdb2.py` '
'module with the full Python path.')
parser.add_argument('--stub-class-name', required=True,
help='The name of the generated `XxStub` class.')
parser.add_argument('--core-method-name', default='call_rpc',
help='The name of the core method that will be '
'used to call the actual rpc methods.')
args = parser.parse_args()
generator = Generator(args.pb2_module_name,
args.stub_class_name,
args.core_method_name)
generator.generate()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment