Merge commit '36bca61764984ff5395653cf8377ec5daa71b709' as 'libs/protobuf'

This commit is contained in:
Henry Winkel
2022-10-22 14:46:58 +02:00
2186 changed files with 838730 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
try:
__import__('pkg_resources').declare_namespace(__name__)
except ImportError:
__path__ = __import__('pkgutil').extend_path(__path__, __name__)

View File

@@ -0,0 +1,33 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Copyright 2007 Google Inc. All Rights Reserved.
__version__ = '4.21.8'

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,177 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Provides a container for DescriptorProtos."""
__author__ = 'matthewtoia@google.com (Matt Toia)'
import warnings
class Error(Exception):
pass
class DescriptorDatabaseConflictingDefinitionError(Error):
"""Raised when a proto is added with the same name & different descriptor."""
class DescriptorDatabase(object):
"""A container accepting FileDescriptorProtos and maps DescriptorProtos."""
def __init__(self):
self._file_desc_protos_by_file = {}
self._file_desc_protos_by_symbol = {}
def Add(self, file_desc_proto):
"""Adds the FileDescriptorProto and its types to this database.
Args:
file_desc_proto: The FileDescriptorProto to add.
Raises:
DescriptorDatabaseConflictingDefinitionError: if an attempt is made to
add a proto with the same name but different definition than an
existing proto in the database.
"""
proto_name = file_desc_proto.name
if proto_name not in self._file_desc_protos_by_file:
self._file_desc_protos_by_file[proto_name] = file_desc_proto
elif self._file_desc_protos_by_file[proto_name] != file_desc_proto:
raise DescriptorDatabaseConflictingDefinitionError(
'%s already added, but with different descriptor.' % proto_name)
else:
return
# Add all the top-level descriptors to the index.
package = file_desc_proto.package
for message in file_desc_proto.message_type:
for name in _ExtractSymbols(message, package):
self._AddSymbol(name, file_desc_proto)
for enum in file_desc_proto.enum_type:
self._AddSymbol(('.'.join((package, enum.name))), file_desc_proto)
for enum_value in enum.value:
self._file_desc_protos_by_symbol[
'.'.join((package, enum_value.name))] = file_desc_proto
for extension in file_desc_proto.extension:
self._AddSymbol(('.'.join((package, extension.name))), file_desc_proto)
for service in file_desc_proto.service:
self._AddSymbol(('.'.join((package, service.name))), file_desc_proto)
def FindFileByName(self, name):
"""Finds the file descriptor proto by file name.
Typically the file name is a relative path ending to a .proto file. The
proto with the given name will have to have been added to this database
using the Add method or else an error will be raised.
Args:
name: The file name to find.
Returns:
The file descriptor proto matching the name.
Raises:
KeyError if no file by the given name was added.
"""
return self._file_desc_protos_by_file[name]
def FindFileContainingSymbol(self, symbol):
"""Finds the file descriptor proto containing the specified symbol.
The symbol should be a fully qualified name including the file descriptor's
package and any containing messages. Some examples:
'some.package.name.Message'
'some.package.name.Message.NestedEnum'
'some.package.name.Message.some_field'
The file descriptor proto containing the specified symbol must be added to
this database using the Add method or else an error will be raised.
Args:
symbol: The fully qualified symbol name.
Returns:
The file descriptor proto containing the symbol.
Raises:
KeyError if no file contains the specified symbol.
"""
try:
return self._file_desc_protos_by_symbol[symbol]
except KeyError:
# Fields, enum values, and nested extensions are not in
# _file_desc_protos_by_symbol. Try to find the top level
# descriptor. Non-existent nested symbol under a valid top level
# descriptor can also be found. The behavior is the same with
# protobuf C++.
top_level, _, _ = symbol.rpartition('.')
try:
return self._file_desc_protos_by_symbol[top_level]
except KeyError:
# Raise the original symbol as a KeyError for better diagnostics.
raise KeyError(symbol)
def FindFileContainingExtension(self, extendee_name, extension_number):
# TODO(jieluo): implement this API.
return None
def FindAllExtensionNumbers(self, extendee_name):
# TODO(jieluo): implement this API.
return []
def _AddSymbol(self, name, file_desc_proto):
if name in self._file_desc_protos_by_symbol:
warn_msg = ('Conflict register for file "' + file_desc_proto.name +
'": ' + name +
' is already defined in file "' +
self._file_desc_protos_by_symbol[name].name + '"')
warnings.warn(warn_msg, RuntimeWarning)
self._file_desc_protos_by_symbol[name] = file_desc_proto
def _ExtractSymbols(desc_proto, package):
"""Pulls out all the symbols from a descriptor proto.
Args:
desc_proto: The proto to extract symbols from.
package: The package containing the descriptor type.
Yields:
The fully qualified name found in the descriptor.
"""
message_name = package + '.' + desc_proto.name if package else desc_proto.name
yield message_name
for nested_type in desc_proto.nested_type:
for symbol in _ExtractSymbols(nested_type, message_name):
yield symbol
for enum_type in desc_proto.enum_type:
yield '.'.join((message_name, enum_type.name))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,443 @@
#! /usr/bin/env python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Adds support for parameterized tests to Python's unittest TestCase class.
A parameterized test is a method in a test case that is invoked with different
argument tuples.
A simple example:
class AdditionExample(parameterized.TestCase):
@parameterized.parameters(
(1, 2, 3),
(4, 5, 9),
(1, 1, 3))
def testAddition(self, op1, op2, result):
self.assertEqual(result, op1 + op2)
Each invocation is a separate test case and properly isolated just
like a normal test method, with its own setUp/tearDown cycle. In the
example above, there are three separate testcases, one of which will
fail due to an assertion error (1 + 1 != 3).
Parameters for individual test cases can be tuples (with positional parameters)
or dictionaries (with named parameters):
class AdditionExample(parameterized.TestCase):
@parameterized.parameters(
{'op1': 1, 'op2': 2, 'result': 3},
{'op1': 4, 'op2': 5, 'result': 9},
)
def testAddition(self, op1, op2, result):
self.assertEqual(result, op1 + op2)
If a parameterized test fails, the error message will show the
original test name (which is modified internally) and the arguments
for the specific invocation, which are part of the string returned by
the shortDescription() method on test cases.
The id method of the test, used internally by the unittest framework,
is also modified to show the arguments. To make sure that test names
stay the same across several invocations, object representations like
>>> class Foo(object):
... pass
>>> repr(Foo())
'<__main__.Foo object at 0x23d8610>'
are turned into '<__main__.Foo>'. For even more descriptive names,
especially in test logs, you can use the named_parameters decorator. In
this case, only tuples are supported, and the first parameters has to
be a string (or an object that returns an apt name when converted via
str()):
class NamedExample(parameterized.TestCase):
@parameterized.named_parameters(
('Normal', 'aa', 'aaa', True),
('EmptyPrefix', '', 'abc', True),
('BothEmpty', '', '', True))
def testStartsWith(self, prefix, string, result):
self.assertEqual(result, strings.startswith(prefix))
Named tests also have the benefit that they can be run individually
from the command line:
$ testmodule.py NamedExample.testStartsWithNormal
.
--------------------------------------------------------------------
Ran 1 test in 0.000s
OK
Parameterized Classes
=====================
If invocation arguments are shared across test methods in a single
TestCase class, instead of decorating all test methods
individually, the class itself can be decorated:
@parameterized.parameters(
(1, 2, 3)
(4, 5, 9))
class ArithmeticTest(parameterized.TestCase):
def testAdd(self, arg1, arg2, result):
self.assertEqual(arg1 + arg2, result)
def testSubtract(self, arg2, arg2, result):
self.assertEqual(result - arg1, arg2)
Inputs from Iterables
=====================
If parameters should be shared across several test cases, or are dynamically
created from other sources, a single non-tuple iterable can be passed into
the decorator. This iterable will be used to obtain the test cases:
class AdditionExample(parameterized.TestCase):
@parameterized.parameters(
c.op1, c.op2, c.result for c in testcases
)
def testAddition(self, op1, op2, result):
self.assertEqual(result, op1 + op2)
Single-Argument Test Methods
============================
If a test method takes only one argument, the single argument does not need to
be wrapped into a tuple:
class NegativeNumberExample(parameterized.TestCase):
@parameterized.parameters(
-1, -3, -4, -5
)
def testIsNegative(self, arg):
self.assertTrue(IsNegative(arg))
"""
__author__ = 'tmarek@google.com (Torsten Marek)'
import functools
import re
import types
import unittest
import uuid
try:
# Since python 3
import collections.abc as collections_abc
except ImportError:
# Won't work after python 3.8
import collections as collections_abc
ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>')
_SEPARATOR = uuid.uuid1().hex
_FIRST_ARG = object()
_ARGUMENT_REPR = object()
def _CleanRepr(obj):
return ADDR_RE.sub(r'<\1>', repr(obj))
# Helper function formerly from the unittest module, removed from it in
# Python 2.7.
def _StrClass(cls):
return '%s.%s' % (cls.__module__, cls.__name__)
def _NonStringIterable(obj):
return (isinstance(obj, collections_abc.Iterable) and
not isinstance(obj, str))
def _FormatParameterList(testcase_params):
if isinstance(testcase_params, collections_abc.Mapping):
return ', '.join('%s=%s' % (argname, _CleanRepr(value))
for argname, value in testcase_params.items())
elif _NonStringIterable(testcase_params):
return ', '.join(map(_CleanRepr, testcase_params))
else:
return _FormatParameterList((testcase_params,))
class _ParameterizedTestIter(object):
"""Callable and iterable class for producing new test cases."""
def __init__(self, test_method, testcases, naming_type):
"""Returns concrete test functions for a test and a list of parameters.
The naming_type is used to determine the name of the concrete
functions as reported by the unittest framework. If naming_type is
_FIRST_ARG, the testcases must be tuples, and the first element must
have a string representation that is a valid Python identifier.
Args:
test_method: The decorated test method.
testcases: (list of tuple/dict) A list of parameter
tuples/dicts for individual test invocations.
naming_type: The test naming type, either _NAMED or _ARGUMENT_REPR.
"""
self._test_method = test_method
self.testcases = testcases
self._naming_type = naming_type
def __call__(self, *args, **kwargs):
raise RuntimeError('You appear to be running a parameterized test case '
'without having inherited from parameterized.'
'TestCase. This is bad because none of '
'your test cases are actually being run.')
def __iter__(self):
test_method = self._test_method
naming_type = self._naming_type
def MakeBoundParamTest(testcase_params):
@functools.wraps(test_method)
def BoundParamTest(self):
if isinstance(testcase_params, collections_abc.Mapping):
test_method(self, **testcase_params)
elif _NonStringIterable(testcase_params):
test_method(self, *testcase_params)
else:
test_method(self, testcase_params)
if naming_type is _FIRST_ARG:
# Signal the metaclass that the name of the test function is unique
# and descriptive.
BoundParamTest.__x_use_name__ = True
BoundParamTest.__name__ += str(testcase_params[0])
testcase_params = testcase_params[1:]
elif naming_type is _ARGUMENT_REPR:
# __x_extra_id__ is used to pass naming information to the __new__
# method of TestGeneratorMetaclass.
# The metaclass will make sure to create a unique, but nondescriptive
# name for this test.
BoundParamTest.__x_extra_id__ = '(%s)' % (
_FormatParameterList(testcase_params),)
else:
raise RuntimeError('%s is not a valid naming type.' % (naming_type,))
BoundParamTest.__doc__ = '%s(%s)' % (
BoundParamTest.__name__, _FormatParameterList(testcase_params))
if test_method.__doc__:
BoundParamTest.__doc__ += '\n%s' % (test_method.__doc__,)
return BoundParamTest
return (MakeBoundParamTest(c) for c in self.testcases)
def _IsSingletonList(testcases):
"""True iff testcases contains only a single non-tuple element."""
return len(testcases) == 1 and not isinstance(testcases[0], tuple)
def _ModifyClass(class_object, testcases, naming_type):
assert not getattr(class_object, '_id_suffix', None), (
'Cannot add parameters to %s,'
' which already has parameterized methods.' % (class_object,))
class_object._id_suffix = id_suffix = {}
# We change the size of __dict__ while we iterate over it,
# which Python 3.x will complain about, so use copy().
for name, obj in class_object.__dict__.copy().items():
if (name.startswith(unittest.TestLoader.testMethodPrefix)
and isinstance(obj, types.FunctionType)):
delattr(class_object, name)
methods = {}
_UpdateClassDictForParamTestCase(
methods, id_suffix, name,
_ParameterizedTestIter(obj, testcases, naming_type))
for name, meth in methods.items():
setattr(class_object, name, meth)
def _ParameterDecorator(naming_type, testcases):
"""Implementation of the parameterization decorators.
Args:
naming_type: The naming type.
testcases: Testcase parameters.
Returns:
A function for modifying the decorated object.
"""
def _Apply(obj):
if isinstance(obj, type):
_ModifyClass(
obj,
list(testcases) if not isinstance(testcases, collections_abc.Sequence)
else testcases,
naming_type)
return obj
else:
return _ParameterizedTestIter(obj, testcases, naming_type)
if _IsSingletonList(testcases):
assert _NonStringIterable(testcases[0]), (
'Single parameter argument must be a non-string iterable')
testcases = testcases[0]
return _Apply
def parameters(*testcases): # pylint: disable=invalid-name
"""A decorator for creating parameterized tests.
See the module docstring for a usage example.
Args:
*testcases: Parameters for the decorated method, either a single
iterable, or a list of tuples/dicts/objects (for tests
with only one argument).
Returns:
A test generator to be handled by TestGeneratorMetaclass.
"""
return _ParameterDecorator(_ARGUMENT_REPR, testcases)
def named_parameters(*testcases): # pylint: disable=invalid-name
"""A decorator for creating parameterized tests.
See the module docstring for a usage example. The first element of
each parameter tuple should be a string and will be appended to the
name of the test method.
Args:
*testcases: Parameters for the decorated method, either a single
iterable, or a list of tuples.
Returns:
A test generator to be handled by TestGeneratorMetaclass.
"""
return _ParameterDecorator(_FIRST_ARG, testcases)
class TestGeneratorMetaclass(type):
"""Metaclass for test cases with test generators.
A test generator is an iterable in a testcase that produces callables. These
callables must be single-argument methods. These methods are injected into
the class namespace and the original iterable is removed. If the name of the
iterable conforms to the test pattern, the injected methods will be picked
up as tests by the unittest framework.
In general, it is supposed to be used in conjunction with the
parameters decorator.
"""
def __new__(mcs, class_name, bases, dct):
dct['_id_suffix'] = id_suffix = {}
for name, obj in dct.copy().items():
if (name.startswith(unittest.TestLoader.testMethodPrefix) and
_NonStringIterable(obj)):
iterator = iter(obj)
dct.pop(name)
_UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator)
return type.__new__(mcs, class_name, bases, dct)
def _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator):
"""Adds individual test cases to a dictionary.
Args:
dct: The target dictionary.
id_suffix: The dictionary for mapping names to test IDs.
name: The original name of the test case.
iterator: The iterator generating the individual test cases.
"""
for idx, func in enumerate(iterator):
assert callable(func), 'Test generators must yield callables, got %r' % (
func,)
if getattr(func, '__x_use_name__', False):
new_name = func.__name__
else:
new_name = '%s%s%d' % (name, _SEPARATOR, idx)
assert new_name not in dct, (
'Name of parameterized test case "%s" not unique' % (new_name,))
dct[new_name] = func
id_suffix[new_name] = getattr(func, '__x_extra_id__', '')
class TestCase(unittest.TestCase, metaclass=TestGeneratorMetaclass):
"""Base class for test cases using the parameters decorator."""
def _OriginalName(self):
return self._testMethodName.split(_SEPARATOR)[0]
def __str__(self):
return '%s (%s)' % (self._OriginalName(), _StrClass(self.__class__))
def id(self): # pylint: disable=invalid-name
"""Returns the descriptive ID of the test.
This is used internally by the unittesting framework to get a name
for the test to be used in reports.
Returns:
The test id.
"""
return '%s.%s%s' % (_StrClass(self.__class__),
self._OriginalName(),
self._id_suffix.get(self._testMethodName, ''))
def CoopTestCase(other_base_class):
"""Returns a new base class with a cooperative metaclass base.
This enables the TestCase to be used in combination
with other base classes that have custom metaclasses, such as
mox.MoxTestBase.
Only works with metaclasses that do not override type.__new__.
Example:
import google3
import mox
from google3.testing.pybase import parameterized
class ExampleTest(parameterized.CoopTestCase(mox.MoxTestBase)):
...
Args:
other_base_class: (class) A test case base class.
Returns:
A new class object.
"""
metaclass = type(
'CoopMetaclass',
(other_base_class.__metaclass__,
TestGeneratorMetaclass), {})
return metaclass(
'CoopTestCase',
(other_base_class, TestCase), {})

View File

@@ -0,0 +1,51 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: jieluo@google.com (Jie Luo)
syntax = "proto2";
package google.protobuf.internal;
import "google/protobuf/any.proto";
message TestAny {
optional google.protobuf.Any value = 1;
optional int32 int_value = 2;
map<string,int32> map_value = 3;
extensions 10 to max;
}
message TestAnyExtension1 {
extend TestAny {
optional TestAnyExtension1 extension1 = 98418603;
}
optional int32 i = 15;
}

View File

@@ -0,0 +1,111 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define PY_SSIZE_T_CLEAN
#include <Python.h>
namespace google {
namespace protobuf {
namespace python {
// Version constant.
// This is either 0 for python, 1 for CPP V1, 2 for CPP V2.
//
// 0 is default and is equivalent to
// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
//
// 1 is set with -DPYTHON_PROTO2_CPP_IMPL_V1 and is equivalent to
// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp
// and
// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=1
//
// 2 is set with -DPYTHON_PROTO2_CPP_IMPL_V2 and is equivalent to
// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp
// and
// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2
#ifdef PYTHON_PROTO2_CPP_IMPL_V1
#error "PYTHON_PROTO2_CPP_IMPL_V1 is no longer supported."
#else
#ifdef PYTHON_PROTO2_CPP_IMPL_V2
static int kImplVersion = 2;
#else
#ifdef PYTHON_PROTO2_PYTHON_IMPL
static int kImplVersion = 0;
#else
static int kImplVersion = -1; // -1 means "Unspecified by compiler flags".
#endif // PYTHON_PROTO2_PYTHON_IMPL
#endif // PYTHON_PROTO2_CPP_IMPL_V2
#endif // PYTHON_PROTO2_CPP_IMPL_V1
static const char* kImplVersionName = "api_version";
static const char* kModuleName = "_api_implementation";
static const char kModuleDocstring[] =
"_api_implementation is a module that exposes compile-time constants that\n"
"determine the default API implementation to use for Python proto2.\n"
"\n"
"It complements api_implementation.py by setting defaults using "
"compile-time\n"
"constants defined in C, such that one can set defaults at compilation\n"
"(e.g. with blaze flag --copt=-DPYTHON_PROTO2_CPP_IMPL_V2).";
static struct PyModuleDef _module = {PyModuleDef_HEAD_INIT,
kModuleName,
kModuleDocstring,
-1,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr};
extern "C" {
PyMODINIT_FUNC PyInit__api_implementation() {
PyObject* module = PyModule_Create(&_module);
if (module == nullptr) {
return nullptr;
}
// Adds the module variable "api_version".
if (PyModule_AddIntConstant(module, const_cast<char*>(kImplVersionName),
kImplVersion)) {
Py_DECREF(module);
return nullptr;
}
return module;
}
}
} // namespace python
} // namespace protobuf
} // namespace google

View File

@@ -0,0 +1,162 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Determine which implementation of the protobuf API is used in this process.
"""
import importlib
import os
import sys
import warnings
def _ApiVersionToImplementationType(api_version):
if api_version == 2:
return 'cpp'
if api_version == 1:
raise ValueError('api_version=1 is no longer supported.')
if api_version == 0:
return 'python'
return None
_implementation_type = None
try:
# pylint: disable=g-import-not-at-top
from google.protobuf.internal import _api_implementation
# The compile-time constants in the _api_implementation module can be used to
# switch to a certain implementation of the Python API at build time.
_implementation_type = _ApiVersionToImplementationType(
_api_implementation.api_version)
except ImportError:
pass # Unspecified by compiler flags.
def _CanImport(mod_name):
try:
mod = importlib.import_module(mod_name)
# Work around a known issue in the classic bootstrap .par import hook.
if not mod:
raise ImportError(mod_name + ' import succeeded but was None')
return True
except ImportError:
return False
if _implementation_type is None:
if _CanImport('google._upb._message'):
_implementation_type = 'upb'
elif _CanImport('google.protobuf.pyext._message'):
_implementation_type = 'cpp'
else:
_implementation_type = 'python'
# This environment variable can be used to switch to a certain implementation
# of the Python API, overriding the compile-time constants in the
# _api_implementation module. Right now only 'python', 'cpp' and 'upb' are
# valid values. Any other value will raise error.
_implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION',
_implementation_type)
if _implementation_type not in ('python', 'cpp', 'upb'):
raise ValueError('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION {0} is not '
'supported. Please set to \'python\', \'cpp\' or '
'\'upb\'.'.format(_implementation_type))
if 'PyPy' in sys.version and _implementation_type == 'cpp':
warnings.warn('PyPy does not work yet with cpp protocol buffers. '
'Falling back to the python implementation.')
_implementation_type = 'python'
_c_module = None
if _implementation_type == 'cpp':
try:
# pylint: disable=g-import-not-at-top
from google.protobuf.pyext import _message
_c_module = _message
del _message
except ImportError:
# TODO(jieluo): fail back to python
warnings.warn(
'Selected implementation cpp is not available.')
pass
if _implementation_type == 'upb':
try:
# pylint: disable=g-import-not-at-top
from google._upb import _message
_c_module = _message
del _message
except ImportError:
warnings.warn('Selected implementation upb is not available. '
'Falling back to the python implementation.')
_implementation_type = 'python'
pass
# Detect if serialization should be deterministic by default
try:
# The presence of this module in a build allows the proto implementation to
# be upgraded merely via build deps.
#
# NOTE: Merely importing this automatically enables deterministic proto
# serialization for C++ code, but we still need to export it as a boolean so
# that we can do the same for `_implementation_type == 'python'`.
#
# NOTE2: It is possible for C++ code to enable deterministic serialization by
# default _without_ affecting Python code, if the C++ implementation is not in
# use by this module. That is intended behavior, so we don't actually expose
# this boolean outside of this module.
#
# pylint: disable=g-import-not-at-top,unused-import
from google.protobuf import enable_deterministic_proto_serialization
_python_deterministic_proto_serialization = True
except ImportError:
_python_deterministic_proto_serialization = False
# Usage of this function is discouraged. Clients shouldn't care which
# implementation of the API is in use. Note that there is no guarantee
# that differences between APIs will be maintained.
# Please don't use this function if possible.
def Type():
return _implementation_type
# See comment on 'Type' above.
# TODO(jieluo): Remove the API, it returns a constant. b/228102101
def Version():
return 2
# For internal use only
def IsPythonDefaultSerializationDeterministic():
return _python_deterministic_proto_serialization

View File

@@ -0,0 +1,130 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Builds descriptors, message classes and services for generated _pb2.py.
This file is only called in python generated _pb2.py files. It builds
descriptors, message classes and services that users can directly use
in generated code.
"""
__author__ = 'jieluo@google.com (Jie Luo)'
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
_sym_db = _symbol_database.Default()
def BuildMessageAndEnumDescriptors(file_des, module):
"""Builds message and enum descriptors.
Args:
file_des: FileDescriptor of the .proto file
module: Generated _pb2 module
"""
def BuildNestedDescriptors(msg_des, prefix):
for (name, nested_msg) in msg_des.nested_types_by_name.items():
module_name = prefix + name.upper()
module[module_name] = nested_msg
BuildNestedDescriptors(nested_msg, module_name + '_')
for enum_des in msg_des.enum_types:
module[prefix + enum_des.name.upper()] = enum_des
for (name, msg_des) in file_des.message_types_by_name.items():
module_name = '_' + name.upper()
module[module_name] = msg_des
BuildNestedDescriptors(msg_des, module_name + '_')
def BuildTopDescriptorsAndMessages(file_des, module_name, module):
"""Builds top level descriptors and message classes.
Args:
file_des: FileDescriptor of the .proto file
module_name: str, the name of generated _pb2 module
module: Generated _pb2 module
"""
def BuildMessage(msg_des):
create_dict = {}
for (name, nested_msg) in msg_des.nested_types_by_name.items():
create_dict[name] = BuildMessage(nested_msg)
create_dict['DESCRIPTOR'] = msg_des
create_dict['__module__'] = module_name
message_class = _reflection.GeneratedProtocolMessageType(
msg_des.name, (_message.Message,), create_dict)
_sym_db.RegisterMessage(message_class)
return message_class
# top level enums
for (name, enum_des) in file_des.enum_types_by_name.items():
module['_' + name.upper()] = enum_des
module[name] = enum_type_wrapper.EnumTypeWrapper(enum_des)
for enum_value in enum_des.values:
module[enum_value.name] = enum_value.number
# top level extensions
for (name, extension_des) in file_des.extensions_by_name.items():
module[name.upper() + '_FIELD_NUMBER'] = extension_des.number
module[name] = extension_des
# services
for (name, service) in file_des.services_by_name.items():
module['_' + name.upper()] = service
# Build messages.
for (name, msg_des) in file_des.message_types_by_name.items():
module[name] = BuildMessage(msg_des)
def BuildServices(file_des, module_name, module):
"""Builds services classes and services stub class.
Args:
file_des: FileDescriptor of the .proto file
module_name: str, the name of generated _pb2 module
module: Generated _pb2 module
"""
# pylint: disable=g-import-not-at-top
from google.protobuf import service as _service
from google.protobuf import service_reflection
# pylint: enable=g-import-not-at-top
for (name, service) in file_des.services_by_name.items():
module[name] = service_reflection.GeneratedServiceType(
name, (_service.Service,),
dict(DESCRIPTOR=service, __module__=module_name))
stub_name = name + '_Stub'
module[stub_name] = service_reflection.GeneratedServiceStubType(
stub_name, (module[name],),
dict(DESCRIPTOR=service, __module__=module_name))

View File

@@ -0,0 +1,710 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Contains container classes to represent different protocol buffer types.
This file defines container classes which represent categories of protocol
buffer field types which need extra maintenance. Currently these categories
are:
- Repeated scalar fields - These are all repeated fields which aren't
composite (e.g. they are of simple types like int32, string, etc).
- Repeated composite fields - Repeated fields which are composite. This
includes groups and nested messages.
"""
import collections.abc
import copy
import pickle
from typing import (
Any,
Iterable,
Iterator,
List,
MutableMapping,
MutableSequence,
NoReturn,
Optional,
Sequence,
TypeVar,
Union,
overload,
)
_T = TypeVar('_T')
_K = TypeVar('_K')
_V = TypeVar('_V')
class BaseContainer(Sequence[_T]):
"""Base container class."""
# Minimizes memory usage and disallows assignment to other attributes.
__slots__ = ['_message_listener', '_values']
def __init__(self, message_listener: Any) -> None:
"""
Args:
message_listener: A MessageListener implementation.
The RepeatedScalarFieldContainer will call this object's
Modified() method when it is modified.
"""
self._message_listener = message_listener
self._values = []
@overload
def __getitem__(self, key: int) -> _T:
...
@overload
def __getitem__(self, key: slice) -> List[_T]:
...
def __getitem__(self, key):
"""Retrieves item by the specified key."""
return self._values[key]
def __len__(self) -> int:
"""Returns the number of elements in the container."""
return len(self._values)
def __ne__(self, other: Any) -> bool:
"""Checks if another instance isn't equal to this one."""
# The concrete classes should define __eq__.
return not self == other
__hash__ = None
def __repr__(self) -> str:
return repr(self._values)
def sort(self, *args, **kwargs) -> None:
# Continue to support the old sort_function keyword argument.
# This is expected to be a rare occurrence, so use LBYL to avoid
# the overhead of actually catching KeyError.
if 'sort_function' in kwargs:
kwargs['cmp'] = kwargs.pop('sort_function')
self._values.sort(*args, **kwargs)
def reverse(self) -> None:
self._values.reverse()
# TODO(slebedev): Remove this. BaseContainer does *not* conform to
# MutableSequence, only its subclasses do.
collections.abc.MutableSequence.register(BaseContainer)
class RepeatedScalarFieldContainer(BaseContainer[_T], MutableSequence[_T]):
"""Simple, type-checked, list-like container for holding repeated scalars."""
# Disallows assignment to other attributes.
__slots__ = ['_type_checker']
def __init__(
self,
message_listener: Any,
type_checker: Any,
) -> None:
"""Args:
message_listener: A MessageListener implementation. The
RepeatedScalarFieldContainer will call this object's Modified() method
when it is modified.
type_checker: A type_checkers.ValueChecker instance to run on elements
inserted into this container.
"""
super().__init__(message_listener)
self._type_checker = type_checker
def append(self, value: _T) -> None:
"""Appends an item to the list. Similar to list.append()."""
self._values.append(self._type_checker.CheckValue(value))
if not self._message_listener.dirty:
self._message_listener.Modified()
def insert(self, key: int, value: _T) -> None:
"""Inserts the item at the specified position. Similar to list.insert()."""
self._values.insert(key, self._type_checker.CheckValue(value))
if not self._message_listener.dirty:
self._message_listener.Modified()
def extend(self, elem_seq: Iterable[_T]) -> None:
"""Extends by appending the given iterable. Similar to list.extend()."""
if elem_seq is None:
return
try:
elem_seq_iter = iter(elem_seq)
except TypeError:
if not elem_seq:
# silently ignore falsy inputs :-/.
# TODO(ptucker): Deprecate this behavior. b/18413862
return
raise
new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
if new_values:
self._values.extend(new_values)
self._message_listener.Modified()
def MergeFrom(
self,
other: Union['RepeatedScalarFieldContainer[_T]', Iterable[_T]],
) -> None:
"""Appends the contents of another repeated field of the same type to this
one. We do not check the types of the individual fields.
"""
self._values.extend(other)
self._message_listener.Modified()
def remove(self, elem: _T):
"""Removes an item from the list. Similar to list.remove()."""
self._values.remove(elem)
self._message_listener.Modified()
def pop(self, key: Optional[int] = -1) -> _T:
"""Removes and returns an item at a given index. Similar to list.pop()."""
value = self._values[key]
self.__delitem__(key)
return value
@overload
def __setitem__(self, key: int, value: _T) -> None:
...
@overload
def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
...
def __setitem__(self, key, value) -> None:
"""Sets the item on the specified position."""
if isinstance(key, slice):
if key.step is not None:
raise ValueError('Extended slices not supported')
self._values[key] = map(self._type_checker.CheckValue, value)
self._message_listener.Modified()
else:
self._values[key] = self._type_checker.CheckValue(value)
self._message_listener.Modified()
def __delitem__(self, key: Union[int, slice]) -> None:
"""Deletes the item at the specified position."""
del self._values[key]
self._message_listener.Modified()
def __eq__(self, other: Any) -> bool:
"""Compares the current instance with another one."""
if self is other:
return True
# Special case for the same type which should be common and fast.
if isinstance(other, self.__class__):
return other._values == self._values
# We are presumably comparing against some other sequence type.
return other == self._values
def __deepcopy__(
self,
unused_memo: Any = None,
) -> 'RepeatedScalarFieldContainer[_T]':
clone = RepeatedScalarFieldContainer(
copy.deepcopy(self._message_listener), self._type_checker)
clone.MergeFrom(self)
return clone
def __reduce__(self, **kwargs) -> NoReturn:
raise pickle.PickleError(
"Can't pickle repeated scalar fields, convert to list first")
# TODO(slebedev): Constrain T to be a subtype of Message.
class RepeatedCompositeFieldContainer(BaseContainer[_T], MutableSequence[_T]):
"""Simple, list-like container for holding repeated composite fields."""
# Disallows assignment to other attributes.
__slots__ = ['_message_descriptor']
def __init__(self, message_listener: Any, message_descriptor: Any) -> None:
"""
Note that we pass in a descriptor instead of the generated directly,
since at the time we construct a _RepeatedCompositeFieldContainer we
haven't yet necessarily initialized the type that will be contained in the
container.
Args:
message_listener: A MessageListener implementation.
The RepeatedCompositeFieldContainer will call this object's
Modified() method when it is modified.
message_descriptor: A Descriptor instance describing the protocol type
that should be present in this container. We'll use the
_concrete_class field of this descriptor when the client calls add().
"""
super().__init__(message_listener)
self._message_descriptor = message_descriptor
def add(self, **kwargs: Any) -> _T:
"""Adds a new element at the end of the list and returns it. Keyword
arguments may be used to initialize the element.
"""
new_element = self._message_descriptor._concrete_class(**kwargs)
new_element._SetListener(self._message_listener)
self._values.append(new_element)
if not self._message_listener.dirty:
self._message_listener.Modified()
return new_element
def append(self, value: _T) -> None:
"""Appends one element by copying the message."""
new_element = self._message_descriptor._concrete_class()
new_element._SetListener(self._message_listener)
new_element.CopyFrom(value)
self._values.append(new_element)
if not self._message_listener.dirty:
self._message_listener.Modified()
def insert(self, key: int, value: _T) -> None:
"""Inserts the item at the specified position by copying."""
new_element = self._message_descriptor._concrete_class()
new_element._SetListener(self._message_listener)
new_element.CopyFrom(value)
self._values.insert(key, new_element)
if not self._message_listener.dirty:
self._message_listener.Modified()
def extend(self, elem_seq: Iterable[_T]) -> None:
"""Extends by appending the given sequence of elements of the same type
as this one, copying each individual message.
"""
message_class = self._message_descriptor._concrete_class
listener = self._message_listener
values = self._values
for message in elem_seq:
new_element = message_class()
new_element._SetListener(listener)
new_element.MergeFrom(message)
values.append(new_element)
listener.Modified()
def MergeFrom(
self,
other: Union['RepeatedCompositeFieldContainer[_T]', Iterable[_T]],
) -> None:
"""Appends the contents of another repeated field of the same type to this
one, copying each individual message.
"""
self.extend(other)
def remove(self, elem: _T) -> None:
"""Removes an item from the list. Similar to list.remove()."""
self._values.remove(elem)
self._message_listener.Modified()
def pop(self, key: Optional[int] = -1) -> _T:
"""Removes and returns an item at a given index. Similar to list.pop()."""
value = self._values[key]
self.__delitem__(key)
return value
@overload
def __setitem__(self, key: int, value: _T) -> None:
...
@overload
def __setitem__(self, key: slice, value: Iterable[_T]) -> None:
...
def __setitem__(self, key, value):
# This method is implemented to make RepeatedCompositeFieldContainer
# structurally compatible with typing.MutableSequence. It is
# otherwise unsupported and will always raise an error.
raise TypeError(
f'{self.__class__.__name__} object does not support item assignment')
def __delitem__(self, key: Union[int, slice]) -> None:
"""Deletes the item at the specified position."""
del self._values[key]
self._message_listener.Modified()
def __eq__(self, other: Any) -> bool:
"""Compares the current instance with another one."""
if self is other:
return True
if not isinstance(other, self.__class__):
raise TypeError('Can only compare repeated composite fields against '
'other repeated composite fields.')
return self._values == other._values
class ScalarMap(MutableMapping[_K, _V]):
"""Simple, type-checked, dict-like container for holding repeated scalars."""
# Disallows assignment to other attributes.
__slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener',
'_entry_descriptor']
def __init__(
self,
message_listener: Any,
key_checker: Any,
value_checker: Any,
entry_descriptor: Any,
) -> None:
"""
Args:
message_listener: A MessageListener implementation.
The ScalarMap will call this object's Modified() method when it
is modified.
key_checker: A type_checkers.ValueChecker instance to run on keys
inserted into this container.
value_checker: A type_checkers.ValueChecker instance to run on values
inserted into this container.
entry_descriptor: The MessageDescriptor of a map entry: key and value.
"""
self._message_listener = message_listener
self._key_checker = key_checker
self._value_checker = value_checker
self._entry_descriptor = entry_descriptor
self._values = {}
def __getitem__(self, key: _K) -> _V:
try:
return self._values[key]
except KeyError:
key = self._key_checker.CheckValue(key)
val = self._value_checker.DefaultValue()
self._values[key] = val
return val
def __contains__(self, item: _K) -> bool:
# We check the key's type to match the strong-typing flavor of the API.
# Also this makes it easier to match the behavior of the C++ implementation.
self._key_checker.CheckValue(item)
return item in self._values
@overload
def get(self, key: _K) -> Optional[_V]:
...
@overload
def get(self, key: _K, default: _T) -> Union[_V, _T]:
...
# We need to override this explicitly, because our defaultdict-like behavior
# will make the default implementation (from our base class) always insert
# the key.
def get(self, key, default=None):
if key in self:
return self[key]
else:
return default
def __setitem__(self, key: _K, value: _V) -> _T:
checked_key = self._key_checker.CheckValue(key)
checked_value = self._value_checker.CheckValue(value)
self._values[checked_key] = checked_value
self._message_listener.Modified()
def __delitem__(self, key: _K) -> None:
del self._values[key]
self._message_listener.Modified()
def __len__(self) -> int:
return len(self._values)
def __iter__(self) -> Iterator[_K]:
return iter(self._values)
def __repr__(self) -> str:
return repr(self._values)
def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None:
self._values.update(other._values)
self._message_listener.Modified()
def InvalidateIterators(self) -> None:
# It appears that the only way to reliably invalidate iterators to
# self._values is to ensure that its size changes.
original = self._values
self._values = original.copy()
original[None] = None
# This is defined in the abstract base, but we can do it much more cheaply.
def clear(self) -> None:
self._values.clear()
self._message_listener.Modified()
def GetEntryClass(self) -> Any:
return self._entry_descriptor._concrete_class
class MessageMap(MutableMapping[_K, _V]):
"""Simple, type-checked, dict-like container for with submessage values."""
# Disallows assignment to other attributes.
__slots__ = ['_key_checker', '_values', '_message_listener',
'_message_descriptor', '_entry_descriptor']
def __init__(
self,
message_listener: Any,
message_descriptor: Any,
key_checker: Any,
entry_descriptor: Any,
) -> None:
"""
Args:
message_listener: A MessageListener implementation.
The ScalarMap will call this object's Modified() method when it
is modified.
key_checker: A type_checkers.ValueChecker instance to run on keys
inserted into this container.
value_checker: A type_checkers.ValueChecker instance to run on values
inserted into this container.
entry_descriptor: The MessageDescriptor of a map entry: key and value.
"""
self._message_listener = message_listener
self._message_descriptor = message_descriptor
self._key_checker = key_checker
self._entry_descriptor = entry_descriptor
self._values = {}
def __getitem__(self, key: _K) -> _V:
key = self._key_checker.CheckValue(key)
try:
return self._values[key]
except KeyError:
new_element = self._message_descriptor._concrete_class()
new_element._SetListener(self._message_listener)
self._values[key] = new_element
self._message_listener.Modified()
return new_element
def get_or_create(self, key: _K) -> _V:
"""get_or_create() is an alias for getitem (ie. map[key]).
Args:
key: The key to get or create in the map.
This is useful in cases where you want to be explicit that the call is
mutating the map. This can avoid lint errors for statements like this
that otherwise would appear to be pointless statements:
msg.my_map[key]
"""
return self[key]
@overload
def get(self, key: _K) -> Optional[_V]:
...
@overload
def get(self, key: _K, default: _T) -> Union[_V, _T]:
...
# We need to override this explicitly, because our defaultdict-like behavior
# will make the default implementation (from our base class) always insert
# the key.
def get(self, key, default=None):
if key in self:
return self[key]
else:
return default
def __contains__(self, item: _K) -> bool:
item = self._key_checker.CheckValue(item)
return item in self._values
def __setitem__(self, key: _K, value: _V) -> NoReturn:
raise ValueError('May not set values directly, call my_map[key].foo = 5')
def __delitem__(self, key: _K) -> None:
key = self._key_checker.CheckValue(key)
del self._values[key]
self._message_listener.Modified()
def __len__(self) -> int:
return len(self._values)
def __iter__(self) -> Iterator[_K]:
return iter(self._values)
def __repr__(self) -> str:
return repr(self._values)
def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None:
# pylint: disable=protected-access
for key in other._values:
# According to documentation: "When parsing from the wire or when merging,
# if there are duplicate map keys the last key seen is used".
if key in self:
del self[key]
self[key].CopyFrom(other[key])
# self._message_listener.Modified() not required here, because
# mutations to submessages already propagate.
def InvalidateIterators(self) -> None:
# It appears that the only way to reliably invalidate iterators to
# self._values is to ensure that its size changes.
original = self._values
self._values = original.copy()
original[None] = None
# This is defined in the abstract base, but we can do it much more cheaply.
def clear(self) -> None:
self._values.clear()
self._message_listener.Modified()
def GetEntryClass(self) -> Any:
return self._entry_descriptor._concrete_class
class _UnknownField:
"""A parsed unknown field."""
# Disallows assignment to other attributes.
__slots__ = ['_field_number', '_wire_type', '_data']
def __init__(self, field_number, wire_type, data):
self._field_number = field_number
self._wire_type = wire_type
self._data = data
return
def __lt__(self, other):
# pylint: disable=protected-access
return self._field_number < other._field_number
def __eq__(self, other):
if self is other:
return True
# pylint: disable=protected-access
return (self._field_number == other._field_number and
self._wire_type == other._wire_type and
self._data == other._data)
class UnknownFieldRef: # pylint: disable=missing-class-docstring
def __init__(self, parent, index):
self._parent = parent
self._index = index
def _check_valid(self):
if not self._parent:
raise ValueError('UnknownField does not exist. '
'The parent message might be cleared.')
if self._index >= len(self._parent):
raise ValueError('UnknownField does not exist. '
'The parent message might be cleared.')
@property
def field_number(self):
self._check_valid()
# pylint: disable=protected-access
return self._parent._internal_get(self._index)._field_number
@property
def wire_type(self):
self._check_valid()
# pylint: disable=protected-access
return self._parent._internal_get(self._index)._wire_type
@property
def data(self):
self._check_valid()
# pylint: disable=protected-access
return self._parent._internal_get(self._index)._data
class UnknownFieldSet:
"""UnknownField container"""
# Disallows assignment to other attributes.
__slots__ = ['_values']
def __init__(self):
self._values = []
def __getitem__(self, index):
if self._values is None:
raise ValueError('UnknownFields does not exist. '
'The parent message might be cleared.')
size = len(self._values)
if index < 0:
index += size
if index < 0 or index >= size:
raise IndexError('index %d out of range'.index)
return UnknownFieldRef(self, index)
def _internal_get(self, index):
return self._values[index]
def __len__(self):
if self._values is None:
raise ValueError('UnknownFields does not exist. '
'The parent message might be cleared.')
return len(self._values)
def _add(self, field_number, wire_type, data):
unknown_field = _UnknownField(field_number, wire_type, data)
self._values.append(unknown_field)
return unknown_field
def __iter__(self):
for i in range(len(self)):
yield UnknownFieldRef(self, i)
def _extend(self, other):
if other is None:
return
# pylint: disable=protected-access
self._values.extend(other._values)
def __eq__(self, other):
if self is other:
return True
# Sort unknown fields because their order shouldn't
# affect equality test.
values = list(self._values)
if other is None:
return not values
values.sort()
# pylint: disable=protected-access
other_values = sorted(other._values)
return values == other_values
def _clear(self):
for value in self._values:
# pylint: disable=protected-access
if isinstance(value._data, UnknownFieldSet):
value._data._clear() # pylint: disable=protected-access
self._values = None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,127 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests for google.protobuf.descriptor_database."""
__author__ = 'matthewtoia@google.com (Matt Toia)'
import unittest
import warnings
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf.internal import no_package_pb2
from google.protobuf.internal import testing_refleaks
from google.protobuf import descriptor_database
@testing_refleaks.TestCase
class DescriptorDatabaseTest(unittest.TestCase):
def testAdd(self):
db = descriptor_database.DescriptorDatabase()
file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
factory_test2_pb2.DESCRIPTOR.serialized_pb)
file_desc_proto2 = descriptor_pb2.FileDescriptorProto.FromString(
no_package_pb2.DESCRIPTOR.serialized_pb)
db.Add(file_desc_proto)
db.Add(file_desc_proto2)
self.assertEqual(file_desc_proto, db.FindFileByName(
'google/protobuf/internal/factory_test2.proto'))
# Can find message type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message'))
# Can find nested message type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Message'))
# Can find enum type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Enum'))
# Can find nested enum type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum'))
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.MessageWithNestedEnumOnly.NestedEnum'))
# Can find field.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.list_field'))
# Can find enum value.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Enum.FACTORY_2_VALUE_0'))
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.FACTORY_2_VALUE_0'))
self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol(
'.NO_PACKAGE_VALUE_0'))
# Can find top level extension.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.another_field'))
# Can find nested extension inside a message.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.one_more_field'))
# Can find service.
file_desc_proto2 = descriptor_pb2.FileDescriptorProto.FromString(
unittest_pb2.DESCRIPTOR.serialized_pb)
db.Add(file_desc_proto2)
self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol(
'protobuf_unittest.TestService'))
# Non-existent field under a valid top level symbol can also be
# found. The behavior is the same with protobuf C++.
self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes.none_field'))
with self.assertRaisesRegex(KeyError, r'\'protobuf_unittest\.NoneMessage\''):
db.FindFileContainingSymbol('protobuf_unittest.NoneMessage')
def testConflictRegister(self):
db = descriptor_database.DescriptorDatabase()
unittest_fd = descriptor_pb2.FileDescriptorProto.FromString(
unittest_pb2.DESCRIPTOR.serialized_pb)
db.Add(unittest_fd)
conflict_fd = descriptor_pb2.FileDescriptorProto.FromString(
unittest_pb2.DESCRIPTOR.serialized_pb)
conflict_fd.name = 'other_file2'
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter('always')
db.Add(conflict_fd)
self.assertTrue(len(w))
self.assertIs(w[0].category, RuntimeWarning)
self.assertIn('Conflict register for file "other_file2": ',
str(w[0].message))
self.assertIn('already defined in file '
'"google/protobuf/unittest.proto"',
str(w[0].message))
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,96 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto2";
package google.protobuf.python.internal;
message DescriptorPoolTest1 {
extensions 1000 to max;
enum NestedEnum {
ALPHA = 1;
BETA = 2;
}
optional NestedEnum nested_enum = 1 [default = BETA];
message NestedMessage {
enum NestedEnum {
EPSILON = 5;
ZETA = 6;
}
optional NestedEnum nested_enum = 1 [default = ZETA];
optional string nested_field = 2 [default = "beta"];
optional DeepNestedMessage deep_nested_message = 3;
message DeepNestedMessage {
enum NestedEnum {
ETA = 7;
THETA = 8;
}
optional NestedEnum nested_enum = 1 [default = ETA];
optional string nested_field = 2 [default = "theta"];
}
}
optional NestedMessage nested_message = 2;
}
message DescriptorPoolTest2 {
enum NestedEnum {
GAMMA = 3;
DELTA = 4;
}
optional NestedEnum nested_enum = 1 [default = GAMMA];
message NestedMessage {
enum NestedEnum {
IOTA = 9;
KAPPA = 10;
}
optional NestedEnum nested_enum = 1 [default = IOTA];
optional string nested_field = 2 [default = "delta"];
optional DeepNestedMessage deep_nested_message = 3;
message DeepNestedMessage {
enum NestedEnum {
LAMBDA = 11;
MU = 12;
}
optional NestedEnum nested_enum = 1 [default = MU];
optional string nested_field = 2 [default = "lambda"];
}
}
optional NestedMessage nested_message = 2;
}

View File

@@ -0,0 +1,73 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto2";
package google.protobuf.python.internal;
import "google/protobuf/internal/descriptor_pool_test1.proto";
import public "google/protobuf/internal/more_messages.proto";
message DescriptorPoolTest3 {
extend DescriptorPoolTest1 {
optional DescriptorPoolTest3 descriptor_pool_test = 1001;
}
enum NestedEnum {
NU = 13;
XI = 14;
}
optional NestedEnum nested_enum = 1 [default = XI];
message NestedMessage {
enum NestedEnum {
OMICRON = 15;
PI = 16;
}
optional NestedEnum nested_enum = 1 [default = PI];
optional string nested_field = 2 [default = "nu"];
optional DeepNestedMessage deep_nested_message = 3;
message DeepNestedMessage {
enum NestedEnum {
RHO = 17;
SIGMA = 18;
}
optional NestedEnum nested_enum = 1 [default = RHO];
optional string nested_field = 2 [default = "sigma"];
}
}
optional NestedMessage nested_message = 2;
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,829 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Code for encoding protocol message primitives.
Contains the logic for encoding every logical protocol field type
into one of the 5 physical wire types.
This code is designed to push the Python interpreter's performance to the
limits.
The basic idea is that at startup time, for every field (i.e. every
FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The
sizer takes a value of this field's type and computes its byte size. The
encoder takes a writer function and a value. It encodes the value into byte
strings and invokes the writer function to write those strings. Typically the
writer function is the write() method of a BytesIO.
We try to do as much work as possible when constructing the writer and the
sizer rather than when calling them. In particular:
* We copy any needed global functions to local variables, so that we do not need
to do costly global table lookups at runtime.
* Similarly, we try to do any attribute lookups at startup time if possible.
* Every field's tag is encoded to bytes at startup, since it can't change at
runtime.
* Whatever component of the field size we can compute at startup, we do.
* We *avoid* sharing code if doing so would make the code slower and not sharing
does not burden us too much. For example, encoders for repeated fields do
not just call the encoders for singular fields in a loop because this would
add an extra function call overhead for every loop iteration; instead, we
manually inline the single-value encoder into the loop.
* If a Python function lacks a return statement, Python actually generates
instructions to pop the result of the last statement off the stack, push
None onto the stack, and then return that. If we really don't care what
value is returned, then we can save two instructions by returning the
result of the last statement. It looks funny but it helps.
* We assume that type and bounds checking has happened at a higher level.
"""
__author__ = 'kenton@google.com (Kenton Varda)'
import struct
from google.protobuf.internal import wire_format
# This will overflow and thus become IEEE-754 "infinity". We would use
# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
_POS_INF = 1e10000
_NEG_INF = -_POS_INF
def _VarintSize(value):
"""Compute the size of a varint value."""
if value <= 0x7f: return 1
if value <= 0x3fff: return 2
if value <= 0x1fffff: return 3
if value <= 0xfffffff: return 4
if value <= 0x7ffffffff: return 5
if value <= 0x3ffffffffff: return 6
if value <= 0x1ffffffffffff: return 7
if value <= 0xffffffffffffff: return 8
if value <= 0x7fffffffffffffff: return 9
return 10
def _SignedVarintSize(value):
"""Compute the size of a signed varint value."""
if value < 0: return 10
if value <= 0x7f: return 1
if value <= 0x3fff: return 2
if value <= 0x1fffff: return 3
if value <= 0xfffffff: return 4
if value <= 0x7ffffffff: return 5
if value <= 0x3ffffffffff: return 6
if value <= 0x1ffffffffffff: return 7
if value <= 0xffffffffffffff: return 8
if value <= 0x7fffffffffffffff: return 9
return 10
def _TagSize(field_number):
"""Returns the number of bytes required to serialize a tag with this field
number."""
# Just pass in type 0, since the type won't affect the tag+type size.
return _VarintSize(wire_format.PackTag(field_number, 0))
# --------------------------------------------------------------------
# In this section we define some generic sizers. Each of these functions
# takes parameters specific to a particular field type, e.g. int32 or fixed64.
# It returns another function which in turn takes parameters specific to a
# particular field, e.g. the field number and whether it is repeated or packed.
# Look at the next section to see how these are used.
def _SimpleSizer(compute_value_size):
"""A sizer which uses the function compute_value_size to compute the size of
each value. Typically compute_value_size is _VarintSize."""
def SpecificSizer(field_number, is_repeated, is_packed):
tag_size = _TagSize(field_number)
if is_packed:
local_VarintSize = _VarintSize
def PackedFieldSize(value):
result = 0
for element in value:
result += compute_value_size(element)
return result + local_VarintSize(result) + tag_size
return PackedFieldSize
elif is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
result += compute_value_size(element)
return result
return RepeatedFieldSize
else:
def FieldSize(value):
return tag_size + compute_value_size(value)
return FieldSize
return SpecificSizer
def _ModifiedSizer(compute_value_size, modify_value):
"""Like SimpleSizer, but modify_value is invoked on each value before it is
passed to compute_value_size. modify_value is typically ZigZagEncode."""
def SpecificSizer(field_number, is_repeated, is_packed):
tag_size = _TagSize(field_number)
if is_packed:
local_VarintSize = _VarintSize
def PackedFieldSize(value):
result = 0
for element in value:
result += compute_value_size(modify_value(element))
return result + local_VarintSize(result) + tag_size
return PackedFieldSize
elif is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
result += compute_value_size(modify_value(element))
return result
return RepeatedFieldSize
else:
def FieldSize(value):
return tag_size + compute_value_size(modify_value(value))
return FieldSize
return SpecificSizer
def _FixedSizer(value_size):
"""Like _SimpleSizer except for a fixed-size field. The input is the size
of one value."""
def SpecificSizer(field_number, is_repeated, is_packed):
tag_size = _TagSize(field_number)
if is_packed:
local_VarintSize = _VarintSize
def PackedFieldSize(value):
result = len(value) * value_size
return result + local_VarintSize(result) + tag_size
return PackedFieldSize
elif is_repeated:
element_size = value_size + tag_size
def RepeatedFieldSize(value):
return len(value) * element_size
return RepeatedFieldSize
else:
field_size = value_size + tag_size
def FieldSize(value):
return field_size
return FieldSize
return SpecificSizer
# ====================================================================
# Here we declare a sizer constructor for each field type. Each "sizer
# constructor" is a function that takes (field_number, is_repeated, is_packed)
# as parameters and returns a sizer, which in turn takes a field value as
# a parameter and returns its encoded size.
Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
SInt32Sizer = SInt64Sizer = _ModifiedSizer(
_SignedVarintSize, wire_format.ZigZagEncode)
Fixed32Sizer = SFixed32Sizer = FloatSizer = _FixedSizer(4)
Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
BoolSizer = _FixedSizer(1)
def StringSizer(field_number, is_repeated, is_packed):
"""Returns a sizer for a string field."""
tag_size = _TagSize(field_number)
local_VarintSize = _VarintSize
local_len = len
assert not is_packed
if is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
l = local_len(element.encode('utf-8'))
result += local_VarintSize(l) + l
return result
return RepeatedFieldSize
else:
def FieldSize(value):
l = local_len(value.encode('utf-8'))
return tag_size + local_VarintSize(l) + l
return FieldSize
def BytesSizer(field_number, is_repeated, is_packed):
"""Returns a sizer for a bytes field."""
tag_size = _TagSize(field_number)
local_VarintSize = _VarintSize
local_len = len
assert not is_packed
if is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
l = local_len(element)
result += local_VarintSize(l) + l
return result
return RepeatedFieldSize
else:
def FieldSize(value):
l = local_len(value)
return tag_size + local_VarintSize(l) + l
return FieldSize
def GroupSizer(field_number, is_repeated, is_packed):
"""Returns a sizer for a group field."""
tag_size = _TagSize(field_number) * 2
assert not is_packed
if is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
result += element.ByteSize()
return result
return RepeatedFieldSize
else:
def FieldSize(value):
return tag_size + value.ByteSize()
return FieldSize
def MessageSizer(field_number, is_repeated, is_packed):
"""Returns a sizer for a message field."""
tag_size = _TagSize(field_number)
local_VarintSize = _VarintSize
assert not is_packed
if is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
l = element.ByteSize()
result += local_VarintSize(l) + l
return result
return RepeatedFieldSize
else:
def FieldSize(value):
l = value.ByteSize()
return tag_size + local_VarintSize(l) + l
return FieldSize
# --------------------------------------------------------------------
# MessageSet is special: it needs custom logic to compute its size properly.
def MessageSetItemSizer(field_number):
"""Returns a sizer for extensions of MessageSet.
The message set message looks like this:
message MessageSet {
repeated group Item = 1 {
required int32 type_id = 2;
required string message = 3;
}
}
"""
static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
_TagSize(3))
local_VarintSize = _VarintSize
def FieldSize(value):
l = value.ByteSize()
return static_size + local_VarintSize(l) + l
return FieldSize
# --------------------------------------------------------------------
# Map is special: it needs custom logic to compute its size properly.
def MapSizer(field_descriptor, is_message_map):
"""Returns a sizer for a map field."""
# Can't look at field_descriptor.message_type._concrete_class because it may
# not have been initialized yet.
message_type = field_descriptor.message_type
message_sizer = MessageSizer(field_descriptor.number, False, False)
def FieldSize(map_value):
total = 0
for key in map_value:
value = map_value[key]
# It's wasteful to create the messages and throw them away one second
# later since we'll do the same for the actual encode. But there's not an
# obvious way to avoid this within the current design without tons of code
# duplication. For message map, value.ByteSize() should be called to
# update the status.
entry_msg = message_type._concrete_class(key=key, value=value)
total += message_sizer(entry_msg)
if is_message_map:
value.ByteSize()
return total
return FieldSize
# ====================================================================
# Encoders!
def _VarintEncoder():
"""Return an encoder for a basic varint value (does not include tag)."""
local_int2byte = struct.Struct('>B').pack
def EncodeVarint(write, value, unused_deterministic=None):
bits = value & 0x7f
value >>= 7
while value:
write(local_int2byte(0x80|bits))
bits = value & 0x7f
value >>= 7
return write(local_int2byte(bits))
return EncodeVarint
def _SignedVarintEncoder():
"""Return an encoder for a basic signed varint value (does not include
tag)."""
local_int2byte = struct.Struct('>B').pack
def EncodeSignedVarint(write, value, unused_deterministic=None):
if value < 0:
value += (1 << 64)
bits = value & 0x7f
value >>= 7
while value:
write(local_int2byte(0x80|bits))
bits = value & 0x7f
value >>= 7
return write(local_int2byte(bits))
return EncodeSignedVarint
_EncodeVarint = _VarintEncoder()
_EncodeSignedVarint = _SignedVarintEncoder()
def _VarintBytes(value):
"""Encode the given integer as a varint and return the bytes. This is only
called at startup time so it doesn't need to be fast."""
pieces = []
_EncodeVarint(pieces.append, value, True)
return b"".join(pieces)
def TagBytes(field_number, wire_type):
"""Encode the given tag and return the bytes. Only called at startup."""
return bytes(_VarintBytes(wire_format.PackTag(field_number, wire_type)))
# --------------------------------------------------------------------
# As with sizers (see above), we have a number of common encoder
# implementations.
def _SimpleEncoder(wire_type, encode_value, compute_value_size):
"""Return a constructor for an encoder for fields of a particular type.
Args:
wire_type: The field's wire type, for encoding tags.
encode_value: A function which encodes an individual value, e.g.
_EncodeVarint().
compute_value_size: A function which computes the size of an individual
value, e.g. _VarintSize().
"""
def SpecificEncoder(field_number, is_repeated, is_packed):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
def EncodePackedField(write, value, deterministic):
write(tag_bytes)
size = 0
for element in value:
size += compute_value_size(element)
local_EncodeVarint(write, size, deterministic)
for element in value:
encode_value(write, element, deterministic)
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag_bytes)
encode_value(write, element, deterministic)
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeField(write, value, deterministic):
write(tag_bytes)
return encode_value(write, value, deterministic)
return EncodeField
return SpecificEncoder
def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
"""Like SimpleEncoder but additionally invokes modify_value on every value
before passing it to encode_value. Usually modify_value is ZigZagEncode."""
def SpecificEncoder(field_number, is_repeated, is_packed):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
def EncodePackedField(write, value, deterministic):
write(tag_bytes)
size = 0
for element in value:
size += compute_value_size(modify_value(element))
local_EncodeVarint(write, size, deterministic)
for element in value:
encode_value(write, modify_value(element), deterministic)
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag_bytes)
encode_value(write, modify_value(element), deterministic)
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeField(write, value, deterministic):
write(tag_bytes)
return encode_value(write, modify_value(value), deterministic)
return EncodeField
return SpecificEncoder
def _StructPackEncoder(wire_type, format):
"""Return a constructor for an encoder for a fixed-width field.
Args:
wire_type: The field's wire type, for encoding tags.
format: The format string to pass to struct.pack().
"""
value_size = struct.calcsize(format)
def SpecificEncoder(field_number, is_repeated, is_packed):
local_struct_pack = struct.pack
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
def EncodePackedField(write, value, deterministic):
write(tag_bytes)
local_EncodeVarint(write, len(value) * value_size, deterministic)
for element in value:
write(local_struct_pack(format, element))
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeRepeatedField(write, value, unused_deterministic=None):
for element in value:
write(tag_bytes)
write(local_struct_pack(format, element))
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeField(write, value, unused_deterministic=None):
write(tag_bytes)
return write(local_struct_pack(format, value))
return EncodeField
return SpecificEncoder
def _FloatingPointEncoder(wire_type, format):
"""Return a constructor for an encoder for float fields.
This is like StructPackEncoder, but catches errors that may be due to
passing non-finite floating-point values to struct.pack, and makes a
second attempt to encode those values.
Args:
wire_type: The field's wire type, for encoding tags.
format: The format string to pass to struct.pack().
"""
value_size = struct.calcsize(format)
if value_size == 4:
def EncodeNonFiniteOrRaise(write, value):
# Remember that the serialized form uses little-endian byte order.
if value == _POS_INF:
write(b'\x00\x00\x80\x7F')
elif value == _NEG_INF:
write(b'\x00\x00\x80\xFF')
elif value != value: # NaN
write(b'\x00\x00\xC0\x7F')
else:
raise
elif value_size == 8:
def EncodeNonFiniteOrRaise(write, value):
if value == _POS_INF:
write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
elif value == _NEG_INF:
write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
elif value != value: # NaN
write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
else:
raise
else:
raise ValueError('Can\'t encode floating-point values that are '
'%d bytes long (only 4 or 8)' % value_size)
def SpecificEncoder(field_number, is_repeated, is_packed):
local_struct_pack = struct.pack
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
def EncodePackedField(write, value, deterministic):
write(tag_bytes)
local_EncodeVarint(write, len(value) * value_size, deterministic)
for element in value:
# This try/except block is going to be faster than any code that
# we could write to check whether element is finite.
try:
write(local_struct_pack(format, element))
except SystemError:
EncodeNonFiniteOrRaise(write, element)
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeRepeatedField(write, value, unused_deterministic=None):
for element in value:
write(tag_bytes)
try:
write(local_struct_pack(format, element))
except SystemError:
EncodeNonFiniteOrRaise(write, element)
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeField(write, value, unused_deterministic=None):
write(tag_bytes)
try:
write(local_struct_pack(format, value))
except SystemError:
EncodeNonFiniteOrRaise(write, value)
return EncodeField
return SpecificEncoder
# ====================================================================
# Here we declare an encoder constructor for each field type. These work
# very similarly to sizer constructors, described earlier.
Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
UInt32Encoder = UInt64Encoder = _SimpleEncoder(
wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
wire_format.ZigZagEncode)
# Note that Python conveniently guarantees that when using the '<' prefix on
# formats, they will also have the same size across all platforms (as opposed
# to without the prefix, where their sizes depend on the C compiler's basic
# type sizes).
Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
FloatEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f')
DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
def BoolEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a boolean field."""
false_byte = b'\x00'
true_byte = b'\x01'
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
def EncodePackedField(write, value, deterministic):
write(tag_bytes)
local_EncodeVarint(write, len(value), deterministic)
for element in value:
if element:
write(true_byte)
else:
write(false_byte)
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
def EncodeRepeatedField(write, value, unused_deterministic=None):
for element in value:
write(tag_bytes)
if element:
write(true_byte)
else:
write(false_byte)
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
def EncodeField(write, value, unused_deterministic=None):
write(tag_bytes)
if value:
return write(true_byte)
return write(false_byte)
return EncodeField
def StringEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a string field."""
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
local_len = len
assert not is_packed
if is_repeated:
def EncodeRepeatedField(write, value, deterministic):
for element in value:
encoded = element.encode('utf-8')
write(tag)
local_EncodeVarint(write, local_len(encoded), deterministic)
write(encoded)
return EncodeRepeatedField
else:
def EncodeField(write, value, deterministic):
encoded = value.encode('utf-8')
write(tag)
local_EncodeVarint(write, local_len(encoded), deterministic)
return write(encoded)
return EncodeField
def BytesEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a bytes field."""
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
local_len = len
assert not is_packed
if is_repeated:
def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag)
local_EncodeVarint(write, local_len(element), deterministic)
write(element)
return EncodeRepeatedField
else:
def EncodeField(write, value, deterministic):
write(tag)
local_EncodeVarint(write, local_len(value), deterministic)
return write(value)
return EncodeField
def GroupEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a group field."""
start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
assert not is_packed
if is_repeated:
def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(start_tag)
element._InternalSerialize(write, deterministic)
write(end_tag)
return EncodeRepeatedField
else:
def EncodeField(write, value, deterministic):
write(start_tag)
value._InternalSerialize(write, deterministic)
return write(end_tag)
return EncodeField
def MessageEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a message field."""
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
assert not is_packed
if is_repeated:
def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag)
local_EncodeVarint(write, element.ByteSize(), deterministic)
element._InternalSerialize(write, deterministic)
return EncodeRepeatedField
else:
def EncodeField(write, value, deterministic):
write(tag)
local_EncodeVarint(write, value.ByteSize(), deterministic)
return value._InternalSerialize(write, deterministic)
return EncodeField
# --------------------------------------------------------------------
# As before, MessageSet is special.
def MessageSetItemEncoder(field_number):
"""Encoder for extensions of MessageSet.
The message set message looks like this:
message MessageSet {
repeated group Item = 1 {
required int32 type_id = 2;
required string message = 3;
}
}
"""
start_bytes = b"".join([
TagBytes(1, wire_format.WIRETYPE_START_GROUP),
TagBytes(2, wire_format.WIRETYPE_VARINT),
_VarintBytes(field_number),
TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
local_EncodeVarint = _EncodeVarint
def EncodeField(write, value, deterministic):
write(start_bytes)
local_EncodeVarint(write, value.ByteSize(), deterministic)
value._InternalSerialize(write, deterministic)
return write(end_bytes)
return EncodeField
# --------------------------------------------------------------------
# As before, Map is special.
def MapEncoder(field_descriptor):
"""Encoder for extensions of MessageSet.
Maps always have a wire format like this:
message MapEntry {
key_type key = 1;
value_type value = 2;
}
repeated MapEntry map = N;
"""
# Can't look at field_descriptor.message_type._concrete_class because it may
# not have been initialized yet.
message_type = field_descriptor.message_type
encode_message = MessageEncoder(field_descriptor.number, False, False)
def EncodeField(write, value, deterministic):
value_keys = sorted(value.keys()) if deterministic else value
for key in value_keys:
entry_msg = message_type._concrete_class(key=key, value=value[key])
encode_message(write, entry_msg, deterministic)
return EncodeField

View File

@@ -0,0 +1,124 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""A simple wrapper around enum types to expose utility functions.
Instances are created as properties with the same name as the enum they wrap
on proto classes. For usage, see:
reflection_test.py
"""
__author__ = 'rabsatt@google.com (Kevin Rabsatt)'
class EnumTypeWrapper(object):
"""A utility for finding the names of enum values."""
DESCRIPTOR = None
# This is a type alias, which mypy typing stubs can type as
# a genericized parameter constrained to an int, allowing subclasses
# to be typed with more constraint in .pyi stubs
# Eg.
# def MyGeneratedEnum(Message):
# ValueType = NewType('ValueType', int)
# def Name(self, number: MyGeneratedEnum.ValueType) -> str
ValueType = int
def __init__(self, enum_type):
"""Inits EnumTypeWrapper with an EnumDescriptor."""
self._enum_type = enum_type
self.DESCRIPTOR = enum_type # pylint: disable=invalid-name
def Name(self, number): # pylint: disable=invalid-name
"""Returns a string containing the name of an enum value."""
try:
return self._enum_type.values_by_number[number].name
except KeyError:
pass # fall out to break exception chaining
if not isinstance(number, int):
raise TypeError(
'Enum value for {} must be an int, but got {} {!r}.'.format(
self._enum_type.name, type(number), number))
else:
# repr here to handle the odd case when you pass in a boolean.
raise ValueError('Enum {} has no name defined for value {!r}'.format(
self._enum_type.name, number))
def Value(self, name): # pylint: disable=invalid-name
"""Returns the value corresponding to the given enum name."""
try:
return self._enum_type.values_by_name[name].number
except KeyError:
pass # fall out to break exception chaining
raise ValueError('Enum {} has no value defined for name {!r}'.format(
self._enum_type.name, name))
def keys(self):
"""Return a list of the string names in the enum.
Returns:
A list of strs, in the order they were defined in the .proto file.
"""
return [value_descriptor.name
for value_descriptor in self._enum_type.values]
def values(self):
"""Return a list of the integer values in the enum.
Returns:
A list of ints, in the order they were defined in the .proto file.
"""
return [value_descriptor.number
for value_descriptor in self._enum_type.values]
def items(self):
"""Return a list of the (name, value) pairs of the enum.
Returns:
A list of (str, int) pairs, in the order they were defined
in the .proto file.
"""
return [(value_descriptor.name, value_descriptor.number)
for value_descriptor in self._enum_type.values]
def __getattr__(self, name):
"""Returns the value corresponding to the given enum name."""
try:
return super(
EnumTypeWrapper,
self).__getattribute__('_enum_type').values_by_name[name].number
except KeyError:
pass # fall out to break exception chaining
raise AttributeError('Enum {} has no value defined for name {!r}'.format(
self._enum_type.name, name))

View File

@@ -0,0 +1,213 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Contains _ExtensionDict class to represent extensions.
"""
from google.protobuf.internal import type_checkers
from google.protobuf.descriptor import FieldDescriptor
def _VerifyExtensionHandle(message, extension_handle):
"""Verify that the given extension handle is valid."""
if not isinstance(extension_handle, FieldDescriptor):
raise KeyError('HasExtension() expects an extension handle, got: %s' %
extension_handle)
if not extension_handle.is_extension:
raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
if not extension_handle.containing_type:
raise KeyError('"%s" is missing a containing_type.'
% extension_handle.full_name)
if extension_handle.containing_type is not message.DESCRIPTOR:
raise KeyError('Extension "%s" extends message type "%s", but this '
'message is of type "%s".' %
(extension_handle.full_name,
extension_handle.containing_type.full_name,
message.DESCRIPTOR.full_name))
# TODO(robinson): Unify error handling of "unknown extension" crap.
# TODO(robinson): Support iteritems()-style iteration over all
# extensions with the "has" bits turned on?
class _ExtensionDict(object):
"""Dict-like container for Extension fields on proto instances.
Note that in all cases we expect extension handles to be
FieldDescriptors.
"""
def __init__(self, extended_message):
"""
Args:
extended_message: Message instance for which we are the Extensions dict.
"""
self._extended_message = extended_message
def __getitem__(self, extension_handle):
"""Returns the current value of the given extension handle."""
_VerifyExtensionHandle(self._extended_message, extension_handle)
result = self._extended_message._fields.get(extension_handle)
if result is not None:
return result
if extension_handle.label == FieldDescriptor.LABEL_REPEATED:
result = extension_handle._default_constructor(self._extended_message)
elif extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
message_type = extension_handle.message_type
if not hasattr(message_type, '_concrete_class'):
# pylint: disable=protected-access
self._extended_message._FACTORY.GetPrototype(message_type)
assert getattr(extension_handle.message_type, '_concrete_class', None), (
'Uninitialized concrete class found for field %r (message type %r)'
% (extension_handle.full_name,
extension_handle.message_type.full_name))
result = extension_handle.message_type._concrete_class()
try:
result._SetListener(self._extended_message._listener_for_children)
except ReferenceError:
pass
else:
# Singular scalar -- just return the default without inserting into the
# dict.
return extension_handle.default_value
# Atomically check if another thread has preempted us and, if not, swap
# in the new object we just created. If someone has preempted us, we
# take that object and discard ours.
# WARNING: We are relying on setdefault() being atomic. This is true
# in CPython but we haven't investigated others. This warning appears
# in several other locations in this file.
result = self._extended_message._fields.setdefault(
extension_handle, result)
return result
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
my_fields = self._extended_message.ListFields()
other_fields = other._extended_message.ListFields()
# Get rid of non-extension fields.
my_fields = [field for field in my_fields if field.is_extension]
other_fields = [field for field in other_fields if field.is_extension]
return my_fields == other_fields
def __ne__(self, other):
return not self == other
def __len__(self):
fields = self._extended_message.ListFields()
# Get rid of non-extension fields.
extension_fields = [field for field in fields if field[0].is_extension]
return len(extension_fields)
def __hash__(self):
raise TypeError('unhashable object')
# Note that this is only meaningful for non-repeated, scalar extension
# fields. Note also that we may have to call _Modified() when we do
# successfully set a field this way, to set any necessary "has" bits in the
# ancestors of the extended message.
def __setitem__(self, extension_handle, value):
"""If extension_handle specifies a non-repeated, scalar extension
field, sets the value of that field.
"""
_VerifyExtensionHandle(self._extended_message, extension_handle)
if (extension_handle.label == FieldDescriptor.LABEL_REPEATED or
extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE):
raise TypeError(
'Cannot assign to extension "%s" because it is a repeated or '
'composite type.' % extension_handle.full_name)
# It's slightly wasteful to lookup the type checker each time,
# but we expect this to be a vanishingly uncommon case anyway.
type_checker = type_checkers.GetTypeChecker(extension_handle)
# pylint: disable=protected-access
self._extended_message._fields[extension_handle] = (
type_checker.CheckValue(value))
self._extended_message._Modified()
def __delitem__(self, extension_handle):
self._extended_message.ClearExtension(extension_handle)
def _FindExtensionByName(self, name):
"""Tries to find a known extension with the specified name.
Args:
name: Extension full name.
Returns:
Extension field descriptor.
"""
return self._extended_message._extensions_by_name.get(name, None)
def _FindExtensionByNumber(self, number):
"""Tries to find a known extension with the field number.
Args:
number: Extension field number.
Returns:
Extension field descriptor.
"""
return self._extended_message._extensions_by_number.get(number, None)
def __iter__(self):
# Return a generator over the populated extension fields
return (f[0] for f in self._extended_message.ListFields()
if f[0].is_extension)
def __contains__(self, extension_handle):
_VerifyExtensionHandle(self._extended_message, extension_handle)
if extension_handle not in self._extended_message._fields:
return False
if extension_handle.label == FieldDescriptor.LABEL_REPEATED:
return bool(self._extended_message._fields.get(extension_handle))
if extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
value = self._extended_message._fields.get(extension_handle)
# pylint: disable=protected-access
return value is not None and value._is_present_in_parent
return True

View File

@@ -0,0 +1,70 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: matthewtoia@google.com (Matt Toia)
syntax = "proto2";
package google.protobuf.python.internal;
enum Factory1Enum {
FACTORY_1_VALUE_0 = 0;
FACTORY_1_VALUE_1 = 1;
}
message Factory1Message {
optional Factory1Enum factory_1_enum = 1;
enum NestedFactory1Enum {
NESTED_FACTORY_1_VALUE_0 = 0;
NESTED_FACTORY_1_VALUE_1 = 1;
}
optional NestedFactory1Enum nested_factory_1_enum = 2;
message NestedFactory1Message {
optional string value = 1;
}
optional NestedFactory1Message nested_factory_1_message = 3;
optional int32 scalar_value = 4;
repeated string list_value = 5;
extensions 1000 to max;
}
message Factory1MethodRequest {
optional string argument = 1;
}
message Factory1MethodResponse {
optional string result = 1;
}
service Factory1Service {
// Dummy method for this dummy service.
rpc Factory1Method(Factory1MethodRequest) returns (Factory1MethodResponse) {}
}

View File

@@ -0,0 +1,104 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: matthewtoia@google.com (Matt Toia)
syntax = "proto2";
package google.protobuf.python.internal;
import "google/protobuf/internal/factory_test1.proto";
enum Factory2Enum {
FACTORY_2_VALUE_0 = 0;
FACTORY_2_VALUE_1 = 1;
}
message Factory2Message {
required int32 mandatory = 1;
optional Factory2Enum factory_2_enum = 2;
enum NestedFactory2Enum {
NESTED_FACTORY_2_VALUE_0 = 0;
NESTED_FACTORY_2_VALUE_1 = 1;
}
optional NestedFactory2Enum nested_factory_2_enum = 3;
message NestedFactory2Message {
optional string value = 1;
}
optional NestedFactory2Message nested_factory_2_message = 4;
optional Factory1Message factory_1_message = 5;
optional Factory1Enum factory_1_enum = 6;
optional Factory1Message.NestedFactory1Enum nested_factory_1_enum = 7;
optional Factory1Message.NestedFactory1Message nested_factory_1_message = 8;
optional Factory2Message circular_message = 9;
optional string scalar_value = 10;
repeated string list_value = 11;
repeated group Grouped = 12 {
optional string part_1 = 13;
optional string part_2 = 14;
}
optional LoopMessage loop = 15;
optional int32 int_with_default = 16 [default = 1776];
optional double double_with_default = 17 [default = 9.99];
optional string string_with_default = 18 [default = "hello world"];
optional bool bool_with_default = 19 [default = false];
optional Factory2Enum enum_with_default = 20 [default = FACTORY_2_VALUE_1];
optional bytes bytes_with_default = 21 [default = "a\373\000c"];
extend Factory1Message {
optional string one_more_field = 1001;
}
oneof oneof_field {
int32 oneof_int = 22;
string oneof_string = 23;
}
}
message LoopMessage {
optional Factory2Message loop = 1;
}
message MessageWithNestedEnumOnly {
enum NestedEnum {
NESTED_MESSAGE_ENUM_0 = 0;
}
}
extend Factory1Message {
optional string another_field = 1002;
}
message MessageWithOption {
option no_standard_descriptor_accessor = true;
optional int32 field1 = 1;
}

View File

@@ -0,0 +1,333 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Contains FieldMask class."""
from google.protobuf.descriptor import FieldDescriptor
class FieldMask(object):
"""Class for FieldMask message type."""
__slots__ = ()
def ToJsonString(self):
"""Converts FieldMask to string according to proto3 JSON spec."""
camelcase_paths = []
for path in self.paths:
camelcase_paths.append(_SnakeCaseToCamelCase(path))
return ','.join(camelcase_paths)
def FromJsonString(self, value):
"""Converts string to FieldMask according to proto3 JSON spec."""
if not isinstance(value, str):
raise ValueError('FieldMask JSON value not a string: {!r}'.format(value))
self.Clear()
if value:
for path in value.split(','):
self.paths.append(_CamelCaseToSnakeCase(path))
def IsValidForDescriptor(self, message_descriptor):
"""Checks whether the FieldMask is valid for Message Descriptor."""
for path in self.paths:
if not _IsValidPath(message_descriptor, path):
return False
return True
def AllFieldsFromDescriptor(self, message_descriptor):
"""Gets all direct fields of Message Descriptor to FieldMask."""
self.Clear()
for field in message_descriptor.fields:
self.paths.append(field.name)
def CanonicalFormFromMask(self, mask):
"""Converts a FieldMask to the canonical form.
Removes paths that are covered by another path. For example,
"foo.bar" is covered by "foo" and will be removed if "foo"
is also in the FieldMask. Then sorts all paths in alphabetical order.
Args:
mask: The original FieldMask to be converted.
"""
tree = _FieldMaskTree(mask)
tree.ToFieldMask(self)
def Union(self, mask1, mask2):
"""Merges mask1 and mask2 into this FieldMask."""
_CheckFieldMaskMessage(mask1)
_CheckFieldMaskMessage(mask2)
tree = _FieldMaskTree(mask1)
tree.MergeFromFieldMask(mask2)
tree.ToFieldMask(self)
def Intersect(self, mask1, mask2):
"""Intersects mask1 and mask2 into this FieldMask."""
_CheckFieldMaskMessage(mask1)
_CheckFieldMaskMessage(mask2)
tree = _FieldMaskTree(mask1)
intersection = _FieldMaskTree()
for path in mask2.paths:
tree.IntersectPath(path, intersection)
intersection.ToFieldMask(self)
def MergeMessage(
self, source, destination,
replace_message_field=False, replace_repeated_field=False):
"""Merges fields specified in FieldMask from source to destination.
Args:
source: Source message.
destination: The destination message to be merged into.
replace_message_field: Replace message field if True. Merge message
field if False.
replace_repeated_field: Replace repeated field if True. Append
elements of repeated field if False.
"""
tree = _FieldMaskTree(self)
tree.MergeMessage(
source, destination, replace_message_field, replace_repeated_field)
def _IsValidPath(message_descriptor, path):
"""Checks whether the path is valid for Message Descriptor."""
parts = path.split('.')
last = parts.pop()
for name in parts:
field = message_descriptor.fields_by_name.get(name)
if (field is None or
field.label == FieldDescriptor.LABEL_REPEATED or
field.type != FieldDescriptor.TYPE_MESSAGE):
return False
message_descriptor = field.message_type
return last in message_descriptor.fields_by_name
def _CheckFieldMaskMessage(message):
"""Raises ValueError if message is not a FieldMask."""
message_descriptor = message.DESCRIPTOR
if (message_descriptor.name != 'FieldMask' or
message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
raise ValueError('Message {0} is not a FieldMask.'.format(
message_descriptor.full_name))
def _SnakeCaseToCamelCase(path_name):
"""Converts a path name from snake_case to camelCase."""
result = []
after_underscore = False
for c in path_name:
if c.isupper():
raise ValueError(
'Fail to print FieldMask to Json string: Path name '
'{0} must not contain uppercase letters.'.format(path_name))
if after_underscore:
if c.islower():
result.append(c.upper())
after_underscore = False
else:
raise ValueError(
'Fail to print FieldMask to Json string: The '
'character after a "_" must be a lowercase letter '
'in path name {0}.'.format(path_name))
elif c == '_':
after_underscore = True
else:
result += c
if after_underscore:
raise ValueError('Fail to print FieldMask to Json string: Trailing "_" '
'in path name {0}.'.format(path_name))
return ''.join(result)
def _CamelCaseToSnakeCase(path_name):
"""Converts a field name from camelCase to snake_case."""
result = []
for c in path_name:
if c == '_':
raise ValueError('Fail to parse FieldMask: Path name '
'{0} must not contain "_"s.'.format(path_name))
if c.isupper():
result += '_'
result += c.lower()
else:
result += c
return ''.join(result)
class _FieldMaskTree(object):
"""Represents a FieldMask in a tree structure.
For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
the FieldMaskTree will be:
[_root] -+- foo -+- bar
| |
| +- baz
|
+- bar --- baz
In the tree, each leaf node represents a field path.
"""
__slots__ = ('_root',)
def __init__(self, field_mask=None):
"""Initializes the tree by FieldMask."""
self._root = {}
if field_mask:
self.MergeFromFieldMask(field_mask)
def MergeFromFieldMask(self, field_mask):
"""Merges a FieldMask to the tree."""
for path in field_mask.paths:
self.AddPath(path)
def AddPath(self, path):
"""Adds a field path into the tree.
If the field path to add is a sub-path of an existing field path
in the tree (i.e., a leaf node), it means the tree already matches
the given path so nothing will be added to the tree. If the path
matches an existing non-leaf node in the tree, that non-leaf node
will be turned into a leaf node with all its children removed because
the path matches all the node's children. Otherwise, a new path will
be added.
Args:
path: The field path to add.
"""
node = self._root
for name in path.split('.'):
if name not in node:
node[name] = {}
elif not node[name]:
# Pre-existing empty node implies we already have this entire tree.
return
node = node[name]
# Remove any sub-trees we might have had.
node.clear()
def ToFieldMask(self, field_mask):
"""Converts the tree to a FieldMask."""
field_mask.Clear()
_AddFieldPaths(self._root, '', field_mask)
def IntersectPath(self, path, intersection):
"""Calculates the intersection part of a field path with this tree.
Args:
path: The field path to calculates.
intersection: The out tree to record the intersection part.
"""
node = self._root
for name in path.split('.'):
if name not in node:
return
elif not node[name]:
intersection.AddPath(path)
return
node = node[name]
intersection.AddLeafNodes(path, node)
def AddLeafNodes(self, prefix, node):
"""Adds leaf nodes begin with prefix to this tree."""
if not node:
self.AddPath(prefix)
for name in node:
child_path = prefix + '.' + name
self.AddLeafNodes(child_path, node[name])
def MergeMessage(
self, source, destination,
replace_message, replace_repeated):
"""Merge all fields specified by this tree from source to destination."""
_MergeMessage(
self._root, source, destination, replace_message, replace_repeated)
def _StrConvert(value):
"""Converts value to str if it is not."""
# This file is imported by c extension and some methods like ClearField
# requires string for the field name. py2/py3 has different text
# type and may use unicode.
if not isinstance(value, str):
return value.encode('utf-8')
return value
def _MergeMessage(
node, source, destination, replace_message, replace_repeated):
"""Merge all fields specified by a sub-tree from source to destination."""
source_descriptor = source.DESCRIPTOR
for name in node:
child = node[name]
field = source_descriptor.fields_by_name[name]
if field is None:
raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
name, source_descriptor.full_name))
if child:
# Sub-paths are only allowed for singular message fields.
if (field.label == FieldDescriptor.LABEL_REPEATED or
field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
raise ValueError('Error: Field {0} in message {1} is not a singular '
'message field and cannot have sub-fields.'.format(
name, source_descriptor.full_name))
if source.HasField(name):
_MergeMessage(
child, getattr(source, name), getattr(destination, name),
replace_message, replace_repeated)
continue
if field.label == FieldDescriptor.LABEL_REPEATED:
if replace_repeated:
destination.ClearField(_StrConvert(name))
repeated_source = getattr(source, name)
repeated_destination = getattr(destination, name)
repeated_destination.MergeFrom(repeated_source)
else:
if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
if replace_message:
destination.ClearField(_StrConvert(name))
if source.HasField(name):
getattr(destination, name).MergeFrom(getattr(source, name))
else:
setattr(destination, name, getattr(source, name))
def _AddFieldPaths(node, prefix, field_mask):
"""Adds the field paths descended from node to field_mask."""
if not node and prefix:
field_mask.paths.append(prefix)
return
for name in sorted(node):
if prefix:
child_path = prefix + '.' + name
else:
child_path = name
_AddFieldPaths(node[name], child_path, field_mask)

View File

@@ -0,0 +1,400 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test for google.protobuf.internal.well_known_types."""
import unittest
from google.protobuf import field_mask_pb2
from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_pb2
from google.protobuf.internal import field_mask
from google.protobuf.internal import test_util
from google.protobuf import descriptor
class FieldMaskTest(unittest.TestCase):
def testStringFormat(self):
mask = field_mask_pb2.FieldMask()
self.assertEqual('', mask.ToJsonString())
mask.paths.append('foo')
self.assertEqual('foo', mask.ToJsonString())
mask.paths.append('bar')
self.assertEqual('foo,bar', mask.ToJsonString())
mask.FromJsonString('')
self.assertEqual('', mask.ToJsonString())
mask.FromJsonString('foo')
self.assertEqual(['foo'], mask.paths)
mask.FromJsonString('foo,bar')
self.assertEqual(['foo', 'bar'], mask.paths)
# Test camel case
mask.Clear()
mask.paths.append('foo_bar')
self.assertEqual('fooBar', mask.ToJsonString())
mask.paths.append('bar_quz')
self.assertEqual('fooBar,barQuz', mask.ToJsonString())
mask.FromJsonString('')
self.assertEqual('', mask.ToJsonString())
self.assertEqual([], mask.paths)
mask.FromJsonString('fooBar')
self.assertEqual(['foo_bar'], mask.paths)
mask.FromJsonString('fooBar,barQuz')
self.assertEqual(['foo_bar', 'bar_quz'], mask.paths)
def testDescriptorToFieldMask(self):
mask = field_mask_pb2.FieldMask()
msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
mask.AllFieldsFromDescriptor(msg_descriptor)
self.assertEqual(76, len(mask.paths))
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
for field in msg_descriptor.fields:
self.assertTrue(field.name in mask.paths)
def testIsValidForDescriptor(self):
msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
# Empty mask
mask = field_mask_pb2.FieldMask()
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
# All fields from descriptor
mask.AllFieldsFromDescriptor(msg_descriptor)
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
# Child under optional message
mask.paths.append('optional_nested_message.bb')
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
# Repeated field is only allowed in the last position of path
mask.paths.append('repeated_nested_message.bb')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid top level field
mask = field_mask_pb2.FieldMask()
mask.paths.append('xxx')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid field in root
mask = field_mask_pb2.FieldMask()
mask.paths.append('xxx.zzz')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid field in internal node
mask = field_mask_pb2.FieldMask()
mask.paths.append('optional_nested_message.xxx.zzz')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid field in leaf
mask = field_mask_pb2.FieldMask()
mask.paths.append('optional_nested_message.xxx')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
def testCanonicalFrom(self):
mask = field_mask_pb2.FieldMask()
out_mask = field_mask_pb2.FieldMask()
# Paths will be sorted.
mask.FromJsonString('baz.quz,bar,foo')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('bar,baz.quz,foo', out_mask.ToJsonString())
# Duplicated paths will be removed.
mask.FromJsonString('foo,bar,foo')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('bar,foo', out_mask.ToJsonString())
# Sub-paths of other paths will be removed.
mask.FromJsonString('foo.b1,bar.b1,foo.b2,bar')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString())
# Test more deeply nested cases.
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo.bar.baz1,foo.bar.baz2',
out_mask.ToJsonString())
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo.bar.baz1,foo.bar.baz2',
out_mask.ToJsonString())
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo.bar', out_mask.ToJsonString())
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo', out_mask.ToJsonString())
def testUnion(self):
mask1 = field_mask_pb2.FieldMask()
mask2 = field_mask_pb2.FieldMask()
out_mask = field_mask_pb2.FieldMask()
mask1.FromJsonString('foo,baz')
mask2.FromJsonString('bar,quz')
out_mask.Union(mask1, mask2)
self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString())
# Overlap with duplicated paths.
mask1.FromJsonString('foo,baz.bb')
mask2.FromJsonString('baz.bb,quz')
out_mask.Union(mask1, mask2)
self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString())
# Overlap with paths covering some other paths.
mask1.FromJsonString('foo.bar.baz,quz')
mask2.FromJsonString('foo.bar,bar')
out_mask.Union(mask1, mask2)
self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString())
src = unittest_pb2.TestAllTypes()
with self.assertRaises(ValueError):
out_mask.Union(src, mask2)
def testIntersect(self):
mask1 = field_mask_pb2.FieldMask()
mask2 = field_mask_pb2.FieldMask()
out_mask = field_mask_pb2.FieldMask()
# Test cases without overlapping.
mask1.FromJsonString('foo,baz')
mask2.FromJsonString('bar,quz')
out_mask.Intersect(mask1, mask2)
self.assertEqual('', out_mask.ToJsonString())
self.assertEqual(len(out_mask.paths), 0)
self.assertEqual(out_mask.paths, [])
# Overlap with duplicated paths.
mask1.FromJsonString('foo,baz.bb')
mask2.FromJsonString('baz.bb,quz')
out_mask.Intersect(mask1, mask2)
self.assertEqual('baz.bb', out_mask.ToJsonString())
# Overlap with paths covering some other paths.
mask1.FromJsonString('foo.bar.baz,quz')
mask2.FromJsonString('foo.bar,bar')
out_mask.Intersect(mask1, mask2)
self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
mask1.FromJsonString('foo.bar,bar')
mask2.FromJsonString('foo.bar.baz,quz')
out_mask.Intersect(mask1, mask2)
self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
# Intersect '' with ''
mask1.Clear()
mask2.Clear()
mask1.paths.append('')
mask2.paths.append('')
self.assertEqual(mask1.paths, [''])
self.assertEqual('', mask1.ToJsonString())
out_mask.Intersect(mask1, mask2)
self.assertEqual(out_mask.paths, [])
def testMergeMessageWithoutMapFields(self):
# Test merge one field.
src = unittest_pb2.TestAllTypes()
test_util.SetAllFields(src)
for field in src.DESCRIPTOR.fields:
if field.containing_oneof:
continue
field_name = field.name
dst = unittest_pb2.TestAllTypes()
# Only set one path to mask.
mask = field_mask_pb2.FieldMask()
mask.paths.append(field_name)
mask.MergeMessage(src, dst)
# The expected result message.
msg = unittest_pb2.TestAllTypes()
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
repeated_src = getattr(src, field_name)
repeated_msg = getattr(msg, field_name)
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
for item in repeated_src:
repeated_msg.add().CopyFrom(item)
else:
repeated_msg.extend(repeated_src)
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
getattr(msg, field_name).CopyFrom(getattr(src, field_name))
else:
setattr(msg, field_name, getattr(src, field_name))
# Only field specified in mask is merged.
self.assertEqual(msg, dst)
# Test merge nested fields.
nested_src = unittest_pb2.NestedTestAllTypes()
nested_dst = unittest_pb2.NestedTestAllTypes()
nested_src.child.payload.optional_int32 = 1234
nested_src.child.child.payload.optional_int32 = 5678
mask = field_mask_pb2.FieldMask()
mask.FromJsonString('child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(0, nested_dst.child.child.payload.optional_int32)
mask.FromJsonString('child.child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
nested_dst.Clear()
mask.FromJsonString('child.child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(0, nested_dst.child.payload.optional_int32)
self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
nested_dst.Clear()
mask.FromJsonString('child')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
# Test MergeOptions.
nested_dst.Clear()
nested_dst.child.payload.optional_int64 = 4321
# Message fields will be merged by default.
mask.FromJsonString('child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(4321, nested_dst.child.payload.optional_int64)
# Change the behavior to replace message fields.
mask.FromJsonString('child.payload')
mask.MergeMessage(nested_src, nested_dst, True, False)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(0, nested_dst.child.payload.optional_int64)
# By default, fields missing in source are not cleared in destination.
nested_dst.payload.optional_int32 = 1234
self.assertTrue(nested_dst.HasField('payload'))
mask.FromJsonString('payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertTrue(nested_dst.HasField('payload'))
# But they are cleared when replacing message fields.
nested_dst.Clear()
nested_dst.payload.optional_int32 = 1234
mask.FromJsonString('payload')
mask.MergeMessage(nested_src, nested_dst, True, False)
self.assertFalse(nested_dst.HasField('payload'))
nested_src.payload.repeated_int32.append(1234)
nested_dst.payload.repeated_int32.append(5678)
# Repeated fields will be appended by default.
mask.FromJsonString('payload.repeatedInt32')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(2, len(nested_dst.payload.repeated_int32))
self.assertEqual(5678, nested_dst.payload.repeated_int32[0])
self.assertEqual(1234, nested_dst.payload.repeated_int32[1])
# Change the behavior to replace repeated fields.
mask.FromJsonString('payload.repeatedInt32')
mask.MergeMessage(nested_src, nested_dst, False, True)
self.assertEqual(1, len(nested_dst.payload.repeated_int32))
self.assertEqual(1234, nested_dst.payload.repeated_int32[0])
# Test Merge oneof field.
new_msg = unittest_pb2.TestOneof2()
dst = unittest_pb2.TestOneof2()
dst.foo_message.moo_int = 1
mask = field_mask_pb2.FieldMask()
mask.FromJsonString('fooMessage,fooLazyMessage.mooInt')
mask.MergeMessage(new_msg, dst)
self.assertTrue(dst.HasField('foo_message'))
self.assertFalse(dst.HasField('foo_lazy_message'))
def testMergeMessageWithMapField(self):
empty_map = map_unittest_pb2.TestRecursiveMapMessage()
src_level_2 = map_unittest_pb2.TestRecursiveMapMessage()
src_level_2.a['src level 2'].CopyFrom(empty_map)
src = map_unittest_pb2.TestRecursiveMapMessage()
src.a['common key'].CopyFrom(src_level_2)
src.a['src level 1'].CopyFrom(src_level_2)
dst_level_2 = map_unittest_pb2.TestRecursiveMapMessage()
dst_level_2.a['dst level 2'].CopyFrom(empty_map)
dst = map_unittest_pb2.TestRecursiveMapMessage()
dst.a['common key'].CopyFrom(dst_level_2)
dst.a['dst level 1'].CopyFrom(empty_map)
mask = field_mask_pb2.FieldMask()
mask.FromJsonString('a')
mask.MergeMessage(src, dst)
# map from dst is replaced with map from src.
self.assertEqual(dst.a['common key'], src_level_2)
self.assertEqual(dst.a['src level 1'], src_level_2)
self.assertEqual(dst.a['dst level 1'], empty_map)
def testMergeErrors(self):
src = unittest_pb2.TestAllTypes()
dst = unittest_pb2.TestAllTypes()
mask = field_mask_pb2.FieldMask()
test_util.SetAllFields(src)
mask.FromJsonString('optionalInt32.field')
with self.assertRaises(ValueError) as e:
mask.MergeMessage(src, dst)
self.assertEqual('Error: Field optional_int32 in message '
'protobuf_unittest.TestAllTypes is not a singular '
'message field and cannot have sub-fields.',
str(e.exception))
def testSnakeCaseToCamelCase(self):
self.assertEqual('fooBar',
field_mask._SnakeCaseToCamelCase('foo_bar'))
self.assertEqual('FooBar',
field_mask._SnakeCaseToCamelCase('_foo_bar'))
self.assertEqual('foo3Bar',
field_mask._SnakeCaseToCamelCase('foo3_bar'))
# No uppercase letter is allowed.
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: Path name Foo must '
'not contain uppercase letters.',
field_mask._SnakeCaseToCamelCase, 'Foo')
# Any character after a "_" must be a lowercase letter.
# 1. "_" cannot be followed by another "_".
# 2. "_" cannot be followed by a digit.
# 3. "_" cannot appear as the last character.
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: The character after a '
'"_" must be a lowercase letter in path name foo__bar.',
field_mask._SnakeCaseToCamelCase, 'foo__bar')
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: The character after a '
'"_" must be a lowercase letter in path name foo_3bar.',
field_mask._SnakeCaseToCamelCase, 'foo_3bar')
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: Trailing "_" in path '
'name foo_bar_.', field_mask._SnakeCaseToCamelCase, 'foo_bar_')
def testCamelCaseToSnakeCase(self):
self.assertEqual('foo_bar',
field_mask._CamelCaseToSnakeCase('fooBar'))
self.assertEqual('_foo_bar',
field_mask._CamelCaseToSnakeCase('FooBar'))
self.assertEqual('foo3_bar',
field_mask._CamelCaseToSnakeCase('foo3Bar'))
self.assertRaisesRegex(
ValueError,
'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.',
field_mask._CamelCaseToSnakeCase, 'foo_bar')
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,43 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto2";
import "google/protobuf/descriptor.proto";
package google.protobuf.python.internal;
message FooOptions {
optional string foo_name = 1;
}
extend .google.protobuf.FileOptions {
optional FooOptions foo_options = 120436268;
}

View File

@@ -0,0 +1,354 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# TODO(robinson): Flesh this out considerably. We focused on reflection_test.py
# first, since it's testing the subtler code, and since it provides decent
# indirect testing of the protocol compiler output.
"""Unittest that directly tests the output of the pure-Python protocol
compiler. See //google/protobuf/internal/reflection_test.py for a test which
further ensures that we can use Python protocol message objects as we expect.
"""
__author__ = 'robinson@google.com (Will Robinson)'
import unittest
from google.protobuf.internal import test_bad_identifiers_pb2
from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_import_public_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_mset_wire_format_pb2
from google.protobuf import unittest_no_generic_services_pb2
from google.protobuf import unittest_pb2
from google.protobuf import service
from google.protobuf import symbol_database
MAX_EXTENSION = 536870912
class GeneratorTest(unittest.TestCase):
def testNestedMessageDescriptor(self):
field_name = 'optional_nested_message'
proto_type = unittest_pb2.TestAllTypes
self.assertEqual(
proto_type.NestedMessage.DESCRIPTOR,
proto_type.DESCRIPTOR.fields_by_name[field_name].message_type)
def testEnums(self):
# We test only module-level enums here.
# TODO(robinson): Examine descriptors directly to check
# enum descriptor output.
self.assertEqual(4, unittest_pb2.FOREIGN_FOO)
self.assertEqual(5, unittest_pb2.FOREIGN_BAR)
self.assertEqual(6, unittest_pb2.FOREIGN_BAZ)
proto = unittest_pb2.TestAllTypes()
self.assertEqual(1, proto.FOO)
self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
self.assertEqual(2, proto.BAR)
self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
self.assertEqual(3, proto.BAZ)
self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
def testExtremeDefaultValues(self):
message = unittest_pb2.TestExtremeDefaultValues()
# Python pre-2.6 does not have isinf() or isnan() functions, so we have
# to provide our own.
def isnan(val):
# NaN is never equal to itself.
return val != val
def isinf(val):
# Infinity times zero equals NaN.
return not isnan(val) and isnan(val * 0)
self.assertTrue(isinf(message.inf_double))
self.assertTrue(message.inf_double > 0)
self.assertTrue(isinf(message.neg_inf_double))
self.assertTrue(message.neg_inf_double < 0)
self.assertTrue(isnan(message.nan_double))
self.assertTrue(isinf(message.inf_float))
self.assertTrue(message.inf_float > 0)
self.assertTrue(isinf(message.neg_inf_float))
self.assertTrue(message.neg_inf_float < 0)
self.assertTrue(isnan(message.nan_float))
self.assertEqual("? ? ?? ?? ??? ??/ ??-", message.cpp_trigraph)
def testHasDefaultValues(self):
desc = unittest_pb2.TestAllTypes.DESCRIPTOR
expected_has_default_by_name = {
'optional_int32': False,
'repeated_int32': False,
'optional_nested_message': False,
'default_int32': True,
}
has_default_by_name = dict(
[(f.name, f.has_default_value)
for f in desc.fields
if f.name in expected_has_default_by_name])
self.assertEqual(expected_has_default_by_name, has_default_by_name)
def testContainingTypeBehaviorForExtensions(self):
self.assertEqual(unittest_pb2.optional_int32_extension.containing_type,
unittest_pb2.TestAllExtensions.DESCRIPTOR)
self.assertEqual(unittest_pb2.TestRequired.single.containing_type,
unittest_pb2.TestAllExtensions.DESCRIPTOR)
def testExtensionScope(self):
self.assertEqual(unittest_pb2.optional_int32_extension.extension_scope,
None)
self.assertEqual(unittest_pb2.TestRequired.single.extension_scope,
unittest_pb2.TestRequired.DESCRIPTOR)
def testIsExtension(self):
self.assertTrue(unittest_pb2.optional_int32_extension.is_extension)
self.assertTrue(unittest_pb2.TestRequired.single.is_extension)
message_descriptor = unittest_pb2.TestRequired.DESCRIPTOR
non_extension_descriptor = message_descriptor.fields_by_name['a']
self.assertTrue(not non_extension_descriptor.is_extension)
def testOptions(self):
proto = unittest_mset_wire_format_pb2.TestMessageSet()
self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format)
def testMessageWithCustomOptions(self):
proto = unittest_custom_options_pb2.TestMessageWithCustomOptions()
enum_options = proto.DESCRIPTOR.enum_types_by_name['AnEnum'].GetOptions()
self.assertTrue(enum_options is not None)
# TODO(gps): We really should test for the presence of the enum_opt1
# extension and for its value to be set to -789.
def testNestedTypes(self):
self.assertEqual(
set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types),
set([
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
unittest_pb2.TestAllTypes.OptionalGroup.DESCRIPTOR,
unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR,
]))
self.assertEqual(unittest_pb2.TestEmptyMessage.DESCRIPTOR.nested_types, [])
self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.nested_types, [])
def testContainingType(self):
self.assertTrue(
unittest_pb2.TestEmptyMessage.DESCRIPTOR.containing_type is None)
self.assertTrue(
unittest_pb2.TestAllTypes.DESCRIPTOR.containing_type is None)
self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type,
unittest_pb2.TestAllTypes.DESCRIPTOR)
self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type,
unittest_pb2.TestAllTypes.DESCRIPTOR)
self.assertEqual(
unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR.containing_type,
unittest_pb2.TestAllTypes.DESCRIPTOR)
def testContainingTypeInEnumDescriptor(self):
self.assertTrue(unittest_pb2._FOREIGNENUM.containing_type is None)
self.assertEqual(unittest_pb2._TESTALLTYPES_NESTEDENUM.containing_type,
unittest_pb2.TestAllTypes.DESCRIPTOR)
def testPackage(self):
self.assertEqual(
unittest_pb2.TestAllTypes.DESCRIPTOR.file.package,
'protobuf_unittest')
desc = unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR
self.assertEqual(desc.file.package, 'protobuf_unittest')
self.assertEqual(
unittest_import_pb2.ImportMessage.DESCRIPTOR.file.package,
'protobuf_unittest_import')
self.assertEqual(
unittest_pb2._FOREIGNENUM.file.package, 'protobuf_unittest')
self.assertEqual(
unittest_pb2._TESTALLTYPES_NESTEDENUM.file.package,
'protobuf_unittest')
self.assertEqual(
unittest_import_pb2._IMPORTENUM.file.package,
'protobuf_unittest_import')
def testExtensionRange(self):
self.assertEqual(
unittest_pb2.TestAllTypes.DESCRIPTOR.extension_ranges, [])
self.assertEqual(
unittest_pb2.TestAllExtensions.DESCRIPTOR.extension_ranges,
[(1, MAX_EXTENSION)])
self.assertEqual(
unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR.extension_ranges,
[(42, 43), (4143, 4244), (65536, MAX_EXTENSION)])
def testFileDescriptor(self):
self.assertEqual(unittest_pb2.DESCRIPTOR.name,
'google/protobuf/unittest.proto')
self.assertEqual(unittest_pb2.DESCRIPTOR.package, 'protobuf_unittest')
self.assertFalse(unittest_pb2.DESCRIPTOR.serialized_pb is None)
self.assertEqual(unittest_pb2.DESCRIPTOR.dependencies,
[unittest_import_pb2.DESCRIPTOR])
self.assertEqual(unittest_import_pb2.DESCRIPTOR.dependencies,
[unittest_import_public_pb2.DESCRIPTOR])
self.assertEqual(unittest_import_pb2.DESCRIPTOR.public_dependencies,
[unittest_import_public_pb2.DESCRIPTOR])
def testNoGenericServices(self):
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage"))
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO"))
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "test_extension"))
# Make sure unittest_no_generic_services_pb2 has no services subclassing
# Proto2 Service class.
if hasattr(unittest_no_generic_services_pb2, "TestService"):
self.assertFalse(issubclass(unittest_no_generic_services_pb2.TestService,
service.Service))
def testMessageTypesByName(self):
file_type = unittest_pb2.DESCRIPTOR
self.assertEqual(
unittest_pb2._TESTALLTYPES,
file_type.message_types_by_name[unittest_pb2._TESTALLTYPES.name])
# Nested messages shouldn't be included in the message_types_by_name
# dictionary (like in the C++ API).
self.assertFalse(
unittest_pb2._TESTALLTYPES_NESTEDMESSAGE.name in
file_type.message_types_by_name)
def testEnumTypesByName(self):
file_type = unittest_pb2.DESCRIPTOR
self.assertEqual(
unittest_pb2._FOREIGNENUM,
file_type.enum_types_by_name[unittest_pb2._FOREIGNENUM.name])
def testExtensionsByName(self):
file_type = unittest_pb2.DESCRIPTOR
self.assertEqual(
unittest_pb2.my_extension_string,
file_type.extensions_by_name[unittest_pb2.my_extension_string.name])
def testPublicImports(self):
# Test public imports as embedded message.
all_type_proto = unittest_pb2.TestAllTypes()
self.assertEqual(0, all_type_proto.optional_public_import_message.e)
# PublicImportMessage is actually defined in unittest_import_public_pb2
# module, and is public imported by unittest_import_pb2 module.
public_import_proto = unittest_import_pb2.PublicImportMessage()
self.assertEqual(0, public_import_proto.e)
self.assertTrue(unittest_import_public_pb2.PublicImportMessage is
unittest_import_pb2.PublicImportMessage)
def testBadIdentifiers(self):
# We're just testing that the code was imported without problems.
message = test_bad_identifiers_pb2.TestBadIdentifiers()
self.assertEqual(message.Extensions[test_bad_identifiers_pb2.message],
"foo")
self.assertEqual(message.Extensions[test_bad_identifiers_pb2.descriptor],
"bar")
self.assertEqual(message.Extensions[test_bad_identifiers_pb2.reflection],
"baz")
self.assertEqual(message.Extensions[test_bad_identifiers_pb2.service],
"qux")
def testOneof(self):
desc = unittest_pb2.TestAllTypes.DESCRIPTOR
self.assertEqual(1, len(desc.oneofs))
self.assertEqual('oneof_field', desc.oneofs[0].name)
self.assertEqual(0, desc.oneofs[0].index)
self.assertIs(desc, desc.oneofs[0].containing_type)
self.assertIs(desc.oneofs[0], desc.oneofs_by_name['oneof_field'])
nested_names = set(['oneof_uint32', 'oneof_nested_message',
'oneof_string', 'oneof_bytes'])
self.assertEqual(
nested_names,
set([field.name for field in desc.oneofs[0].fields]))
for field_name, field_desc in desc.fields_by_name.items():
if field_name in nested_names:
self.assertIs(desc.oneofs[0], field_desc.containing_oneof)
else:
self.assertIsNone(field_desc.containing_oneof)
def testEnumWithDupValue(self):
self.assertEqual('FOO1',
unittest_pb2.TestEnumWithDupValue.Name(unittest_pb2.FOO1))
self.assertEqual('FOO1',
unittest_pb2.TestEnumWithDupValue.Name(unittest_pb2.FOO2))
self.assertEqual('BAR1',
unittest_pb2.TestEnumWithDupValue.Name(unittest_pb2.BAR1))
self.assertEqual('BAR1',
unittest_pb2.TestEnumWithDupValue.Name(unittest_pb2.BAR2))
class SymbolDatabaseRegistrationTest(unittest.TestCase):
"""Checks that messages, enums and files are correctly registered."""
def testGetSymbol(self):
self.assertEqual(
unittest_pb2.TestAllTypes, symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes'))
self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage,
symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes.NestedMessage'))
with self.assertRaises(KeyError):
symbol_database.Default().GetSymbol('protobuf_unittest.NestedMessage')
self.assertEqual(
unittest_pb2.TestAllTypes.OptionalGroup,
symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes.OptionalGroup'))
self.assertEqual(
unittest_pb2.TestAllTypes.RepeatedGroup,
symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes.RepeatedGroup'))
def testEnums(self):
self.assertEqual(
'protobuf_unittest.ForeignEnum',
symbol_database.Default().pool.FindEnumTypeByName(
'protobuf_unittest.ForeignEnum').full_name)
self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedEnum',
symbol_database.Default().pool.FindEnumTypeByName(
'protobuf_unittest.TestAllTypes.NestedEnum').full_name)
def testFindFileByName(self):
self.assertEqual(
'google/protobuf/unittest.proto',
symbol_database.Default().pool.FindFileByName(
'google/protobuf/unittest.proto').name)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Unittest for nested public imports."""
import unittest
from google.protobuf.internal.import_test_package import outer_pb2
class ImportTest(unittest.TestCase):
def testPackageInitializationImport(self):
"""Test that we can import nested import public messages."""
msg = outer_pb2.Outer()
self.assertEqual(58, msg.import_public_nested.value)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,33 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Sample module importing a nested proto from itself."""
from google.protobuf.internal.import_test_package import outer_pb2 as myproto

View File

@@ -0,0 +1,40 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// A proto file which is imported by inner.proto to test public importing.
syntax = "proto2";
package google.protobuf.python.internal.import_test_package;
option optimize_for = SPEED;
// Test nested public import
import public "google/protobuf/internal/import_test_package/import_public_nested.proto";

View File

@@ -0,0 +1,40 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// A proto file which is imported by import_public.proto to test nested public
// importing.
syntax = "proto2";
package google.protobuf.python.internal.import_test_package;
message ImportPublicNestedMessage {
optional int32 value = 1 [default = 58];
}

View File

@@ -0,0 +1,40 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto2";
package google.protobuf.python.internal.import_test_package;
// Test public import
import public "google/protobuf/internal/import_test_package/import_public.proto";
message Inner {
optional int32 value = 1 [default = 57];
}

View File

@@ -0,0 +1,40 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto2";
package google.protobuf.python.internal.import_test_package;
import "google/protobuf/internal/import_test_package/inner.proto";
message Outer {
optional Inner inner = 1;
optional ImportPublicNestedMessage import_public_nested = 2;
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,103 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests for google.protobuf.internal.keywords."""
import unittest
from google.protobuf.internal import more_messages_pb2
from google.protobuf import descriptor_pool
class KeywordsConflictTest(unittest.TestCase):
def setUp(self):
super(KeywordsConflictTest, self).setUp()
self.pool = descriptor_pool.Default()
def testMessage(self):
message = getattr(more_messages_pb2, 'class')()
message.int_field = 123
self.assertEqual(message.int_field, 123)
des = self.pool.FindMessageTypeByName('google.protobuf.internal.class')
self.assertEqual(des.name, 'class')
def testNestedMessage(self):
message = getattr(more_messages_pb2, 'class')()
message.nested_message.field = 234
self.assertEqual(message.nested_message.field, 234)
des = self.pool.FindMessageTypeByName('google.protobuf.internal.class.try')
self.assertEqual(des.name, 'try')
def testField(self):
message = getattr(more_messages_pb2, 'class')()
setattr(message, 'if', 123)
setattr(message, 'as', 1)
self.assertEqual(getattr(message, 'if'), 123)
self.assertEqual(getattr(message, 'as'), 1)
def testEnum(self):
class_ = getattr(more_messages_pb2, 'class')
message = class_()
# Normal enum value.
message.enum_field = more_messages_pb2.default
self.assertEqual(message.enum_field, more_messages_pb2.default)
# Top level enum value.
message.enum_field = getattr(more_messages_pb2, 'else')
self.assertEqual(message.enum_field, 1)
# Nested enum value
message.nested_enum_field = getattr(class_, 'True')
self.assertEqual(message.nested_enum_field, 1)
def testExtension(self):
message = getattr(more_messages_pb2, 'class')()
# Top level extension
extension1 = getattr(more_messages_pb2, 'continue')
message.Extensions[extension1] = 456
self.assertEqual(message.Extensions[extension1], 456)
# None top level extension
extension2 = getattr(more_messages_pb2.ExtendClass, 'return')
message.Extensions[extension2] = 789
self.assertEqual(message.Extensions[extension2], 789)
def testExtensionForNestedMessage(self):
message = getattr(more_messages_pb2, 'class')()
extension = getattr(more_messages_pb2, 'with')
message.nested_message.Extensions[extension] = 999
self.assertEqual(message.nested_message.Extensions[extension], 999)
def TestFullKeywordUsed(self):
message = more_messages_pb2.TestFullKeyword()
message.field2.int_field = 123
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,309 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests for google.protobuf.message_factory."""
__author__ = 'matthewtoia@google.com (Matt Toia)'
import unittest
from google.protobuf import descriptor_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf.internal import testing_refleaks
from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
from google.protobuf import message_factory
@testing_refleaks.TestCase
class MessageFactoryTest(unittest.TestCase):
def setUp(self):
self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
factory_test1_pb2.DESCRIPTOR.serialized_pb)
self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
factory_test2_pb2.DESCRIPTOR.serialized_pb)
def _ExerciseDynamicClass(self, cls):
msg = cls()
msg.mandatory = 42
msg.nested_factory_2_enum = 0
msg.nested_factory_2_message.value = 'nested message value'
msg.factory_1_message.factory_1_enum = 1
msg.factory_1_message.nested_factory_1_enum = 0
msg.factory_1_message.nested_factory_1_message.value = (
'nested message value')
msg.factory_1_message.scalar_value = 22
msg.factory_1_message.list_value.extend([u'one', u'two', u'three'])
msg.factory_1_message.list_value.append(u'four')
msg.factory_1_enum = 1
msg.nested_factory_1_enum = 0
msg.nested_factory_1_message.value = 'nested message value'
msg.circular_message.mandatory = 1
msg.circular_message.circular_message.mandatory = 2
msg.circular_message.scalar_value = 'one deep'
msg.scalar_value = 'zero deep'
msg.list_value.extend([u'four', u'three', u'two'])
msg.list_value.append(u'one')
msg.grouped.add()
msg.grouped[0].part_1 = 'hello'
msg.grouped[0].part_2 = 'world'
msg.grouped.add(part_1='testing', part_2='123')
msg.loop.loop.mandatory = 2
msg.loop.loop.loop.loop.mandatory = 4
serialized = msg.SerializeToString()
converted = factory_test2_pb2.Factory2Message.FromString(serialized)
reserialized = converted.SerializeToString()
self.assertEqual(serialized, reserialized)
result = cls.FromString(reserialized)
self.assertEqual(msg, result)
def testGetPrototype(self):
db = descriptor_database.DescriptorDatabase()
pool = descriptor_pool.DescriptorPool(db)
db.Add(self.factory_test1_fd)
db.Add(self.factory_test2_fd)
factory = message_factory.MessageFactory()
cls = factory.GetPrototype(pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message'))
self.assertFalse(cls is factory_test2_pb2.Factory2Message)
self._ExerciseDynamicClass(cls)
cls2 = factory.GetPrototype(pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message'))
self.assertTrue(cls is cls2)
def testCreatePrototypeOverride(self):
class MyMessageFactory(message_factory.MessageFactory):
def CreatePrototype(self, descriptor):
cls = super(MyMessageFactory, self).CreatePrototype(descriptor)
cls.additional_field = 'Some value'
return cls
db = descriptor_database.DescriptorDatabase()
pool = descriptor_pool.DescriptorPool(db)
db.Add(self.factory_test1_fd)
db.Add(self.factory_test2_fd)
factory = MyMessageFactory()
cls = factory.GetPrototype(pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message'))
self.assertTrue(hasattr(cls, 'additional_field'))
def testGetExistingPrototype(self):
factory = message_factory.MessageFactory()
# Get Existing Prototype should not create a new class.
cls = factory.GetPrototype(
descriptor=factory_test2_pb2.Factory2Message.DESCRIPTOR)
msg = factory_test2_pb2.Factory2Message()
self.assertIsInstance(msg, cls)
self.assertIsInstance(msg.factory_1_message,
factory_test1_pb2.Factory1Message)
def testGetMessages(self):
# performed twice because multiple calls with the same input must be allowed
for _ in range(2):
# GetMessage should work regardless of the order the FileDescriptorProto
# are provided. In particular, the function should succeed when the files
# are not in the topological order of dependencies.
# Assuming factory_test2_fd depends on factory_test1_fd.
self.assertIn(self.factory_test1_fd.name,
self.factory_test2_fd.dependency)
# Get messages should work when a file comes before its dependencies:
# factory_test2_fd comes before factory_test1_fd.
messages = message_factory.GetMessages([self.factory_test2_fd,
self.factory_test1_fd])
self.assertTrue(
set(['google.protobuf.python.internal.Factory2Message',
'google.protobuf.python.internal.Factory1Message'],
).issubset(set(messages.keys())))
self._ExerciseDynamicClass(
messages['google.protobuf.python.internal.Factory2Message'])
factory_msg1 = messages['google.protobuf.python.internal.Factory1Message']
self.assertTrue(set(
['google.protobuf.python.internal.Factory2Message.one_more_field',
'google.protobuf.python.internal.another_field'],).issubset(set(
ext.full_name
for ext in factory_msg1.DESCRIPTOR.file.pool.FindAllExtensions(
factory_msg1.DESCRIPTOR))))
msg1 = messages['google.protobuf.python.internal.Factory1Message']()
ext1 = msg1.Extensions._FindExtensionByName(
'google.protobuf.python.internal.Factory2Message.one_more_field')
ext2 = msg1.Extensions._FindExtensionByName(
'google.protobuf.python.internal.another_field')
self.assertEqual(0, len(msg1.Extensions))
msg1.Extensions[ext1] = 'test1'
msg1.Extensions[ext2] = 'test2'
self.assertEqual('test1', msg1.Extensions[ext1])
self.assertEqual('test2', msg1.Extensions[ext2])
self.assertEqual(None,
msg1.Extensions._FindExtensionByNumber(12321))
self.assertEqual(2, len(msg1.Extensions))
if api_implementation.Type() == 'cpp':
self.assertRaises(TypeError,
msg1.Extensions._FindExtensionByName, 0)
self.assertRaises(TypeError,
msg1.Extensions._FindExtensionByNumber, '')
else:
self.assertEqual(None,
msg1.Extensions._FindExtensionByName(0))
self.assertEqual(None,
msg1.Extensions._FindExtensionByNumber(''))
def testDuplicateExtensionNumber(self):
pool = descriptor_pool.DescriptorPool()
factory = message_factory.MessageFactory(pool=pool)
# Add Container message.
f = descriptor_pb2.FileDescriptorProto(
name='google/protobuf/internal/container.proto',
package='google.protobuf.python.internal')
f.message_type.add(name='Container').extension_range.add(start=1, end=10)
pool.Add(f)
msgs = factory.GetMessages([f.name])
self.assertIn('google.protobuf.python.internal.Container', msgs)
# Extend container.
f = descriptor_pb2.FileDescriptorProto(
name='google/protobuf/internal/extension.proto',
package='google.protobuf.python.internal',
dependency=['google/protobuf/internal/container.proto'])
msg = f.message_type.add(name='Extension')
msg.extension.add(
name='extension_field',
number=2,
label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL,
type_name='Extension',
extendee='Container')
pool.Add(f)
msgs = factory.GetMessages([f.name])
self.assertIn('google.protobuf.python.internal.Extension', msgs)
# Add Duplicate extending the same field number.
f = descriptor_pb2.FileDescriptorProto(
name='google/protobuf/internal/duplicate.proto',
package='google.protobuf.python.internal',
dependency=['google/protobuf/internal/container.proto'])
msg = f.message_type.add(name='Duplicate')
msg.extension.add(
name='extension_field',
number=2,
label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL,
type_name='Duplicate',
extendee='Container')
pool.Add(f)
with self.assertRaises(Exception) as cm:
factory.GetMessages([f.name])
self.assertIn(str(cm.exception),
['Extensions '
'"google.protobuf.python.internal.Duplicate.extension_field" and'
' "google.protobuf.python.internal.Extension.extension_field"'
' both try to extend message type'
' "google.protobuf.python.internal.Container"'
' with field number 2.',
'Double registration of Extensions'])
def testExtensionValueInDifferentFile(self):
# Add Container message.
f1 = descriptor_pb2.FileDescriptorProto(
name='google/protobuf/internal/container.proto',
package='google.protobuf.python.internal')
f1.message_type.add(name='Container').extension_range.add(start=1, end=10)
# Add ValueType message.
f2 = descriptor_pb2.FileDescriptorProto(
name='google/protobuf/internal/value_type.proto',
package='google.protobuf.python.internal')
f2.message_type.add(name='ValueType').field.add(
name='setting',
number=1,
label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL,
type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32,
default_value='123')
# Extend container with field of ValueType.
f3 = descriptor_pb2.FileDescriptorProto(
name='google/protobuf/internal/extension.proto',
package='google.protobuf.python.internal',
dependency=[f1.name, f2.name])
f3.extension.add(
name='top_level_extension_field',
number=2,
label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL,
type_name='ValueType',
extendee='Container')
f3.message_type.add(name='Extension').extension.add(
name='nested_extension_field',
number=3,
label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL,
type_name='ValueType',
extendee='Container')
class SimpleDescriptorDB:
def __init__(self, files):
self._files = files
def FindFileByName(self, name):
return self._files[name]
db = SimpleDescriptorDB({f1.name: f1, f2.name: f2, f3.name: f3})
pool = descriptor_pool.DescriptorPool(db)
factory = message_factory.MessageFactory(pool=pool)
msgs = factory.GetMessages([f1.name, f3.name]) # Deliberately not f2.
msg = msgs['google.protobuf.python.internal.Container']
desc = msgs['google.protobuf.python.internal.Extension'].DESCRIPTOR
ext1 = desc.file.extensions_by_name['top_level_extension_field']
ext2 = desc.extensions_by_name['nested_extension_field']
m = msg()
m.Extensions[ext1].setting = 234
m.Extensions[ext2].setting = 345
serialized = m.SerializeToString()
pool = descriptor_pool.DescriptorPool(db)
factory = message_factory.MessageFactory(pool=pool)
msgs = factory.GetMessages([f1.name, f3.name]) # Deliberately not f2.
msg = msgs['google.protobuf.python.internal.Container']
desc = msgs['google.protobuf.python.internal.Extension'].DESCRIPTOR
ext1 = desc.file.extensions_by_name['top_level_extension_field']
ext2 = desc.extensions_by_name['nested_extension_field']
m = msg.FromString(serialized)
self.assertEqual(2, len(m.ListFields()))
self.assertEqual(234, m.Extensions[ext1].setting)
self.assertEqual(345, m.Extensions[ext2].setting)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,78 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Defines a listener interface for observing certain
state transitions on Message objects.
Also defines a null implementation of this interface.
"""
__author__ = 'robinson@google.com (Will Robinson)'
class MessageListener(object):
"""Listens for modifications made to a message. Meant to be registered via
Message._SetListener().
Attributes:
dirty: If True, then calling Modified() would be a no-op. This can be
used to avoid these calls entirely in the common case.
"""
def Modified(self):
"""Called every time the message is modified in such a way that the parent
message may need to be updated. This currently means either:
(a) The message was modified for the first time, so the parent message
should henceforth mark the message as present.
(b) The message's cached byte size became dirty -- i.e. the message was
modified for the first time after a previous call to ByteSize().
Therefore the parent should also mark its byte size as dirty.
Note that (a) implies (b), since new objects start out with a client cached
size (zero). However, we document (a) explicitly because it is important.
Modified() will *only* be called in response to one of these two events --
not every time the sub-message is modified.
Note that if the listener's |dirty| attribute is true, then calling
Modified at the moment would be a no-op, so it can be skipped. Performance-
sensitive callers should check this attribute directly before calling since
it will be true most of the time.
"""
raise NotImplementedError
class NullMessageListener(object):
"""No-op MessageListener implementation."""
def Modified(self):
pass

View File

@@ -0,0 +1,74 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// This file contains messages that extend MessageSet.
syntax = "proto2";
package google.protobuf.internal;
// A message with message_set_wire_format.
message TestMessageSet {
option message_set_wire_format = true;
extensions 4 to max;
}
message TestMessageSetExtension1 {
extend TestMessageSet {
optional TestMessageSetExtension1 message_set_extension = 98418603;
}
optional int32 i = 15;
}
message TestMessageSetExtension2 {
extend TestMessageSet {
optional TestMessageSetExtension2 message_set_extension = 98418634;
}
optional string str = 25;
}
message TestMessageSetExtension3 {
optional string text = 35;
}
extend TestMessageSet {
optional TestMessageSetExtension3 message_set_extension3 = 98418655;
}
// This message was used to generate
// //net/proto2/python/internal/testdata/message_set_message, but is commented
// out since it must not actually exist in code, to simulate an "unknown"
// extension.
// message TestMessageSetUnknownExtension {
// extend TestMessageSet {
// optional TestMessageSetUnknownExtension message_set_extension = 56141421;
// }
// optional int64 a = 1;
// }

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,54 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto2";
package google.protobuf.python.internal;
message TestEnumValues {
enum NestedEnum {
ZERO = 0;
ONE = 1;
}
optional NestedEnum optional_nested_enum = 1;
repeated NestedEnum repeated_nested_enum = 2;
repeated NestedEnum packed_nested_enum = 3 [packed = true];
}
message TestMissingEnumValues {
enum NestedEnum { TWO = 2; }
optional NestedEnum optional_nested_enum = 1;
repeated NestedEnum repeated_nested_enum = 2;
repeated NestedEnum packed_nested_enum = 3 [packed = true];
}
message JustString {
required string dummy = 1;
}

View File

@@ -0,0 +1,62 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: robinson@google.com (Will Robinson)
syntax = "proto2";
package google.protobuf.internal;
message TopLevelMessage {
optional ExtendedMessage submessage = 1 [lazy = true];
optional NestedMessage nested_message = 2 [lazy = true];
}
message NestedMessage {
optional ExtendedMessage submessage = 1 [lazy = true];
}
message ExtendedMessage {
optional int32 optional_int32 = 1001;
repeated string repeated_string = 1002;
extensions 1 to 999;
}
message ForeignMessage {
optional int32 foreign_message_int = 1;
}
extend ExtendedMessage {
optional int32 optional_int_extension = 1;
optional ForeignMessage optional_message_extension = 2;
repeated int32 repeated_int_extension = 3;
repeated ForeignMessage repeated_message_extension = 4;
}

View File

@@ -0,0 +1,51 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: jasonh@google.com (Jason Hsueh)
//
// This file is used to test a corner case in the CPP implementation where the
// generated C++ type is available for the extendee, but the extension is
// defined in a file whose C++ type is not in the binary.
syntax = "proto2";
import "google/protobuf/internal/more_extensions.proto";
package google.protobuf.internal;
message DynamicMessageType {
optional int32 a = 1;
}
extend ExtendedMessage {
optional int32 dynamic_int32_extension = 100;
optional DynamicMessageType dynamic_message_extension = 101;
repeated DynamicMessageType repeated_dynamic_message_extension = 102;
}

View File

@@ -0,0 +1,360 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: robinson@google.com (Will Robinson)
// LINT: LEGACY_NAMES
syntax = "proto2";
package google.protobuf.internal;
// A message where tag numbers are listed out of order, to allow us to test our
// canonicalization of serialized output, which should always be in tag order.
// We also mix in some extensions for extra fun.
message OutOfOrderFields {
optional sint32 optional_sint32 = 5;
extensions 4 to 4;
optional uint32 optional_uint32 = 3;
extensions 2 to 2;
optional int32 optional_int32 = 1;
}
extend OutOfOrderFields {
optional uint64 optional_uint64 = 4;
optional int64 optional_int64 = 2;
}
enum is { // top level enum keyword
default = 0;
else = 1; // top level enum value keyword
}
message class { // message keyword
optional int32 int_field = 1 [json_name = "json_int"];
optional int32 if = 2; // field keyword
optional is as = 3; // enum field keyword
optional is enum_field = 4;
enum for { // nested enum keyword
default = 0;
True = 1; // nested enum value keyword
}
optional for nested_enum_field = 5;
message try {
optional int32 field = 1;
extensions 999 to 9999;
}
optional try
nested_message = 6;
extensions 999 to 9999;
}
extend class {
optional int32 continue = 1001; // top level extension keyword
}
extend class.try {
optional int32 with = 1001;
}
message ExtendClass {
extend class {
optional int32 return = 1002; // nested extension keyword
}
}
message TestFullKeyword {
optional google.protobuf.internal.OutOfOrderFields field1 = 1;
optional google.protobuf.internal.class field2 = 2;
}
// TODO(jieluo): Add keyword support for service.
// service False {
// rpc Bar(class) returns (class);
// }
message LotsNestedMessage {
message B0 {}
message B1 {}
message B2 {}
message B3 {}
message B4 {}
message B5 {}
message B6 {}
message B7 {}
message B8 {}
message B9 {}
message B10 {}
message B11 {}
message B12 {}
message B13 {}
message B14 {}
message B15 {}
message B16 {}
message B17 {}
message B18 {}
message B19 {}
message B20 {}
message B21 {}
message B22 {}
message B23 {}
message B24 {}
message B25 {}
message B26 {}
message B27 {}
message B28 {}
message B29 {}
message B30 {}
message B31 {}
message B32 {}
message B33 {}
message B34 {}
message B35 {}
message B36 {}
message B37 {}
message B38 {}
message B39 {}
message B40 {}
message B41 {}
message B42 {}
message B43 {}
message B44 {}
message B45 {}
message B46 {}
message B47 {}
message B48 {}
message B49 {}
message B50 {}
message B51 {}
message B52 {}
message B53 {}
message B54 {}
message B55 {}
message B56 {}
message B57 {}
message B58 {}
message B59 {}
message B60 {}
message B61 {}
message B62 {}
message B63 {}
message B64 {}
message B65 {}
message B66 {}
message B67 {}
message B68 {}
message B69 {}
message B70 {}
message B71 {}
message B72 {}
message B73 {}
message B74 {}
message B75 {}
message B76 {}
message B77 {}
message B78 {}
message B79 {}
message B80 {}
message B81 {}
message B82 {}
message B83 {}
message B84 {}
message B85 {}
message B86 {}
message B87 {}
message B88 {}
message B89 {}
message B90 {}
message B91 {}
message B92 {}
message B93 {}
message B94 {}
message B95 {}
message B96 {}
message B97 {}
message B98 {}
message B99 {}
message B100 {}
message B101 {}
message B102 {}
message B103 {}
message B104 {}
message B105 {}
message B106 {}
message B107 {}
message B108 {}
message B109 {}
message B110 {}
message B111 {}
message B112 {}
message B113 {}
message B114 {}
message B115 {}
message B116 {}
message B117 {}
message B118 {}
message B119 {}
message B120 {}
message B121 {}
message B122 {}
message B123 {}
message B124 {}
message B125 {}
message B126 {}
message B127 {}
message B128 {}
message B129 {}
message B130 {}
message B131 {}
message B132 {}
message B133 {}
message B134 {}
message B135 {}
message B136 {}
message B137 {}
message B138 {}
message B139 {}
message B140 {}
message B141 {}
message B142 {}
message B143 {}
message B144 {}
message B145 {}
message B146 {}
message B147 {}
message B148 {}
message B149 {}
message B150 {}
message B151 {}
message B152 {}
message B153 {}
message B154 {}
message B155 {}
message B156 {}
message B157 {}
message B158 {}
message B159 {}
message B160 {}
message B161 {}
message B162 {}
message B163 {}
message B164 {}
message B165 {}
message B166 {}
message B167 {}
message B168 {}
message B169 {}
message B170 {}
message B171 {}
message B172 {}
message B173 {}
message B174 {}
message B175 {}
message B176 {}
message B177 {}
message B178 {}
message B179 {}
message B180 {}
message B181 {}
message B182 {}
message B183 {}
message B184 {}
message B185 {}
message B186 {}
message B187 {}
message B188 {}
message B189 {}
message B190 {}
message B191 {}
message B192 {}
message B193 {}
message B194 {}
message B195 {}
message B196 {}
message B197 {}
message B198 {}
message B199 {}
message B200 {}
message B201 {}
message B202 {}
message B203 {}
message B204 {}
message B205 {}
message B206 {}
message B207 {}
message B208 {}
message B209 {}
message B210 {}
message B211 {}
message B212 {}
message B213 {}
message B214 {}
message B215 {}
message B216 {}
message B217 {}
message B218 {}
message B219 {}
message B220 {}
message B221 {}
message B222 {}
message B223 {}
message B224 {}
message B225 {}
message B226 {}
message B227 {}
message B228 {}
message B229 {}
message B230 {}
message B231 {}
message B232 {}
message B233 {}
message B234 {}
message B235 {}
message B236 {}
message B237 {}
message B238 {}
message B239 {}
message B240 {}
message B241 {}
message B242 {}
message B243 {}
message B244 {}
message B245 {}
message B246 {}
message B247 {}
message B248 {}
message B249 {}
message B250 {}
message B251 {}
message B252 {}
message B253 {}
message B254 {}
message B255 {}
}

View File

@@ -0,0 +1,40 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto2";
enum NoPackageEnum {
NO_PACKAGE_VALUE_0 = 0;
NO_PACKAGE_VALUE_1 = 1;
}
message NoPackageMessage {
optional NoPackageEnum no_package_enum = 1;
}

View File

@@ -0,0 +1,215 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test use of numpy types with repeated and non-repeated scalar fields."""
import unittest
import numpy as np
from google.protobuf import unittest_pb2
from google.protobuf.internal import testing_refleaks
message = unittest_pb2.TestAllTypes()
np_float_scalar = np.float64(0.0)
np_1_float_array = np.zeros(shape=(1,), dtype=np.float64)
np_2_float_array = np.zeros(shape=(2,), dtype=np.float64)
np_11_float_array = np.zeros(shape=(1, 1), dtype=np.float64)
np_22_float_array = np.zeros(shape=(2, 2), dtype=np.float64)
np_int_scalar = np.int64(0)
np_1_int_array = np.zeros(shape=(1,), dtype=np.int64)
np_2_int_array = np.zeros(shape=(2,), dtype=np.int64)
np_11_int_array = np.zeros(shape=(1, 1), dtype=np.int64)
np_22_int_array = np.zeros(shape=(2, 2), dtype=np.int64)
np_uint_scalar = np.uint64(0)
np_1_uint_array = np.zeros(shape=(1,), dtype=np.uint64)
np_2_uint_array = np.zeros(shape=(2,), dtype=np.uint64)
np_11_uint_array = np.zeros(shape=(1, 1), dtype=np.uint64)
np_22_uint_array = np.zeros(shape=(2, 2), dtype=np.uint64)
np_bool_scalar = np.bool_(False)
np_1_bool_array = np.zeros(shape=(1,), dtype=np.bool_)
np_2_bool_array = np.zeros(shape=(2,), dtype=np.bool_)
np_11_bool_array = np.zeros(shape=(1, 1), dtype=np.bool_)
np_22_bool_array = np.zeros(shape=(2, 2), dtype=np.bool_)
@testing_refleaks.TestCase
class NumpyIntProtoTest(unittest.TestCase):
# Assigning dim 1 ndarray of ints to repeated field should pass
def testNumpyDim1IntArrayToRepeated_IsValid(self):
message.repeated_int64[:] = np_1_int_array
message.repeated_int64[:] = np_2_int_array
message.repeated_uint64[:] = np_1_uint_array
message.repeated_uint64[:] = np_2_uint_array
# Assigning dim 2 ndarray of ints to repeated field should fail
def testNumpyDim2IntArrayToRepeated_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.repeated_int64[:] = np_11_int_array
with self.assertRaises(TypeError):
message.repeated_int64[:] = np_22_int_array
with self.assertRaises(TypeError):
message.repeated_uint64[:] = np_11_uint_array
with self.assertRaises(TypeError):
message.repeated_uint64[:] = np_22_uint_array
# Assigning any ndarray of floats to repeated int field should fail
def testNumpyFloatArrayToRepeated_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.repeated_int64[:] = np_1_float_array
with self.assertRaises(TypeError):
message.repeated_int64[:] = np_11_float_array
with self.assertRaises(TypeError):
message.repeated_int64[:] = np_22_float_array
# Assigning any np int to scalar field should pass
def testNumpyIntScalarToScalar_IsValid(self):
message.optional_int64 = np_int_scalar
message.optional_uint64 = np_uint_scalar
# Assigning any ndarray of ints to scalar field should fail
def testNumpyIntArrayToScalar_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.optional_int64 = np_1_int_array
with self.assertRaises(TypeError):
message.optional_int64 = np_11_int_array
with self.assertRaises(TypeError):
message.optional_int64 = np_22_int_array
with self.assertRaises(TypeError):
message.optional_uint64 = np_1_uint_array
with self.assertRaises(TypeError):
message.optional_uint64 = np_11_uint_array
with self.assertRaises(TypeError):
message.optional_uint64 = np_22_uint_array
# Assigning any ndarray of floats to scalar field should fail
def testNumpyFloatArrayToScalar_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.optional_int64 = np_1_float_array
with self.assertRaises(TypeError):
message.optional_int64 = np_11_float_array
with self.assertRaises(TypeError):
message.optional_int64 = np_22_float_array
@testing_refleaks.TestCase
class NumpyFloatProtoTest(unittest.TestCase):
# Assigning dim 1 ndarray of floats to repeated field should pass
def testNumpyDim1FloatArrayToRepeated_IsValid(self):
message.repeated_float[:] = np_1_float_array
message.repeated_float[:] = np_2_float_array
# Assigning dim 2 ndarray of floats to repeated field should fail
def testNumpyDim2FloatArrayToRepeated_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.repeated_float[:] = np_11_float_array
with self.assertRaises(TypeError):
message.repeated_float[:] = np_22_float_array
# Assigning any np float to scalar field should pass
def testNumpyFloatScalarToScalar_IsValid(self):
message.optional_float = np_float_scalar
# Assigning any ndarray of float to scalar field should fail
def testNumpyFloatArrayToScalar_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.optional_float = np_1_float_array
with self.assertRaises(TypeError):
message.optional_float = np_11_float_array
with self.assertRaises(TypeError):
message.optional_float = np_22_float_array
@testing_refleaks.TestCase
class NumpyBoolProtoTest(unittest.TestCase):
# Assigning dim 1 ndarray of bool to repeated field should pass
def testNumpyDim1BoolArrayToRepeated_IsValid(self):
message.repeated_bool[:] = np_1_bool_array
message.repeated_bool[:] = np_2_bool_array
# Assigning dim 2 ndarray of bool to repeated field should fail
def testNumpyDim2BoolArrayToRepeated_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.repeated_bool[:] = np_11_bool_array
with self.assertRaises(TypeError):
message.repeated_bool[:] = np_22_bool_array
# Assigning any np bool to scalar field should pass
def testNumpyBoolScalarToScalar_IsValid(self):
message.optional_bool = np_bool_scalar
# Assigning any ndarray of bool to scalar field should fail
def testNumpyBoolArrayToScalar_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.optional_bool = np_1_bool_array
with self.assertRaises(TypeError):
message.optional_bool = np_11_bool_array
with self.assertRaises(TypeError):
message.optional_bool = np_22_bool_array
@testing_refleaks.TestCase
class NumpyProtoIndexingTest(unittest.TestCase):
def testNumpyIntScalarIndexing_Passes(self):
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
self.assertEqual(0, data.repeated_int64[np.int64(0)])
def testNumpyNegative1IntScalarIndexing_Passes(self):
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
self.assertEqual(2, data.repeated_int64[np.int64(-1)])
def testNumpyFloatScalarIndexing_Fails(self):
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
with self.assertRaises(TypeError):
_ = data.repeated_int64[np.float64(0.0)]
def testNumpyIntArrayIndexing_Fails(self):
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
with self.assertRaises(TypeError):
_ = data.repeated_int64[np.array([0])]
with self.assertRaises(TypeError):
_ = data.repeated_int64[np.ndarray((1,), buffer=np.array([0]), dtype=int)]
with self.assertRaises(TypeError):
_ = data.repeated_int64[np.ndarray((1, 1),
buffer=np.array([0]),
dtype=int)]
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,73 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto3";
package google.protobuf.python.internal;
message TestPackedTypes {
enum NestedEnum {
FOO = 0;
BAR = 1;
BAZ = 2;
}
repeated int32 repeated_int32 = 1;
repeated int64 repeated_int64 = 2;
repeated uint32 repeated_uint32 = 3;
repeated uint64 repeated_uint64 = 4;
repeated sint32 repeated_sint32 = 5;
repeated sint64 repeated_sint64 = 6;
repeated fixed32 repeated_fixed32 = 7;
repeated fixed64 repeated_fixed64 = 8;
repeated sfixed32 repeated_sfixed32 = 9;
repeated sfixed64 repeated_sfixed64 = 10;
repeated float repeated_float = 11;
repeated double repeated_double = 12;
repeated bool repeated_bool = 13;
repeated NestedEnum repeated_nested_enum = 14;
}
message TestUnpackedTypes {
repeated int32 repeated_int32 = 1 [packed = false];
repeated int64 repeated_int64 = 2 [packed = false];
repeated uint32 repeated_uint32 = 3 [packed = false];
repeated uint64 repeated_uint64 = 4 [packed = false];
repeated sint32 repeated_sint32 = 5 [packed = false];
repeated sint64 repeated_sint64 = 6 [packed = false];
repeated fixed32 repeated_fixed32 = 7 [packed = false];
repeated fixed64 repeated_fixed64 = 8 [packed = false];
repeated sfixed32 repeated_sfixed32 = 9 [packed = false];
repeated sfixed64 repeated_sfixed64 = 10 [packed = false];
repeated float repeated_float = 11 [packed = false];
repeated double repeated_double = 12 [packed = false];
repeated bool repeated_bool = 13 [packed = false];
repeated TestPackedTypes.NestedEnum repeated_nested_enum = 14 [packed = false];
}

View File

@@ -0,0 +1,106 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests for google.protobuf.proto_builder."""
import collections
import unittest
from google.protobuf import descriptor_pb2 # pylint: disable=g-import-not-at-top
from google.protobuf import descriptor
from google.protobuf import descriptor_pool
from google.protobuf import proto_builder
from google.protobuf import text_format
class ProtoBuilderTest(unittest.TestCase):
def setUp(self):
self.ordered_fields = collections.OrderedDict([
('foo', descriptor_pb2.FieldDescriptorProto.TYPE_INT64),
('bar', descriptor_pb2.FieldDescriptorProto.TYPE_STRING),
])
self._fields = dict(self.ordered_fields)
def testMakeSimpleProtoClass(self):
"""Test that we can create a proto class."""
proto_cls = proto_builder.MakeSimpleProtoClass(
self._fields,
full_name='net.proto2.python.public.proto_builder_test.Test')
proto = proto_cls()
proto.foo = 12345
proto.bar = 'asdf'
self.assertMultiLineEqual(
'bar: "asdf"\nfoo: 12345\n', text_format.MessageToString(proto))
def testOrderedFields(self):
"""Test that the field order is maintained when given an OrderedDict."""
proto_cls = proto_builder.MakeSimpleProtoClass(
self.ordered_fields,
full_name='net.proto2.python.public.proto_builder_test.OrderedTest')
proto = proto_cls()
proto.foo = 12345
proto.bar = 'asdf'
self.assertMultiLineEqual(
'foo: 12345\nbar: "asdf"\n', text_format.MessageToString(proto))
def testMakeSameProtoClassTwice(self):
"""Test that the DescriptorPool is used."""
pool = descriptor_pool.DescriptorPool()
proto_cls1 = proto_builder.MakeSimpleProtoClass(
self._fields,
full_name='net.proto2.python.public.proto_builder_test.Test',
pool=pool)
proto_cls2 = proto_builder.MakeSimpleProtoClass(
self._fields,
full_name='net.proto2.python.public.proto_builder_test.Test',
pool=pool)
self.assertIs(proto_cls1.DESCRIPTOR, proto_cls2.DESCRIPTOR)
def testMakeLargeProtoClass(self):
"""Test that large created protos don't use reserved field numbers."""
num_fields = 123456
fields = {
'foo%d' % i: descriptor_pb2.FieldDescriptorProto.TYPE_INT64
for i in range(num_fields)
}
proto_cls = proto_builder.MakeSimpleProtoClass(
fields,
full_name='net.proto2.python.public.proto_builder_test.LargeProtoTest')
reserved_field_numbers = set(
range(descriptor.FieldDescriptor.FIRST_RESERVED_FIELD_NUMBER,
descriptor.FieldDescriptor.LAST_RESERVED_FIELD_NUMBER + 1))
proto_field_numbers = set(proto_cls.DESCRIPTOR.fields_by_number)
self.assertFalse(reserved_field_numbers.intersection(proto_field_numbers))
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,63 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: qrczak@google.com (Marcin Kowalczyk)
#include "google/protobuf/python/python_protobuf.h"
namespace google {
namespace protobuf {
namespace python {
static const Message* GetCProtoInsidePyProtoStub(PyObject* msg) {
return nullptr;
}
static Message* MutableCProtoInsidePyProtoStub(PyObject* msg) {
return nullptr;
}
// This is initialized with a default, stub implementation.
// If python-google.protobuf.cc is loaded, the function pointer is overridden
// with a full implementation.
const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg) =
GetCProtoInsidePyProtoStub;
Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg) =
MutableCProtoInsidePyProtoStub;
const Message* GetCProtoInsidePyProto(PyObject* msg) {
return GetCProtoInsidePyProtoPtr(msg);
}
Message* MutableCProtoInsidePyProto(PyObject* msg) {
return MutableCProtoInsidePyProtoPtr(msg);
}
} // namespace python
} // namespace protobuf
} // namespace google

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,139 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests for google.protobuf.internal.service_reflection."""
__author__ = 'petar@google.com (Petar Petrov)'
import unittest
from google.protobuf import unittest_pb2
from google.protobuf import service_reflection
from google.protobuf import service
class FooUnitTest(unittest.TestCase):
def testService(self):
class MockRpcChannel(service.RpcChannel):
def CallMethod(self, method, controller, request, response, callback):
self.method = method
self.controller = controller
self.request = request
callback(response)
class MockRpcController(service.RpcController):
def SetFailed(self, msg):
self.failure_message = msg
self.callback_response = None
class MyService(unittest_pb2.TestService):
pass
self.callback_response = None
def MyCallback(response):
self.callback_response = response
rpc_controller = MockRpcController()
channel = MockRpcChannel()
srvc = MyService()
srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback)
self.assertEqual('Method Foo not implemented.',
rpc_controller.failure_message)
self.assertEqual(None, self.callback_response)
rpc_controller.failure_message = None
service_descriptor = unittest_pb2.TestService.GetDescriptor()
srvc.CallMethod(service_descriptor.methods[1], rpc_controller,
unittest_pb2.BarRequest(), MyCallback)
self.assertTrue(srvc.GetRequestClass(service_descriptor.methods[1]) is
unittest_pb2.BarRequest)
self.assertTrue(srvc.GetResponseClass(service_descriptor.methods[1]) is
unittest_pb2.BarResponse)
self.assertEqual('Method Bar not implemented.',
rpc_controller.failure_message)
self.assertEqual(None, self.callback_response)
class MyServiceImpl(unittest_pb2.TestService):
def Foo(self, rpc_controller, request, done):
self.foo_called = True
def Bar(self, rpc_controller, request, done):
self.bar_called = True
srvc = MyServiceImpl()
rpc_controller.failure_message = None
srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback)
self.assertEqual(None, rpc_controller.failure_message)
self.assertEqual(True, srvc.foo_called)
rpc_controller.failure_message = None
srvc.CallMethod(service_descriptor.methods[1], rpc_controller,
unittest_pb2.BarRequest(), MyCallback)
self.assertEqual(None, rpc_controller.failure_message)
self.assertEqual(True, srvc.bar_called)
def testServiceStub(self):
class MockRpcChannel(service.RpcChannel):
def CallMethod(self, method, controller, request,
response_class, callback):
self.method = method
self.controller = controller
self.request = request
callback(response_class())
self.callback_response = None
def MyCallback(response):
self.callback_response = response
channel = MockRpcChannel()
stub = unittest_pb2.TestService_Stub(channel)
rpc_controller = 'controller'
request = 'request'
# GetDescriptor now static, still works as instance method for compatibility
self.assertEqual(unittest_pb2.TestService_Stub.GetDescriptor(),
stub.GetDescriptor())
# Invoke method.
stub.Foo(rpc_controller, request, MyCallback)
self.assertIsInstance(self.callback_response, unittest_pb2.FooResponse)
self.assertEqual(request, channel.request)
self.assertEqual(rpc_controller, channel.controller)
self.assertEqual(stub.GetDescriptor().methods[0], channel.method)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,133 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests for google.protobuf.symbol_database."""
import unittest
from google.protobuf import unittest_pb2
from google.protobuf import descriptor
from google.protobuf import descriptor_pool
from google.protobuf import symbol_database
class SymbolDatabaseTest(unittest.TestCase):
def _Database(self):
if descriptor._USE_C_DESCRIPTORS:
# The C++ implementation does not allow mixing descriptors from
# different pools.
db = symbol_database.SymbolDatabase(pool=descriptor_pool.Default())
else:
db = symbol_database.SymbolDatabase()
# Register representative types from unittest_pb2.
db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR)
db.RegisterMessage(unittest_pb2.TestAllTypes)
db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage)
db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup)
db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup)
db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
db.RegisterServiceDescriptor(unittest_pb2._TESTSERVICE)
return db
def testGetPrototype(self):
instance = self._Database().GetPrototype(
unittest_pb2.TestAllTypes.DESCRIPTOR)
self.assertTrue(instance is unittest_pb2.TestAllTypes)
def testGetMessages(self):
messages = self._Database().GetMessages(
['google/protobuf/unittest.proto'])
self.assertTrue(
unittest_pb2.TestAllTypes is
messages['protobuf_unittest.TestAllTypes'])
def testGetSymbol(self):
self.assertEqual(
unittest_pb2.TestAllTypes, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes'))
self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes.NestedMessage'))
self.assertEqual(
unittest_pb2.TestAllTypes.OptionalGroup, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes.OptionalGroup'))
self.assertEqual(
unittest_pb2.TestAllTypes.RepeatedGroup, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes.RepeatedGroup'))
def testEnums(self):
# Check registration of types in the pool.
self.assertEqual(
'protobuf_unittest.ForeignEnum',
self._Database().pool.FindEnumTypeByName(
'protobuf_unittest.ForeignEnum').full_name)
self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedEnum',
self._Database().pool.FindEnumTypeByName(
'protobuf_unittest.TestAllTypes.NestedEnum').full_name)
def testFindMessageTypeByName(self):
self.assertEqual(
'protobuf_unittest.TestAllTypes',
self._Database().pool.FindMessageTypeByName(
'protobuf_unittest.TestAllTypes').full_name)
self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedMessage',
self._Database().pool.FindMessageTypeByName(
'protobuf_unittest.TestAllTypes.NestedMessage').full_name)
def testFindServiceByName(self):
self.assertEqual(
'protobuf_unittest.TestService',
self._Database().pool.FindServiceByName(
'protobuf_unittest.TestService').full_name)
def testFindFileContainingSymbol(self):
# Lookup based on either enum or message.
self.assertEqual(
'google/protobuf/unittest.proto',
self._Database().pool.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes.NestedEnum').name)
self.assertEqual(
'google/protobuf/unittest.proto',
self._Database().pool.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes').name)
def testFindFileByName(self):
self.assertEqual(
'google/protobuf/unittest.proto',
self._Database().pool.FindFileByName(
'google/protobuf/unittest.proto').name)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,53 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: kenton@google.com (Kenton Varda)
syntax = "proto2";
package protobuf_unittest;
option py_generic_services = true;
message TestBadIdentifiers {
extensions 100 to max;
}
// Make sure these reasonable extension names don't conflict with internal
// variables.
extend TestBadIdentifiers {
optional string message = 100 [default = "foo"];
optional string descriptor = 101 [default = "bar"];
optional string reflection = 102 [default = "baz"];
optional string service = 103 [default = "qux"];
}
message AnotherMessage {}
service AnotherService {}

View File

@@ -0,0 +1,70 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto3";
package google.protobuf.python.internal;
message TestProto3Optional {
message NestedMessage {
// The field name "b" fails to compile in proto1 because it conflicts with
// a local variable named "b" in one of the generated methods. Doh.
// This file needs to compile in proto1 to test backwards-compatibility.
optional int32 bb = 1;
}
enum NestedEnum {
UNSPECIFIED = 0;
FOO = 1;
BAR = 2;
BAZ = 3;
NEG = -1; // Intentionally negative.
}
// Singular
optional int32 optional_int32 = 1;
optional int64 optional_int64 = 2;
optional uint32 optional_uint32 = 3;
optional uint64 optional_uint64 = 4;
optional sint32 optional_sint32 = 5;
optional sint64 optional_sint64 = 6;
optional fixed32 optional_fixed32 = 7;
optional fixed64 optional_fixed64 = 8;
optional sfixed32 optional_sfixed32 = 9;
optional sfixed64 optional_sfixed64 = 10;
optional float optional_float = 11;
optional double optional_double = 12;
optional bool optional_bool = 13;
optional string optional_string = 14;
optional bytes optional_bytes = 15;
optional NestedMessage optional_nested_message = 18;
optional NestedEnum optional_nested_enum = 21;
}

View File

@@ -0,0 +1,878 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Utilities for Python proto2 tests.
This is intentionally modeled on C++ code in
//google/protobuf/test_util.*.
"""
__author__ = 'robinson@google.com (Will Robinson)'
import numbers
import operator
import os.path
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
try:
long # Python 2
except NameError:
long = int # Python 3
# Tests whether the given TestAllTypes message is proto2 or not.
# This is used to gate several fields/features that only exist
# for the proto2 version of the message.
def IsProto2(message):
return message.DESCRIPTOR.syntax == "proto2"
def SetAllNonLazyFields(message):
"""Sets every non-lazy field in the message to a unique value.
Args:
message: A TestAllTypes instance.
"""
#
# Optional fields.
#
message.optional_int32 = 101
message.optional_int64 = 102
message.optional_uint32 = 103
message.optional_uint64 = 104
message.optional_sint32 = 105
message.optional_sint64 = 106
message.optional_fixed32 = 107
message.optional_fixed64 = 108
message.optional_sfixed32 = 109
message.optional_sfixed64 = 110
message.optional_float = 111
message.optional_double = 112
message.optional_bool = True
message.optional_string = u'115'
message.optional_bytes = b'116'
if IsProto2(message):
message.optionalgroup.a = 117
message.optional_nested_message.bb = 118
message.optional_foreign_message.c = 119
message.optional_import_message.d = 120
message.optional_public_import_message.e = 126
message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ
message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ
if IsProto2(message):
message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ
message.optional_string_piece = u'124'
message.optional_cord = u'125'
#
# Repeated fields.
#
message.repeated_int32.append(201)
message.repeated_int64.append(202)
message.repeated_uint32.append(203)
message.repeated_uint64.append(204)
message.repeated_sint32.append(205)
message.repeated_sint64.append(206)
message.repeated_fixed32.append(207)
message.repeated_fixed64.append(208)
message.repeated_sfixed32.append(209)
message.repeated_sfixed64.append(210)
message.repeated_float.append(211)
message.repeated_double.append(212)
message.repeated_bool.append(True)
message.repeated_string.append(u'215')
message.repeated_bytes.append(b'216')
if IsProto2(message):
message.repeatedgroup.add().a = 217
message.repeated_nested_message.add().bb = 218
message.repeated_foreign_message.add().c = 219
message.repeated_import_message.add().d = 220
message.repeated_lazy_message.add().bb = 227
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR)
if IsProto2(message):
message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR)
message.repeated_string_piece.append(u'224')
message.repeated_cord.append(u'225')
# Add a second one of each field and set value by index.
message.repeated_int32.append(0)
message.repeated_int64.append(0)
message.repeated_uint32.append(0)
message.repeated_uint64.append(0)
message.repeated_sint32.append(0)
message.repeated_sint64.append(0)
message.repeated_fixed32.append(0)
message.repeated_fixed64.append(0)
message.repeated_sfixed32.append(0)
message.repeated_sfixed64.append(0)
message.repeated_float.append(0)
message.repeated_double.append(0)
message.repeated_bool.append(True)
message.repeated_string.append(u'0')
message.repeated_bytes.append(b'0')
message.repeated_int32[1] = 301
message.repeated_int64[1] = 302
message.repeated_uint32[1] = 303
message.repeated_uint64[1] = 304
message.repeated_sint32[1] = 305
message.repeated_sint64[1] = 306
message.repeated_fixed32[1] = 307
message.repeated_fixed64[1] = 308
message.repeated_sfixed32[1] = 309
message.repeated_sfixed64[1] = 310
message.repeated_float[1] = 311
message.repeated_double[1] = 312
message.repeated_bool[1] = False
message.repeated_string[1] = u'315'
message.repeated_bytes[1] = b'316'
if IsProto2(message):
message.repeatedgroup.add().a = 317
message.repeated_nested_message.add().bb = 318
message.repeated_foreign_message.add().c = 319
message.repeated_import_message.add().d = 320
message.repeated_lazy_message.add().bb = 327
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
message.repeated_nested_enum[1] = unittest_pb2.TestAllTypes.BAZ
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
if IsProto2(message):
message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
message.repeated_string_piece.append(u'324')
message.repeated_cord.append(u'325')
#
# Fields that have defaults.
#
if IsProto2(message):
message.default_int32 = 401
message.default_int64 = 402
message.default_uint32 = 403
message.default_uint64 = 404
message.default_sint32 = 405
message.default_sint64 = 406
message.default_fixed32 = 407
message.default_fixed64 = 408
message.default_sfixed32 = 409
message.default_sfixed64 = 410
message.default_float = 411
message.default_double = 412
message.default_bool = False
message.default_string = '415'
message.default_bytes = b'416'
message.default_nested_enum = unittest_pb2.TestAllTypes.FOO
message.default_foreign_enum = unittest_pb2.FOREIGN_FOO
message.default_import_enum = unittest_import_pb2.IMPORT_FOO
message.default_string_piece = '424'
message.default_cord = '425'
message.oneof_uint32 = 601
message.oneof_nested_message.bb = 602
message.oneof_string = '603'
message.oneof_bytes = b'604'
def SetAllFields(message):
SetAllNonLazyFields(message)
message.optional_lazy_message.bb = 127
message.optional_unverified_lazy_message.bb = 128
def SetAllExtensions(message):
"""Sets every extension in the message to a unique value.
Args:
message: A unittest_pb2.TestAllExtensions instance.
"""
extensions = message.Extensions
pb2 = unittest_pb2
import_pb2 = unittest_import_pb2
#
# Optional fields.
#
extensions[pb2.optional_int32_extension] = 101
extensions[pb2.optional_int64_extension] = 102
extensions[pb2.optional_uint32_extension] = 103
extensions[pb2.optional_uint64_extension] = 104
extensions[pb2.optional_sint32_extension] = 105
extensions[pb2.optional_sint64_extension] = 106
extensions[pb2.optional_fixed32_extension] = 107
extensions[pb2.optional_fixed64_extension] = 108
extensions[pb2.optional_sfixed32_extension] = 109
extensions[pb2.optional_sfixed64_extension] = 110
extensions[pb2.optional_float_extension] = 111
extensions[pb2.optional_double_extension] = 112
extensions[pb2.optional_bool_extension] = True
extensions[pb2.optional_string_extension] = u'115'
extensions[pb2.optional_bytes_extension] = b'116'
extensions[pb2.optionalgroup_extension].a = 117
extensions[pb2.optional_nested_message_extension].bb = 118
extensions[pb2.optional_foreign_message_extension].c = 119
extensions[pb2.optional_import_message_extension].d = 120
extensions[pb2.optional_public_import_message_extension].e = 126
extensions[pb2.optional_lazy_message_extension].bb = 127
extensions[pb2.optional_unverified_lazy_message_extension].bb = 128
extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ
extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ
extensions[pb2.optional_foreign_enum_extension] = pb2.FOREIGN_BAZ
extensions[pb2.optional_import_enum_extension] = import_pb2.IMPORT_BAZ
extensions[pb2.optional_string_piece_extension] = u'124'
extensions[pb2.optional_cord_extension] = u'125'
#
# Repeated fields.
#
extensions[pb2.repeated_int32_extension].append(201)
extensions[pb2.repeated_int64_extension].append(202)
extensions[pb2.repeated_uint32_extension].append(203)
extensions[pb2.repeated_uint64_extension].append(204)
extensions[pb2.repeated_sint32_extension].append(205)
extensions[pb2.repeated_sint64_extension].append(206)
extensions[pb2.repeated_fixed32_extension].append(207)
extensions[pb2.repeated_fixed64_extension].append(208)
extensions[pb2.repeated_sfixed32_extension].append(209)
extensions[pb2.repeated_sfixed64_extension].append(210)
extensions[pb2.repeated_float_extension].append(211)
extensions[pb2.repeated_double_extension].append(212)
extensions[pb2.repeated_bool_extension].append(True)
extensions[pb2.repeated_string_extension].append(u'215')
extensions[pb2.repeated_bytes_extension].append(b'216')
extensions[pb2.repeatedgroup_extension].add().a = 217
extensions[pb2.repeated_nested_message_extension].add().bb = 218
extensions[pb2.repeated_foreign_message_extension].add().c = 219
extensions[pb2.repeated_import_message_extension].add().d = 220
extensions[pb2.repeated_lazy_message_extension].add().bb = 227
extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAR)
extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR)
extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAR)
extensions[pb2.repeated_string_piece_extension].append(u'224')
extensions[pb2.repeated_cord_extension].append(u'225')
# Append a second one of each field.
extensions[pb2.repeated_int32_extension].append(301)
extensions[pb2.repeated_int64_extension].append(302)
extensions[pb2.repeated_uint32_extension].append(303)
extensions[pb2.repeated_uint64_extension].append(304)
extensions[pb2.repeated_sint32_extension].append(305)
extensions[pb2.repeated_sint64_extension].append(306)
extensions[pb2.repeated_fixed32_extension].append(307)
extensions[pb2.repeated_fixed64_extension].append(308)
extensions[pb2.repeated_sfixed32_extension].append(309)
extensions[pb2.repeated_sfixed64_extension].append(310)
extensions[pb2.repeated_float_extension].append(311)
extensions[pb2.repeated_double_extension].append(312)
extensions[pb2.repeated_bool_extension].append(False)
extensions[pb2.repeated_string_extension].append(u'315')
extensions[pb2.repeated_bytes_extension].append(b'316')
extensions[pb2.repeatedgroup_extension].add().a = 317
extensions[pb2.repeated_nested_message_extension].add().bb = 318
extensions[pb2.repeated_foreign_message_extension].add().c = 319
extensions[pb2.repeated_import_message_extension].add().d = 320
extensions[pb2.repeated_lazy_message_extension].add().bb = 327
extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAZ)
extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ)
extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAZ)
extensions[pb2.repeated_string_piece_extension].append(u'324')
extensions[pb2.repeated_cord_extension].append(u'325')
#
# Fields with defaults.
#
extensions[pb2.default_int32_extension] = 401
extensions[pb2.default_int64_extension] = 402
extensions[pb2.default_uint32_extension] = 403
extensions[pb2.default_uint64_extension] = 404
extensions[pb2.default_sint32_extension] = 405
extensions[pb2.default_sint64_extension] = 406
extensions[pb2.default_fixed32_extension] = 407
extensions[pb2.default_fixed64_extension] = 408
extensions[pb2.default_sfixed32_extension] = 409
extensions[pb2.default_sfixed64_extension] = 410
extensions[pb2.default_float_extension] = 411
extensions[pb2.default_double_extension] = 412
extensions[pb2.default_bool_extension] = False
extensions[pb2.default_string_extension] = u'415'
extensions[pb2.default_bytes_extension] = b'416'
extensions[pb2.default_nested_enum_extension] = pb2.TestAllTypes.FOO
extensions[pb2.default_foreign_enum_extension] = pb2.FOREIGN_FOO
extensions[pb2.default_import_enum_extension] = import_pb2.IMPORT_FOO
extensions[pb2.default_string_piece_extension] = u'424'
extensions[pb2.default_cord_extension] = '425'
extensions[pb2.oneof_uint32_extension] = 601
extensions[pb2.oneof_nested_message_extension].bb = 602
extensions[pb2.oneof_string_extension] = u'603'
extensions[pb2.oneof_bytes_extension] = b'604'
def SetAllFieldsAndExtensions(message):
"""Sets every field and extension in the message to a unique value.
Args:
message: A unittest_pb2.TestAllExtensions message.
"""
message.my_int = 1
message.my_string = 'foo'
message.my_float = 1.0
message.Extensions[unittest_pb2.my_extension_int] = 23
message.Extensions[unittest_pb2.my_extension_string] = 'bar'
def ExpectAllFieldsAndExtensionsInOrder(serialized):
"""Ensures that serialized is the serialization we expect for a message
filled with SetAllFieldsAndExtensions(). (Specifically, ensures that the
serialization is in canonical, tag-number order).
"""
my_extension_int = unittest_pb2.my_extension_int
my_extension_string = unittest_pb2.my_extension_string
expected_strings = []
message = unittest_pb2.TestFieldOrderings()
message.my_int = 1 # Field 1.
expected_strings.append(message.SerializeToString())
message.Clear()
message.Extensions[my_extension_int] = 23 # Field 5.
expected_strings.append(message.SerializeToString())
message.Clear()
message.my_string = 'foo' # Field 11.
expected_strings.append(message.SerializeToString())
message.Clear()
message.Extensions[my_extension_string] = 'bar' # Field 50.
expected_strings.append(message.SerializeToString())
message.Clear()
message.my_float = 1.0
expected_strings.append(message.SerializeToString())
message.Clear()
expected = b''.join(expected_strings)
if expected != serialized:
raise ValueError('Expected %r, found %r' % (expected, serialized))
def ExpectAllFieldsSet(test_case, message):
"""Check all fields for correct values have after Set*Fields() is called."""
test_case.assertTrue(message.HasField('optional_int32'))
test_case.assertTrue(message.HasField('optional_int64'))
test_case.assertTrue(message.HasField('optional_uint32'))
test_case.assertTrue(message.HasField('optional_uint64'))
test_case.assertTrue(message.HasField('optional_sint32'))
test_case.assertTrue(message.HasField('optional_sint64'))
test_case.assertTrue(message.HasField('optional_fixed32'))
test_case.assertTrue(message.HasField('optional_fixed64'))
test_case.assertTrue(message.HasField('optional_sfixed32'))
test_case.assertTrue(message.HasField('optional_sfixed64'))
test_case.assertTrue(message.HasField('optional_float'))
test_case.assertTrue(message.HasField('optional_double'))
test_case.assertTrue(message.HasField('optional_bool'))
test_case.assertTrue(message.HasField('optional_string'))
test_case.assertTrue(message.HasField('optional_bytes'))
if IsProto2(message):
test_case.assertTrue(message.HasField('optionalgroup'))
test_case.assertTrue(message.HasField('optional_nested_message'))
test_case.assertTrue(message.HasField('optional_foreign_message'))
test_case.assertTrue(message.HasField('optional_import_message'))
test_case.assertTrue(message.optionalgroup.HasField('a'))
test_case.assertTrue(message.optional_nested_message.HasField('bb'))
test_case.assertTrue(message.optional_foreign_message.HasField('c'))
test_case.assertTrue(message.optional_import_message.HasField('d'))
test_case.assertTrue(message.HasField('optional_nested_enum'))
test_case.assertTrue(message.HasField('optional_foreign_enum'))
if IsProto2(message):
test_case.assertTrue(message.HasField('optional_import_enum'))
test_case.assertTrue(message.HasField('optional_string_piece'))
test_case.assertTrue(message.HasField('optional_cord'))
test_case.assertEqual(101, message.optional_int32)
test_case.assertEqual(102, message.optional_int64)
test_case.assertEqual(103, message.optional_uint32)
test_case.assertEqual(104, message.optional_uint64)
test_case.assertEqual(105, message.optional_sint32)
test_case.assertEqual(106, message.optional_sint64)
test_case.assertEqual(107, message.optional_fixed32)
test_case.assertEqual(108, message.optional_fixed64)
test_case.assertEqual(109, message.optional_sfixed32)
test_case.assertEqual(110, message.optional_sfixed64)
test_case.assertEqual(111, message.optional_float)
test_case.assertEqual(112, message.optional_double)
test_case.assertEqual(True, message.optional_bool)
test_case.assertEqual('115', message.optional_string)
test_case.assertEqual(b'116', message.optional_bytes)
if IsProto2(message):
test_case.assertEqual(117, message.optionalgroup.a)
test_case.assertEqual(118, message.optional_nested_message.bb)
test_case.assertEqual(119, message.optional_foreign_message.c)
test_case.assertEqual(120, message.optional_import_message.d)
test_case.assertEqual(126, message.optional_public_import_message.e)
test_case.assertEqual(127, message.optional_lazy_message.bb)
test_case.assertEqual(128, message.optional_unverified_lazy_message.bb)
test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
message.optional_nested_enum)
test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
message.optional_foreign_enum)
if IsProto2(message):
test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
message.optional_import_enum)
# -----------------------------------------------------------------
test_case.assertEqual(2, len(message.repeated_int32))
test_case.assertEqual(2, len(message.repeated_int64))
test_case.assertEqual(2, len(message.repeated_uint32))
test_case.assertEqual(2, len(message.repeated_uint64))
test_case.assertEqual(2, len(message.repeated_sint32))
test_case.assertEqual(2, len(message.repeated_sint64))
test_case.assertEqual(2, len(message.repeated_fixed32))
test_case.assertEqual(2, len(message.repeated_fixed64))
test_case.assertEqual(2, len(message.repeated_sfixed32))
test_case.assertEqual(2, len(message.repeated_sfixed64))
test_case.assertEqual(2, len(message.repeated_float))
test_case.assertEqual(2, len(message.repeated_double))
test_case.assertEqual(2, len(message.repeated_bool))
test_case.assertEqual(2, len(message.repeated_string))
test_case.assertEqual(2, len(message.repeated_bytes))
if IsProto2(message):
test_case.assertEqual(2, len(message.repeatedgroup))
test_case.assertEqual(2, len(message.repeated_nested_message))
test_case.assertEqual(2, len(message.repeated_foreign_message))
test_case.assertEqual(2, len(message.repeated_import_message))
test_case.assertEqual(2, len(message.repeated_nested_enum))
test_case.assertEqual(2, len(message.repeated_foreign_enum))
if IsProto2(message):
test_case.assertEqual(2, len(message.repeated_import_enum))
test_case.assertEqual(2, len(message.repeated_string_piece))
test_case.assertEqual(2, len(message.repeated_cord))
test_case.assertEqual(201, message.repeated_int32[0])
test_case.assertEqual(202, message.repeated_int64[0])
test_case.assertEqual(203, message.repeated_uint32[0])
test_case.assertEqual(204, message.repeated_uint64[0])
test_case.assertEqual(205, message.repeated_sint32[0])
test_case.assertEqual(206, message.repeated_sint64[0])
test_case.assertEqual(207, message.repeated_fixed32[0])
test_case.assertEqual(208, message.repeated_fixed64[0])
test_case.assertEqual(209, message.repeated_sfixed32[0])
test_case.assertEqual(210, message.repeated_sfixed64[0])
test_case.assertEqual(211, message.repeated_float[0])
test_case.assertEqual(212, message.repeated_double[0])
test_case.assertEqual(True, message.repeated_bool[0])
test_case.assertEqual('215', message.repeated_string[0])
test_case.assertEqual(b'216', message.repeated_bytes[0])
if IsProto2(message):
test_case.assertEqual(217, message.repeatedgroup[0].a)
test_case.assertEqual(218, message.repeated_nested_message[0].bb)
test_case.assertEqual(219, message.repeated_foreign_message[0].c)
test_case.assertEqual(220, message.repeated_import_message[0].d)
test_case.assertEqual(227, message.repeated_lazy_message[0].bb)
test_case.assertEqual(unittest_pb2.TestAllTypes.BAR,
message.repeated_nested_enum[0])
test_case.assertEqual(unittest_pb2.FOREIGN_BAR,
message.repeated_foreign_enum[0])
if IsProto2(message):
test_case.assertEqual(unittest_import_pb2.IMPORT_BAR,
message.repeated_import_enum[0])
test_case.assertEqual(301, message.repeated_int32[1])
test_case.assertEqual(302, message.repeated_int64[1])
test_case.assertEqual(303, message.repeated_uint32[1])
test_case.assertEqual(304, message.repeated_uint64[1])
test_case.assertEqual(305, message.repeated_sint32[1])
test_case.assertEqual(306, message.repeated_sint64[1])
test_case.assertEqual(307, message.repeated_fixed32[1])
test_case.assertEqual(308, message.repeated_fixed64[1])
test_case.assertEqual(309, message.repeated_sfixed32[1])
test_case.assertEqual(310, message.repeated_sfixed64[1])
test_case.assertEqual(311, message.repeated_float[1])
test_case.assertEqual(312, message.repeated_double[1])
test_case.assertEqual(False, message.repeated_bool[1])
test_case.assertEqual('315', message.repeated_string[1])
test_case.assertEqual(b'316', message.repeated_bytes[1])
if IsProto2(message):
test_case.assertEqual(317, message.repeatedgroup[1].a)
test_case.assertEqual(318, message.repeated_nested_message[1].bb)
test_case.assertEqual(319, message.repeated_foreign_message[1].c)
test_case.assertEqual(320, message.repeated_import_message[1].d)
test_case.assertEqual(327, message.repeated_lazy_message[1].bb)
test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
message.repeated_nested_enum[1])
test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
message.repeated_foreign_enum[1])
if IsProto2(message):
test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
message.repeated_import_enum[1])
# -----------------------------------------------------------------
if IsProto2(message):
test_case.assertTrue(message.HasField('default_int32'))
test_case.assertTrue(message.HasField('default_int64'))
test_case.assertTrue(message.HasField('default_uint32'))
test_case.assertTrue(message.HasField('default_uint64'))
test_case.assertTrue(message.HasField('default_sint32'))
test_case.assertTrue(message.HasField('default_sint64'))
test_case.assertTrue(message.HasField('default_fixed32'))
test_case.assertTrue(message.HasField('default_fixed64'))
test_case.assertTrue(message.HasField('default_sfixed32'))
test_case.assertTrue(message.HasField('default_sfixed64'))
test_case.assertTrue(message.HasField('default_float'))
test_case.assertTrue(message.HasField('default_double'))
test_case.assertTrue(message.HasField('default_bool'))
test_case.assertTrue(message.HasField('default_string'))
test_case.assertTrue(message.HasField('default_bytes'))
test_case.assertTrue(message.HasField('default_nested_enum'))
test_case.assertTrue(message.HasField('default_foreign_enum'))
test_case.assertTrue(message.HasField('default_import_enum'))
test_case.assertEqual(401, message.default_int32)
test_case.assertEqual(402, message.default_int64)
test_case.assertEqual(403, message.default_uint32)
test_case.assertEqual(404, message.default_uint64)
test_case.assertEqual(405, message.default_sint32)
test_case.assertEqual(406, message.default_sint64)
test_case.assertEqual(407, message.default_fixed32)
test_case.assertEqual(408, message.default_fixed64)
test_case.assertEqual(409, message.default_sfixed32)
test_case.assertEqual(410, message.default_sfixed64)
test_case.assertEqual(411, message.default_float)
test_case.assertEqual(412, message.default_double)
test_case.assertEqual(False, message.default_bool)
test_case.assertEqual('415', message.default_string)
test_case.assertEqual(b'416', message.default_bytes)
test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
message.default_nested_enum)
test_case.assertEqual(unittest_pb2.FOREIGN_FOO,
message.default_foreign_enum)
test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
message.default_import_enum)
def GoldenFile(filename):
"""Finds the given golden file and returns a file object representing it."""
# Search up the directory tree looking for the C++ protobuf source code.
path = '.'
while os.path.exists(path):
if os.path.exists(os.path.join(path, 'src/google/protobuf')):
# Found it. Load the golden file from the testdata directory.
full_path = os.path.join(path, 'src/google/protobuf/testdata', filename)
return open(full_path, 'rb')
path = os.path.join(path, '..')
# Search internally.
path = '.'
full_path = os.path.join(path, 'third_party/py/google/protobuf/testdata',
filename)
if os.path.exists(full_path):
# Found it. Load the golden file from the testdata directory.
return open(full_path, 'rb')
# Search for cross-repo path.
full_path = os.path.join('external/com_google_protobuf/src/google/protobuf/testdata',
filename)
if os.path.exists(full_path):
# Found it. Load the golden file from the testdata directory.
return open(full_path, 'rb')
raise RuntimeError(
'Could not find golden files. This test must be run from within the '
'protobuf source package so that it can read test data files from the '
'C++ source tree.')
def GoldenFileData(filename):
"""Finds the given golden file and returns its contents."""
with GoldenFile(filename) as f:
return f.read()
def SetAllPackedFields(message):
"""Sets every field in the message to a unique value.
Args:
message: A TestPackedTypes instance.
"""
message.packed_int32.extend([601, 701])
message.packed_int64.extend([602, 702])
message.packed_uint32.extend([603, 703])
message.packed_uint64.extend([604, 704])
message.packed_sint32.extend([605, 705])
message.packed_sint64.extend([606, 706])
message.packed_fixed32.extend([607, 707])
message.packed_fixed64.extend([608, 708])
message.packed_sfixed32.extend([609, 709])
message.packed_sfixed64.extend([610, 710])
message.packed_float.extend([611.0, 711.0])
message.packed_double.extend([612.0, 712.0])
message.packed_bool.extend([True, False])
message.packed_enum.extend([unittest_pb2.FOREIGN_BAR,
unittest_pb2.FOREIGN_BAZ])
def SetAllPackedExtensions(message):
"""Sets every extension in the message to a unique value.
Args:
message: A unittest_pb2.TestPackedExtensions instance.
"""
extensions = message.Extensions
pb2 = unittest_pb2
extensions[pb2.packed_int32_extension].extend([601, 701])
extensions[pb2.packed_int64_extension].extend([602, 702])
extensions[pb2.packed_uint32_extension].extend([603, 703])
extensions[pb2.packed_uint64_extension].extend([604, 704])
extensions[pb2.packed_sint32_extension].extend([605, 705])
extensions[pb2.packed_sint64_extension].extend([606, 706])
extensions[pb2.packed_fixed32_extension].extend([607, 707])
extensions[pb2.packed_fixed64_extension].extend([608, 708])
extensions[pb2.packed_sfixed32_extension].extend([609, 709])
extensions[pb2.packed_sfixed64_extension].extend([610, 710])
extensions[pb2.packed_float_extension].extend([611.0, 711.0])
extensions[pb2.packed_double_extension].extend([612.0, 712.0])
extensions[pb2.packed_bool_extension].extend([True, False])
extensions[pb2.packed_enum_extension].extend([unittest_pb2.FOREIGN_BAR,
unittest_pb2.FOREIGN_BAZ])
def SetAllUnpackedFields(message):
"""Sets every field in the message to a unique value.
Args:
message: A unittest_pb2.TestUnpackedTypes instance.
"""
message.unpacked_int32.extend([601, 701])
message.unpacked_int64.extend([602, 702])
message.unpacked_uint32.extend([603, 703])
message.unpacked_uint64.extend([604, 704])
message.unpacked_sint32.extend([605, 705])
message.unpacked_sint64.extend([606, 706])
message.unpacked_fixed32.extend([607, 707])
message.unpacked_fixed64.extend([608, 708])
message.unpacked_sfixed32.extend([609, 709])
message.unpacked_sfixed64.extend([610, 710])
message.unpacked_float.extend([611.0, 711.0])
message.unpacked_double.extend([612.0, 712.0])
message.unpacked_bool.extend([True, False])
message.unpacked_enum.extend([unittest_pb2.FOREIGN_BAR,
unittest_pb2.FOREIGN_BAZ])
class NonStandardInteger(numbers.Integral):
"""An integer object that does not subclass int.
This is used to verify that both C++ and regular proto systems can handle
integer others than int and long and that they handle them in predictable
ways.
NonStandardInteger is the minimal legal specification for a custom Integral.
As such, it does not support 0 < x < 5 and it is not hashable.
Note: This is added here instead of relying on numpy or a similar library
with custom integers to limit dependencies.
"""
def __init__(self, val, error_string_on_conversion=None):
assert isinstance(val, numbers.Integral)
if isinstance(val, NonStandardInteger):
val = val.val
self.val = val
self.error_string_on_conversion = error_string_on_conversion
def __long__(self):
if self.error_string_on_conversion:
raise RuntimeError(self.error_string_on_conversion)
return long(self.val)
def __abs__(self):
return NonStandardInteger(operator.abs(self.val))
def __add__(self, y):
return NonStandardInteger(operator.add(self.val, y))
def __div__(self, y):
return NonStandardInteger(operator.div(self.val, y))
def __eq__(self, y):
return operator.eq(self.val, y)
def __floordiv__(self, y):
return NonStandardInteger(operator.floordiv(self.val, y))
def __truediv__(self, y):
return NonStandardInteger(operator.truediv(self.val, y))
def __invert__(self):
return NonStandardInteger(operator.invert(self.val))
def __mod__(self, y):
return NonStandardInteger(operator.mod(self.val, y))
def __mul__(self, y):
return NonStandardInteger(operator.mul(self.val, y))
def __neg__(self):
return NonStandardInteger(operator.neg(self.val))
def __pos__(self):
return NonStandardInteger(operator.pos(self.val))
def __pow__(self, y):
return NonStandardInteger(operator.pow(self.val, y))
def __trunc__(self):
return int(self.val)
def __radd__(self, y):
return NonStandardInteger(operator.add(y, self.val))
def __rdiv__(self, y):
return NonStandardInteger(operator.div(y, self.val))
def __rmod__(self, y):
return NonStandardInteger(operator.mod(y, self.val))
def __rmul__(self, y):
return NonStandardInteger(operator.mul(y, self.val))
def __rpow__(self, y):
return NonStandardInteger(operator.pow(y, self.val))
def __rfloordiv__(self, y):
return NonStandardInteger(operator.floordiv(y, self.val))
def __rtruediv__(self, y):
return NonStandardInteger(operator.truediv(y, self.val))
def __lshift__(self, y):
return NonStandardInteger(operator.lshift(self.val, y))
def __rshift__(self, y):
return NonStandardInteger(operator.rshift(self.val, y))
def __rlshift__(self, y):
return NonStandardInteger(operator.lshift(y, self.val))
def __rrshift__(self, y):
return NonStandardInteger(operator.rshift(y, self.val))
def __le__(self, y):
if isinstance(y, NonStandardInteger):
y = y.val
return operator.le(self.val, y)
def __lt__(self, y):
if isinstance(y, NonStandardInteger):
y = y.val
return operator.lt(self.val, y)
def __and__(self, y):
return NonStandardInteger(operator.and_(self.val, y))
def __or__(self, y):
return NonStandardInteger(operator.or_(self.val, y))
def __xor__(self, y):
return NonStandardInteger(operator.xor(self.val, y))
def __rand__(self, y):
return NonStandardInteger(operator.and_(y, self.val))
def __ror__(self, y):
return NonStandardInteger(operator.or_(y, self.val))
def __rxor__(self, y):
return NonStandardInteger(operator.xor(y, self.val))
def __bool__(self):
return self.val
def __nonzero__(self):
return self.val
def __ceil__(self):
return self
def __floor__(self):
return self
def __int__(self):
if self.error_string_on_conversion:
raise RuntimeError(self.error_string_on_conversion)
return int(self.val)
def __round__(self):
return self
def __repr__(self):
return 'NonStandardInteger(%s)' % self.val

View File

@@ -0,0 +1,142 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""A subclass of unittest.TestCase which checks for reference leaks.
To use:
- Use testing_refleak.BaseTestCase instead of unittest.TestCase
- Configure and compile Python with --with-pydebug
If sys.gettotalrefcount() is not available (because Python was built without
the Py_DEBUG option), then this module is a no-op and tests will run normally.
"""
import copyreg
import gc
import sys
import unittest
class LocalTestResult(unittest.TestResult):
"""A TestResult which forwards events to a parent object, except for Skips."""
def __init__(self, parent_result):
unittest.TestResult.__init__(self)
self.parent_result = parent_result
def addError(self, test, error):
self.parent_result.addError(test, error)
def addFailure(self, test, error):
self.parent_result.addFailure(test, error)
def addSkip(self, test, reason):
pass
class ReferenceLeakCheckerMixin(object):
"""A mixin class for TestCase, which checks reference counts."""
NB_RUNS = 3
def run(self, result=None):
testMethod = getattr(self, self._testMethodName)
expecting_failure_method = getattr(testMethod, "__unittest_expecting_failure__", False)
expecting_failure_class = getattr(self, "__unittest_expecting_failure__", False)
if expecting_failure_class or expecting_failure_method:
return
# python_message.py registers all Message classes to some pickle global
# registry, which makes the classes immortal.
# We save a copy of this registry, and reset it before we could references.
self._saved_pickle_registry = copyreg.dispatch_table.copy()
# Run the test twice, to warm up the instance attributes.
super(ReferenceLeakCheckerMixin, self).run(result=result)
super(ReferenceLeakCheckerMixin, self).run(result=result)
oldrefcount = 0
local_result = LocalTestResult(result)
num_flakes = 0
refcount_deltas = []
while len(refcount_deltas) < self.NB_RUNS:
oldrefcount = self._getRefcounts()
super(ReferenceLeakCheckerMixin, self).run(result=local_result)
newrefcount = self._getRefcounts()
# If the GC was able to collect some objects after the call to run() that
# it could not collect before the call, then the counts won't match.
if newrefcount < oldrefcount and num_flakes < 2:
# This result is (probably) a flake -- garbage collectors aren't very
# predictable, but a lower ending refcount is the opposite of the
# failure we are testing for. If the result is repeatable, then we will
# eventually report it, but not after trying to eliminate it.
num_flakes += 1
continue
num_flakes = 0
refcount_deltas.append(newrefcount - oldrefcount)
print(refcount_deltas, self)
try:
self.assertEqual(refcount_deltas, [0] * self.NB_RUNS)
except Exception: # pylint: disable=broad-except
result.addError(self, sys.exc_info())
def _getRefcounts(self):
copyreg.dispatch_table.clear()
copyreg.dispatch_table.update(self._saved_pickle_registry)
# It is sometimes necessary to gc.collect() multiple times, to ensure
# that all objects can be collected.
gc.collect()
gc.collect()
gc.collect()
return sys.gettotalrefcount()
if hasattr(sys, 'gettotalrefcount'):
def TestCase(test_class):
new_bases = (ReferenceLeakCheckerMixin,) + test_class.__bases__
new_class = type(test_class)(
test_class.__name__, new_bases, dict(test_class.__dict__))
return new_class
SkipReferenceLeakChecker = unittest.skip
else:
# When PyDEBUG is not enabled, run the tests normally.
def TestCase(test_class):
return test_class
def SkipReferenceLeakChecker(reason):
del reason # Don't skip, so don't need a reason.
def Same(func):
return func
return Same

View File

@@ -0,0 +1,67 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests for google.protobuf.text_encoding."""
import unittest
from google.protobuf import text_encoding
TEST_VALUES = [
("foo\\rbar\\nbaz\\t",
"foo\\rbar\\nbaz\\t",
b"foo\rbar\nbaz\t"),
("\\'full of \\\"sound\\\" and \\\"fury\\\"\\'",
"\\'full of \\\"sound\\\" and \\\"fury\\\"\\'",
b"'full of \"sound\" and \"fury\"'"),
("signi\\\\fying\\\\ nothing\\\\",
"signi\\\\fying\\\\ nothing\\\\",
b"signi\\fying\\ nothing\\"),
("\\010\\t\\n\\013\\014\\r",
"\x08\\t\\n\x0b\x0c\\r",
b"\010\011\012\013\014\015")]
class TextEncodingTestCase(unittest.TestCase):
def testCEscape(self):
for escaped, escaped_utf8, unescaped in TEST_VALUES:
self.assertEqual(escaped,
text_encoding.CEscape(unescaped, as_utf8=False))
self.assertEqual(escaped_utf8,
text_encoding.CEscape(unescaped, as_utf8=True))
def testCUnescape(self):
for escaped, escaped_utf8, unescaped in TEST_VALUES:
self.assertEqual(unescaped, text_encoding.CUnescape(escaped))
self.assertEqual(unescaped, text_encoding.CUnescape(escaped_utf8))
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,435 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Provides type checking routines.
This module defines type checking utilities in the forms of dictionaries:
VALUE_CHECKERS: A dictionary of field types and a value validation object.
TYPE_TO_BYTE_SIZE_FN: A dictionary with field types and a size computing
function.
TYPE_TO_SERIALIZE_METHOD: A dictionary with field types and serialization
function.
FIELD_TYPE_TO_WIRE_TYPE: A dictionary with field typed and their
corresponding wire types.
TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization
function.
"""
__author__ = 'robinson@google.com (Will Robinson)'
import ctypes
import numbers
from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf import descriptor
_FieldDescriptor = descriptor.FieldDescriptor
def TruncateToFourByteFloat(original):
return ctypes.c_float(original).value
def ToShortestFloat(original):
"""Returns the shortest float that has same value in wire."""
# All 4 byte floats have between 6 and 9 significant digits, so we
# start with 6 as the lower bound.
# It has to be iterative because use '.9g' directly can not get rid
# of the noises for most values. For example if set a float_field=0.9
# use '.9g' will print 0.899999976.
precision = 6
rounded = float('{0:.{1}g}'.format(original, precision))
while TruncateToFourByteFloat(rounded) != original:
precision += 1
rounded = float('{0:.{1}g}'.format(original, precision))
return rounded
def SupportsOpenEnums(field_descriptor):
return field_descriptor.containing_type.syntax == 'proto3'
def GetTypeChecker(field):
"""Returns a type checker for a message field of the specified types.
Args:
field: FieldDescriptor object for this field.
Returns:
An instance of TypeChecker which can be used to verify the types
of values assigned to a field of the specified type.
"""
if (field.cpp_type == _FieldDescriptor.CPPTYPE_STRING and
field.type == _FieldDescriptor.TYPE_STRING):
return UnicodeValueChecker()
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
if SupportsOpenEnums(field):
# When open enums are supported, any int32 can be assigned.
return _VALUE_CHECKERS[_FieldDescriptor.CPPTYPE_INT32]
else:
return EnumValueChecker(field.enum_type)
return _VALUE_CHECKERS[field.cpp_type]
# None of the typecheckers below make any attempt to guard against people
# subclassing builtin types and doing weird things. We're not trying to
# protect against malicious clients here, just people accidentally shooting
# themselves in the foot in obvious ways.
class TypeChecker(object):
"""Type checker used to catch type errors as early as possible
when the client is setting scalar fields in protocol messages.
"""
def __init__(self, *acceptable_types):
self._acceptable_types = acceptable_types
def CheckValue(self, proposed_value):
"""Type check the provided value and return it.
The returned value might have been normalized to another type.
"""
if not isinstance(proposed_value, self._acceptable_types):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), self._acceptable_types))
raise TypeError(message)
return proposed_value
class TypeCheckerWithDefault(TypeChecker):
def __init__(self, default_value, *acceptable_types):
TypeChecker.__init__(self, *acceptable_types)
self._default_value = default_value
def DefaultValue(self):
return self._default_value
class BoolValueChecker(object):
"""Type checker used for bool fields."""
def CheckValue(self, proposed_value):
if not hasattr(proposed_value, '__index__') or (
type(proposed_value).__module__ == 'numpy' and
type(proposed_value).__name__ == 'ndarray'):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), (bool, int)))
raise TypeError(message)
return bool(proposed_value)
def DefaultValue(self):
return False
# IntValueChecker and its subclasses perform integer type-checks
# and bounds-checks.
class IntValueChecker(object):
"""Checker used for integer fields. Performs type-check and range check."""
def CheckValue(self, proposed_value):
if not hasattr(proposed_value, '__index__') or (
type(proposed_value).__module__ == 'numpy' and
type(proposed_value).__name__ == 'ndarray'):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), (int,)))
raise TypeError(message)
if not self._MIN <= int(proposed_value) <= self._MAX:
raise ValueError('Value out of range: %d' % proposed_value)
# We force all values to int to make alternate implementations where the
# distinction is more significant (e.g. the C++ implementation) simpler.
proposed_value = int(proposed_value)
return proposed_value
def DefaultValue(self):
return 0
class EnumValueChecker(object):
"""Checker used for enum fields. Performs type-check and range check."""
def __init__(self, enum_type):
self._enum_type = enum_type
def CheckValue(self, proposed_value):
if not isinstance(proposed_value, numbers.Integral):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), (int,)))
raise TypeError(message)
if int(proposed_value) not in self._enum_type.values_by_number:
raise ValueError('Unknown enum value: %d' % proposed_value)
return proposed_value
def DefaultValue(self):
return self._enum_type.values[0].number
class UnicodeValueChecker(object):
"""Checker used for string fields.
Always returns a unicode value, even if the input is of type str.
"""
def CheckValue(self, proposed_value):
if not isinstance(proposed_value, (bytes, str)):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), (bytes, str)))
raise TypeError(message)
# If the value is of type 'bytes' make sure that it is valid UTF-8 data.
if isinstance(proposed_value, bytes):
try:
proposed_value = proposed_value.decode('utf-8')
except UnicodeDecodeError:
raise ValueError('%.1024r has type bytes, but isn\'t valid UTF-8 '
'encoding. Non-UTF-8 strings must be converted to '
'unicode objects before being added.' %
(proposed_value))
else:
try:
proposed_value.encode('utf8')
except UnicodeEncodeError:
raise ValueError('%.1024r isn\'t a valid unicode string and '
'can\'t be encoded in UTF-8.'%
(proposed_value))
return proposed_value
def DefaultValue(self):
return u""
class Int32ValueChecker(IntValueChecker):
# We're sure to use ints instead of longs here since comparison may be more
# efficient.
_MIN = -2147483648
_MAX = 2147483647
class Uint32ValueChecker(IntValueChecker):
_MIN = 0
_MAX = (1 << 32) - 1
class Int64ValueChecker(IntValueChecker):
_MIN = -(1 << 63)
_MAX = (1 << 63) - 1
class Uint64ValueChecker(IntValueChecker):
_MIN = 0
_MAX = (1 << 64) - 1
# The max 4 bytes float is about 3.4028234663852886e+38
_FLOAT_MAX = float.fromhex('0x1.fffffep+127')
_FLOAT_MIN = -_FLOAT_MAX
_INF = float('inf')
_NEG_INF = float('-inf')
class DoubleValueChecker(object):
"""Checker used for double fields.
Performs type-check and range check.
"""
def CheckValue(self, proposed_value):
"""Check and convert proposed_value to float."""
if (not hasattr(proposed_value, '__float__') and
not hasattr(proposed_value, '__index__')) or (
type(proposed_value).__module__ == 'numpy' and
type(proposed_value).__name__ == 'ndarray'):
message = ('%.1024r has type %s, but expected one of: int, float' %
(proposed_value, type(proposed_value)))
raise TypeError(message)
return float(proposed_value)
def DefaultValue(self):
return 0.0
class FloatValueChecker(DoubleValueChecker):
"""Checker used for float fields.
Performs type-check and range check.
Values exceeding a 32-bit float will be converted to inf/-inf.
"""
def CheckValue(self, proposed_value):
"""Check and convert proposed_value to float."""
converted_value = super().CheckValue(proposed_value)
# This inf rounding matches the C++ proto SafeDoubleToFloat logic.
if converted_value > _FLOAT_MAX:
return _INF
if converted_value < _FLOAT_MIN:
return _NEG_INF
return TruncateToFourByteFloat(converted_value)
# Type-checkers for all scalar CPPTYPEs.
_VALUE_CHECKERS = {
_FieldDescriptor.CPPTYPE_INT32: Int32ValueChecker(),
_FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(),
_FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(),
_FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(),
_FieldDescriptor.CPPTYPE_DOUBLE: DoubleValueChecker(),
_FieldDescriptor.CPPTYPE_FLOAT: FloatValueChecker(),
_FieldDescriptor.CPPTYPE_BOOL: BoolValueChecker(),
_FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes),
}
# Map from field type to a function F, such that F(field_num, value)
# gives the total byte size for a value of the given type. This
# byte size includes tag information and any other additional space
# associated with serializing "value".
TYPE_TO_BYTE_SIZE_FN = {
_FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize,
_FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize,
_FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize,
_FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize,
_FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize,
_FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize,
_FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize,
_FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize,
_FieldDescriptor.TYPE_STRING: wire_format.StringByteSize,
_FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize,
_FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize,
_FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize,
_FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize,
_FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize,
_FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize,
_FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize,
_FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize,
_FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize
}
# Maps from field types to encoder constructors.
TYPE_TO_ENCODER = {
_FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder,
_FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder,
_FieldDescriptor.TYPE_INT64: encoder.Int64Encoder,
_FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder,
_FieldDescriptor.TYPE_INT32: encoder.Int32Encoder,
_FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder,
_FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder,
_FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder,
_FieldDescriptor.TYPE_STRING: encoder.StringEncoder,
_FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder,
_FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder,
_FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder,
_FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder,
_FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder,
_FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder,
_FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder,
_FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder,
_FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder,
}
# Maps from field types to sizer constructors.
TYPE_TO_SIZER = {
_FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer,
_FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer,
_FieldDescriptor.TYPE_INT64: encoder.Int64Sizer,
_FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer,
_FieldDescriptor.TYPE_INT32: encoder.Int32Sizer,
_FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer,
_FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer,
_FieldDescriptor.TYPE_BOOL: encoder.BoolSizer,
_FieldDescriptor.TYPE_STRING: encoder.StringSizer,
_FieldDescriptor.TYPE_GROUP: encoder.GroupSizer,
_FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer,
_FieldDescriptor.TYPE_BYTES: encoder.BytesSizer,
_FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer,
_FieldDescriptor.TYPE_ENUM: encoder.EnumSizer,
_FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer,
_FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer,
_FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer,
_FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer,
}
# Maps from field type to a decoder constructor.
TYPE_TO_DECODER = {
_FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder,
_FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder,
_FieldDescriptor.TYPE_INT64: decoder.Int64Decoder,
_FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder,
_FieldDescriptor.TYPE_INT32: decoder.Int32Decoder,
_FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder,
_FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder,
_FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder,
_FieldDescriptor.TYPE_STRING: decoder.StringDecoder,
_FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder,
_FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder,
_FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder,
_FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder,
_FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder,
_FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder,
_FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder,
_FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder,
_FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder,
}
# Maps from field type to expected wiretype.
FIELD_TYPE_TO_WIRE_TYPE = {
_FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64,
_FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32,
_FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT,
_FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT,
_FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT,
_FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64,
_FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32,
_FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT,
_FieldDescriptor.TYPE_STRING:
wire_format.WIRETYPE_LENGTH_DELIMITED,
_FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP,
_FieldDescriptor.TYPE_MESSAGE:
wire_format.WIRETYPE_LENGTH_DELIMITED,
_FieldDescriptor.TYPE_BYTES:
wire_format.WIRETYPE_LENGTH_DELIMITED,
_FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT,
_FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT,
_FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32,
_FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64,
_FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT,
_FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT,
}

View File

@@ -0,0 +1,461 @@
# -*- coding: utf-8 -*-
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test for preservation of unknown fields in the pure Python implementation."""
__author__ = 'bohdank@google.com (Bohdan Koval)'
import sys
import unittest
from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import encoder
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import missing_enum_values_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import type_checkers
from google.protobuf.internal import wire_format
from google.protobuf import descriptor
from google.protobuf import unknown_fields
try:
import tracemalloc # pylint: disable=g-import-not-at-top
except ImportError:
# Requires python 3.4+
pass
@testing_refleaks.TestCase
class UnknownFieldsTest(unittest.TestCase):
def setUp(self):
self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
self.all_fields = unittest_pb2.TestAllTypes()
test_util.SetAllFields(self.all_fields)
self.all_fields_data = self.all_fields.SerializeToString()
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
def testSerialize(self):
data = self.empty_message.SerializeToString()
# Don't use assertEqual because we don't want to dump raw binary data to
# stdout.
self.assertTrue(data == self.all_fields_data)
def testSerializeProto3(self):
# Verify proto3 unknown fields behavior.
message = unittest_proto3_arena_pb2.TestEmptyMessage()
message.ParseFromString(self.all_fields_data)
self.assertEqual(self.all_fields_data, message.SerializeToString())
def testByteSize(self):
self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
def testListFields(self):
# Make sure ListFields doesn't return unknown fields.
self.assertEqual(0, len(self.empty_message.ListFields()))
def testSerializeMessageSetWireFormatUnknownExtension(self):
# Create a message using the message set wire format with an unknown
# message.
raw = unittest_mset_pb2.RawMessageSet()
# Add an unknown extension.
item = raw.item.add()
item.type_id = 98218603
message1 = message_set_extensions_pb2.TestMessageSetExtension1()
message1.i = 12345
item.message = message1.SerializeToString()
serialized = raw.SerializeToString()
# Parse message using the message set wire format.
proto = message_set_extensions_pb2.TestMessageSet()
proto.MergeFromString(serialized)
unknown_field_set = unknown_fields.UnknownFieldSet(proto)
self.assertEqual(len(unknown_field_set), 1)
# Unknown field should have wire format data which can be parsed back to
# original message.
self.assertEqual(unknown_field_set[0].field_number, item.type_id)
self.assertEqual(unknown_field_set[0].wire_type,
wire_format.WIRETYPE_LENGTH_DELIMITED)
d = unknown_field_set[0].data
message_new = message_set_extensions_pb2.TestMessageSetExtension1()
message_new.ParseFromString(d)
self.assertEqual(message1, message_new)
# Verify that the unknown extension is serialized unchanged
reserialized = proto.SerializeToString()
new_raw = unittest_mset_pb2.RawMessageSet()
new_raw.MergeFromString(reserialized)
self.assertEqual(raw, new_raw)
def testEquals(self):
message = unittest_pb2.TestEmptyMessage()
message.ParseFromString(self.all_fields_data)
self.assertEqual(self.empty_message, message)
self.all_fields.ClearField('optional_string')
message.ParseFromString(self.all_fields.SerializeToString())
self.assertNotEqual(self.empty_message, message)
def testDiscardUnknownFields(self):
self.empty_message.DiscardUnknownFields()
self.assertEqual(b'', self.empty_message.SerializeToString())
# Test message field and repeated message field.
message = unittest_pb2.TestAllTypes()
other_message = unittest_pb2.TestAllTypes()
other_message.optional_string = 'discard'
message.optional_nested_message.ParseFromString(
other_message.SerializeToString())
message.repeated_nested_message.add().ParseFromString(
other_message.SerializeToString())
self.assertNotEqual(
b'', message.optional_nested_message.SerializeToString())
self.assertNotEqual(
b'', message.repeated_nested_message[0].SerializeToString())
message.DiscardUnknownFields()
self.assertEqual(b'', message.optional_nested_message.SerializeToString())
self.assertEqual(
b'', message.repeated_nested_message[0].SerializeToString())
msg = map_unittest_pb2.TestMap()
msg.map_int32_all_types[1].optional_nested_message.ParseFromString(
other_message.SerializeToString())
msg.map_string_string['1'] = 'test'
self.assertNotEqual(
b'',
msg.map_int32_all_types[1].optional_nested_message.SerializeToString())
msg.DiscardUnknownFields()
self.assertEqual(
b'',
msg.map_int32_all_types[1].optional_nested_message.SerializeToString())
@testing_refleaks.TestCase
class UnknownFieldsAccessorsTest(unittest.TestCase):
def setUp(self):
self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
self.all_fields = unittest_pb2.TestAllTypes()
test_util.SetAllFields(self.all_fields)
self.all_fields_data = self.all_fields.SerializeToString()
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
# InternalCheckUnknownField() is an additional Pure Python check which checks
# a detail of unknown fields. It cannot be used by the C++
# implementation because some protect members are called.
# The test is added for historical reasons. It is not necessary as
# serialized string is checked.
# TODO(jieluo): Remove message._unknown_fields.
def InternalCheckUnknownField(self, name, expected_value):
if api_implementation.Type() != 'python':
return
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
result_dict = {}
for tag_bytes, value in self.empty_message._unknown_fields:
if tag_bytes == field_tag:
decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
self.assertEqual(expected_value, result_dict[field_descriptor])
def CheckUnknownField(self, name, unknown_field_set, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
field_descriptor.type]
for unknown_field in unknown_field_set:
if unknown_field.field_number == field_descriptor.number:
self.assertEqual(expected_type, unknown_field.wire_type)
if expected_type == 3:
# Check group
self.assertEqual(expected_value[0],
unknown_field.data[0].field_number)
self.assertEqual(expected_value[1], unknown_field.data[0].wire_type)
self.assertEqual(expected_value[2], unknown_field.data[0].data)
continue
if expected_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
self.assertIn(type(unknown_field.data), (str, bytes))
if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
self.assertIn(unknown_field.data, expected_value)
else:
self.assertEqual(expected_value, unknown_field.data)
def testCheckUnknownFieldValue(self):
unknown_field_set = unknown_fields.UnknownFieldSet(self.empty_message)
# Test enum.
self.CheckUnknownField('optional_nested_enum',
unknown_field_set,
self.all_fields.optional_nested_enum)
self.InternalCheckUnknownField('optional_nested_enum',
self.all_fields.optional_nested_enum)
# Test repeated enum.
self.CheckUnknownField('repeated_nested_enum',
unknown_field_set,
self.all_fields.repeated_nested_enum)
self.InternalCheckUnknownField('repeated_nested_enum',
self.all_fields.repeated_nested_enum)
# Test varint.
self.CheckUnknownField('optional_int32',
unknown_field_set,
self.all_fields.optional_int32)
self.InternalCheckUnknownField('optional_int32',
self.all_fields.optional_int32)
# Test fixed32.
self.CheckUnknownField('optional_fixed32',
unknown_field_set,
self.all_fields.optional_fixed32)
self.InternalCheckUnknownField('optional_fixed32',
self.all_fields.optional_fixed32)
# Test fixed64.
self.CheckUnknownField('optional_fixed64',
unknown_field_set,
self.all_fields.optional_fixed64)
self.InternalCheckUnknownField('optional_fixed64',
self.all_fields.optional_fixed64)
# Test length delimited.
self.CheckUnknownField('optional_string',
unknown_field_set,
self.all_fields.optional_string.encode('utf-8'))
self.InternalCheckUnknownField('optional_string',
self.all_fields.optional_string)
# Test group.
self.CheckUnknownField('optionalgroup',
unknown_field_set,
(17, 0, 117))
self.InternalCheckUnknownField('optionalgroup',
self.all_fields.optionalgroup)
self.assertEqual(98, len(unknown_field_set))
def testCopyFrom(self):
message = unittest_pb2.TestEmptyMessage()
message.CopyFrom(self.empty_message)
self.assertEqual(message.SerializeToString(), self.all_fields_data)
def testMergeFrom(self):
message = unittest_pb2.TestAllTypes()
message.optional_int32 = 1
message.optional_uint32 = 2
source = unittest_pb2.TestEmptyMessage()
source.ParseFromString(message.SerializeToString())
message.ClearField('optional_int32')
message.optional_int64 = 3
message.optional_uint32 = 4
destination = unittest_pb2.TestEmptyMessage()
unknown_field_set = unknown_fields.UnknownFieldSet(destination)
self.assertEqual(0, len(unknown_field_set))
destination.ParseFromString(message.SerializeToString())
self.assertEqual(0, len(unknown_field_set))
unknown_field_set = unknown_fields.UnknownFieldSet(destination)
self.assertEqual(2, len(unknown_field_set))
destination.MergeFrom(source)
self.assertEqual(2, len(unknown_field_set))
# Check that the fields where correctly merged, even stored in the unknown
# fields set.
message.ParseFromString(destination.SerializeToString())
self.assertEqual(message.optional_int32, 1)
self.assertEqual(message.optional_uint32, 2)
self.assertEqual(message.optional_int64, 3)
def testClear(self):
unknown_field_set = unknown_fields.UnknownFieldSet(self.empty_message)
self.empty_message.Clear()
# All cleared, even unknown fields.
self.assertEqual(self.empty_message.SerializeToString(), b'')
self.assertEqual(len(unknown_field_set), 98)
@unittest.skipIf((sys.version_info.major, sys.version_info.minor) < (3, 4),
'tracemalloc requires python 3.4+')
def testUnknownFieldsNoMemoryLeak(self):
# Call to UnknownFields must not leak memory
nb_leaks = 1234
def leaking_function():
for _ in range(nb_leaks):
unknown_fields.UnknownFieldSet(self.empty_message)
tracemalloc.start()
snapshot1 = tracemalloc.take_snapshot()
leaking_function()
snapshot2 = tracemalloc.take_snapshot()
top_stats = snapshot2.compare_to(snapshot1, 'lineno')
tracemalloc.stop()
# There's no easy way to look for a precise leak source.
# Rely on a "marker" count value while checking allocated memory.
self.assertEqual([], [x for x in top_stats if x.count_diff == nb_leaks])
def testSubUnknownFields(self):
message = unittest_pb2.TestAllTypes()
message.optionalgroup.a = 123
destination = unittest_pb2.TestEmptyMessage()
destination.ParseFromString(message.SerializeToString())
sub_unknown_fields = unknown_fields.UnknownFieldSet(destination)[0].data
self.assertEqual(1, len(sub_unknown_fields))
self.assertEqual(sub_unknown_fields[0].data, 123)
destination.Clear()
self.assertEqual(1, len(sub_unknown_fields))
self.assertEqual(sub_unknown_fields[0].data, 123)
message.Clear()
message.optional_uint32 = 456
nested_message = unittest_pb2.NestedTestAllTypes()
nested_message.payload.optional_nested_message.ParseFromString(
message.SerializeToString())
unknown_field_set = unknown_fields.UnknownFieldSet(
nested_message.payload.optional_nested_message)
self.assertEqual(unknown_field_set[0].data, 456)
nested_message.ClearField('payload')
self.assertEqual(unknown_field_set[0].data, 456)
unknown_field_set = unknown_fields.UnknownFieldSet(
nested_message.payload.optional_nested_message)
self.assertEqual(0, len(unknown_field_set))
def testUnknownField(self):
message = unittest_pb2.TestAllTypes()
message.optional_int32 = 123
destination = unittest_pb2.TestEmptyMessage()
destination.ParseFromString(message.SerializeToString())
unknown_field = unknown_fields.UnknownFieldSet(destination)[0]
destination.Clear()
self.assertEqual(unknown_field.data, 123)
def testUnknownExtensions(self):
message = unittest_pb2.TestEmptyMessageWithExtensions()
message.ParseFromString(self.all_fields_data)
self.assertEqual(len(unknown_fields.UnknownFieldSet(message)), 98)
self.assertEqual(message.SerializeToString(), self.all_fields_data)
@testing_refleaks.TestCase
class UnknownEnumValuesTest(unittest.TestCase):
def setUp(self):
self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR
self.message = missing_enum_values_pb2.TestEnumValues()
# TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum.
self.message.optional_nested_enum = (
missing_enum_values_pb2.TestEnumValues.ZERO)
self.message.repeated_nested_enum.extend([
missing_enum_values_pb2.TestEnumValues.ZERO,
missing_enum_values_pb2.TestEnumValues.ONE,
])
self.message.packed_nested_enum.extend([
missing_enum_values_pb2.TestEnumValues.ZERO,
missing_enum_values_pb2.TestEnumValues.ONE,
])
self.message_data = self.message.SerializeToString()
self.missing_message = missing_enum_values_pb2.TestMissingEnumValues()
self.missing_message.ParseFromString(self.message_data)
# CheckUnknownField() is an additional Pure Python check which checks
# a detail of unknown fields. It cannot be used by the C++
# implementation because some protect members are called.
# The test is added for historical reasons. It is not necessary as
# serialized string is checked.
def CheckUnknownField(self, name, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
unknown_field_set = unknown_fields.UnknownFieldSet(self.missing_message)
self.assertIsInstance(unknown_field_set, unknown_fields.UnknownFieldSet)
count = 0
for field in unknown_field_set:
if field.field_number == field_descriptor.number:
count += 1
if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
self.assertIn(field.data, expected_value)
else:
self.assertEqual(expected_value, field.data)
if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
self.assertEqual(count, len(expected_value))
else:
self.assertEqual(count, 1)
def testUnknownParseMismatchEnumValue(self):
just_string = missing_enum_values_pb2.JustString()
just_string.dummy = 'blah'
missing = missing_enum_values_pb2.TestEnumValues()
# The parse is invalid, storing the string proto into the set of
# unknown fields.
missing.ParseFromString(just_string.SerializeToString())
# Fetching the enum field shouldn't crash, instead returning the
# default value.
self.assertEqual(missing.optional_nested_enum, 0)
def testUnknownEnumValue(self):
self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
self.assertEqual(self.missing_message.optional_nested_enum, 2)
# Clear does not do anything.
serialized = self.missing_message.SerializeToString()
self.missing_message.ClearField('optional_nested_enum')
self.assertEqual(self.missing_message.SerializeToString(), serialized)
def testUnknownRepeatedEnumValue(self):
self.assertEqual([], self.missing_message.repeated_nested_enum)
def testUnknownPackedEnumValue(self):
self.assertEqual([], self.missing_message.packed_nested_enum)
def testCheckUnknownFieldValueForEnum(self):
unknown_field_set = unknown_fields.UnknownFieldSet(self.missing_message)
self.assertEqual(len(unknown_field_set), 5)
self.CheckUnknownField('optional_nested_enum',
self.message.optional_nested_enum)
self.CheckUnknownField('repeated_nested_enum',
self.message.repeated_nested_enum)
self.CheckUnknownField('packed_nested_enum',
self.message.packed_nested_enum)
def testRoundTrip(self):
new_message = missing_enum_values_pb2.TestEnumValues()
new_message.ParseFromString(self.missing_message.SerializeToString())
self.assertEqual(self.message, new_message)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,582 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Contains well known classes.
This files defines well known classes which need extra maintenance including:
- Any
- Duration
- FieldMask
- Struct
- Timestamp
"""
__author__ = 'jieluo@google.com (Jie Luo)'
import calendar
import collections.abc
import datetime
from google.protobuf.internal import field_mask
FieldMask = field_mask.FieldMask
_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
_NANOS_PER_SECOND = 1000000000
_NANOS_PER_MILLISECOND = 1000000
_NANOS_PER_MICROSECOND = 1000
_MILLIS_PER_SECOND = 1000
_MICROS_PER_SECOND = 1000000
_SECONDS_PER_DAY = 24 * 3600
_DURATION_SECONDS_MAX = 315576000000
class Any(object):
"""Class for Any Message type."""
__slots__ = ()
def Pack(self, msg, type_url_prefix='type.googleapis.com/',
deterministic=None):
"""Packs the specified message into current Any message."""
if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/':
self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
else:
self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
self.value = msg.SerializeToString(deterministic=deterministic)
def Unpack(self, msg):
"""Unpacks the current Any message into specified message."""
descriptor = msg.DESCRIPTOR
if not self.Is(descriptor):
return False
msg.ParseFromString(self.value)
return True
def TypeName(self):
"""Returns the protobuf type name of the inner message."""
# Only last part is to be used: b/25630112
return self.type_url.split('/')[-1]
def Is(self, descriptor):
"""Checks if this Any represents the given protobuf type."""
return '/' in self.type_url and self.TypeName() == descriptor.full_name
_EPOCH_DATETIME_NAIVE = datetime.datetime.utcfromtimestamp(0)
_EPOCH_DATETIME_AWARE = datetime.datetime.fromtimestamp(
0, tz=datetime.timezone.utc)
class Timestamp(object):
"""Class for Timestamp message type."""
__slots__ = ()
def ToJsonString(self):
"""Converts Timestamp to RFC 3339 date string format.
Returns:
A string converted from timestamp. The string is always Z-normalized
and uses 3, 6 or 9 fractional digits as required to represent the
exact time. Example of the return format: '1972-01-01T10:00:20.021Z'
"""
nanos = self.nanos % _NANOS_PER_SECOND
total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND
seconds = total_sec % _SECONDS_PER_DAY
days = (total_sec - seconds) // _SECONDS_PER_DAY
dt = datetime.datetime(1970, 1, 1) + datetime.timedelta(days, seconds)
result = dt.isoformat()
if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + 'Z'
if (nanos % 1e6) == 0:
# Serialize 3 fractional digits.
return result + '.%03dZ' % (nanos / 1e6)
if (nanos % 1e3) == 0:
# Serialize 6 fractional digits.
return result + '.%06dZ' % (nanos / 1e3)
# Serialize 9 fractional digits.
return result + '.%09dZ' % nanos
def FromJsonString(self, value):
"""Parse a RFC 3339 date string format to Timestamp.
Args:
value: A date string. Any fractional digits (or none) and any offset are
accepted as long as they fit into nano-seconds precision.
Example of accepted format: '1972-01-01T10:00:20.021-05:00'
Raises:
ValueError: On parsing problems.
"""
if not isinstance(value, str):
raise ValueError('Timestamp JSON value not a string: {!r}'.format(value))
timezone_offset = value.find('Z')
if timezone_offset == -1:
timezone_offset = value.find('+')
if timezone_offset == -1:
timezone_offset = value.rfind('-')
if timezone_offset == -1:
raise ValueError(
'Failed to parse timestamp: missing valid timezone offset.')
time_value = value[0:timezone_offset]
# Parse datetime and nanos.
point_position = time_value.find('.')
if point_position == -1:
second_value = time_value
nano_value = ''
else:
second_value = time_value[:point_position]
nano_value = time_value[point_position + 1:]
if 't' in second_value:
raise ValueError(
'time data \'{0}\' does not match format \'%Y-%m-%dT%H:%M:%S\', '
'lowercase \'t\' is not accepted'.format(second_value))
date_object = datetime.datetime.strptime(second_value, _TIMESTAMPFOMAT)
td = date_object - datetime.datetime(1970, 1, 1)
seconds = td.seconds + td.days * _SECONDS_PER_DAY
if len(nano_value) > 9:
raise ValueError(
'Failed to parse Timestamp: nanos {0} more than '
'9 fractional digits.'.format(nano_value))
if nano_value:
nanos = round(float('0.' + nano_value) * 1e9)
else:
nanos = 0
# Parse timezone offsets.
if value[timezone_offset] == 'Z':
if len(value) != timezone_offset + 1:
raise ValueError('Failed to parse timestamp: invalid trailing'
' data {0}.'.format(value))
else:
timezone = value[timezone_offset:]
pos = timezone.find(':')
if pos == -1:
raise ValueError(
'Invalid timezone offset value: {0}.'.format(timezone))
if timezone[0] == '+':
seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
else:
seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
# Set seconds and nanos
self.seconds = int(seconds)
self.nanos = int(nanos)
def GetCurrentTime(self):
"""Get the current UTC into Timestamp."""
self.FromDatetime(datetime.datetime.utcnow())
def ToNanoseconds(self):
"""Converts Timestamp to nanoseconds since epoch."""
return self.seconds * _NANOS_PER_SECOND + self.nanos
def ToMicroseconds(self):
"""Converts Timestamp to microseconds since epoch."""
return (self.seconds * _MICROS_PER_SECOND +
self.nanos // _NANOS_PER_MICROSECOND)
def ToMilliseconds(self):
"""Converts Timestamp to milliseconds since epoch."""
return (self.seconds * _MILLIS_PER_SECOND +
self.nanos // _NANOS_PER_MILLISECOND)
def ToSeconds(self):
"""Converts Timestamp to seconds since epoch."""
return self.seconds
def FromNanoseconds(self, nanos):
"""Converts nanoseconds since epoch to Timestamp."""
self.seconds = nanos // _NANOS_PER_SECOND
self.nanos = nanos % _NANOS_PER_SECOND
def FromMicroseconds(self, micros):
"""Converts microseconds since epoch to Timestamp."""
self.seconds = micros // _MICROS_PER_SECOND
self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND
def FromMilliseconds(self, millis):
"""Converts milliseconds since epoch to Timestamp."""
self.seconds = millis // _MILLIS_PER_SECOND
self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND
def FromSeconds(self, seconds):
"""Converts seconds since epoch to Timestamp."""
self.seconds = seconds
self.nanos = 0
def ToDatetime(self, tzinfo=None):
"""Converts Timestamp to a datetime.
Args:
tzinfo: A datetime.tzinfo subclass; defaults to None.
Returns:
If tzinfo is None, returns a timezone-naive UTC datetime (with no timezone
information, i.e. not aware that it's UTC).
Otherwise, returns a timezone-aware datetime in the input timezone.
"""
delta = datetime.timedelta(
seconds=self.seconds,
microseconds=_RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND))
if tzinfo is None:
return _EPOCH_DATETIME_NAIVE + delta
else:
return _EPOCH_DATETIME_AWARE.astimezone(tzinfo) + delta
def FromDatetime(self, dt):
"""Converts datetime to Timestamp.
Args:
dt: A datetime. If it's timezone-naive, it's assumed to be in UTC.
"""
# Using this guide: http://wiki.python.org/moin/WorkingWithTime
# And this conversion guide: http://docs.python.org/library/time.html
# Turn the date parameter into a tuple (struct_time) that can then be
# manipulated into a long value of seconds. During the conversion from
# struct_time to long, the source date in UTC, and so it follows that the
# correct transformation is calendar.timegm()
self.seconds = calendar.timegm(dt.utctimetuple())
self.nanos = dt.microsecond * _NANOS_PER_MICROSECOND
class Duration(object):
"""Class for Duration message type."""
__slots__ = ()
def ToJsonString(self):
"""Converts Duration to string format.
Returns:
A string converted from self. The string format will contains
3, 6, or 9 fractional digits depending on the precision required to
represent the exact Duration value. For example: "1s", "1.010s",
"1.000000100s", "-3.100s"
"""
_CheckDurationValid(self.seconds, self.nanos)
if self.seconds < 0 or self.nanos < 0:
result = '-'
seconds = - self.seconds + int((0 - self.nanos) // 1e9)
nanos = (0 - self.nanos) % 1e9
else:
result = ''
seconds = self.seconds + int(self.nanos // 1e9)
nanos = self.nanos % 1e9
result += '%d' % seconds
if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + 's'
if (nanos % 1e6) == 0:
# Serialize 3 fractional digits.
return result + '.%03ds' % (nanos / 1e6)
if (nanos % 1e3) == 0:
# Serialize 6 fractional digits.
return result + '.%06ds' % (nanos / 1e3)
# Serialize 9 fractional digits.
return result + '.%09ds' % nanos
def FromJsonString(self, value):
"""Converts a string to Duration.
Args:
value: A string to be converted. The string must end with 's'. Any
fractional digits (or none) are accepted as long as they fit into
precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
Raises:
ValueError: On parsing problems.
"""
if not isinstance(value, str):
raise ValueError('Duration JSON value not a string: {!r}'.format(value))
if len(value) < 1 or value[-1] != 's':
raise ValueError(
'Duration must end with letter "s": {0}.'.format(value))
try:
pos = value.find('.')
if pos == -1:
seconds = int(value[:-1])
nanos = 0
else:
seconds = int(value[:pos])
if value[0] == '-':
nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
else:
nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
_CheckDurationValid(seconds, nanos)
self.seconds = seconds
self.nanos = nanos
except ValueError as e:
raise ValueError(
'Couldn\'t parse duration: {0} : {1}.'.format(value, e))
def ToNanoseconds(self):
"""Converts a Duration to nanoseconds."""
return self.seconds * _NANOS_PER_SECOND + self.nanos
def ToMicroseconds(self):
"""Converts a Duration to microseconds."""
micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND)
return self.seconds * _MICROS_PER_SECOND + micros
def ToMilliseconds(self):
"""Converts a Duration to milliseconds."""
millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND)
return self.seconds * _MILLIS_PER_SECOND + millis
def ToSeconds(self):
"""Converts a Duration to seconds."""
return self.seconds
def FromNanoseconds(self, nanos):
"""Converts nanoseconds to Duration."""
self._NormalizeDuration(nanos // _NANOS_PER_SECOND,
nanos % _NANOS_PER_SECOND)
def FromMicroseconds(self, micros):
"""Converts microseconds to Duration."""
self._NormalizeDuration(
micros // _MICROS_PER_SECOND,
(micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND)
def FromMilliseconds(self, millis):
"""Converts milliseconds to Duration."""
self._NormalizeDuration(
millis // _MILLIS_PER_SECOND,
(millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND)
def FromSeconds(self, seconds):
"""Converts seconds to Duration."""
self.seconds = seconds
self.nanos = 0
def ToTimedelta(self):
"""Converts Duration to timedelta."""
return datetime.timedelta(
seconds=self.seconds, microseconds=_RoundTowardZero(
self.nanos, _NANOS_PER_MICROSECOND))
def FromTimedelta(self, td):
"""Converts timedelta to Duration."""
self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
td.microseconds * _NANOS_PER_MICROSECOND)
def _NormalizeDuration(self, seconds, nanos):
"""Set Duration by seconds and nanos."""
# Force nanos to be negative if the duration is negative.
if seconds < 0 and nanos > 0:
seconds += 1
nanos -= _NANOS_PER_SECOND
self.seconds = seconds
self.nanos = nanos
def _CheckDurationValid(seconds, nanos):
if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:
raise ValueError(
'Duration is not valid: Seconds {0} must be in range '
'[-315576000000, 315576000000].'.format(seconds))
if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND:
raise ValueError(
'Duration is not valid: Nanos {0} must be in range '
'[-999999999, 999999999].'.format(nanos))
if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0):
raise ValueError(
'Duration is not valid: Sign mismatch.')
def _RoundTowardZero(value, divider):
"""Truncates the remainder part after division."""
# For some languages, the sign of the remainder is implementation
# dependent if any of the operands is negative. Here we enforce
# "rounded toward zero" semantics. For example, for (-5) / 2 an
# implementation may give -3 as the result with the remainder being
# 1. This function ensures we always return -2 (closer to zero).
result = value // divider
remainder = value % divider
if result < 0 and remainder > 0:
return result + 1
else:
return result
def _SetStructValue(struct_value, value):
if value is None:
struct_value.null_value = 0
elif isinstance(value, bool):
# Note: this check must come before the number check because in Python
# True and False are also considered numbers.
struct_value.bool_value = value
elif isinstance(value, str):
struct_value.string_value = value
elif isinstance(value, (int, float)):
struct_value.number_value = value
elif isinstance(value, (dict, Struct)):
struct_value.struct_value.Clear()
struct_value.struct_value.update(value)
elif isinstance(value, (list, ListValue)):
struct_value.list_value.Clear()
struct_value.list_value.extend(value)
else:
raise ValueError('Unexpected type')
def _GetStructValue(struct_value):
which = struct_value.WhichOneof('kind')
if which == 'struct_value':
return struct_value.struct_value
elif which == 'null_value':
return None
elif which == 'number_value':
return struct_value.number_value
elif which == 'string_value':
return struct_value.string_value
elif which == 'bool_value':
return struct_value.bool_value
elif which == 'list_value':
return struct_value.list_value
elif which is None:
raise ValueError('Value not set')
class Struct(object):
"""Class for Struct message type."""
__slots__ = ()
def __getitem__(self, key):
return _GetStructValue(self.fields[key])
def __contains__(self, item):
return item in self.fields
def __setitem__(self, key, value):
_SetStructValue(self.fields[key], value)
def __delitem__(self, key):
del self.fields[key]
def __len__(self):
return len(self.fields)
def __iter__(self):
return iter(self.fields)
def keys(self): # pylint: disable=invalid-name
return self.fields.keys()
def values(self): # pylint: disable=invalid-name
return [self[key] for key in self]
def items(self): # pylint: disable=invalid-name
return [(key, self[key]) for key in self]
def get_or_create_list(self, key):
"""Returns a list for this key, creating if it didn't exist already."""
if not self.fields[key].HasField('list_value'):
# Clear will mark list_value modified which will indeed create a list.
self.fields[key].list_value.Clear()
return self.fields[key].list_value
def get_or_create_struct(self, key):
"""Returns a struct for this key, creating if it didn't exist already."""
if not self.fields[key].HasField('struct_value'):
# Clear will mark struct_value modified which will indeed create a struct.
self.fields[key].struct_value.Clear()
return self.fields[key].struct_value
def update(self, dictionary): # pylint: disable=invalid-name
for key, value in dictionary.items():
_SetStructValue(self.fields[key], value)
collections.abc.MutableMapping.register(Struct)
class ListValue(object):
"""Class for ListValue message type."""
__slots__ = ()
def __len__(self):
return len(self.values)
def append(self, value):
_SetStructValue(self.values.add(), value)
def extend(self, elem_seq):
for value in elem_seq:
self.append(value)
def __getitem__(self, index):
"""Retrieves item by the specified index."""
return _GetStructValue(self.values.__getitem__(index))
def __setitem__(self, index, value):
_SetStructValue(self.values.__getitem__(index), value)
def __delitem__(self, key):
del self.values[key]
def items(self):
for i in range(len(self)):
yield self[i]
def add_struct(self):
"""Appends and returns a struct value as the next value in the list."""
struct_value = self.values.add().struct_value
# Clear will mark struct_value modified which will indeed create a struct.
struct_value.Clear()
return struct_value
def add_list(self):
"""Appends and returns a list value as the next value in the list."""
list_value = self.values.add().list_value
# Clear will mark list_value modified which will indeed create a list.
list_value.Clear()
return list_value
collections.abc.MutableSequence.register(ListValue)
# LINT.IfChange(wktbases)
WKTBASES = {
'google.protobuf.Any': Any,
'google.protobuf.Duration': Duration,
'google.protobuf.FieldMask': FieldMask,
'google.protobuf.ListValue': ListValue,
'google.protobuf.Struct': Struct,
'google.protobuf.Timestamp': Timestamp,
}
# LINT.ThenChange(//depot/google.protobuf/compiler/python/pyi_generator.cc:wktbases)

View File

@@ -0,0 +1,653 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test for google.protobuf.internal.well_known_types."""
__author__ = 'jieluo@google.com (Jie Luo)'
import collections.abc as collections_abc
import datetime
import unittest
from google.protobuf import any_pb2
from google.protobuf import duration_pb2
from google.protobuf import struct_pb2
from google.protobuf import timestamp_pb2
from google.protobuf import unittest_pb2
from google.protobuf.internal import any_test_pb2
from google.protobuf.internal import well_known_types
from google.protobuf import text_format
from google.protobuf.internal import _parameterized
try:
# New module in Python 3.9:
import zoneinfo # pylint:disable=g-import-not-at-top
_TZ_JAPAN = zoneinfo.ZoneInfo('Japan')
_TZ_PACIFIC = zoneinfo.ZoneInfo('US/Pacific')
except ImportError:
_TZ_JAPAN = datetime.timezone(datetime.timedelta(hours=9), 'Japan')
_TZ_PACIFIC = datetime.timezone(datetime.timedelta(hours=-8), 'US/Pacific')
class TimeUtilTestBase(_parameterized.TestCase):
def CheckTimestampConversion(self, message, text):
self.assertEqual(text, message.ToJsonString())
parsed_message = timestamp_pb2.Timestamp()
parsed_message.FromJsonString(text)
self.assertEqual(message, parsed_message)
def CheckDurationConversion(self, message, text):
self.assertEqual(text, message.ToJsonString())
parsed_message = duration_pb2.Duration()
parsed_message.FromJsonString(text)
self.assertEqual(message, parsed_message)
class TimeUtilTest(TimeUtilTestBase):
def testTimestampSerializeAndParse(self):
message = timestamp_pb2.Timestamp()
# Generated output should contain 3, 6, or 9 fractional digits.
message.seconds = 0
message.nanos = 0
self.CheckTimestampConversion(message, '1970-01-01T00:00:00Z')
message.nanos = 10000000
self.CheckTimestampConversion(message, '1970-01-01T00:00:00.010Z')
message.nanos = 10000
self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000010Z')
message.nanos = 10
self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000000010Z')
# Test min timestamps.
message.seconds = -62135596800
message.nanos = 0
self.CheckTimestampConversion(message, '0001-01-01T00:00:00Z')
# Test max timestamps.
message.seconds = 253402300799
message.nanos = 999999999
self.CheckTimestampConversion(message, '9999-12-31T23:59:59.999999999Z')
# Test negative timestamps.
message.seconds = -1
self.CheckTimestampConversion(message, '1969-12-31T23:59:59.999999999Z')
# Parsing accepts an fractional digits as long as they fit into nano
# precision.
message.FromJsonString('1970-01-01T00:00:00.1Z')
self.assertEqual(0, message.seconds)
self.assertEqual(100000000, message.nanos)
# Parsing accepts offsets.
message.FromJsonString('1970-01-01T00:00:00-08:00')
self.assertEqual(8 * 3600, message.seconds)
self.assertEqual(0, message.nanos)
# It is not easy to check with current time. For test coverage only.
message.GetCurrentTime()
self.assertNotEqual(8 * 3600, message.seconds)
def testDurationSerializeAndParse(self):
message = duration_pb2.Duration()
# Generated output should contain 3, 6, or 9 fractional digits.
message.seconds = 0
message.nanos = 0
self.CheckDurationConversion(message, '0s')
message.nanos = 10000000
self.CheckDurationConversion(message, '0.010s')
message.nanos = 10000
self.CheckDurationConversion(message, '0.000010s')
message.nanos = 10
self.CheckDurationConversion(message, '0.000000010s')
# Test min and max
message.seconds = 315576000000
message.nanos = 999999999
self.CheckDurationConversion(message, '315576000000.999999999s')
message.seconds = -315576000000
message.nanos = -999999999
self.CheckDurationConversion(message, '-315576000000.999999999s')
# Parsing accepts an fractional digits as long as they fit into nano
# precision.
message.FromJsonString('0.1s')
self.assertEqual(100000000, message.nanos)
message.FromJsonString('0.0000001s')
self.assertEqual(100, message.nanos)
def testTimestampIntegerConversion(self):
message = timestamp_pb2.Timestamp()
message.FromNanoseconds(1)
self.assertEqual('1970-01-01T00:00:00.000000001Z',
message.ToJsonString())
self.assertEqual(1, message.ToNanoseconds())
message.FromNanoseconds(-1)
self.assertEqual('1969-12-31T23:59:59.999999999Z',
message.ToJsonString())
self.assertEqual(-1, message.ToNanoseconds())
message.FromMicroseconds(1)
self.assertEqual('1970-01-01T00:00:00.000001Z',
message.ToJsonString())
self.assertEqual(1, message.ToMicroseconds())
message.FromMicroseconds(-1)
self.assertEqual('1969-12-31T23:59:59.999999Z',
message.ToJsonString())
self.assertEqual(-1, message.ToMicroseconds())
message.FromMilliseconds(1)
self.assertEqual('1970-01-01T00:00:00.001Z',
message.ToJsonString())
self.assertEqual(1, message.ToMilliseconds())
message.FromMilliseconds(-1)
self.assertEqual('1969-12-31T23:59:59.999Z',
message.ToJsonString())
self.assertEqual(-1, message.ToMilliseconds())
message.FromSeconds(1)
self.assertEqual('1970-01-01T00:00:01Z',
message.ToJsonString())
self.assertEqual(1, message.ToSeconds())
message.FromSeconds(-1)
self.assertEqual('1969-12-31T23:59:59Z',
message.ToJsonString())
self.assertEqual(-1, message.ToSeconds())
message.FromNanoseconds(1999)
self.assertEqual(1, message.ToMicroseconds())
# For negative values, Timestamp will be rounded down.
# For example, "1969-12-31T23:59:59.5Z" (i.e., -0.5s) rounded to seconds
# will be "1969-12-31T23:59:59Z" (i.e., -1s) rather than
# "1970-01-01T00:00:00Z" (i.e., 0s).
message.FromNanoseconds(-1999)
self.assertEqual(-2, message.ToMicroseconds())
def testDurationIntegerConversion(self):
message = duration_pb2.Duration()
message.FromNanoseconds(1)
self.assertEqual('0.000000001s',
message.ToJsonString())
self.assertEqual(1, message.ToNanoseconds())
message.FromNanoseconds(-1)
self.assertEqual('-0.000000001s',
message.ToJsonString())
self.assertEqual(-1, message.ToNanoseconds())
message.FromMicroseconds(1)
self.assertEqual('0.000001s',
message.ToJsonString())
self.assertEqual(1, message.ToMicroseconds())
message.FromMicroseconds(-1)
self.assertEqual('-0.000001s',
message.ToJsonString())
self.assertEqual(-1, message.ToMicroseconds())
message.FromMilliseconds(1)
self.assertEqual('0.001s',
message.ToJsonString())
self.assertEqual(1, message.ToMilliseconds())
message.FromMilliseconds(-1)
self.assertEqual('-0.001s',
message.ToJsonString())
self.assertEqual(-1, message.ToMilliseconds())
message.FromSeconds(1)
self.assertEqual('1s', message.ToJsonString())
self.assertEqual(1, message.ToSeconds())
message.FromSeconds(-1)
self.assertEqual('-1s',
message.ToJsonString())
self.assertEqual(-1, message.ToSeconds())
# Test truncation behavior.
message.FromNanoseconds(1999)
self.assertEqual(1, message.ToMicroseconds())
# For negative values, Duration will be rounded towards 0.
message.FromNanoseconds(-1999)
self.assertEqual(-1, message.ToMicroseconds())
def testTimezoneNaiveDatetimeConversion(self):
message = timestamp_pb2.Timestamp()
naive_utc_epoch = datetime.datetime(1970, 1, 1)
message.FromDatetime(naive_utc_epoch)
self.assertEqual(0, message.seconds)
self.assertEqual(0, message.nanos)
self.assertEqual(naive_utc_epoch, message.ToDatetime())
naive_epoch_morning = datetime.datetime(1970, 1, 1, 8, 0, 0, 1)
message.FromDatetime(naive_epoch_morning)
self.assertEqual(8 * 3600, message.seconds)
self.assertEqual(1000, message.nanos)
self.assertEqual(naive_epoch_morning, message.ToDatetime())
message.FromMilliseconds(1999)
self.assertEqual(1, message.seconds)
self.assertEqual(999_000_000, message.nanos)
self.assertEqual(datetime.datetime(1970, 1, 1, 0, 0, 1, 999000),
message.ToDatetime())
naive_future = datetime.datetime(2555, 2, 22, 1, 2, 3, 456789)
message.FromDatetime(naive_future)
self.assertEqual(naive_future, message.ToDatetime())
naive_end_of_time = datetime.datetime.max
message.FromDatetime(naive_end_of_time)
self.assertEqual(naive_end_of_time, message.ToDatetime())
# Two hours after the Unix Epoch, around the world.
@_parameterized.named_parameters(
('London', [1970, 1, 1, 2], datetime.timezone.utc),
('Tokyo', [1970, 1, 1, 11], _TZ_JAPAN),
('LA', [1969, 12, 31, 18], _TZ_PACIFIC),
)
def testTimezoneAwareDatetimeConversion(self, date_parts, tzinfo):
original_datetime = datetime.datetime(*date_parts, tzinfo=tzinfo) # pylint:disable=g-tzinfo-datetime
message = timestamp_pb2.Timestamp()
message.FromDatetime(original_datetime)
self.assertEqual(7200, message.seconds)
self.assertEqual(0, message.nanos)
# ToDatetime() with no parameters produces a naive UTC datetime, i.e. it not
# only loses the original timezone information (e.g. US/Pacific) as it's
# "normalised" to UTC, but also drops the information that the datetime
# represents a UTC one.
naive_datetime = message.ToDatetime()
self.assertEqual(datetime.datetime(1970, 1, 1, 2), naive_datetime)
self.assertIsNone(naive_datetime.tzinfo)
self.assertNotEqual(original_datetime, naive_datetime) # not even for UTC!
# In contrast, ToDatetime(tzinfo=) produces an aware datetime in the given
# timezone.
aware_datetime = message.ToDatetime(tzinfo=tzinfo)
self.assertEqual(original_datetime, aware_datetime)
self.assertEqual(
datetime.datetime(1970, 1, 1, 2, tzinfo=datetime.timezone.utc),
aware_datetime)
self.assertEqual(tzinfo, aware_datetime.tzinfo)
def testTimedeltaConversion(self):
message = duration_pb2.Duration()
message.FromNanoseconds(1999999999)
td = message.ToTimedelta()
self.assertEqual(1, td.seconds)
self.assertEqual(999999, td.microseconds)
message.FromNanoseconds(-1999999999)
td = message.ToTimedelta()
self.assertEqual(-1, td.days)
self.assertEqual(86398, td.seconds)
self.assertEqual(1, td.microseconds)
message.FromMicroseconds(-1)
td = message.ToTimedelta()
self.assertEqual(-1, td.days)
self.assertEqual(86399, td.seconds)
self.assertEqual(999999, td.microseconds)
converted_message = duration_pb2.Duration()
converted_message.FromTimedelta(td)
self.assertEqual(message, converted_message)
def testInvalidTimestamp(self):
message = timestamp_pb2.Timestamp()
self.assertRaisesRegex(
ValueError, 'Failed to parse timestamp: missing valid timezone offset.',
message.FromJsonString, '')
self.assertRaisesRegex(
ValueError, 'Failed to parse timestamp: invalid trailing data '
'1970-01-01T00:00:01Ztrail.', message.FromJsonString,
'1970-01-01T00:00:01Ztrail')
self.assertRaisesRegex(
ValueError, 'time data \'10000-01-01T00:00:00\' does not match'
' format \'%Y-%m-%dT%H:%M:%S\'', message.FromJsonString,
'10000-01-01T00:00:00.00Z')
self.assertRaisesRegex(
ValueError, 'nanos 0123456789012 more than 9 fractional digits.',
message.FromJsonString, '1970-01-01T00:00:00.0123456789012Z')
self.assertRaisesRegex(
ValueError,
(r'Invalid timezone offset value: \+08.'),
message.FromJsonString,
'1972-01-01T01:00:00.01+08',
)
self.assertRaisesRegex(ValueError, 'year (0 )?is out of range',
message.FromJsonString, '0000-01-01T00:00:00Z')
message.seconds = 253402300800
self.assertRaisesRegex(OverflowError, 'date value out of range',
message.ToJsonString)
def testInvalidDuration(self):
message = duration_pb2.Duration()
self.assertRaisesRegex(ValueError, 'Duration must end with letter "s": 1.',
message.FromJsonString, '1')
self.assertRaisesRegex(ValueError, 'Couldn\'t parse duration: 1...2s.',
message.FromJsonString, '1...2s')
text = '-315576000001.000000000s'
self.assertRaisesRegex(
ValueError,
r'Duration is not valid\: Seconds -315576000001 must be in range'
r' \[-315576000000\, 315576000000\].', message.FromJsonString, text)
text = '315576000001.000000000s'
self.assertRaisesRegex(
ValueError,
r'Duration is not valid\: Seconds 315576000001 must be in range'
r' \[-315576000000\, 315576000000\].', message.FromJsonString, text)
message.seconds = -315576000001
message.nanos = 0
self.assertRaisesRegex(
ValueError,
r'Duration is not valid\: Seconds -315576000001 must be in range'
r' \[-315576000000\, 315576000000\].', message.ToJsonString)
message.seconds = 0
message.nanos = 999999999 + 1
self.assertRaisesRegex(
ValueError, r'Duration is not valid\: Nanos 1000000000 must be in range'
r' \[-999999999\, 999999999\].', message.ToJsonString)
message.seconds = -1
message.nanos = 1
self.assertRaisesRegex(ValueError,
r'Duration is not valid\: Sign mismatch.',
message.ToJsonString)
class StructTest(unittest.TestCase):
def testStruct(self):
struct = struct_pb2.Struct()
self.assertIsInstance(struct, collections_abc.Mapping)
self.assertEqual(0, len(struct))
struct_class = struct.__class__
struct['key1'] = 5
struct['key2'] = 'abc'
struct['key3'] = True
struct.get_or_create_struct('key4')['subkey'] = 11.0
struct_list = struct.get_or_create_list('key5')
self.assertIsInstance(struct_list, collections_abc.Sequence)
struct_list.extend([6, 'seven', True, False, None])
struct_list.add_struct()['subkey2'] = 9
struct['key6'] = {'subkey': {}}
struct['key7'] = [2, False]
self.assertEqual(7, len(struct))
self.assertTrue(isinstance(struct, well_known_types.Struct))
self.assertEqual(5, struct['key1'])
self.assertEqual('abc', struct['key2'])
self.assertIs(True, struct['key3'])
self.assertEqual(11, struct['key4']['subkey'])
inner_struct = struct_class()
inner_struct['subkey2'] = 9
self.assertEqual([6, 'seven', True, False, None, inner_struct],
list(struct['key5'].items()))
self.assertEqual({}, dict(struct['key6']['subkey'].fields))
self.assertEqual([2, False], list(struct['key7'].items()))
serialized = struct.SerializeToString()
struct2 = struct_pb2.Struct()
struct2.ParseFromString(serialized)
self.assertEqual(struct, struct2)
for key, value in struct.items():
self.assertIn(key, struct)
self.assertIn(key, struct2)
self.assertEqual(value, struct2[key])
self.assertEqual(7, len(struct.keys()))
self.assertEqual(7, len(struct.values()))
for key in struct.keys():
self.assertIn(key, struct)
self.assertIn(key, struct2)
self.assertEqual(struct[key], struct2[key])
item = (next(iter(struct.keys())), next(iter(struct.values())))
self.assertEqual(item, next(iter(struct.items())))
self.assertTrue(isinstance(struct2, well_known_types.Struct))
self.assertEqual(5, struct2['key1'])
self.assertEqual('abc', struct2['key2'])
self.assertIs(True, struct2['key3'])
self.assertEqual(11, struct2['key4']['subkey'])
self.assertEqual([6, 'seven', True, False, None, inner_struct],
list(struct2['key5'].items()))
struct_list = struct2['key5']
self.assertEqual(6, struct_list[0])
self.assertEqual('seven', struct_list[1])
self.assertEqual(True, struct_list[2])
self.assertEqual(False, struct_list[3])
self.assertEqual(None, struct_list[4])
self.assertEqual(inner_struct, struct_list[5])
struct_list[1] = 7
self.assertEqual(7, struct_list[1])
struct_list.add_list().extend([1, 'two', True, False, None])
self.assertEqual([1, 'two', True, False, None],
list(struct_list[6].items()))
struct_list.extend([{'nested_struct': 30}, ['nested_list', 99], {}, []])
self.assertEqual(11, len(struct_list.values))
self.assertEqual(30, struct_list[7]['nested_struct'])
self.assertEqual('nested_list', struct_list[8][0])
self.assertEqual(99, struct_list[8][1])
self.assertEqual({}, dict(struct_list[9].fields))
self.assertEqual([], list(struct_list[10].items()))
struct_list[0] = {'replace': 'set'}
struct_list[1] = ['replace', 'set']
self.assertEqual('set', struct_list[0]['replace'])
self.assertEqual(['replace', 'set'], list(struct_list[1].items()))
text_serialized = str(struct)
struct3 = struct_pb2.Struct()
text_format.Merge(text_serialized, struct3)
self.assertEqual(struct, struct3)
struct.get_or_create_struct('key3')['replace'] = 12
self.assertEqual(12, struct['key3']['replace'])
# Tests empty list.
struct.get_or_create_list('empty_list')
empty_list = struct['empty_list']
self.assertEqual([], list(empty_list.items()))
list2 = struct_pb2.ListValue()
list2.add_list()
empty_list = list2[0]
self.assertEqual([], list(empty_list.items()))
# Tests empty struct.
struct.get_or_create_struct('empty_struct')
empty_struct = struct['empty_struct']
self.assertEqual({}, dict(empty_struct.fields))
list2.add_struct()
empty_struct = list2[1]
self.assertEqual({}, dict(empty_struct.fields))
self.assertEqual(9, len(struct))
del struct['key3']
del struct['key4']
self.assertEqual(7, len(struct))
self.assertEqual(6, len(struct['key5']))
del struct['key5'][1]
self.assertEqual(5, len(struct['key5']))
self.assertEqual([6, True, False, None, inner_struct],
list(struct['key5'].items()))
def testStructAssignment(self):
# Tests struct assignment from another struct
s1 = struct_pb2.Struct()
s2 = struct_pb2.Struct()
for value in [1, 'a', [1], ['a'], {'a': 'b'}]:
s1['x'] = value
s2['x'] = s1['x']
self.assertEqual(s1['x'], s2['x'])
def testMergeFrom(self):
struct = struct_pb2.Struct()
struct_class = struct.__class__
dictionary = {
'key1': 5,
'key2': 'abc',
'key3': True,
'key4': {'subkey': 11.0},
'key5': [6, 'seven', True, False, None, {'subkey2': 9}],
'key6': [['nested_list', True]],
'empty_struct': {},
'empty_list': []
}
struct.update(dictionary)
self.assertEqual(5, struct['key1'])
self.assertEqual('abc', struct['key2'])
self.assertIs(True, struct['key3'])
self.assertEqual(11, struct['key4']['subkey'])
inner_struct = struct_class()
inner_struct['subkey2'] = 9
self.assertEqual([6, 'seven', True, False, None, inner_struct],
list(struct['key5'].items()))
self.assertEqual(2, len(struct['key6'][0].values))
self.assertEqual('nested_list', struct['key6'][0][0])
self.assertEqual(True, struct['key6'][0][1])
empty_list = struct['empty_list']
self.assertEqual([], list(empty_list.items()))
empty_struct = struct['empty_struct']
self.assertEqual({}, dict(empty_struct.fields))
# According to documentation: "When parsing from the wire or when merging,
# if there are duplicate map keys the last key seen is used".
duplicate = {
'key4': {'replace': 20},
'key5': [[False, 5]]
}
struct.update(duplicate)
self.assertEqual(1, len(struct['key4'].fields))
self.assertEqual(20, struct['key4']['replace'])
self.assertEqual(1, len(struct['key5'].values))
self.assertEqual(False, struct['key5'][0][0])
self.assertEqual(5, struct['key5'][0][1])
class AnyTest(unittest.TestCase):
def testAnyMessage(self):
# Creates and sets message.
msg = any_test_pb2.TestAny()
msg_descriptor = msg.DESCRIPTOR
all_types = unittest_pb2.TestAllTypes()
all_descriptor = all_types.DESCRIPTOR
all_types.repeated_string.append(u'\u00fc\ua71f')
# Packs to Any.
msg.value.Pack(all_types)
self.assertEqual(msg.value.type_url,
'type.googleapis.com/%s' % all_descriptor.full_name)
self.assertEqual(msg.value.value,
all_types.SerializeToString())
# Tests Is() method.
self.assertTrue(msg.value.Is(all_descriptor))
self.assertFalse(msg.value.Is(msg_descriptor))
# Unpacks Any.
unpacked_message = unittest_pb2.TestAllTypes()
self.assertTrue(msg.value.Unpack(unpacked_message))
self.assertEqual(all_types, unpacked_message)
# Unpacks to different type.
self.assertFalse(msg.value.Unpack(msg))
# Only Any messages have Pack method.
try:
msg.Pack(all_types)
except AttributeError:
pass
else:
raise AttributeError('%s should not have Pack method.' %
msg_descriptor.full_name)
def testUnpackWithNoSlashInTypeUrl(self):
msg = any_test_pb2.TestAny()
all_types = unittest_pb2.TestAllTypes()
all_descriptor = all_types.DESCRIPTOR
msg.value.Pack(all_types)
# Reset type_url to part of type_url after '/'
msg.value.type_url = msg.value.TypeName()
self.assertFalse(msg.value.Is(all_descriptor))
unpacked_message = unittest_pb2.TestAllTypes()
self.assertFalse(msg.value.Unpack(unpacked_message))
def testMessageName(self):
# Creates and sets message.
submessage = any_test_pb2.TestAny()
submessage.int_value = 12345
msg = any_pb2.Any()
msg.Pack(submessage)
self.assertEqual(msg.TypeName(), 'google.protobuf.internal.TestAny')
def testPackWithCustomTypeUrl(self):
submessage = any_test_pb2.TestAny()
submessage.int_value = 12345
msg = any_pb2.Any()
# Pack with a custom type URL prefix.
msg.Pack(submessage, 'type.myservice.com')
self.assertEqual(msg.type_url,
'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name)
# Pack with a custom type URL prefix ending with '/'.
msg.Pack(submessage, 'type.myservice.com/')
self.assertEqual(msg.type_url,
'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name)
# Pack with an empty type URL prefix.
msg.Pack(submessage, '')
self.assertEqual(msg.type_url,
'/%s' % submessage.DESCRIPTOR.full_name)
# Test unpacking the type.
unpacked_message = any_test_pb2.TestAny()
self.assertTrue(msg.Unpack(unpacked_message))
self.assertEqual(submessage, unpacked_message)
def testPackDeterministic(self):
submessage = any_test_pb2.TestAny()
for i in range(10):
submessage.map_value[str(i)] = i * 2
msg = any_pb2.Any()
msg.Pack(submessage, deterministic=True)
serialized = msg.SerializeToString(deterministic=True)
golden = (b'\n4type.googleapis.com/google.protobuf.internal.TestAny\x12F'
b'\x1a\x05\n\x010\x10\x00\x1a\x05\n\x011\x10\x02\x1a\x05\n\x01'
b'2\x10\x04\x1a\x05\n\x013\x10\x06\x1a\x05\n\x014\x10\x08\x1a'
b'\x05\n\x015\x10\n\x1a\x05\n\x016\x10\x0c\x1a\x05\n\x017\x10'
b'\x0e\x1a\x05\n\x018\x10\x10\x1a\x05\n\x019\x10\x12')
self.assertEqual(golden, serialized)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,268 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Constants and static functions to support protocol buffer wire format."""
__author__ = 'robinson@google.com (Will Robinson)'
import struct
from google.protobuf import descriptor
from google.protobuf import message
TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag.
TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7
# These numbers identify the wire type of a protocol buffer value.
# We use the least-significant TAG_TYPE_BITS bits of the varint-encoded
# tag-and-type to store one of these WIRETYPE_* constants.
# These values must match WireType enum in google/protobuf/wire_format.h.
WIRETYPE_VARINT = 0
WIRETYPE_FIXED64 = 1
WIRETYPE_LENGTH_DELIMITED = 2
WIRETYPE_START_GROUP = 3
WIRETYPE_END_GROUP = 4
WIRETYPE_FIXED32 = 5
_WIRETYPE_MAX = 5
# Bounds for various integer types.
INT32_MAX = int((1 << 31) - 1)
INT32_MIN = int(-(1 << 31))
UINT32_MAX = (1 << 32) - 1
INT64_MAX = (1 << 63) - 1
INT64_MIN = -(1 << 63)
UINT64_MAX = (1 << 64) - 1
# "struct" format strings that will encode/decode the specified formats.
FORMAT_UINT32_LITTLE_ENDIAN = '<I'
FORMAT_UINT64_LITTLE_ENDIAN = '<Q'
FORMAT_FLOAT_LITTLE_ENDIAN = '<f'
FORMAT_DOUBLE_LITTLE_ENDIAN = '<d'
# We'll have to provide alternate implementations of AppendLittleEndian*() on
# any architectures where these checks fail.
if struct.calcsize(FORMAT_UINT32_LITTLE_ENDIAN) != 4:
raise AssertionError('Format "I" is not a 32-bit number.')
if struct.calcsize(FORMAT_UINT64_LITTLE_ENDIAN) != 8:
raise AssertionError('Format "Q" is not a 64-bit number.')
def PackTag(field_number, wire_type):
"""Returns an unsigned 32-bit integer that encodes the field number and
wire type information in standard protocol message wire format.
Args:
field_number: Expected to be an integer in the range [1, 1 << 29)
wire_type: One of the WIRETYPE_* constants.
"""
if not 0 <= wire_type <= _WIRETYPE_MAX:
raise message.EncodeError('Unknown wire type: %d' % wire_type)
return (field_number << TAG_TYPE_BITS) | wire_type
def UnpackTag(tag):
"""The inverse of PackTag(). Given an unsigned 32-bit number,
returns a (field_number, wire_type) tuple.
"""
return (tag >> TAG_TYPE_BITS), (tag & TAG_TYPE_MASK)
def ZigZagEncode(value):
"""ZigZag Transform: Encodes signed integers so that they can be
effectively used with varint encoding. See wire_format.h for
more details.
"""
if value >= 0:
return value << 1
return (value << 1) ^ (~0)
def ZigZagDecode(value):
"""Inverse of ZigZagEncode()."""
if not value & 0x1:
return value >> 1
return (value >> 1) ^ (~0)
# The *ByteSize() functions below return the number of bytes required to
# serialize "field number + type" information and then serialize the value.
def Int32ByteSize(field_number, int32):
return Int64ByteSize(field_number, int32)
def Int32ByteSizeNoTag(int32):
return _VarUInt64ByteSizeNoTag(0xffffffffffffffff & int32)
def Int64ByteSize(field_number, int64):
# Have to convert to uint before calling UInt64ByteSize().
return UInt64ByteSize(field_number, 0xffffffffffffffff & int64)
def UInt32ByteSize(field_number, uint32):
return UInt64ByteSize(field_number, uint32)
def UInt64ByteSize(field_number, uint64):
return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64)
def SInt32ByteSize(field_number, int32):
return UInt32ByteSize(field_number, ZigZagEncode(int32))
def SInt64ByteSize(field_number, int64):
return UInt64ByteSize(field_number, ZigZagEncode(int64))
def Fixed32ByteSize(field_number, fixed32):
return TagByteSize(field_number) + 4
def Fixed64ByteSize(field_number, fixed64):
return TagByteSize(field_number) + 8
def SFixed32ByteSize(field_number, sfixed32):
return TagByteSize(field_number) + 4
def SFixed64ByteSize(field_number, sfixed64):
return TagByteSize(field_number) + 8
def FloatByteSize(field_number, flt):
return TagByteSize(field_number) + 4
def DoubleByteSize(field_number, double):
return TagByteSize(field_number) + 8
def BoolByteSize(field_number, b):
return TagByteSize(field_number) + 1
def EnumByteSize(field_number, enum):
return UInt32ByteSize(field_number, enum)
def StringByteSize(field_number, string):
return BytesByteSize(field_number, string.encode('utf-8'))
def BytesByteSize(field_number, b):
return (TagByteSize(field_number)
+ _VarUInt64ByteSizeNoTag(len(b))
+ len(b))
def GroupByteSize(field_number, message):
return (2 * TagByteSize(field_number) # START and END group.
+ message.ByteSize())
def MessageByteSize(field_number, message):
return (TagByteSize(field_number)
+ _VarUInt64ByteSizeNoTag(message.ByteSize())
+ message.ByteSize())
def MessageSetItemByteSize(field_number, msg):
# First compute the sizes of the tags.
# There are 2 tags for the beginning and ending of the repeated group, that
# is field number 1, one with field number 2 (type_id) and one with field
# number 3 (message).
total_size = (2 * TagByteSize(1) + TagByteSize(2) + TagByteSize(3))
# Add the number of bytes for type_id.
total_size += _VarUInt64ByteSizeNoTag(field_number)
message_size = msg.ByteSize()
# The number of bytes for encoding the length of the message.
total_size += _VarUInt64ByteSizeNoTag(message_size)
# The size of the message.
total_size += message_size
return total_size
def TagByteSize(field_number):
"""Returns the bytes required to serialize a tag with this field number."""
# Just pass in type 0, since the type won't affect the tag+type size.
return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0))
# Private helper function for the *ByteSize() functions above.
def _VarUInt64ByteSizeNoTag(uint64):
"""Returns the number of bytes required to serialize a single varint
using boundary value comparisons. (unrolled loop optimization -WPierce)
uint64 must be unsigned.
"""
if uint64 <= 0x7f: return 1
if uint64 <= 0x3fff: return 2
if uint64 <= 0x1fffff: return 3
if uint64 <= 0xfffffff: return 4
if uint64 <= 0x7ffffffff: return 5
if uint64 <= 0x3ffffffffff: return 6
if uint64 <= 0x1ffffffffffff: return 7
if uint64 <= 0xffffffffffffff: return 8
if uint64 <= 0x7fffffffffffffff: return 9
if uint64 > UINT64_MAX:
raise message.EncodeError('Value out of range: %d' % uint64)
return 10
NON_PACKABLE_TYPES = (
descriptor.FieldDescriptor.TYPE_STRING,
descriptor.FieldDescriptor.TYPE_GROUP,
descriptor.FieldDescriptor.TYPE_MESSAGE,
descriptor.FieldDescriptor.TYPE_BYTES
)
def IsTypePackable(field_type):
"""Return true iff packable = true is valid for fields of this type.
Args:
field_type: a FieldDescriptor::Type value.
Returns:
True iff fields of this type are packable.
"""
return field_type not in NON_PACKABLE_TYPES

View File

@@ -0,0 +1,252 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test for google.protobuf.internal.wire_format."""
__author__ = 'robinson@google.com (Will Robinson)'
import unittest
from google.protobuf import message
from google.protobuf.internal import wire_format
class WireFormatTest(unittest.TestCase):
def testPackTag(self):
field_number = 0xabc
tag_type = 2
self.assertEqual((field_number << 3) | tag_type,
wire_format.PackTag(field_number, tag_type))
PackTag = wire_format.PackTag
# Number too high.
self.assertRaises(message.EncodeError, PackTag, field_number, 6)
# Number too low.
self.assertRaises(message.EncodeError, PackTag, field_number, -1)
def testUnpackTag(self):
# Test field numbers that will require various varint sizes.
for expected_field_number in (1, 15, 16, 2047, 2048):
for expected_wire_type in range(6): # Highest-numbered wiretype is 5.
field_number, wire_type = wire_format.UnpackTag(
wire_format.PackTag(expected_field_number, expected_wire_type))
self.assertEqual(expected_field_number, field_number)
self.assertEqual(expected_wire_type, wire_type)
self.assertRaises(TypeError, wire_format.UnpackTag, None)
self.assertRaises(TypeError, wire_format.UnpackTag, 'abc')
self.assertRaises(TypeError, wire_format.UnpackTag, 0.0)
self.assertRaises(TypeError, wire_format.UnpackTag, object())
def testZigZagEncode(self):
Z = wire_format.ZigZagEncode
self.assertEqual(0, Z(0))
self.assertEqual(1, Z(-1))
self.assertEqual(2, Z(1))
self.assertEqual(3, Z(-2))
self.assertEqual(4, Z(2))
self.assertEqual(0xfffffffe, Z(0x7fffffff))
self.assertEqual(0xffffffff, Z(-0x80000000))
self.assertEqual(0xfffffffffffffffe, Z(0x7fffffffffffffff))
self.assertEqual(0xffffffffffffffff, Z(-0x8000000000000000))
self.assertRaises(TypeError, Z, None)
self.assertRaises(TypeError, Z, 'abcd')
self.assertRaises(TypeError, Z, 0.0)
self.assertRaises(TypeError, Z, object())
def testZigZagDecode(self):
Z = wire_format.ZigZagDecode
self.assertEqual(0, Z(0))
self.assertEqual(-1, Z(1))
self.assertEqual(1, Z(2))
self.assertEqual(-2, Z(3))
self.assertEqual(2, Z(4))
self.assertEqual(0x7fffffff, Z(0xfffffffe))
self.assertEqual(-0x80000000, Z(0xffffffff))
self.assertEqual(0x7fffffffffffffff, Z(0xfffffffffffffffe))
self.assertEqual(-0x8000000000000000, Z(0xffffffffffffffff))
self.assertRaises(TypeError, Z, None)
self.assertRaises(TypeError, Z, 'abcd')
self.assertRaises(TypeError, Z, 0.0)
self.assertRaises(TypeError, Z, object())
def NumericByteSizeTestHelper(self, byte_size_fn, value, expected_value_size):
# Use field numbers that cause various byte sizes for the tag information.
for field_number, tag_bytes in ((15, 1), (16, 2), (2047, 2), (2048, 3)):
expected_size = expected_value_size + tag_bytes
actual_size = byte_size_fn(field_number, value)
self.assertEqual(expected_size, actual_size,
'byte_size_fn: %s, field_number: %d, value: %r\n'
'Expected: %d, Actual: %d'% (
byte_size_fn, field_number, value, expected_size, actual_size))
def testByteSizeFunctions(self):
# Test all numeric *ByteSize() functions.
NUMERIC_ARGS = [
# Int32ByteSize().
[wire_format.Int32ByteSize, 0, 1],
[wire_format.Int32ByteSize, 127, 1],
[wire_format.Int32ByteSize, 128, 2],
[wire_format.Int32ByteSize, -1, 10],
# Int64ByteSize().
[wire_format.Int64ByteSize, 0, 1],
[wire_format.Int64ByteSize, 127, 1],
[wire_format.Int64ByteSize, 128, 2],
[wire_format.Int64ByteSize, -1, 10],
# UInt32ByteSize().
[wire_format.UInt32ByteSize, 0, 1],
[wire_format.UInt32ByteSize, 127, 1],
[wire_format.UInt32ByteSize, 128, 2],
[wire_format.UInt32ByteSize, wire_format.UINT32_MAX, 5],
# UInt64ByteSize().
[wire_format.UInt64ByteSize, 0, 1],
[wire_format.UInt64ByteSize, 127, 1],
[wire_format.UInt64ByteSize, 128, 2],
[wire_format.UInt64ByteSize, wire_format.UINT64_MAX, 10],
# SInt32ByteSize().
[wire_format.SInt32ByteSize, 0, 1],
[wire_format.SInt32ByteSize, -1, 1],
[wire_format.SInt32ByteSize, 1, 1],
[wire_format.SInt32ByteSize, -63, 1],
[wire_format.SInt32ByteSize, 63, 1],
[wire_format.SInt32ByteSize, -64, 1],
[wire_format.SInt32ByteSize, 64, 2],
# SInt64ByteSize().
[wire_format.SInt64ByteSize, 0, 1],
[wire_format.SInt64ByteSize, -1, 1],
[wire_format.SInt64ByteSize, 1, 1],
[wire_format.SInt64ByteSize, -63, 1],
[wire_format.SInt64ByteSize, 63, 1],
[wire_format.SInt64ByteSize, -64, 1],
[wire_format.SInt64ByteSize, 64, 2],
# Fixed32ByteSize().
[wire_format.Fixed32ByteSize, 0, 4],
[wire_format.Fixed32ByteSize, wire_format.UINT32_MAX, 4],
# Fixed64ByteSize().
[wire_format.Fixed64ByteSize, 0, 8],
[wire_format.Fixed64ByteSize, wire_format.UINT64_MAX, 8],
# SFixed32ByteSize().
[wire_format.SFixed32ByteSize, 0, 4],
[wire_format.SFixed32ByteSize, wire_format.INT32_MIN, 4],
[wire_format.SFixed32ByteSize, wire_format.INT32_MAX, 4],
# SFixed64ByteSize().
[wire_format.SFixed64ByteSize, 0, 8],
[wire_format.SFixed64ByteSize, wire_format.INT64_MIN, 8],
[wire_format.SFixed64ByteSize, wire_format.INT64_MAX, 8],
# FloatByteSize().
[wire_format.FloatByteSize, 0.0, 4],
[wire_format.FloatByteSize, 1000000000.0, 4],
[wire_format.FloatByteSize, -1000000000.0, 4],
# DoubleByteSize().
[wire_format.DoubleByteSize, 0.0, 8],
[wire_format.DoubleByteSize, 1000000000.0, 8],
[wire_format.DoubleByteSize, -1000000000.0, 8],
# BoolByteSize().
[wire_format.BoolByteSize, False, 1],
[wire_format.BoolByteSize, True, 1],
# EnumByteSize().
[wire_format.EnumByteSize, 0, 1],
[wire_format.EnumByteSize, 127, 1],
[wire_format.EnumByteSize, 128, 2],
[wire_format.EnumByteSize, wire_format.UINT32_MAX, 5],
]
for args in NUMERIC_ARGS:
self.NumericByteSizeTestHelper(*args)
# Test strings and bytes.
for byte_size_fn in (wire_format.StringByteSize, wire_format.BytesByteSize):
# 1 byte for tag, 1 byte for length, 3 bytes for contents.
self.assertEqual(5, byte_size_fn(10, 'abc'))
# 2 bytes for tag, 1 byte for length, 3 bytes for contents.
self.assertEqual(6, byte_size_fn(16, 'abc'))
# 2 bytes for tag, 2 bytes for length, 128 bytes for contents.
self.assertEqual(132, byte_size_fn(16, 'a' * 128))
# Test UTF-8 string byte size calculation.
# 1 byte for tag, 1 byte for length, 8 bytes for content.
self.assertEqual(10, wire_format.StringByteSize(
5, b'\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82'.decode('utf-8')))
class MockMessage(object):
def __init__(self, byte_size):
self.byte_size = byte_size
def ByteSize(self):
return self.byte_size
message_byte_size = 10
mock_message = MockMessage(byte_size=message_byte_size)
# Test groups.
# (2 * 1) bytes for begin and end tags, plus message_byte_size.
self.assertEqual(2 + message_byte_size,
wire_format.GroupByteSize(1, mock_message))
# (2 * 2) bytes for begin and end tags, plus message_byte_size.
self.assertEqual(4 + message_byte_size,
wire_format.GroupByteSize(16, mock_message))
# Test messages.
# 1 byte for tag, plus 1 byte for length, plus contents.
self.assertEqual(2 + mock_message.byte_size,
wire_format.MessageByteSize(1, mock_message))
# 2 bytes for tag, plus 1 byte for length, plus contents.
self.assertEqual(3 + mock_message.byte_size,
wire_format.MessageByteSize(16, mock_message))
# 2 bytes for tag, plus 2 bytes for length, plus contents.
mock_message.byte_size = 128
self.assertEqual(4 + mock_message.byte_size,
wire_format.MessageByteSize(16, mock_message))
# Test message set item byte size.
# 4 bytes for tags, plus 1 byte for length, plus 1 byte for type_id,
# plus contents.
mock_message.byte_size = 10
self.assertEqual(mock_message.byte_size + 6,
wire_format.MessageSetItemByteSize(1, mock_message))
# 4 bytes for tags, plus 2 bytes for length, plus 1 byte for type_id,
# plus contents.
mock_message.byte_size = 128
self.assertEqual(mock_message.byte_size + 7,
wire_format.MessageSetItemByteSize(1, mock_message))
# 4 bytes for tags, plus 2 bytes for length, plus 2 byte for type_id,
# plus contents.
self.assertEqual(mock_message.byte_size + 8,
wire_format.MessageSetItemByteSize(128, mock_message))
# Too-long varint.
self.assertRaises(message.EncodeError,
wire_format.UInt64ByteSize, 1, 1 << 128)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,913 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Contains routines for printing protocol messages in JSON format.
Simple usage example:
# Create a proto object and serialize it to a json format string.
message = my_proto_pb2.MyMessage(foo='bar')
json_string = json_format.MessageToJson(message)
# Parse a json format string to proto object.
message = json_format.Parse(json_string, my_proto_pb2.MyMessage())
"""
__author__ = 'jieluo@google.com (Jie Luo)'
import base64
from collections import OrderedDict
import json
import math
from operator import methodcaller
import re
import sys
from google.protobuf.internal import type_checkers
from google.protobuf import descriptor
from google.protobuf import symbol_database
_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
_INT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT32,
descriptor.FieldDescriptor.CPPTYPE_UINT32,
descriptor.FieldDescriptor.CPPTYPE_INT64,
descriptor.FieldDescriptor.CPPTYPE_UINT64])
_INT64_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT64,
descriptor.FieldDescriptor.CPPTYPE_UINT64])
_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT,
descriptor.FieldDescriptor.CPPTYPE_DOUBLE])
_INFINITY = 'Infinity'
_NEG_INFINITY = '-Infinity'
_NAN = 'NaN'
_UNPAIRED_SURROGATE_PATTERN = re.compile(
u'[\ud800-\udbff](?![\udc00-\udfff])|(?<![\ud800-\udbff])[\udc00-\udfff]')
_VALID_EXTENSION_NAME = re.compile(r'\[[a-zA-Z0-9\._]*\]$')
class Error(Exception):
"""Top-level module error for json_format."""
class SerializeToJsonError(Error):
"""Thrown if serialization to JSON fails."""
class ParseError(Error):
"""Thrown in case of parsing error."""
def MessageToJson(
message,
including_default_value_fields=False,
preserving_proto_field_name=False,
indent=2,
sort_keys=False,
use_integers_for_enums=False,
descriptor_pool=None,
float_precision=None,
ensure_ascii=True):
"""Converts protobuf message to JSON format.
Args:
message: The protocol buffers message instance to serialize.
including_default_value_fields: If True, singular primitive fields,
repeated fields, and map fields will always be serialized. If
False, only serialize non-empty fields. Singular message fields
and oneof fields are not affected by this option.
preserving_proto_field_name: If True, use the original proto field
names as defined in the .proto file. If False, convert the field
names to lowerCamelCase.
indent: The JSON object will be pretty-printed with this indent level.
An indent level of 0 or negative will only insert newlines. If the
indent level is None, no newlines will be inserted.
sort_keys: If True, then the output will be sorted by field names.
use_integers_for_enums: If true, print integers instead of enum names.
descriptor_pool: A Descriptor Pool for resolving types. If None use the
default.
float_precision: If set, use this to specify float field valid digits.
ensure_ascii: If True, strings with non-ASCII characters are escaped.
If False, Unicode strings are returned unchanged.
Returns:
A string containing the JSON formatted protocol buffer message.
"""
printer = _Printer(
including_default_value_fields,
preserving_proto_field_name,
use_integers_for_enums,
descriptor_pool,
float_precision=float_precision)
return printer.ToJsonString(message, indent, sort_keys, ensure_ascii)
def MessageToDict(
message,
including_default_value_fields=False,
preserving_proto_field_name=False,
use_integers_for_enums=False,
descriptor_pool=None,
float_precision=None):
"""Converts protobuf message to a dictionary.
When the dictionary is encoded to JSON, it conforms to proto3 JSON spec.
Args:
message: The protocol buffers message instance to serialize.
including_default_value_fields: If True, singular primitive fields,
repeated fields, and map fields will always be serialized. If
False, only serialize non-empty fields. Singular message fields
and oneof fields are not affected by this option.
preserving_proto_field_name: If True, use the original proto field
names as defined in the .proto file. If False, convert the field
names to lowerCamelCase.
use_integers_for_enums: If true, print integers instead of enum names.
descriptor_pool: A Descriptor Pool for resolving types. If None use the
default.
float_precision: If set, use this to specify float field valid digits.
Returns:
A dict representation of the protocol buffer message.
"""
printer = _Printer(
including_default_value_fields,
preserving_proto_field_name,
use_integers_for_enums,
descriptor_pool,
float_precision=float_precision)
# pylint: disable=protected-access
return printer._MessageToJsonObject(message)
def _IsMapEntry(field):
return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
field.message_type.has_options and
field.message_type.GetOptions().map_entry)
class _Printer(object):
"""JSON format printer for protocol message."""
def __init__(
self,
including_default_value_fields=False,
preserving_proto_field_name=False,
use_integers_for_enums=False,
descriptor_pool=None,
float_precision=None):
self.including_default_value_fields = including_default_value_fields
self.preserving_proto_field_name = preserving_proto_field_name
self.use_integers_for_enums = use_integers_for_enums
self.descriptor_pool = descriptor_pool
if float_precision:
self.float_format = '.{}g'.format(float_precision)
else:
self.float_format = None
def ToJsonString(self, message, indent, sort_keys, ensure_ascii):
js = self._MessageToJsonObject(message)
return json.dumps(
js, indent=indent, sort_keys=sort_keys, ensure_ascii=ensure_ascii)
def _MessageToJsonObject(self, message):
"""Converts message to an object according to Proto3 JSON Specification."""
message_descriptor = message.DESCRIPTOR
full_name = message_descriptor.full_name
if _IsWrapperMessage(message_descriptor):
return self._WrapperMessageToJsonObject(message)
if full_name in _WKTJSONMETHODS:
return methodcaller(_WKTJSONMETHODS[full_name][0], message)(self)
js = {}
return self._RegularMessageToJsonObject(message, js)
def _RegularMessageToJsonObject(self, message, js):
"""Converts normal message according to Proto3 JSON Specification."""
fields = message.ListFields()
try:
for field, value in fields:
if self.preserving_proto_field_name:
name = field.name
else:
name = field.json_name
if _IsMapEntry(field):
# Convert a map field.
v_field = field.message_type.fields_by_name['value']
js_map = {}
for key in value:
if isinstance(key, bool):
if key:
recorded_key = 'true'
else:
recorded_key = 'false'
else:
recorded_key = str(key)
js_map[recorded_key] = self._FieldToJsonObject(
v_field, value[key])
js[name] = js_map
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
# Convert a repeated field.
js[name] = [self._FieldToJsonObject(field, k)
for k in value]
elif field.is_extension:
name = '[%s]' % field.full_name
js[name] = self._FieldToJsonObject(field, value)
else:
js[name] = self._FieldToJsonObject(field, value)
# Serialize default value if including_default_value_fields is True.
if self.including_default_value_fields:
message_descriptor = message.DESCRIPTOR
for field in message_descriptor.fields:
# Singular message fields and oneof fields will not be affected.
if ((field.label != descriptor.FieldDescriptor.LABEL_REPEATED and
field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or
field.containing_oneof):
continue
if self.preserving_proto_field_name:
name = field.name
else:
name = field.json_name
if name in js:
# Skip the field which has been serialized already.
continue
if _IsMapEntry(field):
js[name] = {}
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
js[name] = []
else:
js[name] = self._FieldToJsonObject(field, field.default_value)
except ValueError as e:
raise SerializeToJsonError(
'Failed to serialize {0} field: {1}.'.format(field.name, e))
return js
def _FieldToJsonObject(self, field, value):
"""Converts field value according to Proto3 JSON Specification."""
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
return self._MessageToJsonObject(value)
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
if self.use_integers_for_enums:
return value
if field.enum_type.full_name == 'google.protobuf.NullValue':
return None
enum_value = field.enum_type.values_by_number.get(value, None)
if enum_value is not None:
return enum_value.name
else:
if field.file.syntax == 'proto3':
return value
raise SerializeToJsonError('Enum field contains an integer value '
'which can not mapped to an enum value.')
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
# Use base64 Data encoding for bytes
return base64.b64encode(value).decode('utf-8')
else:
return value
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
return bool(value)
elif field.cpp_type in _INT64_TYPES:
return str(value)
elif field.cpp_type in _FLOAT_TYPES:
if math.isinf(value):
if value < 0.0:
return _NEG_INFINITY
else:
return _INFINITY
if math.isnan(value):
return _NAN
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_FLOAT:
if self.float_format:
return float(format(value, self.float_format))
else:
return type_checkers.ToShortestFloat(value)
return value
def _AnyMessageToJsonObject(self, message):
"""Converts Any message according to Proto3 JSON Specification."""
if not message.ListFields():
return {}
# Must print @type first, use OrderedDict instead of {}
js = OrderedDict()
type_url = message.type_url
js['@type'] = type_url
sub_message = _CreateMessageFromTypeUrl(type_url, self.descriptor_pool)
sub_message.ParseFromString(message.value)
message_descriptor = sub_message.DESCRIPTOR
full_name = message_descriptor.full_name
if _IsWrapperMessage(message_descriptor):
js['value'] = self._WrapperMessageToJsonObject(sub_message)
return js
if full_name in _WKTJSONMETHODS:
js['value'] = methodcaller(_WKTJSONMETHODS[full_name][0],
sub_message)(self)
return js
return self._RegularMessageToJsonObject(sub_message, js)
def _GenericMessageToJsonObject(self, message):
"""Converts message according to Proto3 JSON Specification."""
# Duration, Timestamp and FieldMask have ToJsonString method to do the
# convert. Users can also call the method directly.
return message.ToJsonString()
def _ValueMessageToJsonObject(self, message):
"""Converts Value message according to Proto3 JSON Specification."""
which = message.WhichOneof('kind')
# If the Value message is not set treat as null_value when serialize
# to JSON. The parse back result will be different from original message.
if which is None or which == 'null_value':
return None
if which == 'list_value':
return self._ListValueMessageToJsonObject(message.list_value)
if which == 'struct_value':
value = message.struct_value
else:
value = getattr(message, which)
oneof_descriptor = message.DESCRIPTOR.fields_by_name[which]
return self._FieldToJsonObject(oneof_descriptor, value)
def _ListValueMessageToJsonObject(self, message):
"""Converts ListValue message according to Proto3 JSON Specification."""
return [self._ValueMessageToJsonObject(value)
for value in message.values]
def _StructMessageToJsonObject(self, message):
"""Converts Struct message according to Proto3 JSON Specification."""
fields = message.fields
ret = {}
for key in fields:
ret[key] = self._ValueMessageToJsonObject(fields[key])
return ret
def _WrapperMessageToJsonObject(self, message):
return self._FieldToJsonObject(
message.DESCRIPTOR.fields_by_name['value'], message.value)
def _IsWrapperMessage(message_descriptor):
return message_descriptor.file.name == 'google/protobuf/wrappers.proto'
def _DuplicateChecker(js):
result = {}
for name, value in js:
if name in result:
raise ParseError('Failed to load JSON: duplicate key {0}.'.format(name))
result[name] = value
return result
def _CreateMessageFromTypeUrl(type_url, descriptor_pool):
"""Creates a message from a type URL."""
db = symbol_database.Default()
pool = db.pool if descriptor_pool is None else descriptor_pool
type_name = type_url.split('/')[-1]
try:
message_descriptor = pool.FindMessageTypeByName(type_name)
except KeyError:
raise TypeError(
'Can not find message descriptor by type_url: {0}'.format(type_url))
message_class = db.GetPrototype(message_descriptor)
return message_class()
def Parse(text,
message,
ignore_unknown_fields=False,
descriptor_pool=None,
max_recursion_depth=100):
"""Parses a JSON representation of a protocol message into a message.
Args:
text: Message JSON representation.
message: A protocol buffer message to merge into.
ignore_unknown_fields: If True, do not raise errors for unknown fields.
descriptor_pool: A Descriptor Pool for resolving types. If None use the
default.
max_recursion_depth: max recursion depth of JSON message to be
deserialized. JSON messages over this depth will fail to be
deserialized. Default value is 100.
Returns:
The same message passed as argument.
Raises::
ParseError: On JSON parsing problems.
"""
if not isinstance(text, str):
text = text.decode('utf-8')
try:
js = json.loads(text, object_pairs_hook=_DuplicateChecker)
except ValueError as e:
raise ParseError('Failed to load JSON: {0}.'.format(str(e)))
return ParseDict(js, message, ignore_unknown_fields, descriptor_pool,
max_recursion_depth)
def ParseDict(js_dict,
message,
ignore_unknown_fields=False,
descriptor_pool=None,
max_recursion_depth=100):
"""Parses a JSON dictionary representation into a message.
Args:
js_dict: Dict representation of a JSON message.
message: A protocol buffer message to merge into.
ignore_unknown_fields: If True, do not raise errors for unknown fields.
descriptor_pool: A Descriptor Pool for resolving types. If None use the
default.
max_recursion_depth: max recursion depth of JSON message to be
deserialized. JSON messages over this depth will fail to be
deserialized. Default value is 100.
Returns:
The same message passed as argument.
"""
parser = _Parser(ignore_unknown_fields, descriptor_pool, max_recursion_depth)
parser.ConvertMessage(js_dict, message, '')
return message
_INT_OR_FLOAT = (int, float)
class _Parser(object):
"""JSON format parser for protocol message."""
def __init__(self, ignore_unknown_fields, descriptor_pool,
max_recursion_depth):
self.ignore_unknown_fields = ignore_unknown_fields
self.descriptor_pool = descriptor_pool
self.max_recursion_depth = max_recursion_depth
self.recursion_depth = 0
def ConvertMessage(self, value, message, path):
"""Convert a JSON object into a message.
Args:
value: A JSON object.
message: A WKT or regular protocol message to record the data.
path: parent path to log parse error info.
Raises:
ParseError: In case of convert problems.
"""
self.recursion_depth += 1
if self.recursion_depth > self.max_recursion_depth:
raise ParseError('Message too deep. Max recursion depth is {0}'.format(
self.max_recursion_depth))
message_descriptor = message.DESCRIPTOR
full_name = message_descriptor.full_name
if not path:
path = message_descriptor.name
if _IsWrapperMessage(message_descriptor):
self._ConvertWrapperMessage(value, message, path)
elif full_name in _WKTJSONMETHODS:
methodcaller(_WKTJSONMETHODS[full_name][1], value, message, path)(self)
else:
self._ConvertFieldValuePair(value, message, path)
self.recursion_depth -= 1
def _ConvertFieldValuePair(self, js, message, path):
"""Convert field value pairs into regular message.
Args:
js: A JSON object to convert the field value pairs.
message: A regular protocol message to record the data.
path: parent path to log parse error info.
Raises:
ParseError: In case of problems converting.
"""
names = []
message_descriptor = message.DESCRIPTOR
fields_by_json_name = dict((f.json_name, f)
for f in message_descriptor.fields)
for name in js:
try:
field = fields_by_json_name.get(name, None)
if not field:
field = message_descriptor.fields_by_name.get(name, None)
if not field and _VALID_EXTENSION_NAME.match(name):
if not message_descriptor.is_extendable:
raise ParseError(
'Message type {0} does not have extensions at {1}'.format(
message_descriptor.full_name, path))
identifier = name[1:-1] # strip [] brackets
# pylint: disable=protected-access
field = message.Extensions._FindExtensionByName(identifier)
# pylint: enable=protected-access
if not field:
# Try looking for extension by the message type name, dropping the
# field name following the final . separator in full_name.
identifier = '.'.join(identifier.split('.')[:-1])
# pylint: disable=protected-access
field = message.Extensions._FindExtensionByName(identifier)
# pylint: enable=protected-access
if not field:
if self.ignore_unknown_fields:
continue
raise ParseError(
('Message type "{0}" has no field named "{1}" at "{2}".\n'
' Available Fields(except extensions): "{3}"').format(
message_descriptor.full_name, name, path,
[f.json_name for f in message_descriptor.fields]))
if name in names:
raise ParseError('Message type "{0}" should not have multiple '
'"{1}" fields at "{2}".'.format(
message.DESCRIPTOR.full_name, name, path))
names.append(name)
value = js[name]
# Check no other oneof field is parsed.
if field.containing_oneof is not None and value is not None:
oneof_name = field.containing_oneof.name
if oneof_name in names:
raise ParseError('Message type "{0}" should not have multiple '
'"{1}" oneof fields at "{2}".'.format(
message.DESCRIPTOR.full_name, oneof_name,
path))
names.append(oneof_name)
if value is None:
if (field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE
and field.message_type.full_name == 'google.protobuf.Value'):
sub_message = getattr(message, field.name)
sub_message.null_value = 0
elif (field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM
and field.enum_type.full_name == 'google.protobuf.NullValue'):
setattr(message, field.name, 0)
else:
message.ClearField(field.name)
continue
# Parse field value.
if _IsMapEntry(field):
message.ClearField(field.name)
self._ConvertMapFieldValue(value, message, field,
'{0}.{1}'.format(path, name))
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
message.ClearField(field.name)
if not isinstance(value, list):
raise ParseError('repeated field {0} must be in [] which is '
'{1} at {2}'.format(name, value, path))
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
# Repeated message field.
for index, item in enumerate(value):
sub_message = getattr(message, field.name).add()
# None is a null_value in Value.
if (item is None and
sub_message.DESCRIPTOR.full_name != 'google.protobuf.Value'):
raise ParseError('null is not allowed to be used as an element'
' in a repeated field at {0}.{1}[{2}]'.format(
path, name, index))
self.ConvertMessage(item, sub_message,
'{0}.{1}[{2}]'.format(path, name, index))
else:
# Repeated scalar field.
for index, item in enumerate(value):
if item is None:
raise ParseError('null is not allowed to be used as an element'
' in a repeated field at {0}.{1}[{2}]'.format(
path, name, index))
getattr(message, field.name).append(
_ConvertScalarFieldValue(
item, field, '{0}.{1}[{2}]'.format(path, name, index)))
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
if field.is_extension:
sub_message = message.Extensions[field]
else:
sub_message = getattr(message, field.name)
sub_message.SetInParent()
self.ConvertMessage(value, sub_message, '{0}.{1}'.format(path, name))
else:
if field.is_extension:
message.Extensions[field] = _ConvertScalarFieldValue(
value, field, '{0}.{1}'.format(path, name))
else:
setattr(
message, field.name,
_ConvertScalarFieldValue(value, field,
'{0}.{1}'.format(path, name)))
except ParseError as e:
if field and field.containing_oneof is None:
raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
else:
raise ParseError(str(e))
except ValueError as e:
raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
except TypeError as e:
raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
def _ConvertAnyMessage(self, value, message, path):
"""Convert a JSON representation into Any message."""
if isinstance(value, dict) and not value:
return
try:
type_url = value['@type']
except KeyError:
raise ParseError(
'@type is missing when parsing any message at {0}'.format(path))
try:
sub_message = _CreateMessageFromTypeUrl(type_url, self.descriptor_pool)
except TypeError as e:
raise ParseError('{0} at {1}'.format(e, path))
message_descriptor = sub_message.DESCRIPTOR
full_name = message_descriptor.full_name
if _IsWrapperMessage(message_descriptor):
self._ConvertWrapperMessage(value['value'], sub_message,
'{0}.value'.format(path))
elif full_name in _WKTJSONMETHODS:
methodcaller(_WKTJSONMETHODS[full_name][1], value['value'], sub_message,
'{0}.value'.format(path))(
self)
else:
del value['@type']
self._ConvertFieldValuePair(value, sub_message, path)
value['@type'] = type_url
# Sets Any message
message.value = sub_message.SerializeToString()
message.type_url = type_url
def _ConvertGenericMessage(self, value, message, path):
"""Convert a JSON representation into message with FromJsonString."""
# Duration, Timestamp, FieldMask have a FromJsonString method to do the
# conversion. Users can also call the method directly.
try:
message.FromJsonString(value)
except ValueError as e:
raise ParseError('{0} at {1}'.format(e, path))
def _ConvertValueMessage(self, value, message, path):
"""Convert a JSON representation into Value message."""
if isinstance(value, dict):
self._ConvertStructMessage(value, message.struct_value, path)
elif isinstance(value, list):
self._ConvertListValueMessage(value, message.list_value, path)
elif value is None:
message.null_value = 0
elif isinstance(value, bool):
message.bool_value = value
elif isinstance(value, str):
message.string_value = value
elif isinstance(value, _INT_OR_FLOAT):
message.number_value = value
else:
raise ParseError('Value {0} has unexpected type {1} at {2}'.format(
value, type(value), path))
def _ConvertListValueMessage(self, value, message, path):
"""Convert a JSON representation into ListValue message."""
if not isinstance(value, list):
raise ParseError('ListValue must be in [] which is {0} at {1}'.format(
value, path))
message.ClearField('values')
for index, item in enumerate(value):
self._ConvertValueMessage(item, message.values.add(),
'{0}[{1}]'.format(path, index))
def _ConvertStructMessage(self, value, message, path):
"""Convert a JSON representation into Struct message."""
if not isinstance(value, dict):
raise ParseError('Struct must be in a dict which is {0} at {1}'.format(
value, path))
# Clear will mark the struct as modified so it will be created even if
# there are no values.
message.Clear()
for key in value:
self._ConvertValueMessage(value[key], message.fields[key],
'{0}.{1}'.format(path, key))
return
def _ConvertWrapperMessage(self, value, message, path):
"""Convert a JSON representation into Wrapper message."""
field = message.DESCRIPTOR.fields_by_name['value']
setattr(
message, 'value',
_ConvertScalarFieldValue(value, field, path='{0}.value'.format(path)))
def _ConvertMapFieldValue(self, value, message, field, path):
"""Convert map field value for a message map field.
Args:
value: A JSON object to convert the map field value.
message: A protocol message to record the converted data.
field: The descriptor of the map field to be converted.
path: parent path to log parse error info.
Raises:
ParseError: In case of convert problems.
"""
if not isinstance(value, dict):
raise ParseError(
'Map field {0} must be in a dict which is {1} at {2}'.format(
field.name, value, path))
key_field = field.message_type.fields_by_name['key']
value_field = field.message_type.fields_by_name['value']
for key in value:
key_value = _ConvertScalarFieldValue(key, key_field,
'{0}.key'.format(path), True)
if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
self.ConvertMessage(value[key],
getattr(message, field.name)[key_value],
'{0}[{1}]'.format(path, key_value))
else:
getattr(message, field.name)[key_value] = _ConvertScalarFieldValue(
value[key], value_field, path='{0}[{1}]'.format(path, key_value))
def _ConvertScalarFieldValue(value, field, path, require_str=False):
"""Convert a single scalar field value.
Args:
value: A scalar value to convert the scalar field value.
field: The descriptor of the field to convert.
path: parent path to log parse error info.
require_str: If True, the field value must be a str.
Returns:
The converted scalar field value
Raises:
ParseError: In case of convert problems.
"""
try:
if field.cpp_type in _INT_TYPES:
return _ConvertInteger(value)
elif field.cpp_type in _FLOAT_TYPES:
return _ConvertFloat(value, field)
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
return _ConvertBool(value, require_str)
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
if isinstance(value, str):
encoded = value.encode('utf-8')
else:
encoded = value
# Add extra padding '='
padded_value = encoded + b'=' * (4 - len(encoded) % 4)
return base64.urlsafe_b64decode(padded_value)
else:
# Checking for unpaired surrogates appears to be unreliable,
# depending on the specific Python version, so we check manually.
if _UNPAIRED_SURROGATE_PATTERN.search(value):
raise ParseError('Unpaired surrogate')
return value
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
# Convert an enum value.
enum_value = field.enum_type.values_by_name.get(value, None)
if enum_value is None:
try:
number = int(value)
enum_value = field.enum_type.values_by_number.get(number, None)
except ValueError:
raise ParseError('Invalid enum value {0} for enum type {1}'.format(
value, field.enum_type.full_name))
if enum_value is None:
if field.file.syntax == 'proto3':
# Proto3 accepts unknown enums.
return number
raise ParseError('Invalid enum value {0} for enum type {1}'.format(
value, field.enum_type.full_name))
return enum_value.number
except ParseError as e:
raise ParseError('{0} at {1}'.format(e, path))
def _ConvertInteger(value):
"""Convert an integer.
Args:
value: A scalar value to convert.
Returns:
The integer value.
Raises:
ParseError: If an integer couldn't be consumed.
"""
if isinstance(value, float) and not value.is_integer():
raise ParseError('Couldn\'t parse integer: {0}'.format(value))
if isinstance(value, str) and value.find(' ') != -1:
raise ParseError('Couldn\'t parse integer: "{0}"'.format(value))
if isinstance(value, bool):
raise ParseError('Bool value {0} is not acceptable for '
'integer field'.format(value))
return int(value)
def _ConvertFloat(value, field):
"""Convert an floating point number."""
if isinstance(value, float):
if math.isnan(value):
raise ParseError('Couldn\'t parse NaN, use quoted "NaN" instead')
if math.isinf(value):
if value > 0:
raise ParseError('Couldn\'t parse Infinity or value too large, '
'use quoted "Infinity" instead')
else:
raise ParseError('Couldn\'t parse -Infinity or value too small, '
'use quoted "-Infinity" instead')
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_FLOAT:
# pylint: disable=protected-access
if value > type_checkers._FLOAT_MAX:
raise ParseError('Float value too large')
# pylint: disable=protected-access
if value < type_checkers._FLOAT_MIN:
raise ParseError('Float value too small')
if value == 'nan':
raise ParseError('Couldn\'t parse float "nan", use "NaN" instead')
try:
# Assume Python compatible syntax.
return float(value)
except ValueError:
# Check alternative spellings.
if value == _NEG_INFINITY:
return float('-inf')
elif value == _INFINITY:
return float('inf')
elif value == _NAN:
return float('nan')
else:
raise ParseError('Couldn\'t parse float: {0}'.format(value))
def _ConvertBool(value, require_str):
"""Convert a boolean value.
Args:
value: A scalar value to convert.
require_str: If True, value must be a str.
Returns:
The bool parsed.
Raises:
ParseError: If a boolean value couldn't be consumed.
"""
if require_str:
if value == 'true':
return True
elif value == 'false':
return False
else:
raise ParseError('Expected "true" or "false", not {0}'.format(value))
if not isinstance(value, bool):
raise ParseError('Expected true or false without quotes')
return value
_WKTJSONMETHODS = {
'google.protobuf.Any': ['_AnyMessageToJsonObject',
'_ConvertAnyMessage'],
'google.protobuf.Duration': ['_GenericMessageToJsonObject',
'_ConvertGenericMessage'],
'google.protobuf.FieldMask': ['_GenericMessageToJsonObject',
'_ConvertGenericMessage'],
'google.protobuf.ListValue': ['_ListValueMessageToJsonObject',
'_ConvertListValueMessage'],
'google.protobuf.Struct': ['_StructMessageToJsonObject',
'_ConvertStructMessage'],
'google.protobuf.Timestamp': ['_GenericMessageToJsonObject',
'_ConvertGenericMessage'],
'google.protobuf.Value': ['_ValueMessageToJsonObject',
'_ConvertValueMessage']
}

View File

@@ -0,0 +1,425 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# TODO(robinson): We should just make these methods all "pure-virtual" and move
# all implementation out, into reflection.py for now.
"""Contains an abstract base class for protocol messages."""
__author__ = 'robinson@google.com (Will Robinson)'
class Error(Exception):
"""Base error type for this module."""
pass
class DecodeError(Error):
"""Exception raised when deserializing messages."""
pass
class EncodeError(Error):
"""Exception raised when serializing messages."""
pass
class Message(object):
"""Abstract base class for protocol messages.
Protocol message classes are almost always generated by the protocol
compiler. These generated types subclass Message and implement the methods
shown below.
"""
# TODO(robinson): Link to an HTML document here.
# TODO(robinson): Document that instances of this class will also
# have an Extensions attribute with __getitem__ and __setitem__.
# Again, not sure how to best convey this.
# TODO(robinson): Document that the class must also have a static
# RegisterExtension(extension_field) method.
# Not sure how to best express at this point.
# TODO(robinson): Document these fields and methods.
__slots__ = []
#: The :class:`google.protobuf.Descriptor`
# for this message type.
DESCRIPTOR = None
def __deepcopy__(self, memo=None):
clone = type(self)()
clone.MergeFrom(self)
return clone
def __eq__(self, other_msg):
"""Recursively compares two messages by value and structure."""
raise NotImplementedError
def __ne__(self, other_msg):
# Can't just say self != other_msg, since that would infinitely recurse. :)
return not self == other_msg
def __hash__(self):
raise TypeError('unhashable object')
def __str__(self):
"""Outputs a human-readable representation of the message."""
raise NotImplementedError
def __unicode__(self):
"""Outputs a human-readable representation of the message."""
raise NotImplementedError
def MergeFrom(self, other_msg):
"""Merges the contents of the specified message into current message.
This method merges the contents of the specified message into the current
message. Singular fields that are set in the specified message overwrite
the corresponding fields in the current message. Repeated fields are
appended. Singular sub-messages and groups are recursively merged.
Args:
other_msg (Message): A message to merge into the current message.
"""
raise NotImplementedError
def CopyFrom(self, other_msg):
"""Copies the content of the specified message into the current message.
The method clears the current message and then merges the specified
message using MergeFrom.
Args:
other_msg (Message): A message to copy into the current one.
"""
if self is other_msg:
return
self.Clear()
self.MergeFrom(other_msg)
def Clear(self):
"""Clears all data that was set in the message."""
raise NotImplementedError
def SetInParent(self):
"""Mark this as present in the parent.
This normally happens automatically when you assign a field of a
sub-message, but sometimes you want to make the sub-message
present while keeping it empty. If you find yourself using this,
you may want to reconsider your design.
"""
raise NotImplementedError
def IsInitialized(self):
"""Checks if the message is initialized.
Returns:
bool: The method returns True if the message is initialized (i.e. all of
its required fields are set).
"""
raise NotImplementedError
# TODO(robinson): MergeFromString() should probably return None and be
# implemented in terms of a helper that returns the # of bytes read. Our
# deserialization routines would use the helper when recursively
# deserializing, but the end user would almost always just want the no-return
# MergeFromString().
def MergeFromString(self, serialized):
"""Merges serialized protocol buffer data into this message.
When we find a field in `serialized` that is already present
in this message:
- If it's a "repeated" field, we append to the end of our list.
- Else, if it's a scalar, we overwrite our field.
- Else, (it's a nonrepeated composite), we recursively merge
into the existing composite.
Args:
serialized (bytes): Any object that allows us to call
``memoryview(serialized)`` to access a string of bytes using the
buffer interface.
Returns:
int: The number of bytes read from `serialized`.
For non-group messages, this will always be `len(serialized)`,
but for messages which are actually groups, this will
generally be less than `len(serialized)`, since we must
stop when we reach an ``END_GROUP`` tag. Note that if
we *do* stop because of an ``END_GROUP`` tag, the number
of bytes returned does not include the bytes
for the ``END_GROUP`` tag information.
Raises:
DecodeError: if the input cannot be parsed.
"""
# TODO(robinson): Document handling of unknown fields.
# TODO(robinson): When we switch to a helper, this will return None.
raise NotImplementedError
def ParseFromString(self, serialized):
"""Parse serialized protocol buffer data into this message.
Like :func:`MergeFromString()`, except we clear the object first.
Raises:
message.DecodeError if the input cannot be parsed.
"""
self.Clear()
return self.MergeFromString(serialized)
def SerializeToString(self, **kwargs):
"""Serializes the protocol message to a binary string.
Keyword Args:
deterministic (bool): If true, requests deterministic serialization
of the protobuf, with predictable ordering of map keys.
Returns:
A binary string representation of the message if all of the required
fields in the message are set (i.e. the message is initialized).
Raises:
EncodeError: if the message isn't initialized (see :func:`IsInitialized`).
"""
raise NotImplementedError
def SerializePartialToString(self, **kwargs):
"""Serializes the protocol message to a binary string.
This method is similar to SerializeToString but doesn't check if the
message is initialized.
Keyword Args:
deterministic (bool): If true, requests deterministic serialization
of the protobuf, with predictable ordering of map keys.
Returns:
bytes: A serialized representation of the partial message.
"""
raise NotImplementedError
# TODO(robinson): Decide whether we like these better
# than auto-generated has_foo() and clear_foo() methods
# on the instances themselves. This way is less consistent
# with C++, but it makes reflection-type access easier and
# reduces the number of magically autogenerated things.
#
# TODO(robinson): Be sure to document (and test) exactly
# which field names are accepted here. Are we case-sensitive?
# What do we do with fields that share names with Python keywords
# like 'lambda' and 'yield'?
#
# nnorwitz says:
# """
# Typically (in python), an underscore is appended to names that are
# keywords. So they would become lambda_ or yield_.
# """
def ListFields(self):
"""Returns a list of (FieldDescriptor, value) tuples for present fields.
A message field is non-empty if HasField() would return true. A singular
primitive field is non-empty if HasField() would return true in proto2 or it
is non zero in proto3. A repeated field is non-empty if it contains at least
one element. The fields are ordered by field number.
Returns:
list[tuple(FieldDescriptor, value)]: field descriptors and values
for all fields in the message which are not empty. The values vary by
field type.
"""
raise NotImplementedError
def HasField(self, field_name):
"""Checks if a certain field is set for the message.
For a oneof group, checks if any field inside is set. Note that if the
field_name is not defined in the message descriptor, :exc:`ValueError` will
be raised.
Args:
field_name (str): The name of the field to check for presence.
Returns:
bool: Whether a value has been set for the named field.
Raises:
ValueError: if the `field_name` is not a member of this message.
"""
raise NotImplementedError
def ClearField(self, field_name):
"""Clears the contents of a given field.
Inside a oneof group, clears the field set. If the name neither refers to a
defined field or oneof group, :exc:`ValueError` is raised.
Args:
field_name (str): The name of the field to check for presence.
Raises:
ValueError: if the `field_name` is not a member of this message.
"""
raise NotImplementedError
def WhichOneof(self, oneof_group):
"""Returns the name of the field that is set inside a oneof group.
If no field is set, returns None.
Args:
oneof_group (str): the name of the oneof group to check.
Returns:
str or None: The name of the group that is set, or None.
Raises:
ValueError: no group with the given name exists
"""
raise NotImplementedError
def HasExtension(self, extension_handle):
"""Checks if a certain extension is present for this message.
Extensions are retrieved using the :attr:`Extensions` mapping (if present).
Args:
extension_handle: The handle for the extension to check.
Returns:
bool: Whether the extension is present for this message.
Raises:
KeyError: if the extension is repeated. Similar to repeated fields,
there is no separate notion of presence: a "not present" repeated
extension is an empty list.
"""
raise NotImplementedError
def ClearExtension(self, extension_handle):
"""Clears the contents of a given extension.
Args:
extension_handle: The handle for the extension to clear.
"""
raise NotImplementedError
def UnknownFields(self):
"""Returns the UnknownFieldSet.
Returns:
UnknownFieldSet: The unknown fields stored in this message.
"""
raise NotImplementedError
def DiscardUnknownFields(self):
"""Clears all fields in the :class:`UnknownFieldSet`.
This operation is recursive for nested message.
"""
raise NotImplementedError
def ByteSize(self):
"""Returns the serialized size of this message.
Recursively calls ByteSize() on all contained messages.
Returns:
int: The number of bytes required to serialize this message.
"""
raise NotImplementedError
@classmethod
def FromString(cls, s):
raise NotImplementedError
@staticmethod
def RegisterExtension(extension_handle):
raise NotImplementedError
def _SetListener(self, message_listener):
"""Internal method used by the protocol message implementation.
Clients should not call this directly.
Sets a listener that this message will call on certain state transitions.
The purpose of this method is to register back-edges from children to
parents at runtime, for the purpose of setting "has" bits and
byte-size-dirty bits in the parent and ancestor objects whenever a child or
descendant object is modified.
If the client wants to disconnect this Message from the object tree, she
explicitly sets callback to None.
If message_listener is None, unregisters any existing listener. Otherwise,
message_listener must implement the MessageListener interface in
internal/message_listener.py, and we discard any listener registered
via a previous _SetListener() call.
"""
raise NotImplementedError
def __getstate__(self):
"""Support the pickle protocol."""
return dict(serialized=self.SerializePartialToString())
def __setstate__(self, state):
"""Support the pickle protocol."""
self.__init__()
serialized = state['serialized']
# On Python 3, using encoding='latin1' is required for unpickling
# protos pickled by Python 2.
if not isinstance(serialized, bytes):
serialized = serialized.encode('latin1')
self.ParseFromString(serialized)
def __reduce__(self):
message_descriptor = self.DESCRIPTOR
if message_descriptor.containing_type is None:
return type(self), (), self.__getstate__()
# the message type must be nested.
# Python does not pickle nested classes; use the symbol_database on the
# receiving end.
container = message_descriptor
return (_InternalConstructMessage, (container.full_name,),
self.__getstate__())
def _InternalConstructMessage(full_name):
"""Constructs a nested message."""
from google.protobuf import symbol_database # pylint:disable=g-import-not-at-top
return symbol_database.Default().GetSymbol(full_name)()

View File

@@ -0,0 +1,176 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Provides a factory class for generating dynamic messages.
The easiest way to use this class is if you have access to the FileDescriptor
protos containing the messages you want to create you can just do the following:
message_classes = message_factory.GetMessages(iterable_of_file_descriptors)
my_proto_instance = message_classes['some.proto.package.MessageName']()
"""
__author__ = 'matthewtoia@google.com (Matt Toia)'
from google.protobuf.internal import api_implementation
from google.protobuf import descriptor_pool
from google.protobuf import message
if api_implementation.Type() == 'python':
from google.protobuf.internal import python_message as message_impl
else:
from google.protobuf.pyext import cpp_message as message_impl # pylint: disable=g-import-not-at-top
# The type of all Message classes.
_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType
class MessageFactory(object):
"""Factory for creating Proto2 messages from descriptors in a pool."""
def __init__(self, pool=None):
"""Initializes a new factory."""
self.pool = pool or descriptor_pool.DescriptorPool()
def GetPrototype(self, descriptor):
"""Obtains a proto2 message class based on the passed in descriptor.
Passing a descriptor with a fully qualified name matching a previous
invocation will cause the same class to be returned.
Args:
descriptor: The descriptor to build from.
Returns:
A class describing the passed in descriptor.
"""
concrete_class = getattr(descriptor, '_concrete_class', None)
if concrete_class:
return concrete_class
result_class = self.CreatePrototype(descriptor)
return result_class
def CreatePrototype(self, descriptor):
"""Builds a proto2 message class based on the passed in descriptor.
Don't call this function directly, it always creates a new class. Call
GetPrototype() instead. This method is meant to be overridden in subblasses
to perform additional operations on the newly constructed class.
Args:
descriptor: The descriptor to build from.
Returns:
A class describing the passed in descriptor.
"""
descriptor_name = descriptor.name
result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE(
descriptor_name,
(message.Message,),
{
'DESCRIPTOR': descriptor,
# If module not set, it wrongly points to message_factory module.
'__module__': None,
})
result_class._FACTORY = self # pylint: disable=protected-access
for field in descriptor.fields:
if field.message_type:
self.GetPrototype(field.message_type)
for extension in result_class.DESCRIPTOR.extensions:
extended_class = self.GetPrototype(extension.containing_type)
extended_class.RegisterExtension(extension)
if extension.message_type:
self.GetPrototype(extension.message_type)
return result_class
def GetMessages(self, files):
"""Gets all the messages from a specified file.
This will find and resolve dependencies, failing if the descriptor
pool cannot satisfy them.
Args:
files: The file names to extract messages from.
Returns:
A dictionary mapping proto names to the message classes. This will include
any dependent messages as well as any messages defined in the same file as
a specified message.
"""
result = {}
for file_name in files:
file_desc = self.pool.FindFileByName(file_name)
for desc in file_desc.message_types_by_name.values():
result[desc.full_name] = self.GetPrototype(desc)
# While the extension FieldDescriptors are created by the descriptor pool,
# the python classes created in the factory need them to be registered
# explicitly, which is done below.
#
# The call to RegisterExtension will specifically check if the
# extension was already registered on the object and either
# ignore the registration if the original was the same, or raise
# an error if they were different.
for extension in file_desc.extensions_by_name.values():
extended_class = self.GetPrototype(extension.containing_type)
extended_class.RegisterExtension(extension)
if extension.message_type:
self.GetPrototype(extension.message_type)
return result
_FACTORY = MessageFactory()
def GetMessages(file_protos):
"""Builds a dictionary of all the messages available in a set of files.
Args:
file_protos: Iterable of FileDescriptorProto to build messages out of.
Returns:
A dictionary mapping proto names to the message classes. This will include
any dependent messages as well as any messages defined in the same file as
a specified message.
"""
# The cpp implementation of the protocol buffer library requires to add the
# message in topological order of the dependency graph.
file_by_name = {file_proto.name: file_proto for file_proto in file_protos}
def _AddFile(file_proto):
for dependency in file_proto.dependency:
if dependency in file_by_name:
# Remove from elements to be visited, in order to cut cycles.
_AddFile(file_by_name.pop(dependency))
_FACTORY.pool.Add(file_proto)
while file_by_name:
_AddFile(file_by_name.popitem()[1])
return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos])

View File

@@ -0,0 +1,144 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// This file can be included by other C++ libraries, typically extension modules
// which want to interact with the Python Messages coming from the "cpp"
// implementation of protocol buffers.
//
// Usage:
// Declare a (probably static) variable to hold the API:
// const PyProto_API* py_proto_api;
// In some initialization function, write:
// py_proto_api = static_cast<const PyProto_API*>(PyCapsule_Import(
// PyProtoAPICapsuleName(), 0));
// if (!py_proto_api) { ...handle ImportError... }
// Then use the methods of the returned class:
// py_proto_api->GetMessagePointer(...);
#ifndef GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
#define GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/message.h"
namespace google {
namespace protobuf {
namespace python {
// Note on the implementation:
// This API is designed after
// https://docs.python.org/3/extending/extending.html#providing-a-c-api-for-an-extension-module
// The class below contains no mutable state, and all methods are "const";
// we use a C++ class instead of a C struct with functions pointers just because
// the code looks more readable.
struct PyProto_API {
// The API object is created at initialization time and never freed.
// This destructor is never called.
virtual ~PyProto_API() {}
// Operations on Messages.
// If the passed object is a Python Message, returns its internal pointer.
// Otherwise, returns NULL with an exception set.
virtual const Message* GetMessagePointer(PyObject* msg) const = 0;
// If the passed object is a Python Message, returns a mutable pointer.
// Otherwise, returns NULL with an exception set.
// This function will succeed only if there are no other Python objects
// pointing to the message, like submessages or repeated containers.
// With the current implementation, only empty messages are in this case.
virtual Message* GetMutableMessagePointer(PyObject* msg) const = 0;
// If the passed object is a Python Message Descriptor, returns its internal
// pointer.
// Otherwise, returns NULL with an exception set.
virtual const Descriptor* MessageDescriptor_AsDescriptor(
PyObject* desc) const = 0;
// If the passed object is a Python Enum Descriptor, returns its internal
// pointer.
// Otherwise, returns NULL with an exception set.
virtual const EnumDescriptor* EnumDescriptor_AsDescriptor(
PyObject* enum_desc) const = 0;
// Expose the underlying DescriptorPool and MessageFactory to enable C++ code
// to create Python-compatible message.
virtual const DescriptorPool* GetDefaultDescriptorPool() const = 0;
virtual MessageFactory* GetDefaultMessageFactory() const = 0;
// Allocate a new protocol buffer as a python object for the provided
// descriptor. This function works even if no Python module has been imported
// for the corresponding protocol buffer class.
// The factory is usually null; when provided, it is the MessageFactory which
// owns the Python class, and will be used to find and create Extensions for
// this message.
// When null is returned, a python error has already been set.
virtual PyObject* NewMessage(const Descriptor* descriptor,
PyObject* py_message_factory) const = 0;
// Allocate a new protocol buffer where the underlying object is owned by C++.
// The factory must currently be null. This function works even if no Python
// module has been imported for the corresponding protocol buffer class.
// When null is returned, a python error has already been set.
//
// Since this call returns a python object owned by C++, some operations
// are risky, and it must be used carefully. In particular:
// * Avoid modifying the returned object from the C++ side while there are
// existing python references to it or it's subobjects.
// * Avoid using python references to this object or any subobjects after the
// C++ object has been freed.
// * Calling this with the same C++ pointer will result in multiple distinct
// python objects referencing the same C++ object.
virtual PyObject* NewMessageOwnedExternally(
Message* msg, PyObject* py_message_factory) const = 0;
// Returns a new reference for the given DescriptorPool.
// The returned object does not manage the C++ DescriptorPool: it is the
// responsibility of the caller to keep it alive.
// As long as the returned Python DescriptorPool object is kept alive,
// functions that process C++ descriptors or messages created from this pool
// can work and return their Python counterparts.
virtual PyObject* DescriptorPool_FromPool(
const google::protobuf::DescriptorPool* pool) const = 0;
};
inline const char* PyProtoAPICapsuleName() {
static const char kCapsuleName[] = "google.protobuf.pyext._message.proto_API";
return kCapsuleName;
}
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__

View File

@@ -0,0 +1,134 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Dynamic Protobuf class creator."""
from collections import OrderedDict
import hashlib
import os
from google.protobuf import descriptor_pb2
from google.protobuf import descriptor
from google.protobuf import message_factory
def _GetMessageFromFactory(factory, full_name):
"""Get a proto class from the MessageFactory by name.
Args:
factory: a MessageFactory instance.
full_name: str, the fully qualified name of the proto type.
Returns:
A class, for the type identified by full_name.
Raises:
KeyError, if the proto is not found in the factory's descriptor pool.
"""
proto_descriptor = factory.pool.FindMessageTypeByName(full_name)
proto_cls = factory.GetPrototype(proto_descriptor)
return proto_cls
def MakeSimpleProtoClass(fields, full_name=None, pool=None):
"""Create a Protobuf class whose fields are basic types.
Note: this doesn't validate field names!
Args:
fields: dict of {name: field_type} mappings for each field in the proto. If
this is an OrderedDict the order will be maintained, otherwise the
fields will be sorted by name.
full_name: optional str, the fully-qualified name of the proto type.
pool: optional DescriptorPool instance.
Returns:
a class, the new protobuf class with a FileDescriptor.
"""
factory = message_factory.MessageFactory(pool=pool)
if full_name is not None:
try:
proto_cls = _GetMessageFromFactory(factory, full_name)
return proto_cls
except KeyError:
# The factory's DescriptorPool doesn't know about this class yet.
pass
# Get a list of (name, field_type) tuples from the fields dict. If fields was
# an OrderedDict we keep the order, but otherwise we sort the field to ensure
# consistent ordering.
field_items = fields.items()
if not isinstance(fields, OrderedDict):
field_items = sorted(field_items)
# Use a consistent file name that is unlikely to conflict with any imported
# proto files.
fields_hash = hashlib.sha1()
for f_name, f_type in field_items:
fields_hash.update(f_name.encode('utf-8'))
fields_hash.update(str(f_type).encode('utf-8'))
proto_file_name = fields_hash.hexdigest() + '.proto'
# If the proto is anonymous, use the same hash to name it.
if full_name is None:
full_name = ('net.proto2.python.public.proto_builder.AnonymousProto_' +
fields_hash.hexdigest())
try:
proto_cls = _GetMessageFromFactory(factory, full_name)
return proto_cls
except KeyError:
# The factory's DescriptorPool doesn't know about this class yet.
pass
# This is the first time we see this proto: add a new descriptor to the pool.
factory.pool.Add(
_MakeFileDescriptorProto(proto_file_name, full_name, field_items))
return _GetMessageFromFactory(factory, full_name)
def _MakeFileDescriptorProto(proto_file_name, full_name, field_items):
"""Populate FileDescriptorProto for MessageFactory's DescriptorPool."""
package, name = full_name.rsplit('.', 1)
file_proto = descriptor_pb2.FileDescriptorProto()
file_proto.name = os.path.join(package.replace('.', '/'), proto_file_name)
file_proto.package = package
desc_proto = file_proto.message_type.add()
desc_proto.name = name
for f_number, (f_name, f_type) in enumerate(field_items, 1):
field_proto = desc_proto.field.add()
field_proto.name = f_name
# # If the number falls in the reserved range, reassign it to the correct
# # number after the range.
if f_number >= descriptor.FieldDescriptor.FIRST_RESERVED_FIELD_NUMBER:
f_number += (
descriptor.FieldDescriptor.LAST_RESERVED_FIELD_NUMBER -
descriptor.FieldDescriptor.FIRST_RESERVED_FIELD_NUMBER + 1)
field_proto.number = f_number
field_proto.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
field_proto.type = f_type
return file_proto

View File

@@ -0,0 +1,6 @@
This is the 'v2' C++ implementation for python proto2.
It is active when:
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2

View File

@@ -0,0 +1,72 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Protocol message implementation hooks for C++ implementation.
Contains helper functions used to create protocol message classes from
Descriptor objects at runtime backed by the protocol buffer C++ API.
"""
__author__ = 'tibell@google.com (Johan Tibell)'
from google.protobuf.internal import api_implementation
# pylint: disable=protected-access
_message = api_implementation._c_module
# TODO(jieluo): Remove this import after fix api_implementation
if _message is None:
from google.protobuf.pyext import _message
class GeneratedProtocolMessageType(_message.MessageMeta):
"""Metaclass for protocol message classes created at runtime from Descriptors.
The protocol compiler currently uses this metaclass to create protocol
message classes at runtime. Clients can also manually create their own
classes at runtime, as in this example:
mydescriptor = Descriptor(.....)
factory = symbol_database.Default()
factory.pool.AddDescriptor(mydescriptor)
MyProtoClass = factory.GetPrototype(mydescriptor)
myproto_instance = MyProtoClass()
myproto.foo_field = 23
...
The above example will not work for nested types. If you wish to include them,
use reflection.MakeClass() instead of manually instantiating the class in
order to create the appropriate class structure.
"""
# Must be consistent with the protocol-compiler code in
# proto2/compiler/internal/generator.*.
_DESCRIPTOR_KEY = 'DESCRIPTOR'

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,105 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: petar@google.com (Petar Petrov)
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/descriptor.h"
namespace google {
namespace protobuf {
namespace python {
extern PyTypeObject PyMessageDescriptor_Type;
extern PyTypeObject PyFieldDescriptor_Type;
extern PyTypeObject PyEnumDescriptor_Type;
extern PyTypeObject PyEnumValueDescriptor_Type;
extern PyTypeObject PyFileDescriptor_Type;
extern PyTypeObject PyOneofDescriptor_Type;
extern PyTypeObject PyServiceDescriptor_Type;
extern PyTypeObject PyMethodDescriptor_Type;
// Wraps a Descriptor in a Python object.
// The C++ pointer is usually borrowed from the global DescriptorPool.
// In any case, it must stay alive as long as the Python object.
// Returns a new reference.
PyObject* PyMessageDescriptor_FromDescriptor(const Descriptor* descriptor);
PyObject* PyFieldDescriptor_FromDescriptor(const FieldDescriptor* descriptor);
PyObject* PyEnumDescriptor_FromDescriptor(const EnumDescriptor* descriptor);
PyObject* PyEnumValueDescriptor_FromDescriptor(
const EnumValueDescriptor* descriptor);
PyObject* PyOneofDescriptor_FromDescriptor(const OneofDescriptor* descriptor);
PyObject* PyFileDescriptor_FromDescriptor(
const FileDescriptor* file_descriptor);
PyObject* PyServiceDescriptor_FromDescriptor(
const ServiceDescriptor* descriptor);
PyObject* PyMethodDescriptor_FromDescriptor(
const MethodDescriptor* descriptor);
// Alternate constructor of PyFileDescriptor, used when we already have a
// serialized FileDescriptorProto that can be cached.
// Returns a new reference.
PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb(
const FileDescriptor* file_descriptor, PyObject* serialized_pb);
// Return the C++ descriptor pointer.
// This function checks the parameter type; on error, return NULL with a Python
// exception set.
const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj);
const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj);
const EnumDescriptor* PyEnumDescriptor_AsDescriptor(PyObject* obj);
const FileDescriptor* PyFileDescriptor_AsDescriptor(PyObject* obj);
const ServiceDescriptor* PyServiceDescriptor_AsDescriptor(PyObject* obj);
// Returns the raw C++ pointer.
const void* PyDescriptor_AsVoidPtr(PyObject* obj);
// Check that the calling Python code is the global scope of a _pb2.py module.
// This function is used to support the current code generated by the proto
// compiler, which insists on modifying descriptors after they have been
// created.
//
// stacklevel indicates which Python frame should be the _pb2.py module.
//
// Don't use this function outside descriptor classes.
bool _CalledFromGeneratedFile(int stacklevel);
bool InitDescriptor();
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,110 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__
// Mappings and Sequences of descriptors.
// They implement containers like fields_by_name, EnumDescriptor.values...
// See descriptor_containers.cc for more description.
#define PY_SSIZE_T_CLEAN
#include <Python.h>
namespace google {
namespace protobuf {
class Descriptor;
class FileDescriptor;
class EnumDescriptor;
class OneofDescriptor;
class ServiceDescriptor;
namespace python {
// Initialize the various types and objects.
bool InitDescriptorMappingTypes();
// Each function below returns a Mapping, or a Sequence of descriptors.
// They all return a new reference.
namespace message_descriptor {
PyObject* NewMessageFieldsByName(const Descriptor* descriptor);
PyObject* NewMessageFieldsByCamelcaseName(const Descriptor* descriptor);
PyObject* NewMessageFieldsByNumber(const Descriptor* descriptor);
PyObject* NewMessageFieldsSeq(const Descriptor* descriptor);
PyObject* NewMessageNestedTypesSeq(const Descriptor* descriptor);
PyObject* NewMessageNestedTypesByName(const Descriptor* descriptor);
PyObject* NewMessageEnumsByName(const Descriptor* descriptor);
PyObject* NewMessageEnumsSeq(const Descriptor* descriptor);
PyObject* NewMessageEnumValuesByName(const Descriptor* descriptor);
PyObject* NewMessageExtensionsByName(const Descriptor* descriptor);
PyObject* NewMessageExtensionsSeq(const Descriptor* descriptor);
PyObject* NewMessageOneofsByName(const Descriptor* descriptor);
PyObject* NewMessageOneofsSeq(const Descriptor* descriptor);
} // namespace message_descriptor
namespace enum_descriptor {
PyObject* NewEnumValuesByName(const EnumDescriptor* descriptor);
PyObject* NewEnumValuesByNumber(const EnumDescriptor* descriptor);
PyObject* NewEnumValuesSeq(const EnumDescriptor* descriptor);
} // namespace enum_descriptor
namespace oneof_descriptor {
PyObject* NewOneofFieldsSeq(const OneofDescriptor* descriptor);
} // namespace oneof_descriptor
namespace file_descriptor {
PyObject* NewFileMessageTypesByName(const FileDescriptor* descriptor);
PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor);
PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor);
PyObject* NewFileServicesByName(const FileDescriptor* descriptor);
PyObject* NewFileDependencies(const FileDescriptor* descriptor);
PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor);
} // namespace file_descriptor
namespace service_descriptor {
PyObject* NewServiceMethodsSeq(const ServiceDescriptor* descriptor);
PyObject* NewServiceMethodsByName(const ServiceDescriptor* descriptor);
} // namespace service_descriptor
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__

View File

@@ -0,0 +1,189 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// This file defines a C++ DescriptorDatabase, which wraps a Python Database
// and delegate all its operations to Python methods.
#include "google/protobuf/pyext/descriptor_database.h"
#include <cstdint>
#include <string>
#include <vector>
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
namespace google {
namespace protobuf {
namespace python {
PyDescriptorDatabase::PyDescriptorDatabase(PyObject* py_database)
: py_database_(py_database) {
Py_INCREF(py_database_);
}
PyDescriptorDatabase::~PyDescriptorDatabase() { Py_DECREF(py_database_); }
// Convert a Python object to a FileDescriptorProto pointer.
// Handles all kinds of Python errors, which are simply logged.
static bool GetFileDescriptorProto(PyObject* py_descriptor,
FileDescriptorProto* output) {
if (py_descriptor == nullptr) {
if (PyErr_ExceptionMatches(PyExc_KeyError)) {
// Expected error: item was simply not found.
PyErr_Clear();
} else {
GOOGLE_LOG(ERROR) << "DescriptorDatabase method raised an error";
PyErr_Print();
}
return false;
}
if (py_descriptor == Py_None) {
return false;
}
const Descriptor* filedescriptor_descriptor =
FileDescriptorProto::default_instance().GetDescriptor();
CMessage* message = reinterpret_cast<CMessage*>(py_descriptor);
if (PyObject_TypeCheck(py_descriptor, CMessage_Type) &&
message->message->GetDescriptor() == filedescriptor_descriptor) {
// Fast path: Just use the pointer.
FileDescriptorProto* file_proto =
static_cast<FileDescriptorProto*>(message->message);
*output = *file_proto;
return true;
} else {
// Slow path: serialize the message. This allows to use databases which
// use a different implementation of FileDescriptorProto.
ScopedPyObjectPtr serialized_pb(
PyObject_CallMethod(py_descriptor, "SerializeToString", nullptr));
if (serialized_pb == nullptr) {
GOOGLE_LOG(ERROR)
<< "DescriptorDatabase method did not return a FileDescriptorProto";
PyErr_Print();
return false;
}
char* str;
Py_ssize_t len;
if (PyBytes_AsStringAndSize(serialized_pb.get(), &str, &len) < 0) {
GOOGLE_LOG(ERROR)
<< "DescriptorDatabase method did not return a FileDescriptorProto";
PyErr_Print();
return false;
}
FileDescriptorProto file_proto;
if (!file_proto.ParseFromArray(str, len)) {
GOOGLE_LOG(ERROR)
<< "DescriptorDatabase method did not return a FileDescriptorProto";
return false;
}
*output = file_proto;
return true;
}
}
// Find a file by file name.
bool PyDescriptorDatabase::FindFileByName(const std::string& filename,
FileDescriptorProto* output) {
ScopedPyObjectPtr py_descriptor(PyObject_CallMethod(
py_database_, "FindFileByName", "s#", filename.c_str(), filename.size()));
return GetFileDescriptorProto(py_descriptor.get(), output);
}
// Find the file that declares the given fully-qualified symbol name.
bool PyDescriptorDatabase::FindFileContainingSymbol(
const std::string& symbol_name, FileDescriptorProto* output) {
ScopedPyObjectPtr py_descriptor(
PyObject_CallMethod(py_database_, "FindFileContainingSymbol", "s#",
symbol_name.c_str(), symbol_name.size()));
return GetFileDescriptorProto(py_descriptor.get(), output);
}
// Find the file which defines an extension extending the given message type
// with the given field number.
// Python DescriptorDatabases are not required to implement this method.
bool PyDescriptorDatabase::FindFileContainingExtension(
const std::string& containing_type, int field_number,
FileDescriptorProto* output) {
ScopedPyObjectPtr py_method(
PyObject_GetAttrString(py_database_, "FindFileContainingExtension"));
if (py_method == nullptr) {
// This method is not implemented, returns without error.
PyErr_Clear();
return false;
}
ScopedPyObjectPtr py_descriptor(
PyObject_CallFunction(py_method.get(), "s#i", containing_type.c_str(),
containing_type.size(), field_number));
return GetFileDescriptorProto(py_descriptor.get(), output);
}
// Finds the tag numbers used by all known extensions of
// containing_type, and appends them to output in an undefined
// order.
// Python DescriptorDatabases are not required to implement this method.
bool PyDescriptorDatabase::FindAllExtensionNumbers(
const std::string& containing_type, std::vector<int>* output) {
ScopedPyObjectPtr py_method(
PyObject_GetAttrString(py_database_, "FindAllExtensionNumbers"));
if (py_method == nullptr) {
// This method is not implemented, returns without error.
PyErr_Clear();
return false;
}
ScopedPyObjectPtr py_list(
PyObject_CallFunction(py_method.get(), "s#", containing_type.c_str(),
containing_type.size()));
if (py_list == nullptr) {
PyErr_Print();
return false;
}
Py_ssize_t size = PyList_Size(py_list.get());
int64_t item_value;
for (Py_ssize_t i = 0 ; i < size; ++i) {
ScopedPyObjectPtr item(PySequence_GetItem(py_list.get(), i));
item_value = PyLong_AsLong(item.get());
if (item_value < 0) {
GOOGLE_LOG(ERROR)
<< "FindAllExtensionNumbers method did not return "
<< "valid extension numbers.";
PyErr_Print();
return false;
}
output->push_back(item_value);
}
return true;
}
} // namespace python
} // namespace protobuf
} // namespace google

View File

@@ -0,0 +1,86 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <string>
#include <vector>
#include "google/protobuf/descriptor_database.h"
namespace google {
namespace protobuf {
namespace python {
class PyDescriptorDatabase : public DescriptorDatabase {
public:
explicit PyDescriptorDatabase(PyObject* py_database);
~PyDescriptorDatabase() override;
// Implement the abstract interface. All these functions fill the output
// with a copy of FileDescriptorProto.
// Find a file by file name.
bool FindFileByName(const std::string& filename,
FileDescriptorProto* output) override;
// Find the file that declares the given fully-qualified symbol name.
bool FindFileContainingSymbol(const std::string& symbol_name,
FileDescriptorProto* output) override;
// Find the file which defines an extension extending the given message type
// with the given field number.
// Containing_type must be a fully-qualified type name.
// Python objects are not required to implement this method.
bool FindFileContainingExtension(const std::string& containing_type,
int field_number,
FileDescriptorProto* output) override;
// Finds the tag numbers used by all known extensions of
// containing_type, and appends them to output in an undefined
// order.
// Python objects are not required to implement this method.
bool FindAllExtensionNumbers(const std::string& containing_type,
std::vector<int>* output) override;
private:
// The python object that implements the database. The reference is owned.
PyObject* py_database_;
};
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__

View File

@@ -0,0 +1,820 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Implements the DescriptorPool, which collects all descriptors.
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/descriptor_database.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include "absl/strings/string_view.h"
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob) \
? ((*(charpp) = const_cast<char*>( \
PyUnicode_AsUTF8AndSize(ob, (sizep)))) == nullptr \
? -1 \
: 0) \
: PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
namespace google {
namespace protobuf {
namespace python {
// A map to cache Python Pools per C++ pointer.
// Pointers are not owned here, and belong to the PyDescriptorPool.
static std::unordered_map<const DescriptorPool*, PyDescriptorPool*>*
descriptor_pool_map;
namespace cdescriptor_pool {
// Collects errors that occur during proto file building to allow them to be
// propagated in the python exception instead of only living in ERROR logs.
class BuildFileErrorCollector : public DescriptorPool::ErrorCollector {
public:
BuildFileErrorCollector() : error_message(""), had_errors_(false) {}
void AddError(const std::string& filename, const std::string& element_name,
const Message* descriptor, ErrorLocation location,
const std::string& message) override {
// Replicates the logging behavior that happens in the C++ implementation
// when an error collector is not passed in.
if (!had_errors_) {
error_message +=
("Invalid proto descriptor for file \"" + filename + "\":\n");
had_errors_ = true;
}
// As this only happens on failure and will result in the program not
// running at all, no effort is made to optimize this string manipulation.
error_message += (" " + element_name + ": " + message + "\n");
}
void Clear() {
had_errors_ = false;
error_message = "";
}
std::string error_message;
private:
bool had_errors_;
};
// Create a Python DescriptorPool object, but does not fill the "pool"
// attribute.
static PyDescriptorPool* _CreateDescriptorPool() {
PyDescriptorPool* cpool = PyObject_GC_New(
PyDescriptorPool, &PyDescriptorPool_Type);
if (cpool == nullptr) {
return nullptr;
}
cpool->error_collector = nullptr;
cpool->underlay = nullptr;
cpool->database = nullptr;
cpool->is_owned = false;
cpool->is_mutable = false;
cpool->descriptor_options = new std::unordered_map<const void*, PyObject*>();
cpool->py_message_factory = message_factory::NewMessageFactory(
&PyMessageFactory_Type, cpool);
if (cpool->py_message_factory == nullptr) {
Py_DECREF(cpool);
return nullptr;
}
PyObject_GC_Track(cpool);
return cpool;
}
// Create a Python DescriptorPool, using the given pool as an underlay:
// new messages will be added to a custom pool, not to the underlay.
//
// Ownership of the underlay is not transferred, its pointer should
// stay alive.
static PyDescriptorPool* PyDescriptorPool_NewWithUnderlay(
const DescriptorPool* underlay) {
PyDescriptorPool* cpool = _CreateDescriptorPool();
if (cpool == nullptr) {
return nullptr;
}
cpool->pool = new DescriptorPool(underlay);
cpool->is_owned = true;
cpool->is_mutable = true;
cpool->underlay = underlay;
if (!descriptor_pool_map->insert(
std::make_pair(cpool->pool, cpool)).second) {
// Should never happen -- would indicate an internal error / bug.
PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered");
return nullptr;
}
return cpool;
}
static PyDescriptorPool* PyDescriptorPool_NewWithDatabase(
DescriptorDatabase* database) {
PyDescriptorPool* cpool = _CreateDescriptorPool();
if (cpool == nullptr) {
return nullptr;
}
if (database != nullptr) {
cpool->error_collector = new BuildFileErrorCollector();
cpool->pool = new DescriptorPool(database, cpool->error_collector);
cpool->is_mutable = false;
cpool->database = database;
} else {
cpool->pool = new DescriptorPool();
cpool->is_mutable = true;
}
cpool->is_owned = true;
if (!descriptor_pool_map->insert(std::make_pair(cpool->pool, cpool)).second) {
// Should never happen -- would indicate an internal error / bug.
PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered");
return nullptr;
}
return cpool;
}
// The public DescriptorPool constructor.
static PyObject* New(PyTypeObject* type,
PyObject* args, PyObject* kwargs) {
static const char* kwlist[] = {"descriptor_db", nullptr};
PyObject* py_database = nullptr;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O",
const_cast<char**>(kwlist), &py_database)) {
return nullptr;
}
DescriptorDatabase* database = nullptr;
if (py_database && py_database != Py_None) {
database = new PyDescriptorDatabase(py_database);
}
return reinterpret_cast<PyObject*>(
PyDescriptorPool_NewWithDatabase(database));
}
static void Dealloc(PyObject* pself) {
PyDescriptorPool* self = reinterpret_cast<PyDescriptorPool*>(pself);
descriptor_pool_map->erase(self->pool);
Py_CLEAR(self->py_message_factory);
for (std::unordered_map<const void*, PyObject*>::iterator it =
self->descriptor_options->begin();
it != self->descriptor_options->end(); ++it) {
Py_DECREF(it->second);
}
delete self->descriptor_options;
delete self->database;
if (self->is_owned) {
delete self->pool;
}
delete self->error_collector;
Py_TYPE(self)->tp_free(pself);
}
static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
PyDescriptorPool* self = reinterpret_cast<PyDescriptorPool*>(pself);
Py_VISIT(self->py_message_factory);
return 0;
}
static int GcClear(PyObject* pself) {
PyDescriptorPool* self = reinterpret_cast<PyDescriptorPool*>(pself);
Py_CLEAR(self->py_message_factory);
return 0;
}
PyObject* SetErrorFromCollector(DescriptorPool::ErrorCollector* self,
const char* name, const char* error_type) {
BuildFileErrorCollector* error_collector =
reinterpret_cast<BuildFileErrorCollector*>(self);
if (error_collector && !error_collector->error_message.empty()) {
PyErr_Format(PyExc_KeyError, "Couldn't build file for %s %.200s\n%s",
error_type, name, error_collector->error_message.c_str());
error_collector->Clear();
return nullptr;
}
PyErr_Format(PyExc_KeyError, "Couldn't find %s %.200s", error_type, name);
return nullptr;
}
static PyObject* FindMessageByName(PyObject* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return nullptr;
}
const Descriptor* message_descriptor =
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindMessageTypeByName(
absl::string_view(name, name_size));
if (message_descriptor == nullptr) {
return SetErrorFromCollector(
reinterpret_cast<PyDescriptorPool*>(self)->error_collector, name,
"message");
}
return PyMessageDescriptor_FromDescriptor(message_descriptor);
}
static PyObject* FindFileByName(PyObject* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return nullptr;
}
PyDescriptorPool* py_pool = reinterpret_cast<PyDescriptorPool*>(self);
const FileDescriptor* file_descriptor =
py_pool->pool->FindFileByName(absl::string_view(name, name_size));
if (file_descriptor == nullptr) {
return SetErrorFromCollector(py_pool->error_collector, name, "file");
}
return PyFileDescriptor_FromDescriptor(file_descriptor);
}
PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return nullptr;
}
const FieldDescriptor* field_descriptor =
self->pool->FindFieldByName(absl::string_view(name, name_size));
if (field_descriptor == nullptr) {
return SetErrorFromCollector(self->error_collector, name, "field");
}
return PyFieldDescriptor_FromDescriptor(field_descriptor);
}
static PyObject* FindFieldByNameMethod(PyObject* self, PyObject* arg) {
return FindFieldByName(reinterpret_cast<PyDescriptorPool*>(self), arg);
}
PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return nullptr;
}
const FieldDescriptor* field_descriptor =
self->pool->FindExtensionByName(absl::string_view(name, name_size));
if (field_descriptor == nullptr) {
return SetErrorFromCollector(self->error_collector, name,
"extension field");
}
return PyFieldDescriptor_FromDescriptor(field_descriptor);
}
static PyObject* FindExtensionByNameMethod(PyObject* self, PyObject* arg) {
return FindExtensionByName(reinterpret_cast<PyDescriptorPool*>(self), arg);
}
PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return nullptr;
}
const EnumDescriptor* enum_descriptor =
self->pool->FindEnumTypeByName(absl::string_view(name, name_size));
if (enum_descriptor == nullptr) {
return SetErrorFromCollector(self->error_collector, name, "enum");
}
return PyEnumDescriptor_FromDescriptor(enum_descriptor);
}
static PyObject* FindEnumTypeByNameMethod(PyObject* self, PyObject* arg) {
return FindEnumTypeByName(reinterpret_cast<PyDescriptorPool*>(self), arg);
}
PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return nullptr;
}
const OneofDescriptor* oneof_descriptor =
self->pool->FindOneofByName(absl::string_view(name, name_size));
if (oneof_descriptor == nullptr) {
return SetErrorFromCollector(self->error_collector, name, "oneof");
}
return PyOneofDescriptor_FromDescriptor(oneof_descriptor);
}
static PyObject* FindOneofByNameMethod(PyObject* self, PyObject* arg) {
return FindOneofByName(reinterpret_cast<PyDescriptorPool*>(self), arg);
}
static PyObject* FindServiceByName(PyObject* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return nullptr;
}
const ServiceDescriptor* service_descriptor =
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindServiceByName(
absl::string_view(name, name_size));
if (service_descriptor == nullptr) {
return SetErrorFromCollector(
reinterpret_cast<PyDescriptorPool*>(self)->error_collector, name,
"service");
}
return PyServiceDescriptor_FromDescriptor(service_descriptor);
}
static PyObject* FindMethodByName(PyObject* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return nullptr;
}
const MethodDescriptor* method_descriptor =
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindMethodByName(
absl::string_view(name, name_size));
if (method_descriptor == nullptr) {
return SetErrorFromCollector(
reinterpret_cast<PyDescriptorPool*>(self)->error_collector, name,
"method");
}
return PyMethodDescriptor_FromDescriptor(method_descriptor);
}
static PyObject* FindFileContainingSymbol(PyObject* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return nullptr;
}
const FileDescriptor* file_descriptor =
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindFileContainingSymbol(
absl::string_view(name, name_size));
if (file_descriptor == nullptr) {
return SetErrorFromCollector(
reinterpret_cast<PyDescriptorPool*>(self)->error_collector, name,
"symbol");
}
return PyFileDescriptor_FromDescriptor(file_descriptor);
}
static PyObject* FindExtensionByNumber(PyObject* self, PyObject* args) {
PyObject* message_descriptor;
int number;
if (!PyArg_ParseTuple(args, "Oi", &message_descriptor, &number)) {
return nullptr;
}
const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor(
message_descriptor);
if (descriptor == nullptr) {
return nullptr;
}
const FieldDescriptor* extension_descriptor =
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindExtensionByNumber(
descriptor, number);
if (extension_descriptor == nullptr) {
BuildFileErrorCollector* error_collector =
reinterpret_cast<BuildFileErrorCollector*>(
reinterpret_cast<PyDescriptorPool*>(self)->error_collector);
if (error_collector && !error_collector->error_message.empty()) {
PyErr_Format(PyExc_KeyError, "Couldn't build file for Extension %.d\n%s",
number, error_collector->error_message.c_str());
error_collector->Clear();
return nullptr;
}
PyErr_Format(PyExc_KeyError, "Couldn't find Extension %d", number);
return nullptr;
}
return PyFieldDescriptor_FromDescriptor(extension_descriptor);
}
static PyObject* FindAllExtensions(PyObject* self, PyObject* arg) {
const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor(arg);
if (descriptor == nullptr) {
return nullptr;
}
std::vector<const FieldDescriptor*> extensions;
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindAllExtensions(
descriptor, &extensions);
ScopedPyObjectPtr result(PyList_New(extensions.size()));
if (result == nullptr) {
return nullptr;
}
for (int i = 0; i < extensions.size(); i++) {
PyObject* extension = PyFieldDescriptor_FromDescriptor(extensions[i]);
if (extension == nullptr) {
return nullptr;
}
PyList_SET_ITEM(result.get(), i, extension); // Steals the reference.
}
return result.release();
}
// These functions should not exist -- the only valid way to create
// descriptors is to call Add() or AddSerializedFile().
// But these AddDescriptor() functions were created in Python and some people
// call them, so we support them for now for compatibility.
// However we do check that the existing descriptor already exists in the pool,
// which appears to always be true for existing calls -- but then why do people
// call a function that will just be a no-op?
// TODO(amauryfa): Need to investigate further.
static PyObject* AddFileDescriptor(PyObject* self, PyObject* descriptor) {
const FileDescriptor* file_descriptor =
PyFileDescriptor_AsDescriptor(descriptor);
if (!file_descriptor) {
return nullptr;
}
if (file_descriptor !=
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindFileByName(
file_descriptor->name())) {
PyErr_Format(PyExc_ValueError,
"The file descriptor %s does not belong to this pool",
file_descriptor->name().c_str());
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* AddDescriptor(PyObject* self, PyObject* descriptor) {
const Descriptor* message_descriptor =
PyMessageDescriptor_AsDescriptor(descriptor);
if (!message_descriptor) {
return nullptr;
}
if (message_descriptor !=
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindMessageTypeByName(
message_descriptor->full_name())) {
PyErr_Format(PyExc_ValueError,
"The message descriptor %s does not belong to this pool",
message_descriptor->full_name().c_str());
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* AddEnumDescriptor(PyObject* self, PyObject* descriptor) {
const EnumDescriptor* enum_descriptor =
PyEnumDescriptor_AsDescriptor(descriptor);
if (!enum_descriptor) {
return nullptr;
}
if (enum_descriptor !=
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindEnumTypeByName(
enum_descriptor->full_name())) {
PyErr_Format(PyExc_ValueError,
"The enum descriptor %s does not belong to this pool",
enum_descriptor->full_name().c_str());
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* AddExtensionDescriptor(PyObject* self, PyObject* descriptor) {
const FieldDescriptor* extension_descriptor =
PyFieldDescriptor_AsDescriptor(descriptor);
if (!extension_descriptor) {
return nullptr;
}
if (extension_descriptor !=
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindExtensionByName(
extension_descriptor->full_name())) {
PyErr_Format(PyExc_ValueError,
"The extension descriptor %s does not belong to this pool",
extension_descriptor->full_name().c_str());
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* AddServiceDescriptor(PyObject* self, PyObject* descriptor) {
const ServiceDescriptor* service_descriptor =
PyServiceDescriptor_AsDescriptor(descriptor);
if (!service_descriptor) {
return nullptr;
}
if (service_descriptor !=
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindServiceByName(
service_descriptor->full_name())) {
PyErr_Format(PyExc_ValueError,
"The service descriptor %s does not belong to this pool",
service_descriptor->full_name().c_str());
return nullptr;
}
Py_RETURN_NONE;
}
// The code below loads new Descriptors from a serialized FileDescriptorProto.
static PyObject* AddSerializedFile(PyObject* pself, PyObject* serialized_pb) {
PyDescriptorPool* self = reinterpret_cast<PyDescriptorPool*>(pself);
char* message_type;
Py_ssize_t message_len;
if (self->database != nullptr) {
PyErr_SetString(
PyExc_ValueError,
"Cannot call Add on a DescriptorPool that uses a DescriptorDatabase. "
"Add your file to the underlying database.");
return nullptr;
}
if (!self->is_mutable) {
PyErr_SetString(
PyExc_ValueError,
"This DescriptorPool is not mutable and cannot add new definitions.");
return nullptr;
}
if (PyBytes_AsStringAndSize(serialized_pb, &message_type, &message_len) < 0) {
return nullptr;
}
FileDescriptorProto file_proto;
if (!file_proto.ParseFromArray(message_type, message_len)) {
PyErr_SetString(PyExc_TypeError, "Couldn't parse file content!");
return nullptr;
}
// If the file was already part of a C++ library, all its descriptors are in
// the underlying pool. No need to do anything else.
const FileDescriptor* generated_file = nullptr;
if (self->underlay) {
generated_file = self->underlay->FindFileByName(file_proto.name());
}
if (generated_file != nullptr) {
return PyFileDescriptor_FromDescriptorWithSerializedPb(
generated_file, serialized_pb);
}
BuildFileErrorCollector error_collector;
const FileDescriptor* descriptor =
// Pool is mutable, we can remove the "const".
const_cast<DescriptorPool*>(self->pool)
->BuildFileCollectingErrors(file_proto, &error_collector);
if (descriptor == nullptr) {
PyErr_Format(PyExc_TypeError,
"Couldn't build proto file into descriptor pool!\n%s",
error_collector.error_message.c_str());
return nullptr;
}
return PyFileDescriptor_FromDescriptorWithSerializedPb(
descriptor, serialized_pb);
}
static PyObject* Add(PyObject* self, PyObject* file_descriptor_proto) {
ScopedPyObjectPtr serialized_pb(
PyObject_CallMethod(file_descriptor_proto, "SerializeToString", nullptr));
if (serialized_pb == nullptr) {
return nullptr;
}
return AddSerializedFile(self, serialized_pb.get());
}
static PyMethodDef Methods[] = {
{"Add", Add, METH_O,
"Adds the FileDescriptorProto and its types to this pool."},
{"AddSerializedFile", AddSerializedFile, METH_O,
"Adds a serialized FileDescriptorProto to this pool."},
// TODO(amauryfa): Understand why the Python implementation differs from
// this one, ask users to use another API and deprecate these functions.
{"AddFileDescriptor", AddFileDescriptor, METH_O,
"No-op. Add() must have been called before."},
{"AddDescriptor", AddDescriptor, METH_O,
"No-op. Add() must have been called before."},
{"AddEnumDescriptor", AddEnumDescriptor, METH_O,
"No-op. Add() must have been called before."},
{"AddExtensionDescriptor", AddExtensionDescriptor, METH_O,
"No-op. Add() must have been called before."},
{"AddServiceDescriptor", AddServiceDescriptor, METH_O,
"No-op. Add() must have been called before."},
{"FindFileByName", FindFileByName, METH_O,
"Searches for a file descriptor by its .proto name."},
{"FindMessageTypeByName", FindMessageByName, METH_O,
"Searches for a message descriptor by full name."},
{"FindFieldByName", FindFieldByNameMethod, METH_O,
"Searches for a field descriptor by full name."},
{"FindExtensionByName", FindExtensionByNameMethod, METH_O,
"Searches for extension descriptor by full name."},
{"FindEnumTypeByName", FindEnumTypeByNameMethod, METH_O,
"Searches for enum type descriptor by full name."},
{"FindOneofByName", FindOneofByNameMethod, METH_O,
"Searches for oneof descriptor by full name."},
{"FindServiceByName", FindServiceByName, METH_O,
"Searches for service descriptor by full name."},
{"FindMethodByName", FindMethodByName, METH_O,
"Searches for method descriptor by full name."},
{"FindFileContainingSymbol", FindFileContainingSymbol, METH_O,
"Gets the FileDescriptor containing the specified symbol."},
{"FindExtensionByNumber", FindExtensionByNumber, METH_VARARGS,
"Gets the extension descriptor for the given number."},
{"FindAllExtensions", FindAllExtensions, METH_O,
"Gets all known extensions of the given message descriptor."},
{nullptr},
};
} // namespace cdescriptor_pool
PyTypeObject PyDescriptorPool_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
".DescriptorPool", // tp_name
sizeof(PyDescriptorPool), // tp_basicsize
0, // tp_itemsize
cdescriptor_pool::Dealloc, // tp_dealloc
#if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr
nullptr, // tp_setattr
nullptr, // tp_compare
nullptr, // tp_repr
nullptr, // tp_as_number
nullptr, // tp_as_sequence
nullptr, // tp_as_mapping
nullptr, // tp_hash
nullptr, // tp_call
nullptr, // tp_str
nullptr, // tp_getattro
nullptr, // tp_setattro
nullptr, // tp_as_buffer
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, // tp_flags
"A Descriptor Pool", // tp_doc
cdescriptor_pool::GcTraverse, // tp_traverse
cdescriptor_pool::GcClear, // tp_clear
nullptr, // tp_richcompare
0, // tp_weaklistoffset
nullptr, // tp_iter
nullptr, // tp_iternext
cdescriptor_pool::Methods, // tp_methods
nullptr, // tp_members
nullptr, // tp_getset
nullptr, // tp_base
nullptr, // tp_dict
nullptr, // tp_descr_get
nullptr, // tp_descr_set
0, // tp_dictoffset
nullptr, // tp_init
nullptr, // tp_alloc
cdescriptor_pool::New, // tp_new
PyObject_GC_Del, // tp_free
};
// This is the DescriptorPool which contains all the definitions from the
// generated _pb2.py modules.
static PyDescriptorPool* python_generated_pool = nullptr;
bool InitDescriptorPool() {
if (PyType_Ready(&PyDescriptorPool_Type) < 0)
return false;
// The Pool of messages declared in Python libraries.
// generated_pool() contains all messages already linked in C++ libraries, and
// is used as underlay.
descriptor_pool_map =
new std::unordered_map<const DescriptorPool*, PyDescriptorPool*>;
python_generated_pool = cdescriptor_pool::PyDescriptorPool_NewWithUnderlay(
DescriptorPool::generated_pool());
if (python_generated_pool == nullptr) {
delete descriptor_pool_map;
return false;
}
// Register this pool to be found for C++-generated descriptors.
descriptor_pool_map->insert(
std::make_pair(DescriptorPool::generated_pool(),
python_generated_pool));
return true;
}
// The default DescriptorPool used everywhere in this module.
// Today it's the python_generated_pool.
// TODO(amauryfa): Remove all usages of this function: the pool should be
// derived from the context.
PyDescriptorPool* GetDefaultDescriptorPool() {
return python_generated_pool;
}
PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool) {
// Fast path for standard descriptors.
if (pool == python_generated_pool->pool ||
pool == DescriptorPool::generated_pool()) {
return python_generated_pool;
}
std::unordered_map<const DescriptorPool*, PyDescriptorPool*>::iterator it =
descriptor_pool_map->find(pool);
if (it == descriptor_pool_map->end()) {
PyErr_SetString(PyExc_KeyError, "Unknown descriptor pool");
return nullptr;
}
return it->second;
}
PyObject* PyDescriptorPool_FromPool(const DescriptorPool* pool) {
PyDescriptorPool* existing_pool = GetDescriptorPool_FromPool(pool);
if (existing_pool != nullptr) {
Py_INCREF(existing_pool);
return reinterpret_cast<PyObject*>(existing_pool);
} else {
PyErr_Clear();
}
PyDescriptorPool* cpool = cdescriptor_pool::_CreateDescriptorPool();
if (cpool == nullptr) {
return nullptr;
}
cpool->pool = const_cast<DescriptorPool*>(pool);
cpool->is_owned = false;
cpool->is_mutable = false;
cpool->underlay = nullptr;
if (!descriptor_pool_map->insert(std::make_pair(cpool->pool, cpool)).second) {
// Should never happen -- We already checked the existence above.
PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered");
return nullptr;
}
return reinterpret_cast<PyObject*>(cpool);
}
} // namespace python
} // namespace protobuf
} // namespace google

View File

@@ -0,0 +1,149 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <unordered_map>
#include "google/protobuf/descriptor.h"
namespace google {
namespace protobuf {
namespace python {
struct PyMessageFactory;
// The (meta) type of all Messages classes.
struct CMessageClass;
// Wraps operations to the global DescriptorPool which contains information
// about all messages and fields.
//
// There is normally one pool per process. We make it a Python object only
// because it contains many Python references.
//
// "Methods" that interacts with this DescriptorPool are in the cdescriptor_pool
// namespace.
typedef struct PyDescriptorPool {
PyObject_HEAD;
// The C++ pool containing Descriptors.
const DescriptorPool* pool;
// True if we should free the pointer above.
bool is_owned;
// True if this pool accepts new proto definitions.
// In this case it is allowed to const_cast<DescriptorPool*>(pool).
bool is_mutable;
// The error collector to store error info. Can be NULL. This pointer is
// owned.
DescriptorPool::ErrorCollector* error_collector;
// The C++ pool acting as an underlay. Can be NULL.
// This pointer is not owned and must stay alive.
const DescriptorPool* underlay;
// The C++ descriptor database used to fetch unknown protos. Can be NULL.
// This pointer is owned.
const DescriptorDatabase* database;
// The preferred MessageFactory to be used by descriptors.
// TODO(amauryfa): Don't create the Factory from the DescriptorPool, but
// use the one passed while creating message classes. And remove this member.
PyMessageFactory* py_message_factory;
// Cache the options for any kind of descriptor.
// Descriptor pointers are owned by the DescriptorPool above.
// Python objects are owned by the map.
std::unordered_map<const void*, PyObject*>* descriptor_options;
} PyDescriptorPool;
extern PyTypeObject PyDescriptorPool_Type;
namespace cdescriptor_pool {
// The functions below are also exposed as methods of the DescriptorPool type.
// Looks up a field by name. Returns a PyFieldDescriptor corresponding to
// the field on success, or NULL on failure.
//
// Returns a new reference.
PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* name);
// Looks up an extension by name. Returns a PyFieldDescriptor corresponding
// to the field on success, or NULL on failure.
//
// Returns a new reference.
PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg);
// Looks up an enum type by name. Returns a PyEnumDescriptor corresponding
// to the field on success, or NULL on failure.
//
// Returns a new reference.
PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg);
// Looks up a oneof by name. Returns a COneofDescriptor corresponding
// to the oneof on success, or NULL on failure.
//
// Returns a new reference.
PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg);
} // namespace cdescriptor_pool
// Retrieves the global descriptor pool owned by the _message module.
// This is the one used by pb2.py generated modules.
// Returns a *borrowed* reference.
// "Default" pool used to register messages from _pb2.py modules.
PyDescriptorPool* GetDefaultDescriptorPool();
// Retrieves an existing python descriptor pool owning the C++ descriptor pool.
// Returns a *borrowed* reference.
PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool);
// Wraps a C++ descriptor pool in a Python object, creates it if necessary.
// Returns a new reference.
PyObject* PyDescriptorPool_FromPool(const DescriptorPool* pool);
// Initialize objects used by this module.
bool InitDescriptorPool();
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__

View File

@@ -0,0 +1,487 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#include "google/protobuf/pyext/extension_dict.h"
#include <cstdint>
#include <memory>
#include <vector>
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/message.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/repeated_composite_container.h"
#include "google/protobuf/pyext/repeated_scalar_container.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include "absl/strings/string_view.h"
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob) \
? ((*(charpp) = const_cast<char*>( \
PyUnicode_AsUTF8AndSize(ob, (sizep)))) == nullptr \
? -1 \
: 0) \
: PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
namespace google {
namespace protobuf {
namespace python {
namespace extension_dict {
static Py_ssize_t len(ExtensionDict* self) {
Py_ssize_t size = 0;
std::vector<const FieldDescriptor*> fields;
self->parent->message->GetReflection()->ListFields(*self->parent->message,
&fields);
for (size_t i = 0; i < fields.size(); ++i) {
if (fields[i]->is_extension()) {
// With C++ descriptors, the field can always be retrieved, but for
// unknown extensions which have not been imported in Python code, there
// is no message class and we cannot retrieve the value.
// ListFields() has the same behavior.
if (fields[i]->message_type() != nullptr &&
message_factory::GetMessageClass(
cmessage::GetFactoryForMessage(self->parent),
fields[i]->message_type()) == nullptr) {
PyErr_Clear();
continue;
}
++size;
}
}
return size;
}
struct ExtensionIterator {
PyObject_HEAD;
Py_ssize_t index;
std::vector<const FieldDescriptor*> fields;
// Owned reference, to keep the FieldDescriptors alive.
ExtensionDict* extension_dict;
};
PyObject* GetIter(PyObject* _self) {
ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
ScopedPyObjectPtr obj(PyType_GenericAlloc(&ExtensionIterator_Type, 0));
if (obj == nullptr) {
return PyErr_Format(PyExc_MemoryError,
"Could not allocate extension iterator");
}
ExtensionIterator* iter = reinterpret_cast<ExtensionIterator*>(obj.get());
// Call "placement new" to initialize. So the constructor of
// std::vector<...> fields will be called.
new (iter) ExtensionIterator;
self->parent->message->GetReflection()->ListFields(*self->parent->message,
&iter->fields);
iter->index = 0;
Py_INCREF(self);
iter->extension_dict = self;
return obj.release();
}
static void DeallocExtensionIterator(PyObject* _self) {
ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self);
self->fields.clear();
Py_XDECREF(self->extension_dict);
freefunc tp_free = Py_TYPE(_self)->tp_free;
self->~ExtensionIterator();
(*tp_free)(_self);
}
PyObject* subscript(ExtensionDict* self, PyObject* key) {
const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
if (descriptor == nullptr) {
return nullptr;
}
if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
return nullptr;
}
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
return cmessage::InternalGetScalar(self->parent->message, descriptor);
}
CMessage::CompositeFieldsMap::iterator iterator =
self->parent->composite_fields->find(descriptor);
if (iterator != self->parent->composite_fields->end()) {
Py_INCREF(iterator->second);
return iterator->second->AsPyObject();
}
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
// TODO(plabatut): consider building the class on the fly!
ContainerBase* sub_message = cmessage::InternalGetSubMessage(
self->parent, descriptor);
if (sub_message == nullptr) {
return nullptr;
}
(*self->parent->composite_fields)[descriptor] = sub_message;
return sub_message->AsPyObject();
}
if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
// On the fly message class creation is needed to support the following
// situation:
// 1- add FileDescriptor to the pool that contains extensions of a message
// defined by another proto file. Do not create any message classes.
// 2- instantiate an extended message, and access the extension using
// the field descriptor.
// 3- the extension submessage fails to be returned, because no class has
// been created.
// It happens when deserializing text proto format, or when enumerating
// fields of a deserialized message.
CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
cmessage::GetFactoryForMessage(self->parent),
descriptor->message_type());
ScopedPyObjectPtr message_class_handler(
reinterpret_cast<PyObject*>(message_class));
if (message_class == nullptr) {
return nullptr;
}
ContainerBase* py_container = repeated_composite_container::NewContainer(
self->parent, descriptor, message_class);
if (py_container == nullptr) {
return nullptr;
}
(*self->parent->composite_fields)[descriptor] = py_container;
return py_container->AsPyObject();
} else {
ContainerBase* py_container = repeated_scalar_container::NewContainer(
self->parent, descriptor);
if (py_container == nullptr) {
return nullptr;
}
(*self->parent->composite_fields)[descriptor] = py_container;
return py_container->AsPyObject();
}
}
PyErr_SetString(PyExc_ValueError, "control reached unexpected line");
return nullptr;
}
int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
if (descriptor == nullptr) {
return -1;
}
if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
return -1;
}
if (value == nullptr) {
return cmessage::ClearFieldByDescriptor(self->parent, descriptor);
}
if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
"type");
return -1;
}
cmessage::AssureWritable(self->parent);
if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
return -1;
}
return 0;
}
PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
char* name;
Py_ssize_t name_size;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return nullptr;
}
PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
const FieldDescriptor* message_extension =
pool->pool->FindExtensionByName(absl::string_view(name, name_size));
if (message_extension == nullptr) {
// Is is the name of a message set extension?
const Descriptor* message_descriptor =
pool->pool->FindMessageTypeByName(absl::string_view(name, name_size));
if (message_descriptor && message_descriptor->extension_count() > 0) {
const FieldDescriptor* extension = message_descriptor->extension(0);
if (extension->is_extension() &&
extension->containing_type()->options().message_set_wire_format() &&
extension->type() == FieldDescriptor::TYPE_MESSAGE &&
extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
message_extension = extension;
}
}
}
if (message_extension == nullptr) {
Py_RETURN_NONE;
}
return PyFieldDescriptor_FromDescriptor(message_extension);
}
PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
int64_t number = PyLong_AsLong(arg);
if (number == -1 && PyErr_Occurred()) {
return nullptr;
}
PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
self->parent->message->GetDescriptor(), number);
if (message_extension == nullptr) {
Py_RETURN_NONE;
}
return PyFieldDescriptor_FromDescriptor(message_extension);
}
static int Contains(PyObject* _self, PyObject* key) {
ExtensionDict* self = reinterpret_cast<ExtensionDict*>(_self);
const FieldDescriptor* field_descriptor =
cmessage::GetExtensionDescriptor(key);
if (field_descriptor == nullptr) {
return -1;
}
if (!field_descriptor->is_extension()) {
PyErr_Format(PyExc_KeyError, "%s is not an extension",
field_descriptor->full_name().c_str());
return -1;
}
const Message* message = self->parent->message;
const Reflection* reflection = message->GetReflection();
if (field_descriptor->is_repeated()) {
if (reflection->FieldSize(*message, field_descriptor) > 0) {
return 1;
}
} else {
if (reflection->HasField(*message, field_descriptor)) {
return 1;
}
}
return 0;
}
ExtensionDict* NewExtensionDict(CMessage *parent) {
ExtensionDict* self = reinterpret_cast<ExtensionDict*>(
PyType_GenericAlloc(&ExtensionDict_Type, 0));
if (self == nullptr) {
return nullptr;
}
Py_INCREF(parent);
self->parent = parent;
return self;
}
void dealloc(PyObject* pself) {
ExtensionDict* self = reinterpret_cast<ExtensionDict*>(pself);
Py_CLEAR(self->parent);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) {
// Only equality comparisons are implemented.
if (opid != Py_EQ && opid != Py_NE) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
bool equals = false;
if (PyObject_TypeCheck(other, &ExtensionDict_Type)) {
equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;
}
if (equals ^ (opid == Py_EQ)) {
Py_RETURN_FALSE;
} else {
Py_RETURN_TRUE;
}
}
static PySequenceMethods SeqMethods = {
(lenfunc)len, // sq_length
nullptr, // sq_concat
nullptr, // sq_repeat
nullptr, // sq_item
nullptr, // sq_slice
nullptr, // sq_ass_item
nullptr, // sq_ass_slice
(objobjproc)Contains, // sq_contains
};
static PyMappingMethods MpMethods = {
(lenfunc)len, /* mp_length */
(binaryfunc)subscript, /* mp_subscript */
(objobjargproc)ass_subscript,/* mp_ass_subscript */
};
#define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
static PyMethodDef Methods[] = {
EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."),
EDMETHOD(_FindExtensionByNumber, METH_O,
"Finds an extension by field number."),
{nullptr, nullptr},
};
} // namespace extension_dict
PyTypeObject ExtensionDict_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) //
FULL_MODULE_NAME ".ExtensionDict", // tp_name
sizeof(ExtensionDict), // tp_basicsize
0, // tp_itemsize
(destructor)extension_dict::dealloc, // tp_dealloc
#if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr
nullptr, // tp_setattr
nullptr, // tp_compare
nullptr, // tp_repr
nullptr, // tp_as_number
&extension_dict::SeqMethods, // tp_as_sequence
&extension_dict::MpMethods, // tp_as_mapping
PyObject_HashNotImplemented, // tp_hash
nullptr, // tp_call
nullptr, // tp_str
nullptr, // tp_getattro
nullptr, // tp_setattro
nullptr, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
"An extension dict", // tp_doc
nullptr, // tp_traverse
nullptr, // tp_clear
(richcmpfunc)extension_dict::RichCompare, // tp_richcompare
0, // tp_weaklistoffset
extension_dict::GetIter, // tp_iter
nullptr, // tp_iternext
extension_dict::Methods, // tp_methods
nullptr, // tp_members
nullptr, // tp_getset
nullptr, // tp_base
nullptr, // tp_dict
nullptr, // tp_descr_get
nullptr, // tp_descr_set
0, // tp_dictoffset
nullptr, // tp_init
};
PyObject* IterNext(PyObject* _self) {
extension_dict::ExtensionIterator* self =
reinterpret_cast<extension_dict::ExtensionIterator*>(_self);
Py_ssize_t total_size = self->fields.size();
Py_ssize_t index = self->index;
while (self->index < total_size) {
index = self->index;
++self->index;
if (self->fields[index]->is_extension()) {
// With C++ descriptors, the field can always be retrieved, but for
// unknown extensions which have not been imported in Python code, there
// is no message class and we cannot retrieve the value.
// ListFields() has the same behavior.
if (self->fields[index]->message_type() != nullptr &&
message_factory::GetMessageClass(
cmessage::GetFactoryForMessage(self->extension_dict->parent),
self->fields[index]->message_type()) == nullptr) {
PyErr_Clear();
continue;
}
return PyFieldDescriptor_FromDescriptor(self->fields[index]);
}
}
return nullptr;
}
PyTypeObject ExtensionIterator_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) //
FULL_MODULE_NAME ".ExtensionIterator", // tp_name
sizeof(extension_dict::ExtensionIterator), // tp_basicsize
0, // tp_itemsize
extension_dict::DeallocExtensionIterator, // tp_dealloc
#if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr
nullptr, // tp_setattr
nullptr, // tp_compare
nullptr, // tp_repr
nullptr, // tp_as_number
nullptr, // tp_as_sequence
nullptr, // tp_as_mapping
nullptr, // tp_hash
nullptr, // tp_call
nullptr, // tp_str
nullptr, // tp_getattro
nullptr, // tp_setattro
nullptr, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
"A scalar map iterator", // tp_doc
nullptr, // tp_traverse
nullptr, // tp_clear
nullptr, // tp_richcompare
0, // tp_weaklistoffset
PyObject_SelfIter, // tp_iter
IterNext, // tp_iternext
nullptr, // tp_methods
nullptr, // tp_members
nullptr, // tp_getset
nullptr, // tp_base
nullptr, // tp_dict
nullptr, // tp_descr_get
nullptr, // tp_descr_set
0, // tp_dictoffset
nullptr, // tp_init
};
} // namespace python
} // namespace protobuf
} // namespace google

View File

@@ -0,0 +1,70 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/pyext/message.h"
namespace google {
namespace protobuf {
class Message;
class FieldDescriptor;
namespace python {
typedef struct ExtensionDict {
PyObject_HEAD;
// Strong, owned reference to the parent message. Never NULL.
CMessage* parent;
} ExtensionDict;
extern PyTypeObject ExtensionDict_Type;
extern PyTypeObject ExtensionIterator_Type;
namespace extension_dict {
// Builds an Extensions dict for a specific message.
ExtensionDict* NewExtensionDict(CMessage *parent);
} // namespace extension_dict
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__

View File

@@ -0,0 +1,143 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "google/protobuf/pyext/field.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/message.h"
namespace google {
namespace protobuf {
namespace python {
namespace field {
static PyObject* Repr(PyMessageFieldProperty* self) {
return PyUnicode_FromFormat("<field property '%s'>",
self->field_descriptor->full_name().c_str());
}
static PyObject* DescrGet(PyMessageFieldProperty* self, PyObject* obj,
PyObject* type) {
if (obj == nullptr) {
Py_INCREF(self);
return reinterpret_cast<PyObject*>(self);
}
return cmessage::GetFieldValue(reinterpret_cast<CMessage*>(obj),
self->field_descriptor);
}
static int DescrSet(PyMessageFieldProperty* self, PyObject* obj,
PyObject* value) {
if (value == nullptr) {
PyErr_SetString(PyExc_AttributeError, "Cannot delete field attribute");
return -1;
}
return cmessage::SetFieldValue(reinterpret_cast<CMessage*>(obj),
self->field_descriptor, value);
}
static PyObject* GetDescriptor(PyMessageFieldProperty* self, void* closure) {
return PyFieldDescriptor_FromDescriptor(self->field_descriptor);
}
static PyObject* GetDoc(PyMessageFieldProperty* self, void* closure) {
return PyUnicode_FromFormat("Field %s",
self->field_descriptor->full_name().c_str());
}
static PyGetSetDef Getters[] = {
{"DESCRIPTOR", (getter)GetDescriptor, nullptr, "Field descriptor"},
{"__doc__", (getter)GetDoc, nullptr, nullptr},
{nullptr},
};
} // namespace field
static PyTypeObject _CFieldProperty_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) // head
FULL_MODULE_NAME ".FieldProperty", // tp_name
sizeof(PyMessageFieldProperty), // tp_basicsize
0, // tp_itemsize
nullptr, // tp_dealloc
#if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr
nullptr, // tp_setattr
nullptr, // tp_compare
(reprfunc)field::Repr, // tp_repr
nullptr, // tp_as_number
nullptr, // tp_as_sequence
nullptr, // tp_as_mapping
nullptr, // tp_hash
nullptr, // tp_call
nullptr, // tp_str
nullptr, // tp_getattro
nullptr, // tp_setattro
nullptr, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
"Field property of a Message", // tp_doc
nullptr, // tp_traverse
nullptr, // tp_clear
nullptr, // tp_richcompare
0, // tp_weaklistoffset
nullptr, // tp_iter
nullptr, // tp_iternext
nullptr, // tp_methods
nullptr, // tp_members
field::Getters, // tp_getset
nullptr, // tp_base
nullptr, // tp_dict
(descrgetfunc)field::DescrGet, // tp_descr_get
(descrsetfunc)field::DescrSet, // tp_descr_set
0, // tp_dictoffset
nullptr, // tp_init
nullptr, // tp_alloc
nullptr, // tp_new
};
PyTypeObject* CFieldProperty_Type = &_CFieldProperty_Type;
PyObject* NewFieldProperty(const FieldDescriptor* field_descriptor) {
// Create a new descriptor object
PyMessageFieldProperty* property =
PyObject_New(PyMessageFieldProperty, CFieldProperty_Type);
if (property == nullptr) {
return nullptr;
}
property->field_descriptor = field_descriptor;
return reinterpret_cast<PyObject*>(property);
}
} // namespace python
} // namespace protobuf
} // namespace google

View File

@@ -0,0 +1,60 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_FIELD_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_FIELD_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
namespace google {
namespace protobuf {
class FieldDescriptor;
namespace python {
// A data descriptor that represents a field in a Message class.
struct PyMessageFieldProperty {
PyObject_HEAD;
// This pointer is owned by the same pool as the Message class it belongs to.
const FieldDescriptor* field_descriptor;
};
extern PyTypeObject* CFieldProperty_Type;
PyObject* NewFieldProperty(const FieldDescriptor* field_descriptor);
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_FIELD_H__

View File

@@ -0,0 +1,931 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: haberman@google.com (Josh Haberman)
#include "google/protobuf/pyext/map_container.h"
#include <cstdint>
#include <memory>
#include <string>
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/map.h"
#include "google/protobuf/map_field.h"
#include "google/protobuf/message.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/repeated_composite_container.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
namespace google {
namespace protobuf {
namespace python {
// Functions that need access to map reflection functionality.
// They need to be contained in this class because it is friended.
class MapReflectionFriend {
public:
// Methods that are in common between the map types.
static PyObject* Contains(PyObject* _self, PyObject* key);
static Py_ssize_t Length(PyObject* _self);
static PyObject* GetIterator(PyObject *_self);
static PyObject* IterNext(PyObject* _self);
static PyObject* MergeFrom(PyObject* _self, PyObject* arg);
// Methods that differ between the map types.
static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
static PyObject* ScalarMapToStr(PyObject* _self);
static PyObject* MessageMapToStr(PyObject* _self);
};
struct MapIterator {
PyObject_HEAD;
std::unique_ptr<::google::protobuf::MapIterator> iter;
// A pointer back to the container, so we can notice changes to the version.
// We own a ref on this.
MapContainer* container;
// We need to keep a ref on the parent Message too, because
// MapIterator::~MapIterator() accesses it. Normally this would be ok because
// the ref on container (above) would guarantee outlive semantics. However in
// the case of ClearField(), the MapContainer points to a different message,
// a copy of the original. But our iterator still points to the original,
// which could now get deleted before us.
//
// To prevent this, we ensure that the Message will always stay alive as long
// as this iterator does. This is solely for the benefit of the MapIterator
// destructor -- we should never actually access the iterator in this state
// except to delete it.
CMessage* parent;
// The version of the map when we took the iterator to it.
//
// We store this so that if the map is modified during iteration we can throw
// an error.
uint64_t version;
};
Message* MapContainer::GetMutableMessage() {
cmessage::AssureWritable(parent);
return parent->message;
}
// Consumes a reference on the Python string object.
static bool PyStringToSTL(PyObject* py_string, std::string* stl_string) {
char *value;
Py_ssize_t value_len;
if (!py_string) {
return false;
}
if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
Py_DECREF(py_string);
return false;
} else {
stl_string->assign(value, value_len);
Py_DECREF(py_string);
return true;
}
}
static bool PythonToMapKey(MapContainer* self, PyObject* obj, MapKey* key) {
const FieldDescriptor* field_descriptor =
self->parent_field_descriptor->message_type()->map_key();
switch (field_descriptor->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32: {
GOOGLE_CHECK_GET_INT32(obj, value, false);
key->SetInt32Value(value);
break;
}
case FieldDescriptor::CPPTYPE_INT64: {
GOOGLE_CHECK_GET_INT64(obj, value, false);
key->SetInt64Value(value);
break;
}
case FieldDescriptor::CPPTYPE_UINT32: {
GOOGLE_CHECK_GET_UINT32(obj, value, false);
key->SetUInt32Value(value);
break;
}
case FieldDescriptor::CPPTYPE_UINT64: {
GOOGLE_CHECK_GET_UINT64(obj, value, false);
key->SetUInt64Value(value);
break;
}
case FieldDescriptor::CPPTYPE_BOOL: {
GOOGLE_CHECK_GET_BOOL(obj, value, false);
key->SetBoolValue(value);
break;
}
case FieldDescriptor::CPPTYPE_STRING: {
std::string str;
if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
return false;
}
key->SetStringValue(str);
break;
}
default:
PyErr_Format(
PyExc_SystemError, "Type %d cannot be a map key",
field_descriptor->cpp_type());
return false;
}
return true;
}
static PyObject* MapKeyToPython(MapContainer* self, const MapKey& key) {
const FieldDescriptor* field_descriptor =
self->parent_field_descriptor->message_type()->map_key();
switch (field_descriptor->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32:
return PyLong_FromLong(key.GetInt32Value());
case FieldDescriptor::CPPTYPE_INT64:
return PyLong_FromLongLong(key.GetInt64Value());
case FieldDescriptor::CPPTYPE_UINT32:
return PyLong_FromSize_t(key.GetUInt32Value());
case FieldDescriptor::CPPTYPE_UINT64:
return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
case FieldDescriptor::CPPTYPE_BOOL:
return PyBool_FromLong(key.GetBoolValue());
case FieldDescriptor::CPPTYPE_STRING:
return ToStringObject(field_descriptor, key.GetStringValue());
default:
PyErr_Format(
PyExc_SystemError, "Couldn't convert type %d to value",
field_descriptor->cpp_type());
return nullptr;
}
}
// This is only used for ScalarMap, so we don't need to handle the
// CPPTYPE_MESSAGE case.
PyObject* MapValueRefToPython(MapContainer* self, const MapValueRef& value) {
const FieldDescriptor* field_descriptor =
self->parent_field_descriptor->message_type()->map_value();
switch (field_descriptor->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32:
return PyLong_FromLong(value.GetInt32Value());
case FieldDescriptor::CPPTYPE_INT64:
return PyLong_FromLongLong(value.GetInt64Value());
case FieldDescriptor::CPPTYPE_UINT32:
return PyLong_FromSize_t(value.GetUInt32Value());
case FieldDescriptor::CPPTYPE_UINT64:
return PyLong_FromUnsignedLongLong(value.GetUInt64Value());
case FieldDescriptor::CPPTYPE_FLOAT:
return PyFloat_FromDouble(value.GetFloatValue());
case FieldDescriptor::CPPTYPE_DOUBLE:
return PyFloat_FromDouble(value.GetDoubleValue());
case FieldDescriptor::CPPTYPE_BOOL:
return PyBool_FromLong(value.GetBoolValue());
case FieldDescriptor::CPPTYPE_STRING:
return ToStringObject(field_descriptor, value.GetStringValue());
case FieldDescriptor::CPPTYPE_ENUM:
return PyLong_FromLong(value.GetEnumValue());
default:
PyErr_Format(
PyExc_SystemError, "Couldn't convert type %d to value",
field_descriptor->cpp_type());
return nullptr;
}
}
// This is only used for ScalarMap, so we don't need to handle the
// CPPTYPE_MESSAGE case.
static bool PythonToMapValueRef(MapContainer* self, PyObject* obj,
bool allow_unknown_enum_values,
MapValueRef* value_ref) {
const FieldDescriptor* field_descriptor =
self->parent_field_descriptor->message_type()->map_value();
switch (field_descriptor->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32: {
GOOGLE_CHECK_GET_INT32(obj, value, false);
value_ref->SetInt32Value(value);
return true;
}
case FieldDescriptor::CPPTYPE_INT64: {
GOOGLE_CHECK_GET_INT64(obj, value, false);
value_ref->SetInt64Value(value);
return true;
}
case FieldDescriptor::CPPTYPE_UINT32: {
GOOGLE_CHECK_GET_UINT32(obj, value, false);
value_ref->SetUInt32Value(value);
return true;
}
case FieldDescriptor::CPPTYPE_UINT64: {
GOOGLE_CHECK_GET_UINT64(obj, value, false);
value_ref->SetUInt64Value(value);
return true;
}
case FieldDescriptor::CPPTYPE_FLOAT: {
GOOGLE_CHECK_GET_FLOAT(obj, value, false);
value_ref->SetFloatValue(value);
return true;
}
case FieldDescriptor::CPPTYPE_DOUBLE: {
GOOGLE_CHECK_GET_DOUBLE(obj, value, false);
value_ref->SetDoubleValue(value);
return true;
}
case FieldDescriptor::CPPTYPE_BOOL: {
GOOGLE_CHECK_GET_BOOL(obj, value, false);
value_ref->SetBoolValue(value);
return true;
}
case FieldDescriptor::CPPTYPE_STRING: {
std::string str;
if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
return false;
}
value_ref->SetStringValue(str);
return true;
}
case FieldDescriptor::CPPTYPE_ENUM: {
GOOGLE_CHECK_GET_INT32(obj, value, false);
if (allow_unknown_enum_values) {
value_ref->SetEnumValue(value);
return true;
} else {
const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
const EnumValueDescriptor* enum_value =
enum_descriptor->FindValueByNumber(value);
if (enum_value != nullptr) {
value_ref->SetEnumValue(value);
return true;
} else {
PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
return false;
}
}
break;
}
default:
PyErr_Format(
PyExc_SystemError, "Setting value to a field of unknown type %d",
field_descriptor->cpp_type());
return false;
}
}
// Map methods common to ScalarMap and MessageMap //////////////////////////////
static MapContainer* GetMap(PyObject* obj) {
return reinterpret_cast<MapContainer*>(obj);
}
Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
MapContainer* self = GetMap(_self);
const google::protobuf::Message* message = self->parent->message;
return message->GetReflection()->MapSize(*message,
self->parent_field_descriptor);
}
PyObject* Clear(PyObject* _self) {
MapContainer* self = GetMap(_self);
Message* message = self->GetMutableMessage();
const Reflection* reflection = message->GetReflection();
reflection->ClearField(message, self->parent_field_descriptor);
Py_RETURN_NONE;
}
PyObject* GetEntryClass(PyObject* _self) {
MapContainer* self = GetMap(_self);
CMessageClass* message_class = message_factory::GetMessageClass(
cmessage::GetFactoryForMessage(self->parent),
self->parent_field_descriptor->message_type());
Py_XINCREF(message_class);
return reinterpret_cast<PyObject*>(message_class);
}
PyObject* MapReflectionFriend::MergeFrom(PyObject* _self, PyObject* arg) {
MapContainer* self = GetMap(_self);
if (!PyObject_TypeCheck(arg, ScalarMapContainer_Type) &&
!PyObject_TypeCheck(arg, MessageMapContainer_Type)) {
PyErr_SetString(PyExc_AttributeError, "Not a map field");
return nullptr;
}
MapContainer* other_map = GetMap(arg);
Message* message = self->GetMutableMessage();
const Message* other_message = other_map->parent->message;
const Reflection* reflection = message->GetReflection();
const Reflection* other_reflection = other_message->GetReflection();
internal::MapFieldBase* field = reflection->MutableMapData(
message, self->parent_field_descriptor);
const internal::MapFieldBase* other_field = other_reflection->GetMapData(
*other_message, other_map->parent_field_descriptor);
field->MergeFrom(*other_field);
self->version++;
Py_RETURN_NONE;
}
PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
MapContainer* self = GetMap(_self);
const Message* message = self->parent->message;
const Reflection* reflection = message->GetReflection();
MapKey map_key;
if (!PythonToMapKey(self, key, &map_key)) {
return nullptr;
}
if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
map_key)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
// ScalarMap ///////////////////////////////////////////////////////////////////
MapContainer* NewScalarMapContainer(
CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
return nullptr;
}
PyObject* obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0));
if (obj == nullptr) {
PyErr_Format(PyExc_RuntimeError,
"Could not allocate new container.");
return nullptr;
}
MapContainer* self = GetMap(obj);
Py_INCREF(parent);
self->parent = parent;
self->parent_field_descriptor = parent_field_descriptor;
self->version = 0;
return self;
}
PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
PyObject* key) {
MapContainer* self = GetMap(_self);
Message* message = self->GetMutableMessage();
const Reflection* reflection = message->GetReflection();
MapKey map_key;
MapValueRef value;
if (!PythonToMapKey(self, key, &map_key)) {
return nullptr;
}
if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
map_key, &value)) {
self->version++;
}
return MapValueRefToPython(self, value);
}
int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
PyObject* v) {
MapContainer* self = GetMap(_self);
Message* message = self->GetMutableMessage();
const Reflection* reflection = message->GetReflection();
MapKey map_key;
MapValueRef value;
if (!PythonToMapKey(self, key, &map_key)) {
return -1;
}
if (v) {
// Set item to v.
if (reflection->InsertOrLookupMapValue(
message, self->parent_field_descriptor, map_key, &value)) {
self->version++;
}
if (!PythonToMapValueRef(self, v, reflection->SupportsUnknownEnumValues(),
&value)) {
return -1;
}
return 0;
} else {
// Delete key from map.
if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
map_key)) {
self->version++;
return 0;
} else {
PyErr_Format(PyExc_KeyError, "Key not present in map");
return -1;
}
}
}
static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
PyObject* kwargs) {
static const char* kwlist[] = {"key", "default", nullptr};
PyObject* key;
PyObject* default_value = nullptr;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O",
const_cast<char**>(kwlist), &key,
&default_value)) {
return nullptr;
}
ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
if (is_present.get() == nullptr) {
return nullptr;
}
if (PyObject_IsTrue(is_present.get())) {
return MapReflectionFriend::ScalarMapGetItem(self, key);
} else {
if (default_value != nullptr) {
Py_INCREF(default_value);
return default_value;
} else {
Py_RETURN_NONE;
}
}
}
PyObject* MapReflectionFriend::ScalarMapToStr(PyObject* _self) {
ScopedPyObjectPtr dict(PyDict_New());
if (dict == nullptr) {
return nullptr;
}
ScopedPyObjectPtr key;
ScopedPyObjectPtr value;
MapContainer* self = GetMap(_self);
Message* message = self->GetMutableMessage();
const Reflection* reflection = message->GetReflection();
for (google::protobuf::MapIterator it = reflection->MapBegin(
message, self->parent_field_descriptor);
it != reflection->MapEnd(message, self->parent_field_descriptor);
++it) {
key.reset(MapKeyToPython(self, it.GetKey()));
if (key == nullptr) {
return nullptr;
}
value.reset(MapValueRefToPython(self, it.GetValueRef()));
if (value == nullptr) {
return nullptr;
}
if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
return nullptr;
}
}
return PyObject_Repr(dict.get());
}
static void ScalarMapDealloc(PyObject* _self) {
MapContainer* self = GetMap(_self);
self->RemoveFromParentCache();
PyTypeObject *type = Py_TYPE(_self);
type->tp_free(_self);
if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
// With Python3, the Map class is not static, and must be managed.
Py_DECREF(type);
}
}
static PyMethodDef ScalarMapMethods[] = {
{"__contains__", MapReflectionFriend::Contains, METH_O,
"Tests whether a key is a member of the map."},
{"clear", (PyCFunction)Clear, METH_NOARGS,
"Removes all elements from the map."},
{"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS,
"Gets the value for the given key if present, or otherwise a default"},
{"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
"Return the class used to build Entries of (key, value) pairs."},
{"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
"Merges a map into the current map."},
/*
{ "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
"Makes a deep copy of the class." },
{ "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
"Outputs picklable representation of the repeated field." },
*/
{nullptr, nullptr},
};
PyTypeObject* ScalarMapContainer_Type;
static PyType_Slot ScalarMapContainer_Type_slots[] = {
{Py_tp_dealloc, (void*)ScalarMapDealloc},
{Py_mp_length, (void*)MapReflectionFriend::Length},
{Py_mp_subscript, (void*)MapReflectionFriend::ScalarMapGetItem},
{Py_mp_ass_subscript, (void*)MapReflectionFriend::ScalarMapSetItem},
{Py_tp_methods, (void*)ScalarMapMethods},
{Py_tp_iter, (void*)MapReflectionFriend::GetIterator},
{Py_tp_repr, (void*)MapReflectionFriend::ScalarMapToStr},
{0, nullptr},
};
PyType_Spec ScalarMapContainer_Type_spec = {
FULL_MODULE_NAME ".ScalarMapContainer", sizeof(MapContainer), 0,
Py_TPFLAGS_DEFAULT, ScalarMapContainer_Type_slots};
// MessageMap //////////////////////////////////////////////////////////////////
static MessageMapContainer* GetMessageMap(PyObject* obj) {
return reinterpret_cast<MessageMapContainer*>(obj);
}
static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
// Get or create the CMessage object corresponding to this message.
return self->parent
->BuildSubMessageFromPointer(self->parent_field_descriptor, message,
self->message_class)
->AsPyObject();
}
MessageMapContainer* NewMessageMapContainer(
CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
CMessageClass* message_class) {
if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
return nullptr;
}
PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0);
if (obj == nullptr) {
PyErr_SetString(PyExc_RuntimeError, "Could not allocate new container.");
return nullptr;
}
MessageMapContainer* self = GetMessageMap(obj);
Py_INCREF(parent);
self->parent = parent;
self->parent_field_descriptor = parent_field_descriptor;
self->version = 0;
Py_INCREF(message_class);
self->message_class = message_class;
return self;
}
int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
PyObject* v) {
if (v) {
PyErr_Format(PyExc_ValueError,
"Direct assignment of submessage not allowed");
return -1;
}
// Now we know that this is a delete, not a set.
MessageMapContainer* self = GetMessageMap(_self);
Message* message = self->GetMutableMessage();
const Reflection* reflection = message->GetReflection();
MapKey map_key;
MapValueRef value;
self->version++;
if (!PythonToMapKey(self, key, &map_key)) {
return -1;
}
// Delete key from map.
if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
map_key)) {
// Delete key from CMessage dict.
MapValueRef value;
reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
map_key, &value);
Message* sub_message = value.MutableMessageValue();
// If there is a living weak reference to an item, we "Release" it,
// otherwise we just discard the C++ value.
if (CMessage* released =
self->parent->MaybeReleaseSubMessage(sub_message)) {
Message* msg = released->message;
released->message = msg->New();
msg->GetReflection()->Swap(msg, released->message);
}
// Delete key from map.
reflection->DeleteMapValue(message, self->parent_field_descriptor,
map_key);
return 0;
} else {
PyErr_Format(PyExc_KeyError, "Key not present in map");
return -1;
}
}
PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
PyObject* key) {
MessageMapContainer* self = GetMessageMap(_self);
Message* message = self->GetMutableMessage();
const Reflection* reflection = message->GetReflection();
MapKey map_key;
MapValueRef value;
if (!PythonToMapKey(self, key, &map_key)) {
return nullptr;
}
if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
map_key, &value)) {
self->version++;
}
return GetCMessage(self, value.MutableMessageValue());
}
PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) {
ScopedPyObjectPtr dict(PyDict_New());
if (dict == nullptr) {
return nullptr;
}
ScopedPyObjectPtr key;
ScopedPyObjectPtr value;
MessageMapContainer* self = GetMessageMap(_self);
Message* message = self->GetMutableMessage();
const Reflection* reflection = message->GetReflection();
for (google::protobuf::MapIterator it = reflection->MapBegin(
message, self->parent_field_descriptor);
it != reflection->MapEnd(message, self->parent_field_descriptor);
++it) {
key.reset(MapKeyToPython(self, it.GetKey()));
if (key == nullptr) {
return nullptr;
}
value.reset(GetCMessage(self, it.MutableValueRef()->MutableMessageValue()));
if (value == nullptr) {
return nullptr;
}
if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
return nullptr;
}
}
return PyObject_Repr(dict.get());
}
PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
static const char* kwlist[] = {"key", "default", nullptr};
PyObject* key;
PyObject* default_value = nullptr;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O",
const_cast<char**>(kwlist), &key,
&default_value)) {
return nullptr;
}
ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
if (is_present.get() == nullptr) {
return nullptr;
}
if (PyObject_IsTrue(is_present.get())) {
return MapReflectionFriend::MessageMapGetItem(self, key);
} else {
if (default_value != nullptr) {
Py_INCREF(default_value);
return default_value;
} else {
Py_RETURN_NONE;
}
}
}
static void MessageMapDealloc(PyObject* _self) {
MessageMapContainer* self = GetMessageMap(_self);
self->RemoveFromParentCache();
Py_DECREF(self->message_class);
PyTypeObject *type = Py_TYPE(_self);
type->tp_free(_self);
if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
// With Python3, the Map class is not static, and must be managed.
Py_DECREF(type);
}
}
static PyMethodDef MessageMapMethods[] = {
{"__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
"Tests whether the map contains this element."},
{"clear", (PyCFunction)Clear, METH_NOARGS,
"Removes all elements from the map."},
{"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS,
"Gets the value for the given key if present, or otherwise a default"},
{"get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
"Alias for getitem, useful to make explicit that the map is mutated."},
{"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
"Return the class used to build Entries of (key, value) pairs."},
{"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
"Merges a map into the current map."},
/*
{ "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
"Makes a deep copy of the class." },
{ "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
"Outputs picklable representation of the repeated field." },
*/
{nullptr, nullptr},
};
PyTypeObject* MessageMapContainer_Type;
static PyType_Slot MessageMapContainer_Type_slots[] = {
{Py_tp_dealloc, (void*)MessageMapDealloc},
{Py_mp_length, (void*)MapReflectionFriend::Length},
{Py_mp_subscript, (void*)MapReflectionFriend::MessageMapGetItem},
{Py_mp_ass_subscript, (void*)MapReflectionFriend::MessageMapSetItem},
{Py_tp_methods, (void*)MessageMapMethods},
{Py_tp_iter, (void*)MapReflectionFriend::GetIterator},
{Py_tp_repr, (void*)MapReflectionFriend::MessageMapToStr},
{0, nullptr}};
PyType_Spec MessageMapContainer_Type_spec = {
FULL_MODULE_NAME ".MessageMapContainer", sizeof(MessageMapContainer), 0,
Py_TPFLAGS_DEFAULT, MessageMapContainer_Type_slots};
// MapIterator /////////////////////////////////////////////////////////////////
static MapIterator* GetIter(PyObject* obj) {
return reinterpret_cast<MapIterator*>(obj);
}
PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
MapContainer* self = GetMap(_self);
ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
if (obj == nullptr) {
return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
}
MapIterator* iter = GetIter(obj.get());
Py_INCREF(self);
iter->container = self;
iter->version = self->version;
Py_INCREF(self->parent);
iter->parent = self->parent;
if (MapReflectionFriend::Length(_self) > 0) {
Message* message = self->GetMutableMessage();
const Reflection* reflection = message->GetReflection();
iter->iter.reset(new ::google::protobuf::MapIterator(
reflection->MapBegin(message, self->parent_field_descriptor)));
}
return obj.release();
}
PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
MapIterator* self = GetIter(_self);
// This won't catch mutations to the map performed by MergeFrom(); no easy way
// to address that.
if (self->version != self->container->version) {
return PyErr_Format(PyExc_RuntimeError,
"Map modified during iteration.");
}
if (self->parent != self->container->parent) {
return PyErr_Format(PyExc_RuntimeError,
"Map cleared during iteration.");
}
if (self->iter.get() == nullptr) {
return nullptr;
}
Message* message = self->container->GetMutableMessage();
const Reflection* reflection = message->GetReflection();
if (*self->iter ==
reflection->MapEnd(message, self->container->parent_field_descriptor)) {
return nullptr;
}
PyObject* ret = MapKeyToPython(self->container, self->iter->GetKey());
++(*self->iter);
return ret;
}
static void DeallocMapIterator(PyObject* _self) {
MapIterator* self = GetIter(_self);
self->iter.reset();
Py_CLEAR(self->container);
Py_CLEAR(self->parent);
Py_TYPE(_self)->tp_free(_self);
}
PyTypeObject MapIterator_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
".MapIterator", // tp_name
sizeof(MapIterator), // tp_basicsize
0, // tp_itemsize
DeallocMapIterator, // tp_dealloc
#if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr
nullptr, // tp_setattr
nullptr, // tp_compare
nullptr, // tp_repr
nullptr, // tp_as_number
nullptr, // tp_as_sequence
nullptr, // tp_as_mapping
nullptr, // tp_hash
nullptr, // tp_call
nullptr, // tp_str
nullptr, // tp_getattro
nullptr, // tp_setattro
nullptr, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
"A scalar map iterator", // tp_doc
nullptr, // tp_traverse
nullptr, // tp_clear
nullptr, // tp_richcompare
0, // tp_weaklistoffset
PyObject_SelfIter, // tp_iter
MapReflectionFriend::IterNext, // tp_iternext
nullptr, // tp_methods
nullptr, // tp_members
nullptr, // tp_getset
nullptr, // tp_base
nullptr, // tp_dict
nullptr, // tp_descr_get
nullptr, // tp_descr_set
0, // tp_dictoffset
nullptr, // tp_init
};
bool InitMapContainers() {
// ScalarMapContainer_Type derives from our MutableMapping type.
ScopedPyObjectPtr abc(PyImport_ImportModule("collections.abc"));
if (abc == nullptr) {
return false;
}
ScopedPyObjectPtr mutable_mapping(
PyObject_GetAttrString(abc.get(), "MutableMapping"));
if (mutable_mapping == nullptr) {
return false;
}
Py_INCREF(mutable_mapping.get());
ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get()));
if (bases == nullptr) {
return false;
}
ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get()));
if (PyType_Ready(&MapIterator_Type) < 0) {
return false;
}
MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get()));
return true;
}
} // namespace python
} // namespace protobuf
} // namespace google

View File

@@ -0,0 +1,89 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <cstdint>
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
#include "google/protobuf/pyext/message.h"
namespace google {
namespace protobuf {
class Message;
namespace python {
struct CMessageClass;
// This struct is used directly for ScalarMap, and is the base class of
// MessageMapContainer, which is used for MessageMap.
struct MapContainer : public ContainerBase {
// Use to get a mutable message when necessary.
Message* GetMutableMessage();
// We bump this whenever we perform a mutation, to invalidate existing
// iterators.
uint64_t version;
};
struct MessageMapContainer : public MapContainer {
// The type used to create new child messages.
CMessageClass* message_class;
};
bool InitMapContainers();
extern PyTypeObject* MessageMapContainer_Type;
extern PyTypeObject* ScalarMapContainer_Type;
extern PyTypeObject MapIterator_Type; // Both map types use the same iterator.
// Builds a MapContainer object, from a parent message and a
// field descriptor.
extern MapContainer* NewScalarMapContainer(
CMessage* parent, const FieldDescriptor* parent_field_descriptor);
// Builds a MessageMap object, from a parent message and a
// field descriptor.
extern MessageMapContainer* NewMessageMapContainer(
CMessage* parent, const FieldDescriptor* parent_field_descriptor,
CMessageClass* message_class);
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,377 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_map>
#include "google/protobuf/stubs/common.h"
namespace google {
namespace protobuf {
class Message;
class Reflection;
class FieldDescriptor;
class Descriptor;
class DescriptorPool;
class MessageFactory;
namespace python {
struct ExtensionDict;
struct PyMessageFactory;
struct CMessageClass;
// Most of the complexity of the Message class comes from the "Release"
// behavior:
//
// When a field is cleared, it is only detached from its message. Existing
// references to submessages, to repeated container etc. won't see any change,
// as if the data was effectively managed by these containers.
//
// ExtensionDicts and UnknownFields containers do NOT follow this rule. They
// don't store any data, and always refer to their parent message.
struct ContainerBase {
PyObject_HEAD;
// Strong reference to a parent message object. For a CMessage there are three
// cases:
// - For a top-level message, this pointer is NULL.
// - For a sub-message, this points to the parent message.
// - For a message managed externally, this is a owned reference to Py_None.
//
// For all other types: repeated containers, maps, it always point to a
// valid parent CMessage.
struct CMessage* parent;
// If this object belongs to a parent message, describes which field it comes
// from.
// The pointer is owned by the DescriptorPool (which is kept alive
// through the message's Python class)
const FieldDescriptor* parent_field_descriptor;
PyObject* AsPyObject() { return reinterpret_cast<PyObject*>(this); }
// The Three methods below are only used by Repeated containers, and Maps.
// This implementation works for all containers which have a parent.
PyObject* DeepCopy();
// Delete this container object from its parent. Does not work for messages.
void RemoveFromParentCache();
};
typedef struct CMessage : public ContainerBase {
// Pointer to the C++ Message object for this CMessage.
// - If this object has no parent, we own this pointer.
// - If this object has a parent message, the parent owns this pointer.
Message* message;
// Indicates this submessage is pointing to a default instance of a message.
// Submessages are always first created as read only messages and are then
// made writable, at which point this field is set to false.
bool read_only;
// A mapping indexed by field, containing weak references to contained objects
// which need to implement the "Release" mechanism:
// direct submessages, RepeatedCompositeContainer, RepeatedScalarContainer
// and MapContainer.
typedef std::unordered_map<const FieldDescriptor*, ContainerBase*>
CompositeFieldsMap;
CompositeFieldsMap* composite_fields;
// A mapping containing weak references to indirect child messages, accessed
// through containers: repeated messages, and values of message maps.
// This avoid the creation of similar maps in each of those containers.
typedef std::unordered_map<const Message*, CMessage*> SubMessagesMap;
SubMessagesMap* child_submessages;
// A reference to PyUnknownFields.
PyObject* unknown_field_set;
// Implements the "weakref" protocol for this object.
PyObject* weakreflist;
// Return a *borrowed* reference to the message class.
CMessageClass* GetMessageClass() {
return reinterpret_cast<CMessageClass*>(Py_TYPE(this));
}
// For container containing messages, return a Python object for the given
// pointer to a message.
CMessage* BuildSubMessageFromPointer(const FieldDescriptor* field_descriptor,
Message* sub_message,
CMessageClass* message_class);
CMessage* MaybeReleaseSubMessage(Message* sub_message);
} CMessage;
// The (meta) type of all Messages classes.
// It allows us to cache some C++ pointers in the class object itself, they are
// faster to extract than from the type's dictionary.
struct CMessageClass {
// This is how CPython subclasses C structures: the base structure must be
// the first member of the object.
PyHeapTypeObject super;
// C++ descriptor of this message.
const Descriptor* message_descriptor;
// Owned reference, used to keep the pointer above alive.
// This reference must stay alive until all message pointers are destructed.
PyObject* py_message_descriptor;
// The Python MessageFactory used to create the class. It is needed to resolve
// fields descriptors, including extensions fields; its C++ MessageFactory is
// used to instantiate submessages.
// This reference must stay alive until all message pointers are destructed.
PyMessageFactory* py_message_factory;
PyObject* AsPyObject() {
return reinterpret_cast<PyObject*>(this);
}
};
extern PyTypeObject* CMessageClass_Type;
extern PyTypeObject* CMessage_Type;
namespace cmessage {
// Internal function to create a new empty Message Python object, but with empty
// pointers to the C++ objects.
// The caller must fill self->message, self->owner and eventually self->parent.
CMessage* NewEmptyMessage(CMessageClass* type);
// Retrieves the C++ descriptor of a Python Extension descriptor.
// On error, return NULL with an exception set.
const FieldDescriptor* GetExtensionDescriptor(PyObject* extension);
// Initializes a new CMessage instance for a submessage. Only called once per
// submessage as the result is cached in composite_fields.
//
// Corresponds to reflection api method GetMessage.
CMessage* InternalGetSubMessage(
CMessage* self, const FieldDescriptor* field_descriptor);
// Deletes a range of items in a repeated field (following a
// removal in a RepeatedCompositeContainer).
//
// Corresponds to reflection api method RemoveLast.
int DeleteRepeatedField(CMessage* self,
const FieldDescriptor* field_descriptor,
PyObject* slice);
// Sets the specified scalar value to the message.
int InternalSetScalar(CMessage* self,
const FieldDescriptor* field_descriptor,
PyObject* value);
// Sets the specified scalar value to the message. Requires it is not a Oneof.
int InternalSetNonOneofScalar(Message* message,
const FieldDescriptor* field_descriptor,
PyObject* arg);
// Retrieves the specified scalar value from the message.
//
// Returns a new python reference.
PyObject* InternalGetScalar(const Message* message,
const FieldDescriptor* field_descriptor);
bool SetCompositeField(CMessage* self, const FieldDescriptor* field,
ContainerBase* value);
bool SetSubmessage(CMessage* self, CMessage* submessage);
// Clears the message, removing all contained data. Extension dictionary and
// submessages are released first if there are remaining external references.
//
// Corresponds to message api method Clear.
PyObject* Clear(CMessage* self);
// Clears the data described by the given descriptor.
// Returns -1 on error.
//
// Corresponds to reflection api method ClearField.
int ClearFieldByDescriptor(CMessage* self, const FieldDescriptor* descriptor);
// Checks if the message has the field described by the descriptor. Used for
// extensions (which have no name).
// Returns 1 if true, 0 if false, and -1 on error.
//
// Corresponds to reflection api method HasField
int HasFieldByDescriptor(CMessage* self,
const FieldDescriptor* field_descriptor);
// Checks if the message has the named field.
//
// Corresponds to reflection api method HasField.
PyObject* HasField(CMessage* self, PyObject* arg);
// Initializes values of fields on a newly constructed message.
// Note that positional arguments are disallowed: 'args' must be NULL or the
// empty tuple.
int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs);
PyObject* MergeFrom(CMessage* self, PyObject* arg);
// This method does not do anything beyond checking that no other extension
// has been registered with the same field number on this class.
PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle);
// Get a field from a message.
PyObject* GetFieldValue(CMessage* self,
const FieldDescriptor* field_descriptor);
// Sets the value of a scalar field in a message.
// On error, return -1 with an extension set.
int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor,
PyObject* value);
PyObject* FindInitializationErrors(CMessage* self);
int AssureWritable(CMessage* self);
// Returns the message factory for the given message.
// This is equivalent to message.MESSAGE_FACTORY
//
// The returned factory is suitable for finding fields and building submessages,
// even in the case of extensions.
// Returns a *borrowed* reference, and never fails because we pass a CMessage.
PyMessageFactory* GetFactoryForMessage(CMessage* message);
PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg);
} // namespace cmessage
/* Is 64bit */
#define IS_64BIT (SIZEOF_LONG == 8)
#define FIELD_IS_REPEATED(field_descriptor) \
((field_descriptor)->label() == FieldDescriptor::LABEL_REPEATED)
#define GOOGLE_CHECK_GET_INT32(arg, value, err) \
int32_t value; \
if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_INT64(arg, value, err) \
int64_t value; \
if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_UINT32(arg, value, err) \
uint32_t value; \
if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_UINT64(arg, value, err) \
uint64_t value; \
if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_FLOAT(arg, value, err) \
float value; \
if (!CheckAndGetFloat(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_DOUBLE(arg, value, err) \
double value; \
if (!CheckAndGetDouble(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_BOOL(arg, value, err) \
bool value; \
if (!CheckAndGetBool(arg, &value)) { \
return err; \
}
#define FULL_MODULE_NAME "google.protobuf.pyext._message"
void FormatTypeError(PyObject* arg, const char* expected_types);
template<class T>
bool CheckAndGetInteger(PyObject* arg, T* value);
bool CheckAndGetDouble(PyObject* arg, double* value);
bool CheckAndGetFloat(PyObject* arg, float* value);
bool CheckAndGetBool(PyObject* arg, bool* value);
PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor);
bool CheckAndSetString(
PyObject* arg, Message* message,
const FieldDescriptor* descriptor,
const Reflection* reflection,
bool append,
int index);
PyObject* ToStringObject(const FieldDescriptor* descriptor,
const std::string& value);
// Check if the passed field descriptor belongs to the given message.
// If not, return false and set a Python exception (a KeyError)
bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
const Message* message);
extern PyObject* PickleError_class;
PyObject* PyMessage_New(const Descriptor* descriptor,
PyObject* py_message_factory);
const Message* PyMessage_GetMessagePointer(PyObject* msg);
Message* PyMessage_GetMutableMessagePointer(PyObject* msg);
PyObject* PyMessage_NewMessageOwnedExternally(Message* message,
PyObject* py_message_factory);
bool InitProto2MessageModule(PyObject *m);
// These are referenced by repeated_scalar_container, and must
// be explicitly instantiated.
extern template bool CheckAndGetInteger<int32>(PyObject*, int32*);
extern template bool CheckAndGetInteger<int64>(PyObject*, int64*);
extern template bool CheckAndGetInteger<uint32>(PyObject*, uint32*);
extern template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__

View File

@@ -0,0 +1,307 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <unordered_map>
#include <utility>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob) \
? ((*(charpp) = const_cast<char*>( \
PyUnicode_AsUTF8AndSize(ob, (sizep)))) == nullptr \
? -1 \
: 0) \
: PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
namespace google {
namespace protobuf {
namespace python {
namespace message_factory {
PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
PyType_GenericAlloc(type, 0));
if (factory == nullptr) {
return nullptr;
}
DynamicMessageFactory* message_factory = new DynamicMessageFactory();
// This option might be the default some day.
message_factory->SetDelegateToGeneratedFactory(true);
factory->message_factory = message_factory;
factory->pool = pool;
Py_INCREF(pool);
factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
return factory;
}
PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
static const char* kwlist[] = {"pool", nullptr};
PyObject* pool = nullptr;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O",
const_cast<char**>(kwlist), &pool)) {
return nullptr;
}
ScopedPyObjectPtr owned_pool;
if (pool == nullptr || pool == Py_None) {
owned_pool.reset(PyObject_CallFunction(
reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), nullptr));
if (owned_pool == nullptr) {
return nullptr;
}
pool = owned_pool.get();
} else {
if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) {
PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s",
pool->ob_type->tp_name);
return nullptr;
}
}
return reinterpret_cast<PyObject*>(
NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool)));
}
static void Dealloc(PyObject* pself) {
PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
for (iterator it = self->classes_by_descriptor->begin();
it != self->classes_by_descriptor->end(); ++it) {
Py_CLEAR(it->second);
}
delete self->classes_by_descriptor;
delete self->message_factory;
Py_CLEAR(self->pool);
Py_TYPE(self)->tp_free(pself);
}
static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
Py_VISIT(self->pool);
for (const auto& desc_and_class : *self->classes_by_descriptor) {
Py_VISIT(desc_and_class.second);
}
return 0;
}
static int GcClear(PyObject* pself) {
PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
// Here it's important to not clear self->pool, so that the C++ DescriptorPool
// is still alive when self->message_factory is destructed.
for (auto& desc_and_class : *self->classes_by_descriptor) {
Py_CLEAR(desc_and_class.second);
}
return 0;
}
// Add a message class to our database.
int RegisterMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor,
CMessageClass* message_class) {
Py_INCREF(message_class);
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
std::make_pair(message_descriptor, message_class));
if (!ret.second) {
// Update case: DECREF the previous value.
Py_DECREF(ret.first->second);
ret.first->second = message_class;
}
return 0;
}
CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
const Descriptor* descriptor) {
// This is the same implementation as MessageFactory.GetPrototype().
// Do not create a MessageClass that already exists.
std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
self->classes_by_descriptor->find(descriptor);
if (it != self->classes_by_descriptor->end()) {
Py_INCREF(it->second);
return it->second;
}
ScopedPyObjectPtr py_descriptor(
PyMessageDescriptor_FromDescriptor(descriptor));
if (py_descriptor == nullptr) {
return nullptr;
}
// Create a new message class.
ScopedPyObjectPtr args(Py_BuildValue(
"s(){sOsOsO}", descriptor->name().c_str(),
"DESCRIPTOR", py_descriptor.get(),
"__module__", Py_None,
"message_factory", self));
if (args == nullptr) {
return nullptr;
}
ScopedPyObjectPtr message_class(PyObject_CallObject(
reinterpret_cast<PyObject*>(CMessageClass_Type), args.get()));
if (message_class == nullptr) {
return nullptr;
}
// Create messages class for the messages used by the fields, and registers
// all extensions for these messages during the recursion.
for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
const Descriptor* sub_descriptor =
descriptor->field(field_idx)->message_type();
// It is null if the field type is not a message.
if (sub_descriptor != nullptr) {
CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor);
if (result == nullptr) {
return nullptr;
}
Py_DECREF(result);
}
}
// Register extensions defined in this message.
for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) {
const FieldDescriptor* extension = descriptor->extension(ext_idx);
ScopedPyObjectPtr py_extended_class(
GetOrCreateMessageClass(self, extension->containing_type())
->AsPyObject());
if (py_extended_class == nullptr) {
return nullptr;
}
ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension));
if (py_extension == nullptr) {
return nullptr;
}
ScopedPyObjectPtr result(cmessage::RegisterExtension(
py_extended_class.get(), py_extension.get()));
if (result == nullptr) {
return nullptr;
}
}
return reinterpret_cast<CMessageClass*>(message_class.release());
}
// Retrieve the message class added to our database.
CMessageClass* GetMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor) {
typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
iterator ret = self->classes_by_descriptor->find(message_descriptor);
if (ret == self->classes_by_descriptor->end()) {
PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
message_descriptor->full_name().c_str());
return nullptr;
} else {
return ret->second;
}
}
static PyMethodDef Methods[] = {
{nullptr},
};
static PyObject* GetPool(PyMessageFactory* self, void* closure) {
Py_INCREF(self->pool);
return reinterpret_cast<PyObject*>(self->pool);
}
static PyGetSetDef Getters[] = {
{"pool", (getter)GetPool, nullptr, "DescriptorPool"},
{nullptr},
};
} // namespace message_factory
PyTypeObject PyMessageFactory_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
".MessageFactory", // tp_name
sizeof(PyMessageFactory), // tp_basicsize
0, // tp_itemsize
message_factory::Dealloc, // tp_dealloc
#if PY_VERSION_HEX < 0x03080000
nullptr, // tp_print
#else
0, // tp_vectorcall_offset
#endif
nullptr, // tp_getattr
nullptr, // tp_setattr
nullptr, // tp_compare
nullptr, // tp_repr
nullptr, // tp_as_number
nullptr, // tp_as_sequence
nullptr, // tp_as_mapping
nullptr, // tp_hash
nullptr, // tp_call
nullptr, // tp_str
nullptr, // tp_getattro
nullptr, // tp_setattro
nullptr, // tp_as_buffer
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, // tp_flags
"A static Message Factory", // tp_doc
message_factory::GcTraverse, // tp_traverse
message_factory::GcClear, // tp_clear
nullptr, // tp_richcompare
0, // tp_weaklistoffset
nullptr, // tp_iter
nullptr, // tp_iternext
message_factory::Methods, // tp_methods
nullptr, // tp_members
message_factory::Getters, // tp_getset
nullptr, // tp_base
nullptr, // tp_dict
nullptr, // tp_descr_get
nullptr, // tp_descr_set
0, // tp_dictoffset
nullptr, // tp_init
nullptr, // tp_alloc
message_factory::New, // tp_new
PyObject_GC_Del, // tp_free
};
bool InitMessageFactory() {
if (PyType_Ready(&PyMessageFactory_Type) < 0) {
return false;
}
return true;
}
} // namespace python
} // namespace protobuf
} // namespace google

View File

@@ -0,0 +1,104 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <unordered_map>
#include "google/protobuf/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
namespace google {
namespace protobuf {
class MessageFactory;
namespace python {
// The (meta) type of all Messages classes.
struct CMessageClass;
struct PyMessageFactory {
PyObject_HEAD
// DynamicMessageFactory used to create C++ instances of messages.
// This object cache the descriptors that were used, so the DescriptorPool
// needs to get rid of it before it can delete itself.
//
// Note: A C++ MessageFactory is different from the PyMessageFactory.
// The C++ one creates messages, when the Python one creates classes.
MessageFactory* message_factory;
// Owned reference to a Python DescriptorPool.
// This reference must stay until the message_factory is destructed.
PyDescriptorPool* pool;
// Make our own mapping to retrieve Python classes from C++ descriptors.
//
// Descriptor pointers stored here are owned by the DescriptorPool above.
// Python references to classes are owned by this PyDescriptorPool.
typedef std::unordered_map<const Descriptor*, CMessageClass*>
ClassesByMessageMap;
ClassesByMessageMap* classes_by_descriptor;
};
extern PyTypeObject PyMessageFactory_Type;
namespace message_factory {
// Creates a new MessageFactory instance.
PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool);
// Registers a new Python class for the given message descriptor.
// On error, returns -1 with a Python exception set.
int RegisterMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor,
CMessageClass* message_class);
// Retrieves the Python class registered with the given message descriptor, or
// fail with a TypeError. Returns a *borrowed* reference.
CMessageClass* GetMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor);
// Retrieves the Python class registered with the given message descriptor.
// The class is created if not done yet. Returns a *new* reference.
CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor);
} // namespace message_factory
// Initialize objects used by this module.
// On error, returns false with a Python exception set.
bool InitMessageFactory();
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__

View File

@@ -0,0 +1,134 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/message_lite.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/proto_api.h"
namespace {
// C++ API. Clients get at this via proto_api.h
struct ApiImplementation : google::protobuf::python::PyProto_API {
const google::protobuf::Message* GetMessagePointer(PyObject* msg) const override {
return google::protobuf::python::PyMessage_GetMessagePointer(msg);
}
google::protobuf::Message* GetMutableMessagePointer(PyObject* msg) const override {
return google::protobuf::python::PyMessage_GetMutableMessagePointer(msg);
}
const google::protobuf::Descriptor* MessageDescriptor_AsDescriptor(
PyObject* desc) const override {
return google::protobuf::python::PyMessageDescriptor_AsDescriptor(desc);
}
const google::protobuf::EnumDescriptor* EnumDescriptor_AsDescriptor(
PyObject* enum_desc) const override {
return google::protobuf::python::PyEnumDescriptor_AsDescriptor(enum_desc);
}
const google::protobuf::DescriptorPool* GetDefaultDescriptorPool() const override {
return google::protobuf::python::GetDefaultDescriptorPool()->pool;
}
google::protobuf::MessageFactory* GetDefaultMessageFactory() const override {
return google::protobuf::python::GetDefaultDescriptorPool()
->py_message_factory->message_factory;
}
PyObject* NewMessage(const google::protobuf::Descriptor* descriptor,
PyObject* py_message_factory) const override {
return google::protobuf::python::PyMessage_New(descriptor, py_message_factory);
}
PyObject* NewMessageOwnedExternally(
google::protobuf::Message* msg, PyObject* py_message_factory) const override {
return google::protobuf::python::PyMessage_NewMessageOwnedExternally(
msg, py_message_factory);
}
PyObject* DescriptorPool_FromPool(
const google::protobuf::DescriptorPool* pool) const override {
return google::protobuf::python::PyDescriptorPool_FromPool(pool);
}
};
} // namespace
static const char module_docstring[] =
"python-proto2 is a module that can be used to enhance proto2 Python API\n"
"performance.\n"
"\n"
"It provides access to the protocol buffers C++ reflection API that\n"
"implements the basic protocol buffer functions.";
static PyMethodDef ModuleMethods[] = {
{"SetAllowOversizeProtos",
(PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos, METH_O,
"Enable/disable oversize proto parsing."},
// DO NOT USE: For migration and testing only.
{nullptr, nullptr}};
static struct PyModuleDef _module = {PyModuleDef_HEAD_INIT,
"_message",
module_docstring,
-1,
ModuleMethods, /* m_methods */
nullptr,
nullptr,
nullptr,
nullptr};
PyMODINIT_FUNC PyInit__message() {
PyObject* m;
m = PyModule_Create(&_module);
if (m == nullptr) {
return nullptr;
}
if (!google::protobuf::python::InitProto2MessageModule(m)) {
Py_DECREF(m);
return nullptr;
}
// Adds the C++ API
if (PyObject* api = PyCapsule_New(
new ApiImplementation(), google::protobuf::python::PyProtoAPICapsuleName(),
[](PyObject* o) {
delete (ApiImplementation*)PyCapsule_GetPointer(
o, google::protobuf::python::PyProtoAPICapsuleName());
})) {
PyModule_AddObject(m, "proto_API", api);
} else {
return nullptr;
}
return m;
}

View File

@@ -0,0 +1,40 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
syntax = "proto2";
package google.protobuf.python.internal;
import "google/protobuf/internal/cpp/proto1_api_test.proto";
message TestNestedProto1APIMessage {
optional int32 a = 1;
optional TestMessage.NestedMessage b = 2;
}

View File

@@ -0,0 +1,68 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: tibell@google.com (Johan Tibell)
//
// These message definitions are used to exercises known corner cases
// in the C++ implementation of the Python API.
syntax = "proto2";
package google.protobuf.python.internal;
// Protos optimized for SPEED use a strict superset of the generated code
// of equivalent ones optimized for CODE_SIZE, so we should optimize all our
// tests for speed unless explicitly testing code size optimization.
option optimize_for = SPEED;
message TestAllTypes {
message NestedMessage {
optional int32 bb = 1;
optional ForeignMessage cc = 2;
}
repeated NestedMessage repeated_nested_message = 1;
optional NestedMessage optional_nested_message = 2;
optional int32 optional_int32 = 3;
}
message ForeignMessage {
optional int32 c = 1;
repeated int32 d = 2;
}
message TestAllExtensions { // extension begin
extensions 1 to max;
} // extension end
extend TestAllExtensions { // extension begin
optional TestAllTypes.NestedMessage optional_nested_message_extension = 1;
repeated TestAllTypes.NestedMessage repeated_nested_message_extension = 2;
} // extension end

View File

@@ -0,0 +1,590 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#include "google/protobuf/pyext/repeated_composite_container.h"
#include <memory>
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/message.h"
#include "google/protobuf/reflection.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
namespace google {
namespace protobuf {
namespace python {
namespace repeated_composite_container {
// ---------------------------------------------------------------------
// len()
static Py_ssize_t Length(PyObject* pself) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);
Message* message = self->parent->message;
return message->GetReflection()->FieldSize(*message,
self->parent_field_descriptor);
}
// ---------------------------------------------------------------------
// add()
PyObject* Add(RepeatedCompositeContainer* self, PyObject* args,
PyObject* kwargs) {
if (cmessage::AssureWritable(self->parent) == -1) return nullptr;
Message* message = self->parent->message;
Message* sub_message = message->GetReflection()->AddMessage(
message, self->parent_field_descriptor,
self->child_message_class->py_message_factory->message_factory);
CMessage* cmsg = self->parent->BuildSubMessageFromPointer(
self->parent_field_descriptor, sub_message, self->child_message_class);
if (cmessage::InitAttributes(cmsg, args, kwargs) < 0) {
message->GetReflection()->RemoveLast(message,
self->parent_field_descriptor);
Py_DECREF(cmsg);
return nullptr;
}
return cmsg->AsPyObject();
}
static PyObject* AddMethod(PyObject* self, PyObject* args, PyObject* kwargs) {
return Add(reinterpret_cast<RepeatedCompositeContainer*>(self), args, kwargs);
}
// ---------------------------------------------------------------------
// append()
static PyObject* AddMessage(RepeatedCompositeContainer* self, PyObject* value) {
cmessage::AssureWritable(self->parent);
PyObject* py_cmsg;
Message* message = self->parent->message;
const Reflection* reflection = message->GetReflection();
py_cmsg = Add(self, nullptr, nullptr);
if (py_cmsg == nullptr) return nullptr;
CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
if (ScopedPyObjectPtr(cmessage::MergeFrom(cmsg, value)) == nullptr) {
reflection->RemoveLast(message, self->parent_field_descriptor);
Py_DECREF(cmsg);
return nullptr;
}
return py_cmsg;
}
static PyObject* AppendMethod(PyObject* pself, PyObject* value) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);
ScopedPyObjectPtr py_cmsg(AddMessage(self, value));
if (py_cmsg == nullptr) {
return nullptr;
}
Py_RETURN_NONE;
}
// ---------------------------------------------------------------------
// insert()
static PyObject* Insert(PyObject* pself, PyObject* args) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);
Py_ssize_t index;
PyObject* value;
if (!PyArg_ParseTuple(args, "nO", &index, &value)) {
return nullptr;
}
ScopedPyObjectPtr py_cmsg(AddMessage(self, value));
if (py_cmsg == nullptr) {
return nullptr;
}
// Swap the element to right position.
Message* message = self->parent->message;
const Reflection* reflection = message->GetReflection();
const FieldDescriptor* field_descriptor = self->parent_field_descriptor;
Py_ssize_t length = reflection->FieldSize(*message, field_descriptor) - 1;
Py_ssize_t end_index = index;
if (end_index < 0) end_index += length;
if (end_index < 0) end_index = 0;
for (Py_ssize_t i = length; i > end_index; i--) {
reflection->SwapElements(message, field_descriptor, i, i - 1);
}
Py_RETURN_NONE;
}
// ---------------------------------------------------------------------
// extend()
PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) {
cmessage::AssureWritable(self->parent);
ScopedPyObjectPtr iter(PyObject_GetIter(value));
if (iter == nullptr) {
PyErr_SetString(PyExc_TypeError, "Value must be iterable");
return nullptr;
}
ScopedPyObjectPtr next;
while ((next.reset(PyIter_Next(iter.get()))) != nullptr) {
if (!PyObject_TypeCheck(next.get(), CMessage_Type)) {
PyErr_SetString(PyExc_TypeError, "Not a cmessage");
return nullptr;
}
ScopedPyObjectPtr new_message(Add(self, nullptr, nullptr));
if (new_message == nullptr) {
return nullptr;
}
CMessage* new_cmessage = reinterpret_cast<CMessage*>(new_message.get());
if (ScopedPyObjectPtr(cmessage::MergeFrom(new_cmessage, next.get())) ==
nullptr) {
return nullptr;
}
}
if (PyErr_Occurred()) {
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* ExtendMethod(PyObject* self, PyObject* value) {
return Extend(reinterpret_cast<RepeatedCompositeContainer*>(self), value);
}
PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other) {
return Extend(self, other);
}
static PyObject* MergeFromMethod(PyObject* self, PyObject* other) {
return MergeFrom(reinterpret_cast<RepeatedCompositeContainer*>(self), other);
}
// This function does not check the bounds.
static PyObject* GetItem(RepeatedCompositeContainer* self, Py_ssize_t index,
Py_ssize_t length = -1) {
if (length == -1) {
Message* message = self->parent->message;
const Reflection* reflection = message->GetReflection();
length = reflection->FieldSize(*message, self->parent_field_descriptor);
}
if (index < 0 || index >= length) {
PyErr_Format(PyExc_IndexError, "list index (%zd) out of range", index);
return nullptr;
}
Message* message = self->parent->message;
Message* sub_message = message->GetReflection()->MutableRepeatedMessage(
message, self->parent_field_descriptor, index);
return self->parent
->BuildSubMessageFromPointer(self->parent_field_descriptor, sub_message,
self->child_message_class)
->AsPyObject();
}
PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* item) {
Message* message = self->parent->message;
const Reflection* reflection = message->GetReflection();
Py_ssize_t length =
reflection->FieldSize(*message, self->parent_field_descriptor);
if (PyIndex_Check(item)) {
Py_ssize_t index;
index = PyNumber_AsSsize_t(item, PyExc_IndexError);
if (index == -1 && PyErr_Occurred()) return nullptr;
if (index < 0) index += length;
return GetItem(self, index, length);
} else if (PySlice_Check(item)) {
Py_ssize_t from, to, step, slicelength, cur, i;
PyObject* result;
if (PySlice_GetIndicesEx(item, length, &from, &to, &step, &slicelength) ==
-1) {
return nullptr;
}
if (slicelength <= 0) {
return PyList_New(0);
} else {
result = PyList_New(slicelength);
if (!result) return nullptr;
for (cur = from, i = 0; i < slicelength; cur += step, i++) {
PyList_SET_ITEM(result, i, GetItem(self, cur, length));
}
return result;
}
} else {
PyErr_Format(PyExc_TypeError, "indices must be integers, not %.200s",
item->ob_type->tp_name);
return nullptr;
}
}
static PyObject* SubscriptMethod(PyObject* self, PyObject* slice) {
return Subscript(reinterpret_cast<RepeatedCompositeContainer*>(self), slice);
}
int AssignSubscript(RepeatedCompositeContainer* self, PyObject* slice,
PyObject* value) {
if (value != nullptr) {
PyErr_SetString(PyExc_TypeError, "does not support assignment");
return -1;
}
return cmessage::DeleteRepeatedField(self->parent,
self->parent_field_descriptor, slice);
}
static int AssignSubscriptMethod(PyObject* self, PyObject* slice,
PyObject* value) {
return AssignSubscript(reinterpret_cast<RepeatedCompositeContainer*>(self),
slice, value);
}
static PyObject* Remove(PyObject* pself, PyObject* value) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);
Py_ssize_t len = Length(reinterpret_cast<PyObject*>(self));
for (Py_ssize_t i = 0; i < len; i++) {
ScopedPyObjectPtr item(GetItem(self, i, len));
if (item == nullptr) {
return nullptr;
}
int result = PyObject_RichCompareBool(item.get(), value, Py_EQ);
if (result < 0) {
return nullptr;
}
if (result) {
ScopedPyObjectPtr py_index(PyLong_FromSsize_t(i));
if (AssignSubscript(self, py_index.get(), nullptr) < 0) {
return nullptr;
}
Py_RETURN_NONE;
}
}
PyErr_SetString(PyExc_ValueError, "Item to delete not in list");
return nullptr;
}
static PyObject* RichCompare(PyObject* pself, PyObject* other, int opid) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);
if (!PyObject_TypeCheck(other, &RepeatedCompositeContainer_Type)) {
PyErr_SetString(PyExc_TypeError,
"Can only compare repeated composite fields "
"against other repeated composite fields.");
return nullptr;
}
if (opid == Py_EQ || opid == Py_NE) {
// TODO(anuraag): Don't make new lists just for this...
ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr));
if (full_slice == nullptr) {
return nullptr;
}
ScopedPyObjectPtr list(Subscript(self, full_slice.get()));
if (list == nullptr) {
return nullptr;
}
ScopedPyObjectPtr other_list(
Subscript(reinterpret_cast<RepeatedCompositeContainer*>(other),
full_slice.get()));
if (other_list == nullptr) {
return nullptr;
}
return PyObject_RichCompare(list.get(), other_list.get(), opid);
} else {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
}
static PyObject* ToStr(PyObject* pself) {
ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr));
if (full_slice == nullptr) {
return nullptr;
}
ScopedPyObjectPtr list(Subscript(
reinterpret_cast<RepeatedCompositeContainer*>(pself), full_slice.get()));
if (list == nullptr) {
return nullptr;
}
return PyObject_Repr(list.get());
}
// ---------------------------------------------------------------------
// sort()
static void ReorderAttached(RepeatedCompositeContainer* self,
PyObject* child_list) {
Message* message = self->parent->message;
const Reflection* reflection = message->GetReflection();
const FieldDescriptor* descriptor = self->parent_field_descriptor;
const Py_ssize_t length = Length(reinterpret_cast<PyObject*>(self));
// We need to rearrange things to match python's sort order.
for (Py_ssize_t i = 0; i < length; ++i) {
reflection->UnsafeArenaReleaseLast(message, descriptor);
}
for (Py_ssize_t i = 0; i < length; ++i) {
Message* child_message =
reinterpret_cast<CMessage*>(PyList_GET_ITEM(child_list, i))->message;
reflection->UnsafeArenaAddAllocatedMessage(message, descriptor,
child_message);
}
}
// Returns 0 if successful; returns -1 and sets an exception if
// unsuccessful.
static int SortPythonMessages(RepeatedCompositeContainer* self, PyObject* args,
PyObject* kwds) {
ScopedPyObjectPtr child_list(
PySequence_List(reinterpret_cast<PyObject*>(self)));
if (child_list == nullptr) {
return -1;
}
ScopedPyObjectPtr m(PyObject_GetAttrString(child_list.get(), "sort"));
if (m == nullptr) return -1;
if (ScopedPyObjectPtr(PyObject_Call(m.get(), args, kwds)) == nullptr)
return -1;
ReorderAttached(self, child_list.get());
return 0;
}
static PyObject* Sort(PyObject* pself, PyObject* args, PyObject* kwds) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);
// Support the old sort_function argument for backwards
// compatibility.
if (kwds != nullptr) {
PyObject* sort_func = PyDict_GetItemString(kwds, "sort_function");
if (sort_func != nullptr) {
// Must set before deleting as sort_func is a borrowed reference
// and kwds might be the only thing keeping it alive.
PyDict_SetItemString(kwds, "cmp", sort_func);
PyDict_DelItemString(kwds, "sort_function");
}
}
if (SortPythonMessages(self, args, kwds) < 0) {
return nullptr;
}
Py_RETURN_NONE;
}
// ---------------------------------------------------------------------
// reverse()
// Returns 0 if successful; returns -1 and sets an exception if
// unsuccessful.
static int ReversePythonMessages(RepeatedCompositeContainer* self) {
ScopedPyObjectPtr child_list(
PySequence_List(reinterpret_cast<PyObject*>(self)));
if (child_list == nullptr) {
return -1;
}
if (ScopedPyObjectPtr(
PyObject_CallMethod(child_list.get(), "reverse", nullptr)) == nullptr)
return -1;
ReorderAttached(self, child_list.get());
return 0;
}
static PyObject* Reverse(PyObject* pself) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);
if (ReversePythonMessages(self) < 0) {
return nullptr;
}
Py_RETURN_NONE;
}
// ---------------------------------------------------------------------
static PyObject* Item(PyObject* pself, Py_ssize_t index) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);
return GetItem(self, index);
}
static PyObject* Pop(PyObject* pself, PyObject* args) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);
Py_ssize_t index = -1;
if (!PyArg_ParseTuple(args, "|n", &index)) {
return nullptr;
}
Py_ssize_t length = Length(pself);
if (index < 0) index += length;
PyObject* item = GetItem(self, index, length);
if (item == nullptr) {
return nullptr;
}
ScopedPyObjectPtr py_index(PyLong_FromSsize_t(index));
if (AssignSubscript(self, py_index.get(), nullptr) < 0) {
return nullptr;
}
return item;
}
PyObject* DeepCopy(PyObject* pself, PyObject* arg) {
return reinterpret_cast<RepeatedCompositeContainer*>(pself)->DeepCopy();
}
// The private constructor of RepeatedCompositeContainer objects.
RepeatedCompositeContainer* NewContainer(
CMessage* parent, const FieldDescriptor* parent_field_descriptor,
CMessageClass* child_message_class) {
if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
return nullptr;
}
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(
PyType_GenericAlloc(&RepeatedCompositeContainer_Type, 0));
if (self == nullptr) {
return nullptr;
}
Py_INCREF(parent);
self->parent = parent;
self->parent_field_descriptor = parent_field_descriptor;
Py_INCREF(child_message_class);
self->child_message_class = child_message_class;
return self;
}
static void Dealloc(PyObject* pself) {
RepeatedCompositeContainer* self =
reinterpret_cast<RepeatedCompositeContainer*>(pself);
self->RemoveFromParentCache();
Py_CLEAR(self->child_message_class);
Py_TYPE(self)->tp_free(pself);
}
static PySequenceMethods SqMethods = {
Length, /* sq_length */
nullptr, /* sq_concat */
nullptr, /* sq_repeat */
Item /* sq_item */
};
static PyMappingMethods MpMethods = {
Length, /* mp_length */
SubscriptMethod, /* mp_subscript */
AssignSubscriptMethod, /* mp_ass_subscript */
};
static PyMethodDef Methods[] = {
{"__deepcopy__", DeepCopy, METH_VARARGS, "Makes a deep copy of the class."},
{"add", reinterpret_cast<PyCFunction>(AddMethod),
METH_VARARGS | METH_KEYWORDS, "Adds an object to the repeated container."},
{"append", AppendMethod, METH_O,
"Appends a message to the end of the repeated container."},
{"insert", Insert, METH_VARARGS,
"Inserts a message before the specified index."},
{"extend", ExtendMethod, METH_O, "Adds objects to the repeated container."},
{"pop", Pop, METH_VARARGS,
"Removes an object from the repeated container and returns it."},
{"remove", Remove, METH_O,
"Removes an object from the repeated container."},
{"sort", reinterpret_cast<PyCFunction>(Sort), METH_VARARGS | METH_KEYWORDS,
"Sorts the repeated container."},
{"reverse", reinterpret_cast<PyCFunction>(Reverse), METH_NOARGS,
"Reverses elements order of the repeated container."},
{"MergeFrom", MergeFromMethod, METH_O,
"Adds objects to the repeated container."},
{nullptr, nullptr}};
} // namespace repeated_composite_container
PyTypeObject RepeatedCompositeContainer_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
".RepeatedCompositeContainer", // tp_name
sizeof(RepeatedCompositeContainer), // tp_basicsize
0, // tp_itemsize
repeated_composite_container::Dealloc, // tp_dealloc
#if PY_VERSION_HEX >= 0x03080000
0, // tp_vectorcall_offset
#else
nullptr, // tp_print
#endif
nullptr, // tp_getattr
nullptr, // tp_setattr
nullptr, // tp_compare
repeated_composite_container::ToStr, // tp_repr
nullptr, // tp_as_number
&repeated_composite_container::SqMethods, // tp_as_sequence
&repeated_composite_container::MpMethods, // tp_as_mapping
PyObject_HashNotImplemented, // tp_hash
nullptr, // tp_call
nullptr, // tp_str
nullptr, // tp_getattro
nullptr, // tp_setattro
nullptr, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
"A Repeated scalar container", // tp_doc
nullptr, // tp_traverse
nullptr, // tp_clear
repeated_composite_container::RichCompare, // tp_richcompare
0, // tp_weaklistoffset
nullptr, // tp_iter
nullptr, // tp_iternext
repeated_composite_container::Methods, // tp_methods
nullptr, // tp_members
nullptr, // tp_getset
nullptr, // tp_base
nullptr, // tp_dict
nullptr, // tp_descr_get
nullptr, // tp_descr_set
0, // tp_dictoffset
nullptr, // tp_init
};
} // namespace python
} // namespace protobuf
} // namespace google

View File

@@ -0,0 +1,109 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/pyext/message.h"
namespace google {
namespace protobuf {
class FieldDescriptor;
class Message;
namespace python {
struct CMessageClass;
// A RepeatedCompositeContainer always has a parent message.
// The parent message also caches reference to items of the container.
typedef struct RepeatedCompositeContainer : public ContainerBase {
// The type used to create new child messages.
CMessageClass* child_message_class;
} RepeatedCompositeContainer;
extern PyTypeObject RepeatedCompositeContainer_Type;
namespace repeated_composite_container {
// Builds a RepeatedCompositeContainer object, from a parent message and a
// field descriptor.
RepeatedCompositeContainer* NewContainer(
CMessage* parent,
const FieldDescriptor* parent_field_descriptor,
CMessageClass *child_message_class);
// Appends a new CMessage to the container and returns it. The
// CMessage is initialized using the content of kwargs.
//
// Returns a new reference if successful; returns NULL and sets an
// exception if unsuccessful.
PyObject* Add(RepeatedCompositeContainer* self,
PyObject* args,
PyObject* kwargs);
// Appends all the CMessages in the input iterator to the container.
//
// Returns None if successful; returns NULL and sets an exception if
// unsuccessful.
PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value);
// Appends a new message to the container for each message in the
// input iterator, merging each data element in. Equivalent to extend.
//
// Returns None if successful; returns NULL and sets an exception if
// unsuccessful.
PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other);
// Accesses messages in the container.
//
// Returns a new reference to the message for an integer parameter.
// Returns a new reference to a list of messages for a slice.
PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice);
// Deletes items from the container (cannot be used for assignment).
//
// Returns 0 on success, -1 on failure.
int AssignSubscript(RepeatedCompositeContainer* self,
PyObject* slice,
PyObject* value);
} // namespace repeated_composite_container
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__

View File

@@ -0,0 +1,775 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#include "google/protobuf/pyext/repeated_scalar_container.h"
#include <cstdint>
#include <memory>
#include <string>
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/message.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#define PyString_AsString(ob) \
(PyUnicode_Check(ob) ? PyUnicode_AsUTF8(ob) : PyBytes_AsString(ob))
namespace google {
namespace protobuf {
namespace python {
namespace repeated_scalar_container {
static int InternalAssignRepeatedField(RepeatedScalarContainer* self,
PyObject* list) {
Message* message = self->parent->message;
message->GetReflection()->ClearField(message, self->parent_field_descriptor);
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(list); ++i) {
PyObject* value = PyList_GET_ITEM(list, i);
if (ScopedPyObjectPtr(Append(self, value)) == nullptr) {
return -1;
}
}
return 0;
}
static Py_ssize_t Len(PyObject* pself) {
RepeatedScalarContainer* self =
reinterpret_cast<RepeatedScalarContainer*>(pself);
Message* message = self->parent->message;
return message->GetReflection()->FieldSize(*message,
self->parent_field_descriptor);
}
static int AssignItem(PyObject* pself, Py_ssize_t index, PyObject* arg) {
RepeatedScalarContainer* self =
reinterpret_cast<RepeatedScalarContainer*>(pself);
cmessage::AssureWritable(self->parent);
Message* message = self->parent->message;
const FieldDescriptor* field_descriptor = self->parent_field_descriptor;
const Reflection* reflection = message->GetReflection();
int field_size = reflection->FieldSize(*message, field_descriptor);
if (index < 0) {
index = field_size + index;
}
if (index < 0 || index >= field_size) {
PyErr_Format(PyExc_IndexError, "list assignment index (%d) out of range",
static_cast<int>(index));
return -1;
}
if (arg == nullptr) {
ScopedPyObjectPtr py_index(PyLong_FromLong(index));
return cmessage::DeleteRepeatedField(self->parent, field_descriptor,
py_index.get());
}
if (PySequence_Check(arg) && !(PyBytes_Check(arg) || PyUnicode_Check(arg))) {
PyErr_SetString(PyExc_TypeError, "Value must be scalar");
return -1;
}
switch (field_descriptor->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32: {
GOOGLE_CHECK_GET_INT32(arg, value, -1);
reflection->SetRepeatedInt32(message, field_descriptor, index, value);
break;
}
case FieldDescriptor::CPPTYPE_INT64: {
GOOGLE_CHECK_GET_INT64(arg, value, -1);
reflection->SetRepeatedInt64(message, field_descriptor, index, value);
break;
}
case FieldDescriptor::CPPTYPE_UINT32: {
GOOGLE_CHECK_GET_UINT32(arg, value, -1);
reflection->SetRepeatedUInt32(message, field_descriptor, index, value);
break;
}
case FieldDescriptor::CPPTYPE_UINT64: {
GOOGLE_CHECK_GET_UINT64(arg, value, -1);
reflection->SetRepeatedUInt64(message, field_descriptor, index, value);
break;
}
case FieldDescriptor::CPPTYPE_FLOAT: {
GOOGLE_CHECK_GET_FLOAT(arg, value, -1);
reflection->SetRepeatedFloat(message, field_descriptor, index, value);
break;
}
case FieldDescriptor::CPPTYPE_DOUBLE: {
GOOGLE_CHECK_GET_DOUBLE(arg, value, -1);
reflection->SetRepeatedDouble(message, field_descriptor, index, value);
break;
}
case FieldDescriptor::CPPTYPE_BOOL: {
GOOGLE_CHECK_GET_BOOL(arg, value, -1);
reflection->SetRepeatedBool(message, field_descriptor, index, value);
break;
}
case FieldDescriptor::CPPTYPE_STRING: {
if (!CheckAndSetString(arg, message, field_descriptor, reflection, false,
index)) {
return -1;
}
break;
}
case FieldDescriptor::CPPTYPE_ENUM: {
GOOGLE_CHECK_GET_INT32(arg, value, -1);
if (reflection->SupportsUnknownEnumValues()) {
reflection->SetRepeatedEnumValue(message, field_descriptor, index,
value);
} else {
const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
const EnumValueDescriptor* enum_value =
enum_descriptor->FindValueByNumber(value);
if (enum_value != nullptr) {
reflection->SetRepeatedEnum(message, field_descriptor, index,
enum_value);
} else {
ScopedPyObjectPtr s(PyObject_Str(arg));
if (s != nullptr) {
PyErr_Format(PyExc_ValueError, "Unknown enum value: %s",
PyString_AsString(s.get()));
}
return -1;
}
}
break;
}
default:
PyErr_Format(PyExc_SystemError,
"Adding value to a field of unknown type %d",
field_descriptor->cpp_type());
return -1;
}
return 0;
}
static PyObject* Item(PyObject* pself, Py_ssize_t index) {
RepeatedScalarContainer* self =
reinterpret_cast<RepeatedScalarContainer*>(pself);
Message* message = self->parent->message;
const FieldDescriptor* field_descriptor = self->parent_field_descriptor;
const Reflection* reflection = message->GetReflection();
int field_size = reflection->FieldSize(*message, field_descriptor);
if (index < 0) {
index = field_size + index;
}
if (index < 0 || index >= field_size) {
PyErr_Format(PyExc_IndexError, "list index (%zd) out of range", index);
return nullptr;
}
PyObject* result = nullptr;
switch (field_descriptor->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32: {
int32_t value =
reflection->GetRepeatedInt32(*message, field_descriptor, index);
result = PyLong_FromLong(value);
break;
}
case FieldDescriptor::CPPTYPE_INT64: {
int64_t value =
reflection->GetRepeatedInt64(*message, field_descriptor, index);
result = PyLong_FromLongLong(value);
break;
}
case FieldDescriptor::CPPTYPE_UINT32: {
uint32_t value =
reflection->GetRepeatedUInt32(*message, field_descriptor, index);
result = PyLong_FromLongLong(value);
break;
}
case FieldDescriptor::CPPTYPE_UINT64: {
uint64_t value =
reflection->GetRepeatedUInt64(*message, field_descriptor, index);
result = PyLong_FromUnsignedLongLong(value);
break;
}
case FieldDescriptor::CPPTYPE_FLOAT: {
float value =
reflection->GetRepeatedFloat(*message, field_descriptor, index);
result = PyFloat_FromDouble(value);
break;
}
case FieldDescriptor::CPPTYPE_DOUBLE: {
double value =
reflection->GetRepeatedDouble(*message, field_descriptor, index);
result = PyFloat_FromDouble(value);
break;
}
case FieldDescriptor::CPPTYPE_BOOL: {
bool value =
reflection->GetRepeatedBool(*message, field_descriptor, index);
result = PyBool_FromLong(value ? 1 : 0);
break;
}
case FieldDescriptor::CPPTYPE_ENUM: {
const EnumValueDescriptor* enum_value =
message->GetReflection()->GetRepeatedEnum(*message, field_descriptor,
index);
result = PyLong_FromLong(enum_value->number());
break;
}
case FieldDescriptor::CPPTYPE_STRING: {
std::string scratch;
const std::string& value = reflection->GetRepeatedStringReference(
*message, field_descriptor, index, &scratch);
result = ToStringObject(field_descriptor, value);
break;
}
default:
PyErr_Format(PyExc_SystemError,
"Getting value from a repeated field of unknown type %d",
field_descriptor->cpp_type());
}
return result;
}
static PyObject* Subscript(PyObject* pself, PyObject* slice) {
Py_ssize_t from;
Py_ssize_t to;
Py_ssize_t step;
Py_ssize_t length;
Py_ssize_t slicelength;
bool return_list = false;
if (PyLong_Check(slice)) {
from = to = PyLong_AsLong(slice);
} else if (PyIndex_Check(slice)) {
from = to = PyNumber_AsSsize_t(slice, PyExc_ValueError);
if (from == -1 && PyErr_Occurred()) {
return nullptr;
}
} else if (PySlice_Check(slice)) {
length = Len(pself);
if (PySlice_GetIndicesEx(slice, length, &from, &to, &step, &slicelength) ==
-1) {
return nullptr;
}
return_list = true;
} else {
PyErr_SetString(PyExc_TypeError, "list indices must be integers");
return nullptr;
}
if (!return_list) {
return Item(pself, from);
}
PyObject* list = PyList_New(0);
if (list == nullptr) {
return nullptr;
}
if (from <= to) {
if (step < 0) {
return list;
}
for (Py_ssize_t index = from; index < to; index += step) {
if (index < 0 || index >= length) {
break;
}
ScopedPyObjectPtr s(Item(pself, index));
PyList_Append(list, s.get());
}
} else {
if (step > 0) {
return list;
}
for (Py_ssize_t index = from; index > to; index += step) {
if (index < 0 || index >= length) {
break;
}
ScopedPyObjectPtr s(Item(pself, index));
PyList_Append(list, s.get());
}
}
return list;
}
PyObject* Append(RepeatedScalarContainer* self, PyObject* item) {
cmessage::AssureWritable(self->parent);
Message* message = self->parent->message;
const FieldDescriptor* field_descriptor = self->parent_field_descriptor;
const Reflection* reflection = message->GetReflection();
switch (field_descriptor->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32: {
GOOGLE_CHECK_GET_INT32(item, value, nullptr);
reflection->AddInt32(message, field_descriptor, value);
break;
}
case FieldDescriptor::CPPTYPE_INT64: {
GOOGLE_CHECK_GET_INT64(item, value, nullptr);
reflection->AddInt64(message, field_descriptor, value);
break;
}
case FieldDescriptor::CPPTYPE_UINT32: {
GOOGLE_CHECK_GET_UINT32(item, value, nullptr);
reflection->AddUInt32(message, field_descriptor, value);
break;
}
case FieldDescriptor::CPPTYPE_UINT64: {
GOOGLE_CHECK_GET_UINT64(item, value, nullptr);
reflection->AddUInt64(message, field_descriptor, value);
break;
}
case FieldDescriptor::CPPTYPE_FLOAT: {
GOOGLE_CHECK_GET_FLOAT(item, value, nullptr);
reflection->AddFloat(message, field_descriptor, value);
break;
}
case FieldDescriptor::CPPTYPE_DOUBLE: {
GOOGLE_CHECK_GET_DOUBLE(item, value, nullptr);
reflection->AddDouble(message, field_descriptor, value);
break;
}
case FieldDescriptor::CPPTYPE_BOOL: {
GOOGLE_CHECK_GET_BOOL(item, value, nullptr);
reflection->AddBool(message, field_descriptor, value);
break;
}
case FieldDescriptor::CPPTYPE_STRING: {
if (!CheckAndSetString(item, message, field_descriptor, reflection, true,
-1)) {
return nullptr;
}
break;
}
case FieldDescriptor::CPPTYPE_ENUM: {
GOOGLE_CHECK_GET_INT32(item, value, nullptr);
if (reflection->SupportsUnknownEnumValues()) {
reflection->AddEnumValue(message, field_descriptor, value);
} else {
const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
const EnumValueDescriptor* enum_value =
enum_descriptor->FindValueByNumber(value);
if (enum_value != nullptr) {
reflection->AddEnum(message, field_descriptor, enum_value);
} else {
ScopedPyObjectPtr s(PyObject_Str(item));
if (s != nullptr) {
PyErr_Format(PyExc_ValueError, "Unknown enum value: %s",
PyString_AsString(s.get()));
}
return nullptr;
}
}
break;
}
default:
PyErr_Format(PyExc_SystemError,
"Adding value to a field of unknown type %d",
field_descriptor->cpp_type());
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* AppendMethod(PyObject* self, PyObject* item) {
return Append(reinterpret_cast<RepeatedScalarContainer*>(self), item);
}
static int AssSubscript(PyObject* pself, PyObject* slice, PyObject* value) {
RepeatedScalarContainer* self =
reinterpret_cast<RepeatedScalarContainer*>(pself);
Py_ssize_t from;
Py_ssize_t to;
Py_ssize_t step;
Py_ssize_t length;
Py_ssize_t slicelength;
bool create_list = false;
cmessage::AssureWritable(self->parent);
Message* message = self->parent->message;
const FieldDescriptor* field_descriptor = self->parent_field_descriptor;
if (PyLong_Check(slice)) {
from = to = PyLong_AsLong(slice);
} else if (PySlice_Check(slice)) {
const Reflection* reflection = message->GetReflection();
length = reflection->FieldSize(*message, field_descriptor);
if (PySlice_GetIndicesEx(slice, length, &from, &to, &step, &slicelength) ==
-1) {
return -1;
}
create_list = true;
} else {
PyErr_SetString(PyExc_TypeError, "list indices must be integers");
return -1;
}
if (value == nullptr) {
return cmessage::DeleteRepeatedField(self->parent, field_descriptor, slice);
}
if (!create_list) {
return AssignItem(pself, from, value);
}
ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr));
if (full_slice == nullptr) {
return -1;
}
ScopedPyObjectPtr new_list(Subscript(pself, full_slice.get()));
if (new_list == nullptr) {
return -1;
}
if (PySequence_SetSlice(new_list.get(), from, to, value) < 0) {
return -1;
}
return InternalAssignRepeatedField(self, new_list.get());
}
PyObject* Extend(RepeatedScalarContainer* self, PyObject* value) {
cmessage::AssureWritable(self->parent);
// TODO(ptucker): Deprecate this behavior. b/18413862
if (value == Py_None) {
Py_RETURN_NONE;
}
if ((Py_TYPE(value)->tp_as_sequence == nullptr) && PyObject_Not(value)) {
Py_RETURN_NONE;
}
ScopedPyObjectPtr iter(PyObject_GetIter(value));
if (iter == nullptr) {
PyErr_SetString(PyExc_TypeError, "Value must be iterable");
return nullptr;
}
ScopedPyObjectPtr next;
while ((next.reset(PyIter_Next(iter.get()))) != nullptr) {
if (ScopedPyObjectPtr(Append(self, next.get())) == nullptr) {
return nullptr;
}
}
if (PyErr_Occurred()) {
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* Insert(PyObject* pself, PyObject* args) {
RepeatedScalarContainer* self =
reinterpret_cast<RepeatedScalarContainer*>(pself);
Py_ssize_t index;
PyObject* value;
if (!PyArg_ParseTuple(args, "lO", &index, &value)) {
return nullptr;
}
ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr));
ScopedPyObjectPtr new_list(Subscript(pself, full_slice.get()));
if (PyList_Insert(new_list.get(), index, value) < 0) {
return nullptr;
}
int ret = InternalAssignRepeatedField(self, new_list.get());
if (ret < 0) {
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* Remove(PyObject* pself, PyObject* value) {
Py_ssize_t match_index = -1;
for (Py_ssize_t i = 0; i < Len(pself); ++i) {
ScopedPyObjectPtr elem(Item(pself, i));
if (PyObject_RichCompareBool(elem.get(), value, Py_EQ)) {
match_index = i;
break;
}
}
if (match_index == -1) {
PyErr_SetString(PyExc_ValueError, "remove(x): x not in container");
return nullptr;
}
if (AssignItem(pself, match_index, nullptr) < 0) {
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* ExtendMethod(PyObject* self, PyObject* value) {
return Extend(reinterpret_cast<RepeatedScalarContainer*>(self), value);
}
static PyObject* RichCompare(PyObject* pself, PyObject* other, int opid) {
if (opid != Py_EQ && opid != Py_NE) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
// Copy the contents of this repeated scalar container, and other if it is
// also a repeated scalar container, into Python lists so we can delegate
// to the list's compare method.
ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr));
if (full_slice == nullptr) {
return nullptr;
}
ScopedPyObjectPtr other_list_deleter;
if (PyObject_TypeCheck(other, &RepeatedScalarContainer_Type)) {
other_list_deleter.reset(Subscript(other, full_slice.get()));
other = other_list_deleter.get();
}
ScopedPyObjectPtr list(Subscript(pself, full_slice.get()));
if (list == nullptr) {
return nullptr;
}
return PyObject_RichCompare(list.get(), other, opid);
}
PyObject* Reduce(PyObject* unused_self, PyObject* unused_other) {
PyErr_Format(PickleError_class,
"can't pickle repeated message fields, convert to list first");
return nullptr;
}
static PyObject* Sort(PyObject* pself, PyObject* args, PyObject* kwds) {
// Support the old sort_function argument for backwards
// compatibility.
if (kwds != nullptr) {
PyObject* sort_func = PyDict_GetItemString(kwds, "sort_function");
if (sort_func != nullptr) {
// Must set before deleting as sort_func is a borrowed reference
// and kwds might be the only thing keeping it alive.
if (PyDict_SetItemString(kwds, "cmp", sort_func) == -1) return nullptr;
if (PyDict_DelItemString(kwds, "sort_function") == -1) return nullptr;
}
}
ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr));
if (full_slice == nullptr) {
return nullptr;
}
ScopedPyObjectPtr list(Subscript(pself, full_slice.get()));
if (list == nullptr) {
return nullptr;
}
ScopedPyObjectPtr m(PyObject_GetAttrString(list.get(), "sort"));
if (m == nullptr) {
return nullptr;
}
ScopedPyObjectPtr res(PyObject_Call(m.get(), args, kwds));
if (res == nullptr) {
return nullptr;
}
int ret = InternalAssignRepeatedField(
reinterpret_cast<RepeatedScalarContainer*>(pself), list.get());
if (ret < 0) {
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* Reverse(PyObject* pself) {
ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr));
if (full_slice == nullptr) {
return nullptr;
}
ScopedPyObjectPtr list(Subscript(pself, full_slice.get()));
if (list == nullptr) {
return nullptr;
}
ScopedPyObjectPtr res(PyObject_CallMethod(list.get(), "reverse", nullptr));
if (res == nullptr) {
return nullptr;
}
int ret = InternalAssignRepeatedField(
reinterpret_cast<RepeatedScalarContainer*>(pself), list.get());
if (ret < 0) {
return nullptr;
}
Py_RETURN_NONE;
}
static PyObject* Pop(PyObject* pself, PyObject* args) {
Py_ssize_t index = -1;
if (!PyArg_ParseTuple(args, "|n", &index)) {
return nullptr;
}
PyObject* item = Item(pself, index);
if (item == nullptr) {
PyErr_Format(PyExc_IndexError, "list index (%zd) out of range", index);
return nullptr;
}
if (AssignItem(pself, index, nullptr) < 0) {
return nullptr;
}
return item;
}
static PyObject* ToStr(PyObject* pself) {
ScopedPyObjectPtr full_slice(PySlice_New(nullptr, nullptr, nullptr));
if (full_slice == nullptr) {
return nullptr;
}
ScopedPyObjectPtr list(Subscript(pself, full_slice.get()));
if (list == nullptr) {
return nullptr;
}
return PyObject_Repr(list.get());
}
static PyObject* MergeFrom(PyObject* pself, PyObject* arg) {
return Extend(reinterpret_cast<RepeatedScalarContainer*>(pself), arg);
}
// The private constructor of RepeatedScalarContainer objects.
RepeatedScalarContainer* NewContainer(
CMessage* parent, const FieldDescriptor* parent_field_descriptor) {
if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
return nullptr;
}
RepeatedScalarContainer* self = reinterpret_cast<RepeatedScalarContainer*>(
PyType_GenericAlloc(&RepeatedScalarContainer_Type, 0));
if (self == nullptr) {
return nullptr;
}
Py_INCREF(parent);
self->parent = parent;
self->parent_field_descriptor = parent_field_descriptor;
return self;
}
PyObject* DeepCopy(PyObject* pself, PyObject* arg) {
return reinterpret_cast<RepeatedScalarContainer*>(pself)->DeepCopy();
}
static void Dealloc(PyObject* pself) {
reinterpret_cast<RepeatedScalarContainer*>(pself)->RemoveFromParentCache();
Py_TYPE(pself)->tp_free(pself);
}
static PySequenceMethods SqMethods = {
Len, /* sq_length */
nullptr, /* sq_concat */
nullptr, /* sq_repeat */
Item, /* sq_item */
nullptr, /* sq_slice */
AssignItem /* sq_ass_item */
};
static PyMappingMethods MpMethods = {
Len, /* mp_length */
Subscript, /* mp_subscript */
AssSubscript, /* mp_ass_subscript */
};
static PyMethodDef Methods[] = {
{"__deepcopy__", DeepCopy, METH_VARARGS, "Makes a deep copy of the class."},
{"__reduce__", Reduce, METH_NOARGS,
"Outputs picklable representation of the repeated field."},
{"append", AppendMethod, METH_O,
"Appends an object to the repeated container."},
{"extend", ExtendMethod, METH_O,
"Appends objects to the repeated container."},
{"insert", Insert, METH_VARARGS,
"Inserts an object at the specified position in the container."},
{"pop", Pop, METH_VARARGS,
"Removes an object from the repeated container and returns it."},
{"remove", Remove, METH_O,
"Removes an object from the repeated container."},
{"sort", reinterpret_cast<PyCFunction>(Sort), METH_VARARGS | METH_KEYWORDS,
"Sorts the repeated container."},
{"reverse", reinterpret_cast<PyCFunction>(Reverse), METH_NOARGS,
"Reverses elements order of the repeated container."},
{"MergeFrom", static_cast<PyCFunction>(MergeFrom), METH_O,
"Merges a repeated container into the current container."},
{nullptr, nullptr}};
} // namespace repeated_scalar_container
PyTypeObject RepeatedScalarContainer_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
".RepeatedScalarContainer", // tp_name
sizeof(RepeatedScalarContainer), // tp_basicsize
0, // tp_itemsize
repeated_scalar_container::Dealloc, // tp_dealloc
#if PY_VERSION_HEX >= 0x03080000
0, // tp_vectorcall_offset
#else
nullptr, // tp_print
#endif
nullptr, // tp_getattr
nullptr, // tp_setattr
nullptr, // tp_compare
repeated_scalar_container::ToStr, // tp_repr
nullptr, // tp_as_number
&repeated_scalar_container::SqMethods, // tp_as_sequence
&repeated_scalar_container::MpMethods, // tp_as_mapping
PyObject_HashNotImplemented, // tp_hash
nullptr, // tp_call
nullptr, // tp_str
nullptr, // tp_getattro
nullptr, // tp_setattro
nullptr, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
"A Repeated scalar container", // tp_doc
nullptr, // tp_traverse
nullptr, // tp_clear
repeated_scalar_container::RichCompare, // tp_richcompare
0, // tp_weaklistoffset
nullptr, // tp_iter
nullptr, // tp_iternext
repeated_scalar_container::Methods, // tp_methods
nullptr, // tp_members
nullptr, // tp_getset
nullptr, // tp_base
nullptr, // tp_dict
nullptr, // tp_descr_get
nullptr, // tp_descr_set
0, // tp_dictoffset
nullptr, // tp_init
};
} // namespace python
} // namespace protobuf
} // namespace google

View File

@@ -0,0 +1,76 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/descriptor.h"
#include "google/protobuf/pyext/message.h"
namespace google {
namespace protobuf {
namespace python {
typedef struct RepeatedScalarContainer : public ContainerBase {
} RepeatedScalarContainer;
extern PyTypeObject RepeatedScalarContainer_Type;
namespace repeated_scalar_container {
// Builds a RepeatedScalarContainer object, from a parent message and a
// field descriptor.
extern RepeatedScalarContainer* NewContainer(
CMessage* parent, const FieldDescriptor* parent_field_descriptor);
// Appends the scalar 'item' to the end of the container 'self'.
//
// Returns None if successful; returns NULL and sets an exception if
// unsuccessful.
PyObject* Append(RepeatedScalarContainer* self, PyObject* item);
// Appends all the elements in the input iterator to the container.
//
// Returns None if successful; returns NULL and sets an exception if
// unsuccessful.
PyObject* Extend(RepeatedScalarContainer* self, PyObject* value);
} // namespace repeated_scalar_container
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__

View File

@@ -0,0 +1,164 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__
// Copied from chromium with only changes to the namespace.
#include <limits>
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/stubs/common.h"
namespace google {
namespace protobuf {
namespace python {
template <bool SameSize, bool DestLarger,
bool DestIsSigned, bool SourceIsSigned>
struct IsValidNumericCastImpl;
#define BASE_NUMERIC_CAST_CASE_SPECIALIZATION(A, B, C, D, Code) \
template <> struct IsValidNumericCastImpl<A, B, C, D> { \
template <class Source, class DestBounds> static inline bool Test( \
Source source, DestBounds min, DestBounds max) { \
return Code; \
} \
}
#define BASE_NUMERIC_CAST_CASE_SAME_SIZE(DestSigned, SourceSigned, Code) \
BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
true, true, DestSigned, SourceSigned, Code); \
BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
true, false, DestSigned, SourceSigned, Code)
#define BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(DestSigned, SourceSigned, Code) \
BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
false, false, DestSigned, SourceSigned, Code); \
#define BASE_NUMERIC_CAST_CASE_DEST_LARGER(DestSigned, SourceSigned, Code) \
BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
false, true, DestSigned, SourceSigned, Code); \
// The three top level cases are:
// - Same size
// - Source larger
// - Dest larger
// And for each of those three cases, we handle the 4 different possibilities
// of signed and unsigned. This gives 12 cases to handle, which we enumerate
// below.
//
// The last argument in each of the macros is the actual comparison code. It
// has three arguments available, source (the value), and min/max which are
// the ranges of the destination.
// These are the cases where both types have the same size.
// Both signed.
BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, true, true);
// Both unsigned.
BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, false, true);
// Dest unsigned, Source signed.
BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, true, source >= 0);
// Dest signed, Source unsigned.
// This cast is OK because Dest's max must be less than Source's.
BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, false,
source <= static_cast<Source>(max));
// These are the cases where Source is larger.
// Both unsigned.
BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, false, source <= max);
// Both signed.
BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, true,
source >= min && source <= max);
// Dest is unsigned, Source is signed.
BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, true,
source >= 0 && source <= max);
// Dest is signed, Source is unsigned.
// This cast is OK because Dest's max must be less than Source's.
BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, false,
source <= static_cast<Source>(max));
// These are the cases where Dest is larger.
// Both unsigned.
BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, false, true);
// Both signed.
BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, true, true);
// Dest is unsigned, Source is signed.
BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, true, source >= 0);
// Dest is signed, Source is unsigned.
BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, false, true);
#undef BASE_NUMERIC_CAST_CASE_SPECIALIZATION
#undef BASE_NUMERIC_CAST_CASE_SAME_SIZE
#undef BASE_NUMERIC_CAST_CASE_SOURCE_LARGER
#undef BASE_NUMERIC_CAST_CASE_DEST_LARGER
// The main test for whether the conversion will under or overflow.
template <class Dest, class Source>
inline bool IsValidNumericCast(Source source) {
typedef std::numeric_limits<Source> SourceLimits;
typedef std::numeric_limits<Dest> DestLimits;
static_assert(SourceLimits::is_specialized, "argument must be numeric");
static_assert(SourceLimits::is_integer, "argument must be integral");
static_assert(DestLimits::is_specialized, "result must be numeric");
static_assert(DestLimits::is_integer, "result must be integral");
return IsValidNumericCastImpl<
sizeof(Dest) == sizeof(Source),
(sizeof(Dest) > sizeof(Source)),
DestLimits::is_signed,
SourceLimits::is_signed>::Test(
source,
DestLimits::min(),
DestLimits::max());
}
// checked_numeric_cast<> is analogous to static_cast<> for numeric types,
// except that it CHECKs that the specified numeric conversion will not
// overflow or underflow. Floating point arguments are not currently allowed
// (this is static_asserted), though this could be supported if necessary.
template <class Dest, class Source>
inline Dest checked_numeric_cast(Source source) {
GOOGLE_CHECK(IsValidNumericCast<Dest>(source));
return static_cast<Dest>(source);
}
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__

View File

@@ -0,0 +1,99 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Author: tibell@google.com (Johan Tibell)
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__
#define PY_SSIZE_T_CLEAN
#include <Python.h>
namespace google {
namespace protobuf {
namespace python {
// Owns a python object and decrements the reference count on destruction.
// This class is not threadsafe.
template <typename PyObjectStruct>
class ScopedPythonPtr {
public:
// Takes the ownership of the specified object to ScopedPythonPtr.
// The reference count of the specified py_object is not incremented.
explicit ScopedPythonPtr(PyObjectStruct* py_object = nullptr)
: ptr_(py_object) {}
ScopedPythonPtr(const ScopedPythonPtr&) = delete;
ScopedPythonPtr& operator=(const ScopedPythonPtr&) = delete;
// If a PyObject is owned, decrement its reference count.
~ScopedPythonPtr() { Py_XDECREF(ptr_); }
// Deletes the current owned object, if any.
// Then takes ownership of a new object without incrementing the reference
// count.
// This function must be called with a reference that you own.
// this->reset(this->get()) is wrong!
// this->reset(this->release()) is OK.
PyObjectStruct* reset(PyObjectStruct* p = nullptr) {
Py_XDECREF(ptr_);
ptr_ = p;
return ptr_;
}
// Releases ownership of the object without decrementing the reference count.
// The caller now owns the returned reference.
PyObjectStruct* release() {
PyObject* p = ptr_;
ptr_ = nullptr;
return p;
}
PyObjectStruct* get() const { return ptr_; }
PyObject* as_pyobject() const { return reinterpret_cast<PyObject*>(ptr_); }
// Increments the reference count of the current object.
// Should not be called when no object is held.
void inc() const { Py_INCREF(ptr_); }
// True when a ScopedPyObjectPtr and a raw pointer refer to the same object.
// Comparison operators are non reflexive.
bool operator==(const PyObjectStruct* p) const { return ptr_ == p; }
bool operator!=(const PyObjectStruct* p) const { return ptr_ != p; }
private:
PyObjectStruct* ptr_;
};
typedef ScopedPythonPtr<PyObject> ScopedPyObjectPtr;
} // namespace python
} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__

Some files were not shown because too many files have changed in this diff Show More