ADD: added new version of protobuf

This commit is contained in:
Henry Winkel
2022-12-20 10:09:28 +01:00
parent 4a79559129
commit 1e2b3dda7b
1513 changed files with 123720 additions and 83381 deletions

View File

@@ -30,4 +30,4 @@
# Copyright 2007 Google Inc. All Rights Reserved.
__version__ = '4.21.8'
__version__ = '4.21.12'

View File

@@ -873,14 +873,11 @@ class ServiceDescriptor(_NestedDescriptorBase):
Args:
name (str): Name of the method.
Returns:
MethodDescriptor: The descriptor for the requested method.
Raises:
KeyError: if the method cannot be found in the service.
MethodDescriptor or None: the descriptor for the requested method, if
found.
"""
return self.methods_by_name[name]
return self.methods_by_name.get(name, None)
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.ServiceDescriptorProto.
@@ -1021,7 +1018,13 @@ class FileDescriptor(DescriptorBase):
# FileDescriptor() is called from various places, not only from generated
# files, to register dynamic proto files and messages.
# pylint: disable=g-explicit-bool-comparison
if serialized_pb:
if serialized_pb == b'':
# Cpp generated code must be linked in if serialized_pb is ''
try:
return _message.default_pool.FindFileByName(name)
except KeyError:
raise RuntimeError('Please link in cpp generated lib for %s' % (name))
elif serialized_pb:
return _message.default_pool.AddSerializedFile(serialized_pb)
else:
return super(FileDescriptor, cls).__new__(cls)

View File

@@ -144,6 +144,9 @@ class DescriptorPool(object):
self._service_descriptors = {}
self._file_descriptors = {}
self._toplevel_extensions = {}
# TODO(jieluo): Remove _file_desc_by_toplevel_extension after
# maybe year 2020 for compatibility issue (with 3.4.1 only).
self._file_desc_by_toplevel_extension = {}
self._top_enum_values = {}
# We store extensions in two two-level mappings: The first key is the
# descriptor of the message being extended, the second key is the extension
@@ -217,7 +220,7 @@ class DescriptorPool(object):
file_desc.serialized_pb = serialized_file_desc_proto
return file_desc
# Add Descriptor to descriptor pool is deprecated. Please use Add()
# Add Descriptor to descriptor pool is dreprecated. Please use Add()
# or AddSerializedFile() to add a FileDescriptorProto instead.
@_Deprecated
def AddDescriptor(self, desc):
@@ -242,7 +245,7 @@ class DescriptorPool(object):
self._descriptors[desc.full_name] = desc
self._AddFileDescriptor(desc.file)
# Add EnumDescriptor to descriptor pool is deprecated. Please use Add()
# Add EnumDescriptor to descriptor pool is dreprecated. Please use Add()
# or AddSerializedFile() to add a FileDescriptorProto instead.
@_Deprecated
def AddEnumDescriptor(self, enum_desc):
@@ -283,7 +286,7 @@ class DescriptorPool(object):
self._top_enum_values[full_name] = enum_value
self._AddFileDescriptor(enum_desc.file)
# Add ServiceDescriptor to descriptor pool is deprecated. Please use Add()
# Add ServiceDescriptor to descriptor pool is dreprecated. Please use Add()
# or AddSerializedFile() to add a FileDescriptorProto instead.
@_Deprecated
def AddServiceDescriptor(self, service_desc):
@@ -304,7 +307,7 @@ class DescriptorPool(object):
service_desc.file.name)
self._service_descriptors[service_desc.full_name] = service_desc
# Add ExtensionDescriptor to descriptor pool is deprecated. Please use Add()
# Add ExtensionDescriptor to descriptor pool is dreprecated. Please use Add()
# or AddSerializedFile() to add a FileDescriptorProto instead.
@_Deprecated
def AddExtensionDescriptor(self, extension):
@@ -328,8 +331,6 @@ class DescriptorPool(object):
raise TypeError('Expected an extension descriptor.')
if extension.extension_scope is None:
self._CheckConflictRegister(
extension, extension.full_name, extension.file.name)
self._toplevel_extensions[extension.full_name] = extension
try:
@@ -371,6 +372,12 @@ class DescriptorPool(object):
"""
self._AddFileDescriptor(file_desc)
# TODO(jieluo): This is a temporary solution for FieldDescriptor.file.
# FieldDescriptor.file is added in code gen. Remove this solution after
# maybe 2020 for compatibility reason (with 3.4.1 only).
for extension in file_desc.extensions_by_name.values():
self._file_desc_by_toplevel_extension[
extension.full_name] = file_desc
def _AddFileDescriptor(self, file_desc):
"""Adds a FileDescriptor to the pool, non-recursively.
@@ -476,7 +483,7 @@ class DescriptorPool(object):
pass
try:
return self._toplevel_extensions[symbol].file
return self._file_desc_by_toplevel_extension[symbol]
except KeyError:
pass
@@ -785,6 +792,8 @@ class DescriptorPool(object):
file_descriptor.package, scope)
file_descriptor.extensions_by_name[extension_desc.name] = (
extension_desc)
self._file_desc_by_toplevel_extension[extension_desc.full_name] = (
file_descriptor)
for desc_proto in file_proto.message_type:
self._SetAllFieldTypes(file_proto.package, desc_proto, scope)

View File

@@ -151,6 +151,12 @@ def Type():
return _implementation_type
def _SetType(implementation_type):
"""Never use! Only for protobuf benchmark."""
global _implementation_type
_implementation_type = implementation_type
# See comment on 'Type' above.
# TODO(jieluo): Remove the API, it returns a constant. b/228102101
def Version():

View File

@@ -33,6 +33,7 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
import copy
import os
import unittest
import warnings
@@ -414,19 +415,6 @@ class DescriptorPoolTestBase(object):
field = file_json.message_types_by_name['class'].fields_by_name['int_field']
self.assertEqual(field.json_name, 'json_int')
def testAddSerializedFileTwice(self):
if isinstance(self, SecondaryDescriptorFromDescriptorDB):
if api_implementation.Type() != 'python':
# Cpp extension cannot call Add on a DescriptorPool
# that uses a DescriptorDatabase.
# TODO(jieluo): Fix python and cpp extension diff.
return
self.pool = descriptor_pool.DescriptorPool()
file1_first = self.pool.AddSerializedFile(
self.factory_test1_fd.SerializeToString())
file1_again = self.pool.AddSerializedFile(
self.factory_test1_fd.SerializeToString())
self.assertIs(file1_first, file1_again)
def testEnumDefaultValue(self):
"""Test the default value of enums which don't start at zero."""

View File

@@ -118,30 +118,6 @@ class DescriptorTest(unittest.TestCase):
def GetDescriptorPool(self):
return symbol_database.Default().pool
def testMissingPackage(self):
file_proto = descriptor_pb2.FileDescriptorProto(
name='some/filename/some.proto')
serialized = file_proto.SerializeToString()
pool = descriptor_pool.DescriptorPool()
file_descriptor = pool.AddSerializedFile(serialized)
self.assertEqual('', file_descriptor.package)
def testEmptyPackage(self):
file_proto = descriptor_pb2.FileDescriptorProto(
name='some/filename/some.proto', package='')
serialized = file_proto.SerializeToString()
pool = descriptor_pool.DescriptorPool()
file_descriptor = pool.AddSerializedFile(serialized)
self.assertEqual('', file_descriptor.package)
def testFindMethodByName(self):
service_descriptor = (unittest_custom_options_pb2.
TestServiceWithCustomOptions.DESCRIPTOR)
method_descriptor = service_descriptor.FindMethodByName('Foo')
self.assertEqual(method_descriptor.name, 'Foo')
with self.assertRaises(KeyError):
service_descriptor.FindMethodByName('MethodDoesNotExist')
def testEnumValueName(self):
self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4),
'FOREIGN_FOO')
@@ -619,12 +595,6 @@ class GeneratedDescriptorTest(unittest.TestCase):
def CheckDescriptorMapping(self, mapping):
# Verifies that a property like 'messageDescriptor.fields' has all the
# properties of an immutable abc.Mapping.
iterated_keys = []
for key in mapping:
iterated_keys.append(key)
self.assertEqual(len(iterated_keys), len(mapping))
self.assertEqual(set(iterated_keys), set(mapping.keys()))
self.assertNotEqual(
mapping, unittest_pb2.TestAllExtensions.DESCRIPTOR.fields_by_name)
self.assertNotEqual(mapping, {})
@@ -641,15 +611,10 @@ class GeneratedDescriptorTest(unittest.TestCase):
with self.assertRaises(TypeError):
mapping.get()
# TODO(jieluo): Fix python and cpp extension diff.
if api_implementation.Type() == 'cpp':
self.assertEqual(None, mapping.get([]))
else:
if api_implementation.Type() == 'python':
self.assertRaises(TypeError, mapping.get, [])
with self.assertRaises(TypeError):
if [] in mapping:
pass
with self.assertRaises(TypeError):
_ = mapping[[]]
else:
self.assertEqual(None, mapping.get([]))
# keys(), iterkeys() &co
item = (next(iter(mapping.keys())), next(iter(mapping.values())))
self.assertEqual(item, next(iter(mapping.items())))
@@ -661,12 +626,10 @@ class GeneratedDescriptorTest(unittest.TestCase):
self.assertRaises(KeyError, mapping.__getitem__, 'key_error')
self.assertRaises(KeyError, mapping.__getitem__, len(mapping) + 1)
# TODO(jieluo): Add __repr__ support for DescriptorMapping.
if api_implementation.Type() == 'cpp':
self.assertEqual(str(mapping)[0], '<')
else:
print(str(dict(mapping.items()))[:100])
print(str(mapping)[:100])
if api_implementation.Type() == 'python':
self.assertEqual(len(str(dict(mapping.items()))), len(str(mapping)))
else:
self.assertEqual(str(mapping)[0], '<')
def testDescriptor(self):
message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR

View File

@@ -1,333 +0,0 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Contains FieldMask class."""
from google.protobuf.descriptor import FieldDescriptor
class FieldMask(object):
"""Class for FieldMask message type."""
__slots__ = ()
def ToJsonString(self):
"""Converts FieldMask to string according to proto3 JSON spec."""
camelcase_paths = []
for path in self.paths:
camelcase_paths.append(_SnakeCaseToCamelCase(path))
return ','.join(camelcase_paths)
def FromJsonString(self, value):
"""Converts string to FieldMask according to proto3 JSON spec."""
if not isinstance(value, str):
raise ValueError('FieldMask JSON value not a string: {!r}'.format(value))
self.Clear()
if value:
for path in value.split(','):
self.paths.append(_CamelCaseToSnakeCase(path))
def IsValidForDescriptor(self, message_descriptor):
"""Checks whether the FieldMask is valid for Message Descriptor."""
for path in self.paths:
if not _IsValidPath(message_descriptor, path):
return False
return True
def AllFieldsFromDescriptor(self, message_descriptor):
"""Gets all direct fields of Message Descriptor to FieldMask."""
self.Clear()
for field in message_descriptor.fields:
self.paths.append(field.name)
def CanonicalFormFromMask(self, mask):
"""Converts a FieldMask to the canonical form.
Removes paths that are covered by another path. For example,
"foo.bar" is covered by "foo" and will be removed if "foo"
is also in the FieldMask. Then sorts all paths in alphabetical order.
Args:
mask: The original FieldMask to be converted.
"""
tree = _FieldMaskTree(mask)
tree.ToFieldMask(self)
def Union(self, mask1, mask2):
"""Merges mask1 and mask2 into this FieldMask."""
_CheckFieldMaskMessage(mask1)
_CheckFieldMaskMessage(mask2)
tree = _FieldMaskTree(mask1)
tree.MergeFromFieldMask(mask2)
tree.ToFieldMask(self)
def Intersect(self, mask1, mask2):
"""Intersects mask1 and mask2 into this FieldMask."""
_CheckFieldMaskMessage(mask1)
_CheckFieldMaskMessage(mask2)
tree = _FieldMaskTree(mask1)
intersection = _FieldMaskTree()
for path in mask2.paths:
tree.IntersectPath(path, intersection)
intersection.ToFieldMask(self)
def MergeMessage(
self, source, destination,
replace_message_field=False, replace_repeated_field=False):
"""Merges fields specified in FieldMask from source to destination.
Args:
source: Source message.
destination: The destination message to be merged into.
replace_message_field: Replace message field if True. Merge message
field if False.
replace_repeated_field: Replace repeated field if True. Append
elements of repeated field if False.
"""
tree = _FieldMaskTree(self)
tree.MergeMessage(
source, destination, replace_message_field, replace_repeated_field)
def _IsValidPath(message_descriptor, path):
"""Checks whether the path is valid for Message Descriptor."""
parts = path.split('.')
last = parts.pop()
for name in parts:
field = message_descriptor.fields_by_name.get(name)
if (field is None or
field.label == FieldDescriptor.LABEL_REPEATED or
field.type != FieldDescriptor.TYPE_MESSAGE):
return False
message_descriptor = field.message_type
return last in message_descriptor.fields_by_name
def _CheckFieldMaskMessage(message):
"""Raises ValueError if message is not a FieldMask."""
message_descriptor = message.DESCRIPTOR
if (message_descriptor.name != 'FieldMask' or
message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
raise ValueError('Message {0} is not a FieldMask.'.format(
message_descriptor.full_name))
def _SnakeCaseToCamelCase(path_name):
"""Converts a path name from snake_case to camelCase."""
result = []
after_underscore = False
for c in path_name:
if c.isupper():
raise ValueError(
'Fail to print FieldMask to Json string: Path name '
'{0} must not contain uppercase letters.'.format(path_name))
if after_underscore:
if c.islower():
result.append(c.upper())
after_underscore = False
else:
raise ValueError(
'Fail to print FieldMask to Json string: The '
'character after a "_" must be a lowercase letter '
'in path name {0}.'.format(path_name))
elif c == '_':
after_underscore = True
else:
result += c
if after_underscore:
raise ValueError('Fail to print FieldMask to Json string: Trailing "_" '
'in path name {0}.'.format(path_name))
return ''.join(result)
def _CamelCaseToSnakeCase(path_name):
"""Converts a field name from camelCase to snake_case."""
result = []
for c in path_name:
if c == '_':
raise ValueError('Fail to parse FieldMask: Path name '
'{0} must not contain "_"s.'.format(path_name))
if c.isupper():
result += '_'
result += c.lower()
else:
result += c
return ''.join(result)
class _FieldMaskTree(object):
"""Represents a FieldMask in a tree structure.
For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
the FieldMaskTree will be:
[_root] -+- foo -+- bar
| |
| +- baz
|
+- bar --- baz
In the tree, each leaf node represents a field path.
"""
__slots__ = ('_root',)
def __init__(self, field_mask=None):
"""Initializes the tree by FieldMask."""
self._root = {}
if field_mask:
self.MergeFromFieldMask(field_mask)
def MergeFromFieldMask(self, field_mask):
"""Merges a FieldMask to the tree."""
for path in field_mask.paths:
self.AddPath(path)
def AddPath(self, path):
"""Adds a field path into the tree.
If the field path to add is a sub-path of an existing field path
in the tree (i.e., a leaf node), it means the tree already matches
the given path so nothing will be added to the tree. If the path
matches an existing non-leaf node in the tree, that non-leaf node
will be turned into a leaf node with all its children removed because
the path matches all the node's children. Otherwise, a new path will
be added.
Args:
path: The field path to add.
"""
node = self._root
for name in path.split('.'):
if name not in node:
node[name] = {}
elif not node[name]:
# Pre-existing empty node implies we already have this entire tree.
return
node = node[name]
# Remove any sub-trees we might have had.
node.clear()
def ToFieldMask(self, field_mask):
"""Converts the tree to a FieldMask."""
field_mask.Clear()
_AddFieldPaths(self._root, '', field_mask)
def IntersectPath(self, path, intersection):
"""Calculates the intersection part of a field path with this tree.
Args:
path: The field path to calculates.
intersection: The out tree to record the intersection part.
"""
node = self._root
for name in path.split('.'):
if name not in node:
return
elif not node[name]:
intersection.AddPath(path)
return
node = node[name]
intersection.AddLeafNodes(path, node)
def AddLeafNodes(self, prefix, node):
"""Adds leaf nodes begin with prefix to this tree."""
if not node:
self.AddPath(prefix)
for name in node:
child_path = prefix + '.' + name
self.AddLeafNodes(child_path, node[name])
def MergeMessage(
self, source, destination,
replace_message, replace_repeated):
"""Merge all fields specified by this tree from source to destination."""
_MergeMessage(
self._root, source, destination, replace_message, replace_repeated)
def _StrConvert(value):
"""Converts value to str if it is not."""
# This file is imported by c extension and some methods like ClearField
# requires string for the field name. py2/py3 has different text
# type and may use unicode.
if not isinstance(value, str):
return value.encode('utf-8')
return value
def _MergeMessage(
node, source, destination, replace_message, replace_repeated):
"""Merge all fields specified by a sub-tree from source to destination."""
source_descriptor = source.DESCRIPTOR
for name in node:
child = node[name]
field = source_descriptor.fields_by_name[name]
if field is None:
raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
name, source_descriptor.full_name))
if child:
# Sub-paths are only allowed for singular message fields.
if (field.label == FieldDescriptor.LABEL_REPEATED or
field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
raise ValueError('Error: Field {0} in message {1} is not a singular '
'message field and cannot have sub-fields.'.format(
name, source_descriptor.full_name))
if source.HasField(name):
_MergeMessage(
child, getattr(source, name), getattr(destination, name),
replace_message, replace_repeated)
continue
if field.label == FieldDescriptor.LABEL_REPEATED:
if replace_repeated:
destination.ClearField(_StrConvert(name))
repeated_source = getattr(source, name)
repeated_destination = getattr(destination, name)
repeated_destination.MergeFrom(repeated_source)
else:
if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
if replace_message:
destination.ClearField(_StrConvert(name))
if source.HasField(name):
getattr(destination, name).MergeFrom(getattr(source, name))
else:
setattr(destination, name, getattr(source, name))
def _AddFieldPaths(node, prefix, field_mask):
"""Adds the field paths descended from node to field_mask."""
if not node and prefix:
field_mask.paths.append(prefix)
return
for name in sorted(node):
if prefix:
child_path = prefix + '.' + name
else:
child_path = name
_AddFieldPaths(node[name], child_path, field_mask)

View File

@@ -1,400 +0,0 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test for google.protobuf.internal.well_known_types."""
import unittest
from google.protobuf import field_mask_pb2
from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_pb2
from google.protobuf.internal import field_mask
from google.protobuf.internal import test_util
from google.protobuf import descriptor
class FieldMaskTest(unittest.TestCase):
def testStringFormat(self):
mask = field_mask_pb2.FieldMask()
self.assertEqual('', mask.ToJsonString())
mask.paths.append('foo')
self.assertEqual('foo', mask.ToJsonString())
mask.paths.append('bar')
self.assertEqual('foo,bar', mask.ToJsonString())
mask.FromJsonString('')
self.assertEqual('', mask.ToJsonString())
mask.FromJsonString('foo')
self.assertEqual(['foo'], mask.paths)
mask.FromJsonString('foo,bar')
self.assertEqual(['foo', 'bar'], mask.paths)
# Test camel case
mask.Clear()
mask.paths.append('foo_bar')
self.assertEqual('fooBar', mask.ToJsonString())
mask.paths.append('bar_quz')
self.assertEqual('fooBar,barQuz', mask.ToJsonString())
mask.FromJsonString('')
self.assertEqual('', mask.ToJsonString())
self.assertEqual([], mask.paths)
mask.FromJsonString('fooBar')
self.assertEqual(['foo_bar'], mask.paths)
mask.FromJsonString('fooBar,barQuz')
self.assertEqual(['foo_bar', 'bar_quz'], mask.paths)
def testDescriptorToFieldMask(self):
mask = field_mask_pb2.FieldMask()
msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
mask.AllFieldsFromDescriptor(msg_descriptor)
self.assertEqual(76, len(mask.paths))
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
for field in msg_descriptor.fields:
self.assertTrue(field.name in mask.paths)
def testIsValidForDescriptor(self):
msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
# Empty mask
mask = field_mask_pb2.FieldMask()
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
# All fields from descriptor
mask.AllFieldsFromDescriptor(msg_descriptor)
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
# Child under optional message
mask.paths.append('optional_nested_message.bb')
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
# Repeated field is only allowed in the last position of path
mask.paths.append('repeated_nested_message.bb')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid top level field
mask = field_mask_pb2.FieldMask()
mask.paths.append('xxx')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid field in root
mask = field_mask_pb2.FieldMask()
mask.paths.append('xxx.zzz')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid field in internal node
mask = field_mask_pb2.FieldMask()
mask.paths.append('optional_nested_message.xxx.zzz')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid field in leaf
mask = field_mask_pb2.FieldMask()
mask.paths.append('optional_nested_message.xxx')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
def testCanonicalFrom(self):
mask = field_mask_pb2.FieldMask()
out_mask = field_mask_pb2.FieldMask()
# Paths will be sorted.
mask.FromJsonString('baz.quz,bar,foo')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('bar,baz.quz,foo', out_mask.ToJsonString())
# Duplicated paths will be removed.
mask.FromJsonString('foo,bar,foo')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('bar,foo', out_mask.ToJsonString())
# Sub-paths of other paths will be removed.
mask.FromJsonString('foo.b1,bar.b1,foo.b2,bar')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString())
# Test more deeply nested cases.
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo.bar.baz1,foo.bar.baz2',
out_mask.ToJsonString())
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo.bar.baz1,foo.bar.baz2',
out_mask.ToJsonString())
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo.bar', out_mask.ToJsonString())
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo', out_mask.ToJsonString())
def testUnion(self):
mask1 = field_mask_pb2.FieldMask()
mask2 = field_mask_pb2.FieldMask()
out_mask = field_mask_pb2.FieldMask()
mask1.FromJsonString('foo,baz')
mask2.FromJsonString('bar,quz')
out_mask.Union(mask1, mask2)
self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString())
# Overlap with duplicated paths.
mask1.FromJsonString('foo,baz.bb')
mask2.FromJsonString('baz.bb,quz')
out_mask.Union(mask1, mask2)
self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString())
# Overlap with paths covering some other paths.
mask1.FromJsonString('foo.bar.baz,quz')
mask2.FromJsonString('foo.bar,bar')
out_mask.Union(mask1, mask2)
self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString())
src = unittest_pb2.TestAllTypes()
with self.assertRaises(ValueError):
out_mask.Union(src, mask2)
def testIntersect(self):
mask1 = field_mask_pb2.FieldMask()
mask2 = field_mask_pb2.FieldMask()
out_mask = field_mask_pb2.FieldMask()
# Test cases without overlapping.
mask1.FromJsonString('foo,baz')
mask2.FromJsonString('bar,quz')
out_mask.Intersect(mask1, mask2)
self.assertEqual('', out_mask.ToJsonString())
self.assertEqual(len(out_mask.paths), 0)
self.assertEqual(out_mask.paths, [])
# Overlap with duplicated paths.
mask1.FromJsonString('foo,baz.bb')
mask2.FromJsonString('baz.bb,quz')
out_mask.Intersect(mask1, mask2)
self.assertEqual('baz.bb', out_mask.ToJsonString())
# Overlap with paths covering some other paths.
mask1.FromJsonString('foo.bar.baz,quz')
mask2.FromJsonString('foo.bar,bar')
out_mask.Intersect(mask1, mask2)
self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
mask1.FromJsonString('foo.bar,bar')
mask2.FromJsonString('foo.bar.baz,quz')
out_mask.Intersect(mask1, mask2)
self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
# Intersect '' with ''
mask1.Clear()
mask2.Clear()
mask1.paths.append('')
mask2.paths.append('')
self.assertEqual(mask1.paths, [''])
self.assertEqual('', mask1.ToJsonString())
out_mask.Intersect(mask1, mask2)
self.assertEqual(out_mask.paths, [])
def testMergeMessageWithoutMapFields(self):
# Test merge one field.
src = unittest_pb2.TestAllTypes()
test_util.SetAllFields(src)
for field in src.DESCRIPTOR.fields:
if field.containing_oneof:
continue
field_name = field.name
dst = unittest_pb2.TestAllTypes()
# Only set one path to mask.
mask = field_mask_pb2.FieldMask()
mask.paths.append(field_name)
mask.MergeMessage(src, dst)
# The expected result message.
msg = unittest_pb2.TestAllTypes()
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
repeated_src = getattr(src, field_name)
repeated_msg = getattr(msg, field_name)
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
for item in repeated_src:
repeated_msg.add().CopyFrom(item)
else:
repeated_msg.extend(repeated_src)
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
getattr(msg, field_name).CopyFrom(getattr(src, field_name))
else:
setattr(msg, field_name, getattr(src, field_name))
# Only field specified in mask is merged.
self.assertEqual(msg, dst)
# Test merge nested fields.
nested_src = unittest_pb2.NestedTestAllTypes()
nested_dst = unittest_pb2.NestedTestAllTypes()
nested_src.child.payload.optional_int32 = 1234
nested_src.child.child.payload.optional_int32 = 5678
mask = field_mask_pb2.FieldMask()
mask.FromJsonString('child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(0, nested_dst.child.child.payload.optional_int32)
mask.FromJsonString('child.child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
nested_dst.Clear()
mask.FromJsonString('child.child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(0, nested_dst.child.payload.optional_int32)
self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
nested_dst.Clear()
mask.FromJsonString('child')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
# Test MergeOptions.
nested_dst.Clear()
nested_dst.child.payload.optional_int64 = 4321
# Message fields will be merged by default.
mask.FromJsonString('child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(4321, nested_dst.child.payload.optional_int64)
# Change the behavior to replace message fields.
mask.FromJsonString('child.payload')
mask.MergeMessage(nested_src, nested_dst, True, False)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(0, nested_dst.child.payload.optional_int64)
# By default, fields missing in source are not cleared in destination.
nested_dst.payload.optional_int32 = 1234
self.assertTrue(nested_dst.HasField('payload'))
mask.FromJsonString('payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertTrue(nested_dst.HasField('payload'))
# But they are cleared when replacing message fields.
nested_dst.Clear()
nested_dst.payload.optional_int32 = 1234
mask.FromJsonString('payload')
mask.MergeMessage(nested_src, nested_dst, True, False)
self.assertFalse(nested_dst.HasField('payload'))
nested_src.payload.repeated_int32.append(1234)
nested_dst.payload.repeated_int32.append(5678)
# Repeated fields will be appended by default.
mask.FromJsonString('payload.repeatedInt32')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(2, len(nested_dst.payload.repeated_int32))
self.assertEqual(5678, nested_dst.payload.repeated_int32[0])
self.assertEqual(1234, nested_dst.payload.repeated_int32[1])
# Change the behavior to replace repeated fields.
mask.FromJsonString('payload.repeatedInt32')
mask.MergeMessage(nested_src, nested_dst, False, True)
self.assertEqual(1, len(nested_dst.payload.repeated_int32))
self.assertEqual(1234, nested_dst.payload.repeated_int32[0])
# Test Merge oneof field.
new_msg = unittest_pb2.TestOneof2()
dst = unittest_pb2.TestOneof2()
dst.foo_message.moo_int = 1
mask = field_mask_pb2.FieldMask()
mask.FromJsonString('fooMessage,fooLazyMessage.mooInt')
mask.MergeMessage(new_msg, dst)
self.assertTrue(dst.HasField('foo_message'))
self.assertFalse(dst.HasField('foo_lazy_message'))
def testMergeMessageWithMapField(self):
empty_map = map_unittest_pb2.TestRecursiveMapMessage()
src_level_2 = map_unittest_pb2.TestRecursiveMapMessage()
src_level_2.a['src level 2'].CopyFrom(empty_map)
src = map_unittest_pb2.TestRecursiveMapMessage()
src.a['common key'].CopyFrom(src_level_2)
src.a['src level 1'].CopyFrom(src_level_2)
dst_level_2 = map_unittest_pb2.TestRecursiveMapMessage()
dst_level_2.a['dst level 2'].CopyFrom(empty_map)
dst = map_unittest_pb2.TestRecursiveMapMessage()
dst.a['common key'].CopyFrom(dst_level_2)
dst.a['dst level 1'].CopyFrom(empty_map)
mask = field_mask_pb2.FieldMask()
mask.FromJsonString('a')
mask.MergeMessage(src, dst)
# map from dst is replaced with map from src.
self.assertEqual(dst.a['common key'], src_level_2)
self.assertEqual(dst.a['src level 1'], src_level_2)
self.assertEqual(dst.a['dst level 1'], empty_map)
def testMergeErrors(self):
src = unittest_pb2.TestAllTypes()
dst = unittest_pb2.TestAllTypes()
mask = field_mask_pb2.FieldMask()
test_util.SetAllFields(src)
mask.FromJsonString('optionalInt32.field')
with self.assertRaises(ValueError) as e:
mask.MergeMessage(src, dst)
self.assertEqual('Error: Field optional_int32 in message '
'protobuf_unittest.TestAllTypes is not a singular '
'message field and cannot have sub-fields.',
str(e.exception))
def testSnakeCaseToCamelCase(self):
self.assertEqual('fooBar',
field_mask._SnakeCaseToCamelCase('foo_bar'))
self.assertEqual('FooBar',
field_mask._SnakeCaseToCamelCase('_foo_bar'))
self.assertEqual('foo3Bar',
field_mask._SnakeCaseToCamelCase('foo3_bar'))
# No uppercase letter is allowed.
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: Path name Foo must '
'not contain uppercase letters.',
field_mask._SnakeCaseToCamelCase, 'Foo')
# Any character after a "_" must be a lowercase letter.
# 1. "_" cannot be followed by another "_".
# 2. "_" cannot be followed by a digit.
# 3. "_" cannot appear as the last character.
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: The character after a '
'"_" must be a lowercase letter in path name foo__bar.',
field_mask._SnakeCaseToCamelCase, 'foo__bar')
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: The character after a '
'"_" must be a lowercase letter in path name foo_3bar.',
field_mask._SnakeCaseToCamelCase, 'foo_3bar')
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: Trailing "_" in path '
'name foo_bar_.', field_mask._SnakeCaseToCamelCase, 'foo_bar_')
def testCamelCaseToSnakeCase(self):
self.assertEqual('foo_bar',
field_mask._CamelCaseToSnakeCase('fooBar'))
self.assertEqual('_foo_bar',
field_mask._CamelCaseToSnakeCase('FooBar'))
self.assertEqual('foo3_bar',
field_mask._CamelCaseToSnakeCase('foo3Bar'))
self.assertRaisesRegex(
ValueError,
'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.',
field_mask._CamelCaseToSnakeCase, 'foo_bar')
if __name__ == '__main__':
unittest.main()

View File

@@ -118,16 +118,6 @@ class MessageFactoryTest(unittest.TestCase):
'google.protobuf.python.internal.Factory2Message'))
self.assertTrue(hasattr(cls, 'additional_field'))
def testGetExistingPrototype(self):
factory = message_factory.MessageFactory()
# Get Existing Prototype should not create a new class.
cls = factory.GetPrototype(
descriptor=factory_test2_pb2.Factory2Message.DESCRIPTOR)
msg = factory_test2_pb2.Factory2Message()
self.assertIsInstance(msg, cls)
self.assertIsInstance(msg.factory_1_message,
factory_test1_pb2.Factory1Message)
def testGetMessages(self):
# performed twice because multiple calls with the same input must be allowed
for _ in range(2):

View File

@@ -34,6 +34,10 @@
Note that the golden messages exercise every known field type, thus this
test ends up exercising and verifying nearly all of the parsing and
serialization code in the whole library.
TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of
sense to call this a test of the "message" module, which only declares an
abstract interface.
"""
__author__ = 'gps@google.com (Gregory P. Smith)'
@@ -472,12 +476,6 @@ class MessageTest(unittest.TestCase):
'}\n')
self.assertEqual(sub_msg.bb, 1)
def testAssignRepeatedField(self, message_module):
msg = message_module.NestedTestAllTypes()
msg.payload.repeated_int32[:] = [1, 2, 3, 4]
self.assertEqual(4, len(msg.payload.repeated_int32))
self.assertEqual([1, 2, 3, 4], msg.payload.repeated_int32)
def testMergeFromRepeatedField(self, message_module):
msg = message_module.TestAllTypes()
msg.repeated_int32.append(1)
@@ -889,7 +887,6 @@ class MessageTest(unittest.TestCase):
def testOneofClearField(self, message_module):
m = message_module.TestAllTypes()
m.ClearField('oneof_field')
m.oneof_uint32 = 11
m.ClearField('oneof_field')
if message_module is unittest_pb2:
@@ -1769,19 +1766,6 @@ class Proto3Test(unittest.TestCase):
with self.assertRaises(TypeError):
123 in msg.map_string_string
def testScalarMapComparison(self):
msg1 = map_unittest_pb2.TestMap()
msg2 = map_unittest_pb2.TestMap()
self.assertEqual(msg1.map_int32_int32, msg2.map_int32_int32)
def testMessageMapComparison(self):
msg1 = map_unittest_pb2.TestMap()
msg2 = map_unittest_pb2.TestMap()
self.assertEqual(msg1.map_int32_foreign_message,
msg2.map_int32_foreign_message)
def testMapGet(self):
# Need to test that get() properly returns the default, even though the dict
# has defaultdict-like semantics.
@@ -2469,26 +2453,6 @@ class Proto3Test(unittest.TestCase):
with self.assertRaises(ValueError):
unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801\ud801')
def testCrashNullAA(self):
self.assertEqual(
unittest_proto3_arena_pb2.TestAllTypes.NestedMessage(),
unittest_proto3_arena_pb2.TestAllTypes.NestedMessage())
def testCrashNullAB(self):
self.assertEqual(
unittest_proto3_arena_pb2.TestAllTypes.NestedMessage(),
unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message)
def testCrashNullBA(self):
self.assertEqual(
unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message,
unittest_proto3_arena_pb2.TestAllTypes.NestedMessage())
def testCrashNullBB(self):
self.assertEqual(
unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message,
unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message)

View File

@@ -30,6 +30,7 @@
syntax = "proto2";
package google.protobuf.python.internal;
message TestEnumValues {
@@ -52,3 +53,4 @@ message TestMissingEnumValues {
message JustString {
required string dummy = 1;
}

View File

@@ -1,215 +0,0 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test use of numpy types with repeated and non-repeated scalar fields."""
import unittest
import numpy as np
from google.protobuf import unittest_pb2
from google.protobuf.internal import testing_refleaks
message = unittest_pb2.TestAllTypes()
np_float_scalar = np.float64(0.0)
np_1_float_array = np.zeros(shape=(1,), dtype=np.float64)
np_2_float_array = np.zeros(shape=(2,), dtype=np.float64)
np_11_float_array = np.zeros(shape=(1, 1), dtype=np.float64)
np_22_float_array = np.zeros(shape=(2, 2), dtype=np.float64)
np_int_scalar = np.int64(0)
np_1_int_array = np.zeros(shape=(1,), dtype=np.int64)
np_2_int_array = np.zeros(shape=(2,), dtype=np.int64)
np_11_int_array = np.zeros(shape=(1, 1), dtype=np.int64)
np_22_int_array = np.zeros(shape=(2, 2), dtype=np.int64)
np_uint_scalar = np.uint64(0)
np_1_uint_array = np.zeros(shape=(1,), dtype=np.uint64)
np_2_uint_array = np.zeros(shape=(2,), dtype=np.uint64)
np_11_uint_array = np.zeros(shape=(1, 1), dtype=np.uint64)
np_22_uint_array = np.zeros(shape=(2, 2), dtype=np.uint64)
np_bool_scalar = np.bool_(False)
np_1_bool_array = np.zeros(shape=(1,), dtype=np.bool_)
np_2_bool_array = np.zeros(shape=(2,), dtype=np.bool_)
np_11_bool_array = np.zeros(shape=(1, 1), dtype=np.bool_)
np_22_bool_array = np.zeros(shape=(2, 2), dtype=np.bool_)
@testing_refleaks.TestCase
class NumpyIntProtoTest(unittest.TestCase):
# Assigning dim 1 ndarray of ints to repeated field should pass
def testNumpyDim1IntArrayToRepeated_IsValid(self):
message.repeated_int64[:] = np_1_int_array
message.repeated_int64[:] = np_2_int_array
message.repeated_uint64[:] = np_1_uint_array
message.repeated_uint64[:] = np_2_uint_array
# Assigning dim 2 ndarray of ints to repeated field should fail
def testNumpyDim2IntArrayToRepeated_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.repeated_int64[:] = np_11_int_array
with self.assertRaises(TypeError):
message.repeated_int64[:] = np_22_int_array
with self.assertRaises(TypeError):
message.repeated_uint64[:] = np_11_uint_array
with self.assertRaises(TypeError):
message.repeated_uint64[:] = np_22_uint_array
# Assigning any ndarray of floats to repeated int field should fail
def testNumpyFloatArrayToRepeated_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.repeated_int64[:] = np_1_float_array
with self.assertRaises(TypeError):
message.repeated_int64[:] = np_11_float_array
with self.assertRaises(TypeError):
message.repeated_int64[:] = np_22_float_array
# Assigning any np int to scalar field should pass
def testNumpyIntScalarToScalar_IsValid(self):
message.optional_int64 = np_int_scalar
message.optional_uint64 = np_uint_scalar
# Assigning any ndarray of ints to scalar field should fail
def testNumpyIntArrayToScalar_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.optional_int64 = np_1_int_array
with self.assertRaises(TypeError):
message.optional_int64 = np_11_int_array
with self.assertRaises(TypeError):
message.optional_int64 = np_22_int_array
with self.assertRaises(TypeError):
message.optional_uint64 = np_1_uint_array
with self.assertRaises(TypeError):
message.optional_uint64 = np_11_uint_array
with self.assertRaises(TypeError):
message.optional_uint64 = np_22_uint_array
# Assigning any ndarray of floats to scalar field should fail
def testNumpyFloatArrayToScalar_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.optional_int64 = np_1_float_array
with self.assertRaises(TypeError):
message.optional_int64 = np_11_float_array
with self.assertRaises(TypeError):
message.optional_int64 = np_22_float_array
@testing_refleaks.TestCase
class NumpyFloatProtoTest(unittest.TestCase):
# Assigning dim 1 ndarray of floats to repeated field should pass
def testNumpyDim1FloatArrayToRepeated_IsValid(self):
message.repeated_float[:] = np_1_float_array
message.repeated_float[:] = np_2_float_array
# Assigning dim 2 ndarray of floats to repeated field should fail
def testNumpyDim2FloatArrayToRepeated_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.repeated_float[:] = np_11_float_array
with self.assertRaises(TypeError):
message.repeated_float[:] = np_22_float_array
# Assigning any np float to scalar field should pass
def testNumpyFloatScalarToScalar_IsValid(self):
message.optional_float = np_float_scalar
# Assigning any ndarray of float to scalar field should fail
def testNumpyFloatArrayToScalar_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.optional_float = np_1_float_array
with self.assertRaises(TypeError):
message.optional_float = np_11_float_array
with self.assertRaises(TypeError):
message.optional_float = np_22_float_array
@testing_refleaks.TestCase
class NumpyBoolProtoTest(unittest.TestCase):
# Assigning dim 1 ndarray of bool to repeated field should pass
def testNumpyDim1BoolArrayToRepeated_IsValid(self):
message.repeated_bool[:] = np_1_bool_array
message.repeated_bool[:] = np_2_bool_array
# Assigning dim 2 ndarray of bool to repeated field should fail
def testNumpyDim2BoolArrayToRepeated_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.repeated_bool[:] = np_11_bool_array
with self.assertRaises(TypeError):
message.repeated_bool[:] = np_22_bool_array
# Assigning any np bool to scalar field should pass
def testNumpyBoolScalarToScalar_IsValid(self):
message.optional_bool = np_bool_scalar
# Assigning any ndarray of bool to scalar field should fail
def testNumpyBoolArrayToScalar_RaisesTypeError(self):
with self.assertRaises(TypeError):
message.optional_bool = np_1_bool_array
with self.assertRaises(TypeError):
message.optional_bool = np_11_bool_array
with self.assertRaises(TypeError):
message.optional_bool = np_22_bool_array
@testing_refleaks.TestCase
class NumpyProtoIndexingTest(unittest.TestCase):
def testNumpyIntScalarIndexing_Passes(self):
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
self.assertEqual(0, data.repeated_int64[np.int64(0)])
def testNumpyNegative1IntScalarIndexing_Passes(self):
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
self.assertEqual(2, data.repeated_int64[np.int64(-1)])
def testNumpyFloatScalarIndexing_Fails(self):
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
with self.assertRaises(TypeError):
_ = data.repeated_int64[np.float64(0.0)]
def testNumpyIntArrayIndexing_Fails(self):
data = unittest_pb2.TestAllTypes(repeated_int64=[0, 1, 2])
with self.assertRaises(TypeError):
_ = data.repeated_int64[np.array([0])]
with self.assertRaises(TypeError):
_ = data.repeated_int64[np.ndarray((1,), buffer=np.array([0]), dtype=int)]
with self.assertRaises(TypeError):
_ = data.repeated_int64[np.ndarray((1, 1),
buffer=np.array([0]),
dtype=int)]
if __name__ == '__main__':
unittest.main()

View File

@@ -30,7 +30,7 @@
// Author: qrczak@google.com (Marcin Kowalczyk)
#include "google/protobuf/python/python_protobuf.h"
#include <google/protobuf/python/python_protobuf.h>
namespace google {
namespace protobuf {

View File

@@ -1598,47 +1598,6 @@ class Proto2Tests(TextFormatBase):
self.assertEqual(23, message.message_set.Extensions[ext1].i)
self.assertEqual('foo', message.message_set.Extensions[ext2].str)
# Handle Any messages inside unknown extensions.
message = any_test_pb2.TestAny()
text = ('any_value {\n'
' [type.googleapis.com/google.protobuf.internal.TestAny] {\n'
' [unknown_extension] {\n'
' str: "string"\n'
' any_value {\n'
' [type.googleapis.com/protobuf_unittest.OneString] {\n'
' data: "string"\n'
' }\n'
' }\n'
' }\n'
' }\n'
'}\n'
'int32_value: 123')
text_format.Parse(text, message, allow_unknown_extension=True)
self.assertEqual(123, message.int32_value)
# Fail if invalid Any message type url inside unknown extensions.
message = any_test_pb2.TestAny()
text = ('any_value {\n'
' [type.googleapis.com.invalid/google.protobuf.internal.TestAny] {\n'
' [unknown_extension] {\n'
' str: "string"\n'
' any_value {\n'
' [type.googleapis.com/protobuf_unittest.OneString] {\n'
' data: "string"\n'
' }\n'
' }\n'
' }\n'
' }\n'
'}\n'
'int32_value: 123')
self.assertRaisesRegex(
text_format.ParseError,
'[type.googleapis.com.invalid/google.protobuf.internal.TestAny]',
text_format.Parse,
text,
message,
allow_unknown_extension=True)
def testParseBadIdentifier(self):
message = unittest_pb2.TestAllTypes()
text = ('optional_nested_message { "bb": 1 }')
@@ -2484,6 +2443,5 @@ class OptionalColonMessageToStringTest(unittest.TestCase):
self.assertEqual('repeated_int32: [1]\n', output)
if __name__ == '__main__':
unittest.main()

View File

@@ -44,9 +44,7 @@ import calendar
import collections.abc
import datetime
from google.protobuf.internal import field_mask
FieldMask = field_mask.FieldMask
from google.protobuf.descriptor import FieldDescriptor
_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
_NANOS_PER_SECOND = 1000000000
@@ -432,6 +430,306 @@ def _RoundTowardZero(value, divider):
return result
class FieldMask(object):
"""Class for FieldMask message type."""
__slots__ = ()
def ToJsonString(self):
"""Converts FieldMask to string according to proto3 JSON spec."""
camelcase_paths = []
for path in self.paths:
camelcase_paths.append(_SnakeCaseToCamelCase(path))
return ','.join(camelcase_paths)
def FromJsonString(self, value):
"""Converts string to FieldMask according to proto3 JSON spec."""
if not isinstance(value, str):
raise ValueError('FieldMask JSON value not a string: {!r}'.format(value))
self.Clear()
if value:
for path in value.split(','):
self.paths.append(_CamelCaseToSnakeCase(path))
def IsValidForDescriptor(self, message_descriptor):
"""Checks whether the FieldMask is valid for Message Descriptor."""
for path in self.paths:
if not _IsValidPath(message_descriptor, path):
return False
return True
def AllFieldsFromDescriptor(self, message_descriptor):
"""Gets all direct fields of Message Descriptor to FieldMask."""
self.Clear()
for field in message_descriptor.fields:
self.paths.append(field.name)
def CanonicalFormFromMask(self, mask):
"""Converts a FieldMask to the canonical form.
Removes paths that are covered by another path. For example,
"foo.bar" is covered by "foo" and will be removed if "foo"
is also in the FieldMask. Then sorts all paths in alphabetical order.
Args:
mask: The original FieldMask to be converted.
"""
tree = _FieldMaskTree(mask)
tree.ToFieldMask(self)
def Union(self, mask1, mask2):
"""Merges mask1 and mask2 into this FieldMask."""
_CheckFieldMaskMessage(mask1)
_CheckFieldMaskMessage(mask2)
tree = _FieldMaskTree(mask1)
tree.MergeFromFieldMask(mask2)
tree.ToFieldMask(self)
def Intersect(self, mask1, mask2):
"""Intersects mask1 and mask2 into this FieldMask."""
_CheckFieldMaskMessage(mask1)
_CheckFieldMaskMessage(mask2)
tree = _FieldMaskTree(mask1)
intersection = _FieldMaskTree()
for path in mask2.paths:
tree.IntersectPath(path, intersection)
intersection.ToFieldMask(self)
def MergeMessage(
self, source, destination,
replace_message_field=False, replace_repeated_field=False):
"""Merges fields specified in FieldMask from source to destination.
Args:
source: Source message.
destination: The destination message to be merged into.
replace_message_field: Replace message field if True. Merge message
field if False.
replace_repeated_field: Replace repeated field if True. Append
elements of repeated field if False.
"""
tree = _FieldMaskTree(self)
tree.MergeMessage(
source, destination, replace_message_field, replace_repeated_field)
def _IsValidPath(message_descriptor, path):
"""Checks whether the path is valid for Message Descriptor."""
parts = path.split('.')
last = parts.pop()
for name in parts:
field = message_descriptor.fields_by_name.get(name)
if (field is None or
field.label == FieldDescriptor.LABEL_REPEATED or
field.type != FieldDescriptor.TYPE_MESSAGE):
return False
message_descriptor = field.message_type
return last in message_descriptor.fields_by_name
def _CheckFieldMaskMessage(message):
"""Raises ValueError if message is not a FieldMask."""
message_descriptor = message.DESCRIPTOR
if (message_descriptor.name != 'FieldMask' or
message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
raise ValueError('Message {0} is not a FieldMask.'.format(
message_descriptor.full_name))
def _SnakeCaseToCamelCase(path_name):
"""Converts a path name from snake_case to camelCase."""
result = []
after_underscore = False
for c in path_name:
if c.isupper():
raise ValueError(
'Fail to print FieldMask to Json string: Path name '
'{0} must not contain uppercase letters.'.format(path_name))
if after_underscore:
if c.islower():
result.append(c.upper())
after_underscore = False
else:
raise ValueError(
'Fail to print FieldMask to Json string: The '
'character after a "_" must be a lowercase letter '
'in path name {0}.'.format(path_name))
elif c == '_':
after_underscore = True
else:
result += c
if after_underscore:
raise ValueError('Fail to print FieldMask to Json string: Trailing "_" '
'in path name {0}.'.format(path_name))
return ''.join(result)
def _CamelCaseToSnakeCase(path_name):
"""Converts a field name from camelCase to snake_case."""
result = []
for c in path_name:
if c == '_':
raise ValueError('Fail to parse FieldMask: Path name '
'{0} must not contain "_"s.'.format(path_name))
if c.isupper():
result += '_'
result += c.lower()
else:
result += c
return ''.join(result)
class _FieldMaskTree(object):
"""Represents a FieldMask in a tree structure.
For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
the FieldMaskTree will be:
[_root] -+- foo -+- bar
| |
| +- baz
|
+- bar --- baz
In the tree, each leaf node represents a field path.
"""
__slots__ = ('_root',)
def __init__(self, field_mask=None):
"""Initializes the tree by FieldMask."""
self._root = {}
if field_mask:
self.MergeFromFieldMask(field_mask)
def MergeFromFieldMask(self, field_mask):
"""Merges a FieldMask to the tree."""
for path in field_mask.paths:
self.AddPath(path)
def AddPath(self, path):
"""Adds a field path into the tree.
If the field path to add is a sub-path of an existing field path
in the tree (i.e., a leaf node), it means the tree already matches
the given path so nothing will be added to the tree. If the path
matches an existing non-leaf node in the tree, that non-leaf node
will be turned into a leaf node with all its children removed because
the path matches all the node's children. Otherwise, a new path will
be added.
Args:
path: The field path to add.
"""
node = self._root
for name in path.split('.'):
if name not in node:
node[name] = {}
elif not node[name]:
# Pre-existing empty node implies we already have this entire tree.
return
node = node[name]
# Remove any sub-trees we might have had.
node.clear()
def ToFieldMask(self, field_mask):
"""Converts the tree to a FieldMask."""
field_mask.Clear()
_AddFieldPaths(self._root, '', field_mask)
def IntersectPath(self, path, intersection):
"""Calculates the intersection part of a field path with this tree.
Args:
path: The field path to calculates.
intersection: The out tree to record the intersection part.
"""
node = self._root
for name in path.split('.'):
if name not in node:
return
elif not node[name]:
intersection.AddPath(path)
return
node = node[name]
intersection.AddLeafNodes(path, node)
def AddLeafNodes(self, prefix, node):
"""Adds leaf nodes begin with prefix to this tree."""
if not node:
self.AddPath(prefix)
for name in node:
child_path = prefix + '.' + name
self.AddLeafNodes(child_path, node[name])
def MergeMessage(
self, source, destination,
replace_message, replace_repeated):
"""Merge all fields specified by this tree from source to destination."""
_MergeMessage(
self._root, source, destination, replace_message, replace_repeated)
def _StrConvert(value):
"""Converts value to str if it is not."""
# This file is imported by c extension and some methods like ClearField
# requires string for the field name. py2/py3 has different text
# type and may use unicode.
if not isinstance(value, str):
return value.encode('utf-8')
return value
def _MergeMessage(
node, source, destination, replace_message, replace_repeated):
"""Merge all fields specified by a sub-tree from source to destination."""
source_descriptor = source.DESCRIPTOR
for name in node:
child = node[name]
field = source_descriptor.fields_by_name[name]
if field is None:
raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
name, source_descriptor.full_name))
if child:
# Sub-paths are only allowed for singular message fields.
if (field.label == FieldDescriptor.LABEL_REPEATED or
field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
raise ValueError('Error: Field {0} in message {1} is not a singular '
'message field and cannot have sub-fields.'.format(
name, source_descriptor.full_name))
if source.HasField(name):
_MergeMessage(
child, getattr(source, name), getattr(destination, name),
replace_message, replace_repeated)
continue
if field.label == FieldDescriptor.LABEL_REPEATED:
if replace_repeated:
destination.ClearField(_StrConvert(name))
repeated_source = getattr(source, name)
repeated_destination = getattr(destination, name)
repeated_destination.MergeFrom(repeated_source)
else:
if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
if replace_message:
destination.ClearField(_StrConvert(name))
if source.HasField(name):
getattr(destination, name).MergeFrom(getattr(source, name))
else:
setattr(destination, name, getattr(source, name))
def _AddFieldPaths(node, prefix, field_mask):
"""Adds the field paths descended from node to field_mask."""
if not node and prefix:
field_mask.paths.append(prefix)
return
for name in sorted(node):
if prefix:
child_path = prefix + '.' + name
else:
child_path = name
_AddFieldPaths(node[name], child_path, field_mask)
def _SetStructValue(struct_value, value):
if value is None:
struct_value.null_value = 0

View File

@@ -38,11 +38,15 @@ import unittest
from google.protobuf import any_pb2
from google.protobuf import duration_pb2
from google.protobuf import field_mask_pb2
from google.protobuf import struct_pb2
from google.protobuf import timestamp_pb2
from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_pb2
from google.protobuf.internal import any_test_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import well_known_types
from google.protobuf import descriptor
from google.protobuf import text_format
from google.protobuf.internal import _parameterized
@@ -386,6 +390,362 @@ class TimeUtilTest(TimeUtilTestBase):
message.ToJsonString)
class FieldMaskTest(unittest.TestCase):
def testStringFormat(self):
mask = field_mask_pb2.FieldMask()
self.assertEqual('', mask.ToJsonString())
mask.paths.append('foo')
self.assertEqual('foo', mask.ToJsonString())
mask.paths.append('bar')
self.assertEqual('foo,bar', mask.ToJsonString())
mask.FromJsonString('')
self.assertEqual('', mask.ToJsonString())
mask.FromJsonString('foo')
self.assertEqual(['foo'], mask.paths)
mask.FromJsonString('foo,bar')
self.assertEqual(['foo', 'bar'], mask.paths)
# Test camel case
mask.Clear()
mask.paths.append('foo_bar')
self.assertEqual('fooBar', mask.ToJsonString())
mask.paths.append('bar_quz')
self.assertEqual('fooBar,barQuz', mask.ToJsonString())
mask.FromJsonString('')
self.assertEqual('', mask.ToJsonString())
self.assertEqual([], mask.paths)
mask.FromJsonString('fooBar')
self.assertEqual(['foo_bar'], mask.paths)
mask.FromJsonString('fooBar,barQuz')
self.assertEqual(['foo_bar', 'bar_quz'], mask.paths)
def testDescriptorToFieldMask(self):
mask = field_mask_pb2.FieldMask()
msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
mask.AllFieldsFromDescriptor(msg_descriptor)
self.assertEqual(76, len(mask.paths))
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
for field in msg_descriptor.fields:
self.assertTrue(field.name in mask.paths)
def testIsValidForDescriptor(self):
msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
# Empty mask
mask = field_mask_pb2.FieldMask()
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
# All fields from descriptor
mask.AllFieldsFromDescriptor(msg_descriptor)
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
# Child under optional message
mask.paths.append('optional_nested_message.bb')
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
# Repeated field is only allowed in the last position of path
mask.paths.append('repeated_nested_message.bb')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid top level field
mask = field_mask_pb2.FieldMask()
mask.paths.append('xxx')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid field in root
mask = field_mask_pb2.FieldMask()
mask.paths.append('xxx.zzz')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid field in internal node
mask = field_mask_pb2.FieldMask()
mask.paths.append('optional_nested_message.xxx.zzz')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
# Invalid field in leaf
mask = field_mask_pb2.FieldMask()
mask.paths.append('optional_nested_message.xxx')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
def testCanonicalFrom(self):
mask = field_mask_pb2.FieldMask()
out_mask = field_mask_pb2.FieldMask()
# Paths will be sorted.
mask.FromJsonString('baz.quz,bar,foo')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('bar,baz.quz,foo', out_mask.ToJsonString())
# Duplicated paths will be removed.
mask.FromJsonString('foo,bar,foo')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('bar,foo', out_mask.ToJsonString())
# Sub-paths of other paths will be removed.
mask.FromJsonString('foo.b1,bar.b1,foo.b2,bar')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString())
# Test more deeply nested cases.
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo.bar.baz1,foo.bar.baz2',
out_mask.ToJsonString())
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo.bar.baz1,foo.bar.baz2',
out_mask.ToJsonString())
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo.bar', out_mask.ToJsonString())
mask.FromJsonString(
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo')
out_mask.CanonicalFormFromMask(mask)
self.assertEqual('foo', out_mask.ToJsonString())
def testUnion(self):
mask1 = field_mask_pb2.FieldMask()
mask2 = field_mask_pb2.FieldMask()
out_mask = field_mask_pb2.FieldMask()
mask1.FromJsonString('foo,baz')
mask2.FromJsonString('bar,quz')
out_mask.Union(mask1, mask2)
self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString())
# Overlap with duplicated paths.
mask1.FromJsonString('foo,baz.bb')
mask2.FromJsonString('baz.bb,quz')
out_mask.Union(mask1, mask2)
self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString())
# Overlap with paths covering some other paths.
mask1.FromJsonString('foo.bar.baz,quz')
mask2.FromJsonString('foo.bar,bar')
out_mask.Union(mask1, mask2)
self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString())
src = unittest_pb2.TestAllTypes()
with self.assertRaises(ValueError):
out_mask.Union(src, mask2)
def testIntersect(self):
mask1 = field_mask_pb2.FieldMask()
mask2 = field_mask_pb2.FieldMask()
out_mask = field_mask_pb2.FieldMask()
# Test cases without overlapping.
mask1.FromJsonString('foo,baz')
mask2.FromJsonString('bar,quz')
out_mask.Intersect(mask1, mask2)
self.assertEqual('', out_mask.ToJsonString())
self.assertEqual(len(out_mask.paths), 0)
self.assertEqual(out_mask.paths, [])
# Overlap with duplicated paths.
mask1.FromJsonString('foo,baz.bb')
mask2.FromJsonString('baz.bb,quz')
out_mask.Intersect(mask1, mask2)
self.assertEqual('baz.bb', out_mask.ToJsonString())
# Overlap with paths covering some other paths.
mask1.FromJsonString('foo.bar.baz,quz')
mask2.FromJsonString('foo.bar,bar')
out_mask.Intersect(mask1, mask2)
self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
mask1.FromJsonString('foo.bar,bar')
mask2.FromJsonString('foo.bar.baz,quz')
out_mask.Intersect(mask1, mask2)
self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
# Intersect '' with ''
mask1.Clear()
mask2.Clear()
mask1.paths.append('')
mask2.paths.append('')
self.assertEqual(mask1.paths, [''])
self.assertEqual('', mask1.ToJsonString())
out_mask.Intersect(mask1, mask2)
self.assertEqual(out_mask.paths, [])
def testMergeMessageWithoutMapFields(self):
# Test merge one field.
src = unittest_pb2.TestAllTypes()
test_util.SetAllFields(src)
for field in src.DESCRIPTOR.fields:
if field.containing_oneof:
continue
field_name = field.name
dst = unittest_pb2.TestAllTypes()
# Only set one path to mask.
mask = field_mask_pb2.FieldMask()
mask.paths.append(field_name)
mask.MergeMessage(src, dst)
# The expected result message.
msg = unittest_pb2.TestAllTypes()
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
repeated_src = getattr(src, field_name)
repeated_msg = getattr(msg, field_name)
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
for item in repeated_src:
repeated_msg.add().CopyFrom(item)
else:
repeated_msg.extend(repeated_src)
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
getattr(msg, field_name).CopyFrom(getattr(src, field_name))
else:
setattr(msg, field_name, getattr(src, field_name))
# Only field specified in mask is merged.
self.assertEqual(msg, dst)
# Test merge nested fields.
nested_src = unittest_pb2.NestedTestAllTypes()
nested_dst = unittest_pb2.NestedTestAllTypes()
nested_src.child.payload.optional_int32 = 1234
nested_src.child.child.payload.optional_int32 = 5678
mask = field_mask_pb2.FieldMask()
mask.FromJsonString('child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(0, nested_dst.child.child.payload.optional_int32)
mask.FromJsonString('child.child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
nested_dst.Clear()
mask.FromJsonString('child.child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(0, nested_dst.child.payload.optional_int32)
self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
nested_dst.Clear()
mask.FromJsonString('child')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
# Test MergeOptions.
nested_dst.Clear()
nested_dst.child.payload.optional_int64 = 4321
# Message fields will be merged by default.
mask.FromJsonString('child.payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(4321, nested_dst.child.payload.optional_int64)
# Change the behavior to replace message fields.
mask.FromJsonString('child.payload')
mask.MergeMessage(nested_src, nested_dst, True, False)
self.assertEqual(1234, nested_dst.child.payload.optional_int32)
self.assertEqual(0, nested_dst.child.payload.optional_int64)
# By default, fields missing in source are not cleared in destination.
nested_dst.payload.optional_int32 = 1234
self.assertTrue(nested_dst.HasField('payload'))
mask.FromJsonString('payload')
mask.MergeMessage(nested_src, nested_dst)
self.assertTrue(nested_dst.HasField('payload'))
# But they are cleared when replacing message fields.
nested_dst.Clear()
nested_dst.payload.optional_int32 = 1234
mask.FromJsonString('payload')
mask.MergeMessage(nested_src, nested_dst, True, False)
self.assertFalse(nested_dst.HasField('payload'))
nested_src.payload.repeated_int32.append(1234)
nested_dst.payload.repeated_int32.append(5678)
# Repeated fields will be appended by default.
mask.FromJsonString('payload.repeatedInt32')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(2, len(nested_dst.payload.repeated_int32))
self.assertEqual(5678, nested_dst.payload.repeated_int32[0])
self.assertEqual(1234, nested_dst.payload.repeated_int32[1])
# Change the behavior to replace repeated fields.
mask.FromJsonString('payload.repeatedInt32')
mask.MergeMessage(nested_src, nested_dst, False, True)
self.assertEqual(1, len(nested_dst.payload.repeated_int32))
self.assertEqual(1234, nested_dst.payload.repeated_int32[0])
# Test Merge oneof field.
new_msg = unittest_pb2.TestOneof2()
dst = unittest_pb2.TestOneof2()
dst.foo_message.moo_int = 1
mask = field_mask_pb2.FieldMask()
mask.FromJsonString('fooMessage,fooLazyMessage.mooInt')
mask.MergeMessage(new_msg, dst)
self.assertTrue(dst.HasField('foo_message'))
self.assertFalse(dst.HasField('foo_lazy_message'))
def testMergeMessageWithMapField(self):
empty_map = map_unittest_pb2.TestRecursiveMapMessage()
src_level_2 = map_unittest_pb2.TestRecursiveMapMessage()
src_level_2.a['src level 2'].CopyFrom(empty_map)
src = map_unittest_pb2.TestRecursiveMapMessage()
src.a['common key'].CopyFrom(src_level_2)
src.a['src level 1'].CopyFrom(src_level_2)
dst_level_2 = map_unittest_pb2.TestRecursiveMapMessage()
dst_level_2.a['dst level 2'].CopyFrom(empty_map)
dst = map_unittest_pb2.TestRecursiveMapMessage()
dst.a['common key'].CopyFrom(dst_level_2)
dst.a['dst level 1'].CopyFrom(empty_map)
mask = field_mask_pb2.FieldMask()
mask.FromJsonString('a')
mask.MergeMessage(src, dst)
# map from dst is replaced with map from src.
self.assertEqual(dst.a['common key'], src_level_2)
self.assertEqual(dst.a['src level 1'], src_level_2)
self.assertEqual(dst.a['dst level 1'], empty_map)
def testMergeErrors(self):
src = unittest_pb2.TestAllTypes()
dst = unittest_pb2.TestAllTypes()
mask = field_mask_pb2.FieldMask()
test_util.SetAllFields(src)
mask.FromJsonString('optionalInt32.field')
with self.assertRaises(ValueError) as e:
mask.MergeMessage(src, dst)
self.assertEqual('Error: Field optional_int32 in message '
'protobuf_unittest.TestAllTypes is not a singular '
'message field and cannot have sub-fields.',
str(e.exception))
def testSnakeCaseToCamelCase(self):
self.assertEqual('fooBar',
well_known_types._SnakeCaseToCamelCase('foo_bar'))
self.assertEqual('FooBar',
well_known_types._SnakeCaseToCamelCase('_foo_bar'))
self.assertEqual('foo3Bar',
well_known_types._SnakeCaseToCamelCase('foo3_bar'))
# No uppercase letter is allowed.
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: Path name Foo must '
'not contain uppercase letters.',
well_known_types._SnakeCaseToCamelCase, 'Foo')
# Any character after a "_" must be a lowercase letter.
# 1. "_" cannot be followed by another "_".
# 2. "_" cannot be followed by a digit.
# 3. "_" cannot appear as the last character.
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: The character after a '
'"_" must be a lowercase letter in path name foo__bar.',
well_known_types._SnakeCaseToCamelCase, 'foo__bar')
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: The character after a '
'"_" must be a lowercase letter in path name foo_3bar.',
well_known_types._SnakeCaseToCamelCase, 'foo_3bar')
self.assertRaisesRegex(
ValueError,
'Fail to print FieldMask to Json string: Trailing "_" in path '
'name foo_bar_.', well_known_types._SnakeCaseToCamelCase, 'foo_bar_')
def testCamelCaseToSnakeCase(self):
self.assertEqual('foo_bar',
well_known_types._CamelCaseToSnakeCase('fooBar'))
self.assertEqual('_foo_bar',
well_known_types._CamelCaseToSnakeCase('FooBar'))
self.assertEqual('foo3_bar',
well_known_types._CamelCaseToSnakeCase('foo3Bar'))
self.assertRaisesRegex(
ValueError,
'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.',
well_known_types._CamelCaseToSnakeCase, 'foo_bar')
class StructTest(unittest.TestCase):
def testStruct(self):

View File

@@ -109,8 +109,7 @@ def MessageToJson(
names as defined in the .proto file. If False, convert the field
names to lowerCamelCase.
indent: The JSON object will be pretty-printed with this indent level.
An indent level of 0 or negative will only insert newlines. If the
indent level is None, no newlines will be inserted.
An indent level of 0 or negative will only insert newlines.
sort_keys: If True, then the output will be sorted by field names.
use_integers_for_enums: If true, print integers instead of enum names.
descriptor_pool: A Descriptor Pool for resolving types. If None use the

View File

@@ -74,8 +74,7 @@ class Message(object):
__slots__ = []
#: The :class:`google.protobuf.Descriptor`
# for this message type.
#: The :class:`google.protobuf.descriptor.Descriptor` for this message type.
DESCRIPTOR = None
def __deepcopy__(self, memo=None):

View File

@@ -60,6 +60,9 @@ class MessageFactory(object):
"""Initializes a new factory."""
self.pool = pool or descriptor_pool.DescriptorPool()
# local cache of all classes built from protobuf descriptors
self._classes = {}
def GetPrototype(self, descriptor):
"""Obtains a proto2 message class based on the passed in descriptor.
@@ -72,11 +75,14 @@ class MessageFactory(object):
Returns:
A class describing the passed in descriptor.
"""
concrete_class = getattr(descriptor, '_concrete_class', None)
if concrete_class:
return concrete_class
result_class = self.CreatePrototype(descriptor)
return result_class
if descriptor not in self._classes:
result_class = self.CreatePrototype(descriptor)
# The assignment to _classes is redundant for the base implementation, but
# might avoid confusion in cases where CreatePrototype gets overridden and
# does not call the base implementation.
self._classes[descriptor] = result_class
return result_class
return self._classes[descriptor]
def CreatePrototype(self, descriptor):
"""Builds a proto2 message class based on the passed in descriptor.
@@ -101,11 +107,16 @@ class MessageFactory(object):
'__module__': None,
})
result_class._FACTORY = self # pylint: disable=protected-access
# Assign in _classes before doing recursive calls to avoid infinite
# recursion.
self._classes[descriptor] = result_class
for field in descriptor.fields:
if field.message_type:
self.GetPrototype(field.message_type)
for extension in result_class.DESCRIPTOR.extensions:
extended_class = self.GetPrototype(extension.containing_type)
if extension.containing_type not in self._classes:
self.GetPrototype(extension.containing_type)
extended_class = self._classes[extension.containing_type]
extended_class.RegisterExtension(extension)
if extension.message_type:
self.GetPrototype(extension.message_type)
@@ -141,7 +152,9 @@ class MessageFactory(object):
# an error if they were different.
for extension in file_desc.extensions_by_name.values():
extended_class = self.GetPrototype(extension.containing_type)
if extension.containing_type not in self._classes:
self.GetPrototype(extension.containing_type)
extended_class = self._classes[extension.containing_type]
extended_class.RegisterExtension(extension)
if extension.message_type:
self.GetPrototype(extension.message_type)

View File

@@ -48,8 +48,8 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/message.h"
#include <google/protobuf/descriptor_database.h>
#include <google/protobuf/message.h>
namespace google {
namespace protobuf {
@@ -133,7 +133,8 @@ struct PyProto_API {
};
inline const char* PyProtoAPICapsuleName() {
static const char kCapsuleName[] = "google.protobuf.pyext._message.proto_API";
static const char kCapsuleName[] =
"google.protobuf.pyext._message.proto_API";
return kCapsuleName;
}

View File

@@ -30,7 +30,7 @@
// Author: petar@google.com (Petar Petrov)
#include "google/protobuf/pyext/descriptor.h"
#include <google/protobuf/pyext/descriptor.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
@@ -40,15 +40,15 @@
#include <string>
#include <unordered_map>
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/pyext/descriptor_containers.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/io/coded_stream.h"
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/pyext/descriptor_containers.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/hash.h>
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob) \
@@ -58,37 +58,6 @@
: 0) \
: PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
static PyCodeObject* PyFrame_GetCode(PyFrameObject *frame)
{
Py_INCREF(frame->f_code);
return frame->f_code;
}
static PyFrameObject* PyFrame_GetBack(PyFrameObject *frame)
{
Py_XINCREF(frame->f_back);
return frame->f_back;
}
#endif
#if PY_VERSION_HEX < 0x030B00A7 && !defined(PYPY_VERSION)
static PyObject* PyFrame_GetLocals(PyFrameObject *frame)
{
if (PyFrame_FastToLocalsWithError(frame) < 0) {
return NULL;
}
Py_INCREF(frame->f_locals);
return frame->f_locals;
}
static PyObject* PyFrame_GetGlobals(PyFrameObject *frame)
{
Py_INCREF(frame->f_globals);
return frame->f_globals;
}
#endif
namespace google {
namespace protobuf {
namespace python {
@@ -127,66 +96,48 @@ bool _CalledFromGeneratedFile(int stacklevel) {
// This check is not critical and is somewhat difficult to implement correctly
// in PyPy.
PyFrameObject* frame = PyEval_GetFrame();
PyCodeObject* frame_code = nullptr;
PyObject* frame_globals = nullptr;
PyObject* frame_locals = nullptr;
bool result = false;
if (frame == nullptr) {
goto exit;
return false;
}
Py_INCREF(frame);
while (stacklevel-- > 0) {
PyFrameObject* next_frame = PyFrame_GetBack(frame);
Py_DECREF(frame);
frame = next_frame;
frame = frame->f_back;
if (frame == nullptr) {
goto exit;
return false;
}
}
frame_code = PyFrame_GetCode(frame);
if (frame_code->co_filename == nullptr) {
goto exit;
if (frame->f_code->co_filename == nullptr) {
return false;
}
char* filename;
Py_ssize_t filename_size;
if (PyString_AsStringAndSize(frame_code->co_filename,
if (PyString_AsStringAndSize(frame->f_code->co_filename,
&filename, &filename_size) < 0) {
// filename is not a string.
PyErr_Clear();
goto exit;
return false;
}
if ((filename_size < 3) ||
(strcmp(&filename[filename_size - 3], ".py") != 0)) {
// Cython's stack does not have .py file name and is not at global module
// scope.
result = true;
goto exit;
return true;
}
if (filename_size < 7) {
// filename is too short.
goto exit;
return false;
}
if (strcmp(&filename[filename_size - 7], "_pb2.py") != 0) {
// Filename is not ending with _pb2.
goto exit;
return false;
}
frame_globals = PyFrame_GetGlobals(frame);
frame_locals = PyFrame_GetLocals(frame);
if (frame_globals != frame_locals) {
if (frame->f_globals != frame->f_locals) {
// Not at global module scope
goto exit;
return false;
}
#endif
result = true;
exit:
Py_XDECREF(frame_globals);
Py_XDECREF(frame_locals);
Py_XDECREF(frame_code);
Py_XDECREF(frame);
return result;
return true;
}
// If the calling code is not a _pb2.py file, raise AttributeError.
@@ -539,12 +490,6 @@ static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) {
GetDescriptorPool_FromPool(
_GetDescriptor(self)->file()->pool())->py_message_factory,
_GetDescriptor(self)));
if (concrete_class == nullptr) {
PyErr_Clear();
return nullptr;
}
Py_XINCREF(concrete_class);
return concrete_class->AsPyObject();
}
@@ -1796,8 +1741,7 @@ static PyObject* FindMethodByName(PyBaseDescriptor *self, PyObject* arg) {
}
const MethodDescriptor* method_descriptor =
_GetDescriptor(self)->FindMethodByName(
absl::string_view(name, name_size));
_GetDescriptor(self)->FindMethodByName(StringParam(name, name_size));
if (method_descriptor == nullptr) {
PyErr_Format(PyExc_KeyError, "Couldn't find method %.200s", name);
return nullptr;

View File

@@ -36,12 +36,15 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/descriptor.h"
#include <google/protobuf/descriptor.h>
namespace google {
namespace protobuf {
namespace python {
// Should match the type of ConstStringParam.
using StringParam = std::string;
extern PyTypeObject PyMessageDescriptor_Type;
extern PyTypeObject PyFieldDescriptor_Type;
extern PyTypeObject PyEnumDescriptor_Type;

View File

@@ -49,21 +49,14 @@
// because the Python API is based on C, and does not play well with C++
// inheritance.
// clang-format off
#define PY_SSIZE_T_CLEAN
// This inclusion must appear before all the others.
#include <Python.h>
#include <string>
#include "google/protobuf/pyext/descriptor_containers.h"
// clang-format on
#include "google/protobuf/descriptor.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include "absl/strings/string_view.h"
#include <google/protobuf/descriptor.h>
#include <google/protobuf/pyext/descriptor_containers.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob) \
@@ -82,9 +75,9 @@ struct PyContainer;
typedef int (*CountMethod)(PyContainer* self);
typedef const void* (*GetByIndexMethod)(PyContainer* self, int index);
typedef const void* (*GetByNameMethod)(PyContainer* self,
absl::string_view name);
ConstStringParam name);
typedef const void* (*GetByCamelcaseNameMethod)(PyContainer* self,
absl::string_view name);
ConstStringParam name);
typedef const void* (*GetByNumberMethod)(PyContainer* self, int index);
typedef PyObject* (*NewObjectFromItemMethod)(const void* descriptor);
typedef const std::string& (*GetItemNameMethod)(const void* descriptor);
@@ -182,8 +175,8 @@ static bool _GetItemByKey(PyContainer* self, PyObject* key, const void** item) {
}
return false;
}
*item = self->container_def->get_by_name_fn(
self, absl::string_view(name, name_size));
*item = self->container_def->get_by_name_fn(self,
StringParam(name, name_size));
return true;
}
case PyContainer::KIND_BYCAMELCASENAME: {
@@ -199,7 +192,7 @@ static bool _GetItemByKey(PyContainer* self, PyObject* key, const void** item) {
return false;
}
*item = self->container_def->get_by_camelcase_name_fn(
self, absl::string_view(camelcase_name, name_size));
self, StringParam(camelcase_name, name_size));
return true;
}
case PyContainer::KIND_BYNUMBER: {
@@ -965,12 +958,12 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->field_count();
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindFieldByName(name);
}
static const void* GetByCamelcaseName(PyContainer* self,
absl::string_view name) {
ConstStringParam name) {
return GetDescriptor(self)->FindFieldByCamelcaseName(name);
}
@@ -1035,7 +1028,7 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->nested_type_count();
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindNestedTypeByName(name);
}
@@ -1087,7 +1080,7 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->enum_type_count();
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindEnumTypeByName(name);
}
@@ -1150,7 +1143,7 @@ static int Count(PyContainer* self) {
return count;
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindEnumValueByName(name);
}
@@ -1201,7 +1194,7 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->extension_count();
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindExtensionByName(name);
}
@@ -1253,7 +1246,7 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->oneof_decl_count();
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindOneofByName(name);
}
@@ -1311,7 +1304,7 @@ static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->value(index);
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindValueByName(name);
}
@@ -1415,7 +1408,7 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->method_count();
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindMethodByName(name);
}
@@ -1469,7 +1462,7 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->message_type_count();
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindMessageTypeByName(name);
}
@@ -1509,7 +1502,7 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->enum_type_count();
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindEnumTypeByName(name);
}
@@ -1549,7 +1542,7 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->extension_count();
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindExtensionByName(name);
}
@@ -1589,7 +1582,7 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->service_count();
}
static const void* GetByName(PyContainer* self, absl::string_view name) {
static const void* GetByName(PyContainer* self, ConstStringParam name) {
return GetDescriptor(self)->FindServiceByName(name);
}

View File

@@ -31,17 +31,15 @@
// This file defines a C++ DescriptorDatabase, which wraps a Python Database
// and delegate all its operations to Python methods.
#include "google/protobuf/pyext/descriptor_database.h"
#include <google/protobuf/pyext/descriptor_database.h>
#include <cstdint>
#include <string>
#include <vector>
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
namespace google {
namespace protobuf {

View File

@@ -34,10 +34,7 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <string>
#include <vector>
#include "google/protobuf/descriptor_database.h"
#include <google/protobuf/descriptor_database.h>
namespace google {
namespace protobuf {

View File

@@ -30,22 +30,19 @@
// Implements the DescriptorPool, which collects all descriptors.
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/descriptor_database.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include "absl/strings/string_view.h"
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/descriptor_database.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/hash.h>
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob) \
@@ -249,7 +246,7 @@ static PyObject* FindMessageByName(PyObject* self, PyObject* arg) {
const Descriptor* message_descriptor =
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindMessageTypeByName(
absl::string_view(name, name_size));
StringParam(name, name_size));
if (message_descriptor == nullptr) {
return SetErrorFromCollector(
@@ -273,7 +270,7 @@ static PyObject* FindFileByName(PyObject* self, PyObject* arg) {
PyDescriptorPool* py_pool = reinterpret_cast<PyDescriptorPool*>(self);
const FileDescriptor* file_descriptor =
py_pool->pool->FindFileByName(absl::string_view(name, name_size));
py_pool->pool->FindFileByName(StringParam(name, name_size));
if (file_descriptor == nullptr) {
return SetErrorFromCollector(py_pool->error_collector, name, "file");
@@ -289,7 +286,7 @@ PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) {
}
const FieldDescriptor* field_descriptor =
self->pool->FindFieldByName(absl::string_view(name, name_size));
self->pool->FindFieldByName(StringParam(name, name_size));
if (field_descriptor == nullptr) {
return SetErrorFromCollector(self->error_collector, name, "field");
}
@@ -310,7 +307,7 @@ PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) {
}
const FieldDescriptor* field_descriptor =
self->pool->FindExtensionByName(absl::string_view(name, name_size));
self->pool->FindExtensionByName(StringParam(name, name_size));
if (field_descriptor == nullptr) {
return SetErrorFromCollector(self->error_collector, name,
"extension field");
@@ -332,7 +329,7 @@ PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) {
}
const EnumDescriptor* enum_descriptor =
self->pool->FindEnumTypeByName(absl::string_view(name, name_size));
self->pool->FindEnumTypeByName(StringParam(name, name_size));
if (enum_descriptor == nullptr) {
return SetErrorFromCollector(self->error_collector, name, "enum");
}
@@ -353,7 +350,7 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) {
}
const OneofDescriptor* oneof_descriptor =
self->pool->FindOneofByName(absl::string_view(name, name_size));
self->pool->FindOneofByName(StringParam(name, name_size));
if (oneof_descriptor == nullptr) {
return SetErrorFromCollector(self->error_collector, name, "oneof");
}
@@ -375,7 +372,7 @@ static PyObject* FindServiceByName(PyObject* self, PyObject* arg) {
const ServiceDescriptor* service_descriptor =
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindServiceByName(
absl::string_view(name, name_size));
StringParam(name, name_size));
if (service_descriptor == nullptr) {
return SetErrorFromCollector(
reinterpret_cast<PyDescriptorPool*>(self)->error_collector, name,
@@ -395,7 +392,7 @@ static PyObject* FindMethodByName(PyObject* self, PyObject* arg) {
const MethodDescriptor* method_descriptor =
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindMethodByName(
absl::string_view(name, name_size));
StringParam(name, name_size));
if (method_descriptor == nullptr) {
return SetErrorFromCollector(
reinterpret_cast<PyDescriptorPool*>(self)->error_collector, name,
@@ -415,7 +412,7 @@ static PyObject* FindFileContainingSymbol(PyObject* self, PyObject* arg) {
const FileDescriptor* file_descriptor =
reinterpret_cast<PyDescriptorPool*>(self)->pool->FindFileContainingSymbol(
absl::string_view(name, name_size));
StringParam(name, name_size));
if (file_descriptor == nullptr) {
return SetErrorFromCollector(
reinterpret_cast<PyDescriptorPool*>(self)->error_collector, name,

View File

@@ -35,7 +35,7 @@
#include <Python.h>
#include <unordered_map>
#include "google/protobuf/descriptor.h"
#include <google/protobuf/descriptor.h>
namespace google {
namespace protobuf {

View File

@@ -31,25 +31,23 @@
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#include "google/protobuf/pyext/extension_dict.h"
#include <google/protobuf/pyext/extension_dict.h>
#include <cstdint>
#include <memory>
#include <vector>
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/message.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/repeated_composite_container.h"
#include "google/protobuf/pyext/repeated_scalar_container.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include "absl/strings/string_view.h"
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob) \
@@ -127,9 +125,8 @@ static void DeallocExtensionIterator(PyObject* _self) {
ExtensionIterator* self = reinterpret_cast<ExtensionIterator*>(_self);
self->fields.clear();
Py_XDECREF(self->extension_dict);
freefunc tp_free = Py_TYPE(_self)->tp_free;
self->~ExtensionIterator();
(*tp_free)(_self);
Py_TYPE(_self)->tp_free(_self);
}
PyObject* subscript(ExtensionDict* self, PyObject* key) {
@@ -241,11 +238,11 @@ PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
const FieldDescriptor* message_extension =
pool->pool->FindExtensionByName(absl::string_view(name, name_size));
pool->pool->FindExtensionByName(StringParam(name, name_size));
if (message_extension == nullptr) {
// Is is the name of a message set extension?
const Descriptor* message_descriptor =
pool->pool->FindMessageTypeByName(absl::string_view(name, name_size));
pool->pool->FindMessageTypeByName(StringParam(name, name_size));
if (message_descriptor && message_descriptor->extension_count() > 0) {
const FieldDescriptor* extension = message_descriptor->extension(0);
if (extension->is_extension() &&

View File

@@ -37,7 +37,7 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/pyext/message.h"
#include <google/protobuf/pyext/message.h>
namespace google {
namespace protobuf {

View File

@@ -28,11 +28,11 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "google/protobuf/pyext/field.h"
#include <google/protobuf/pyext/field.h>
#include "google/protobuf/descriptor.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/message.h"
#include <google/protobuf/descriptor.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/message.h>
namespace google {
namespace protobuf {

View File

@@ -30,21 +30,21 @@
// Author: haberman@google.com (Josh Haberman)
#include "google/protobuf/pyext/map_container.h"
#include <google/protobuf/pyext/map_container.h>
#include <cstdint>
#include <memory>
#include <string>
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/map.h"
#include "google/protobuf/map_field.h"
#include "google/protobuf/message.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/repeated_composite_container.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/map.h>
#include <google/protobuf/map_field.h>
#include <google/protobuf/message.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/map_util.h>
namespace google {
namespace protobuf {

View File

@@ -36,9 +36,9 @@
#include <cstdint>
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
#include "google/protobuf/pyext/message.h"
#include <google/protobuf/descriptor.h>
#include <google/protobuf/message.h>
#include <google/protobuf/pyext/message.h>
namespace google {
namespace protobuf {

View File

@@ -31,18 +31,17 @@
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#include "google/protobuf/pyext/message.h"
#include <google/protobuf/pyext/message.h>
#include <structmember.h> // A Python header file.
#include <cstdint>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "absl/strings/match.h"
#include <google/protobuf/stubs/strutil.h>
#ifndef PyVarObject_HEAD_INIT
#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size,
@@ -50,34 +49,33 @@
#ifndef Py_TYPE
#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type)
#endif
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/unknown_field_set.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/extension_dict.h"
#include "google/protobuf/pyext/field.h"
#include "google/protobuf/pyext/map_container.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/repeated_composite_container.h"
#include "google/protobuf/pyext/repeated_scalar_container.h"
#include "google/protobuf/pyext/safe_numerics.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include "google/protobuf/pyext/unknown_field_set.h"
#include "google/protobuf/pyext/unknown_fields.h"
#include "google/protobuf/util/message_differencer.h"
#include "google/protobuf/stubs/strutil.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/strtod.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <google/protobuf/unknown_field_set.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/extension_dict.h>
#include <google/protobuf/pyext/field.h>
#include <google/protobuf/pyext/map_container.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <google/protobuf/pyext/safe_numerics.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/pyext/unknown_field_set.h>
#include <google/protobuf/pyext/unknown_fields.h>
#include <google/protobuf/util/message_differencer.h>
#include <google/protobuf/io/strtod.h>
#include <google/protobuf/stubs/map_util.h>
// clang-format off
#include "google/protobuf/port_def.inc"
#include <google/protobuf/port_def.inc>
// clang-format on
#define PyString_AsString(ob) \
@@ -90,9 +88,6 @@
: 0) \
: PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#define PROTOBUF_PYTHON_PUBLIC "google.protobuf"
#define PROTOBUF_PYTHON_INTERNAL "google.protobuf.internal"
namespace google {
namespace protobuf {
namespace python {
@@ -251,8 +246,8 @@ static PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
ScopedPyObjectPtr new_args;
if (WKT_classes == nullptr) {
ScopedPyObjectPtr well_known_types(
PyImport_ImportModule(PROTOBUF_PYTHON_INTERNAL ".well_known_types"));
ScopedPyObjectPtr well_known_types(PyImport_ImportModule(
"google.protobuf.internal.well_known_types"));
GOOGLE_DCHECK(well_known_types != nullptr);
WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES");
@@ -410,7 +405,7 @@ static PyObject* GetClassAttribute(CMessageClass *self, PyObject* name) {
Py_ssize_t attr_size;
static const char kSuffix[] = "_FIELD_NUMBER";
if (PyString_AsStringAndSize(name, &attr, &attr_size) >= 0 &&
absl::EndsWith(absl::string_view(attr, attr_size), kSuffix)) {
HasSuffixString(StringPiece(attr, attr_size), kSuffix)) {
std::string field_name(attr, attr_size - sizeof(kSuffix) + 1);
LowerString(&field_name);
@@ -915,7 +910,7 @@ static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor,
return nullptr;
}
const EnumValueDescriptor* enum_value_descriptor =
enum_descriptor->FindValueByName(absl::string_view(enum_label, size));
enum_descriptor->FindValueByName(StringParam(enum_label, size));
if (enum_value_descriptor == nullptr) {
PyErr_Format(PyExc_ValueError, "unknown enum label \"%s\"", enum_label);
return nullptr;
@@ -1342,7 +1337,7 @@ int HasFieldByDescriptor(CMessage* self,
}
const FieldDescriptor* FindFieldWithOneofs(const Message* message,
absl::string_view field_name,
ConstStringParam field_name,
bool* in_oneof) {
*in_oneof = false;
const Descriptor* descriptor = message->GetDescriptor();
@@ -1391,8 +1386,8 @@ PyObject* HasField(CMessage* self, PyObject* arg) {
Message* message = self->message;
bool is_in_oneof;
const FieldDescriptor* field_descriptor = FindFieldWithOneofs(
message, absl::string_view(field_name, size), &is_in_oneof);
const FieldDescriptor* field_descriptor =
FindFieldWithOneofs(message, StringParam(field_name, size), &is_in_oneof);
if (field_descriptor == nullptr) {
if (!is_in_oneof) {
PyErr_Format(PyExc_ValueError, "Protocol message %s has no field %s.",
@@ -1576,7 +1571,7 @@ PyObject* ClearField(CMessage* self, PyObject* arg) {
AssureWritable(self);
bool is_in_oneof;
const FieldDescriptor* field_descriptor = FindFieldWithOneofs(
self->message, absl::string_view(field_name, field_size), &is_in_oneof);
self->message, StringParam(field_name, field_size), &is_in_oneof);
if (field_descriptor == nullptr) {
if (is_in_oneof) {
// We gave the name of a oneof, and none of its fields are set.
@@ -1886,7 +1881,7 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
const char* ptr;
internal::ParseContext ctx(
depth, false, &ptr,
absl::string_view(static_cast<const char*>(data.buf), data.len));
StringPiece(static_cast<const char*>(data.buf), data.len));
PyBuffer_Release(&data);
ctx.data().pool = factory->pool->pool;
ctx.data().factory = factory->message_factory;
@@ -1978,7 +1973,7 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
if (PyString_AsStringAndSize(arg, &name_data, &name_size) < 0) return nullptr;
const OneofDescriptor* oneof_desc =
self->message->GetDescriptor()->FindOneofByName(
absl::string_view(name_data, name_size));
StringParam(name_data, name_size));
if (oneof_desc == nullptr) {
PyErr_Format(PyExc_ValueError,
"Protocol message has no oneof \"%s\" field.", name_data);
@@ -2377,7 +2372,7 @@ PyObject* DeepCopy(CMessage* self, PyObject* arg) {
PyObject* ToUnicode(CMessage* self) {
// Lazy import to prevent circular dependencies
ScopedPyObjectPtr text_format(
PyImport_ImportModule(PROTOBUF_PYTHON_PUBLIC ".text_format"));
PyImport_ImportModule("google.protobuf.text_format"));
if (text_format == nullptr) {
return nullptr;
}
@@ -2678,22 +2673,22 @@ CMessage* CMessage::BuildSubMessageFromPointer(
if (!this->child_submessages) {
this->child_submessages = new CMessage::SubMessagesMap();
}
auto it = this->child_submessages->find(sub_message);
if (it != this->child_submessages->end()) {
Py_INCREF(it->second);
return it->second;
}
CMessage* cmsg = FindPtrOrNull(
*this->child_submessages, sub_message);
if (cmsg) {
Py_INCREF(cmsg);
} else {
cmsg = cmessage::NewEmptyMessage(message_class);
CMessage* cmsg = cmessage::NewEmptyMessage(message_class);
if (cmsg == nullptr) {
return nullptr;
if (cmsg == nullptr) {
return nullptr;
}
cmsg->message = sub_message;
Py_INCREF(this);
cmsg->parent = this;
cmsg->parent_field_descriptor = field_descriptor;
cmessage::SetSubmessage(this, cmsg);
}
cmsg->message = sub_message;
Py_INCREF(this);
cmsg->parent = this;
cmsg->parent_field_descriptor = field_descriptor;
cmessage::SetSubmessage(this, cmsg);
return cmsg;
}
@@ -2701,10 +2696,11 @@ CMessage* CMessage::MaybeReleaseSubMessage(Message* sub_message) {
if (!this->child_submessages) {
return nullptr;
}
auto it = this->child_submessages->find(sub_message);
if (it == this->child_submessages->end()) return nullptr;
CMessage* released = it->second;
CMessage* released = FindPtrOrNull(
*this->child_submessages, sub_message);
if (!released) {
return nullptr;
}
// The target message will now own its content.
Py_CLEAR(released->parent);
released->parent_field_descriptor = nullptr;
@@ -3038,8 +3034,8 @@ bool InitProto2MessageModule(PyObject *m) {
PyModule_AddObject(m, "MethodDescriptor",
reinterpret_cast<PyObject*>(&PyMethodDescriptor_Type));
PyObject* enum_type_wrapper =
PyImport_ImportModule(PROTOBUF_PYTHON_INTERNAL ".enum_type_wrapper");
PyObject* enum_type_wrapper = PyImport_ImportModule(
"google.protobuf.internal.enum_type_wrapper");
if (enum_type_wrapper == nullptr) {
return false;
}
@@ -3047,8 +3043,8 @@ bool InitProto2MessageModule(PyObject *m) {
PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper");
Py_DECREF(enum_type_wrapper);
PyObject* message_module =
PyImport_ImportModule(PROTOBUF_PYTHON_PUBLIC ".message");
PyObject* message_module = PyImport_ImportModule(
"google.protobuf.message");
if (message_module == nullptr) {
return false;
}

View File

@@ -42,7 +42,7 @@
#include <string>
#include <unordered_map>
#include "google/protobuf/stubs/common.h"
#include <google/protobuf/stubs/common.h>
namespace google {
namespace protobuf {

View File

@@ -29,16 +29,15 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <unordered_map>
#include <utility>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#define PyString_AsStringAndSize(ob, charpp, sizep) \
(PyUnicode_Check(ob) \

View File

@@ -35,8 +35,8 @@
#include <Python.h>
#include <unordered_map>
#include "google/protobuf/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include <google/protobuf/descriptor.h>
#include <google/protobuf/pyext/descriptor_pool.h>
namespace google {
namespace protobuf {

View File

@@ -31,12 +31,12 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/message_lite.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/proto_api.h"
#include <google/protobuf/message_lite.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/proto_api.h>
namespace {

View File

@@ -31,21 +31,22 @@
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#include "google/protobuf/pyext/repeated_composite_container.h"
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <memory>
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/message.h"
#include "google/protobuf/reflection.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/message_factory.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
#include <google/protobuf/reflection.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/map_util.h>
namespace google {
namespace protobuf {

View File

@@ -37,7 +37,7 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/pyext/message.h"
#include <google/protobuf/pyext/message.h>
namespace google {
namespace protobuf {

View File

@@ -31,21 +31,20 @@
// Author: anuraag@google.com (Anuraag Agrawal)
// Author: tibell@google.com (Johan Tibell)
#include "google/protobuf/pyext/repeated_scalar_container.h"
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <cstdint>
#include <memory>
#include <string>
#include "google/protobuf/stubs/common.h"
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/message.h"
#include "google/protobuf/pyext/descriptor.h"
#include "google/protobuf/pyext/descriptor_pool.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#define PyString_AsString(ob) \
(PyUnicode_Check(ob) ? PyUnicode_AsUTF8(ob) : PyBytes_AsString(ob))

View File

@@ -37,8 +37,8 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "google/protobuf/descriptor.h"
#include "google/protobuf/pyext/message.h"
#include <google/protobuf/descriptor.h>
#include <google/protobuf/pyext/message.h>
namespace google {
namespace protobuf {

View File

@@ -34,8 +34,8 @@
#include <limits>
#include "google/protobuf/stubs/logging.h"
#include "google/protobuf/stubs/common.h"
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
namespace google {
namespace protobuf {

View File

@@ -33,6 +33,8 @@
#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__
#define GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__
#include <google/protobuf/stubs/common.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
namespace google {
@@ -48,8 +50,6 @@ class ScopedPythonPtr {
// The reference count of the specified py_object is not incremented.
explicit ScopedPythonPtr(PyObjectStruct* py_object = nullptr)
: ptr_(py_object) {}
ScopedPythonPtr(const ScopedPythonPtr&) = delete;
ScopedPythonPtr& operator=(const ScopedPythonPtr&) = delete;
// If a PyObject is owned, decrement its reference count.
~ScopedPythonPtr() { Py_XDECREF(ptr_); }
@@ -89,6 +89,8 @@ class ScopedPythonPtr {
private:
PyObjectStruct* ptr_;
GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ScopedPythonPtr);
};
typedef ScopedPythonPtr<PyObject> ScopedPyObjectPtr;

View File

@@ -28,7 +28,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "google/protobuf/pyext/unknown_field_set.h"
#include <google/protobuf/pyext/unknown_field_set.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
@@ -36,11 +36,11 @@
#include <memory>
#include <set>
#include "google/protobuf/message.h"
#include "google/protobuf/unknown_field_set.h"
#include "google/protobuf/wire_format_lite.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include <google/protobuf/message.h>
#include <google/protobuf/unknown_field_set.h>
#include <google/protobuf/wire_format_lite.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
namespace google {
namespace protobuf {

View File

@@ -37,7 +37,7 @@
#include <memory>
#include <set>
#include "google/protobuf/pyext/message.h"
#include <google/protobuf/pyext/message.h>
namespace google {
namespace protobuf {

View File

@@ -28,18 +28,18 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "google/protobuf/pyext/unknown_fields.h"
#include <google/protobuf/pyext/unknown_fields.h>
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <set>
#include <memory>
#include "google/protobuf/message.h"
#include "google/protobuf/pyext/message.h"
#include "google/protobuf/pyext/scoped_pyobject_ptr.h"
#include "google/protobuf/unknown_field_set.h"
#include "google/protobuf/wire_format_lite.h"
#include <google/protobuf/message.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/unknown_field_set.h>
#include <google/protobuf/wire_format_lite.h>
namespace google {
namespace protobuf {

View File

@@ -37,7 +37,7 @@
#include <memory>
#include <set>
#include "google/protobuf/pyext/message.h"
#include <google/protobuf/pyext/message.h>
namespace google {
namespace protobuf {

View File

@@ -66,9 +66,6 @@ from google.protobuf import message_factory
class SymbolDatabase(message_factory.MessageFactory):
"""A database of Python generated symbols."""
# local cache of registered classes.
_classes = {}
def RegisterMessage(self, message):
"""Registers the given message type in the local database.

View File

@@ -53,7 +53,8 @@ for byte, string in _cescape_chr_to_symbol_map.items():
del byte, string
def CEscape(text, as_utf8) -> str:
def CEscape(text, as_utf8):
# type: (...) -> str
"""Escape a bytes string for use in an text protocol buffer.
Args:
@@ -82,7 +83,8 @@ def CEscape(text, as_utf8) -> str:
_CUNESCAPE_HEX = re.compile(r'(\\+)x([0-9a-fA-F])(?![0-9a-fA-F])')
def CUnescape(text: str) -> bytes:
def CUnescape(text):
# type: (str) -> bytes
"""Unescape a text string with C-style escape sequences to UTF-8 bytes.
Args:

View File

@@ -67,7 +67,6 @@ _FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?$', re.IGNORECASE)
_FLOAT_NAN = re.compile('nanf?$', re.IGNORECASE)
_QUOTES = frozenset(("'", '"'))
_ANY_FULL_TYPE_NAME = 'google.protobuf.Any'
_DEBUG_STRING_SILENT_MARKER = '\t '
class Error(Exception):
@@ -126,7 +125,8 @@ def MessageToString(
indent=0,
message_formatter=None,
print_unknown_fields=False,
force_colon=False) -> str:
force_colon=False):
# type: (...) -> str
"""Convert protobuf message to text format.
Double values can be formatted compactly with 15 digits of
@@ -191,7 +191,8 @@ def MessageToString(
return result
def MessageToBytes(message, **kwargs) -> bytes:
def MessageToBytes(message, **kwargs):
# type: (...) -> bytes
"""Convert protobuf message to encoded text format. See MessageToString."""
text = MessageToString(message, **kwargs)
if isinstance(text, bytes):
@@ -557,7 +558,7 @@ class _Printer(object):
# For groups, use the capitalized name.
out.write(field.message_type.name)
else:
out.write(field.name)
out.write(field.name)
if (self.force_colon or
field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE):
@@ -855,15 +856,10 @@ class _Parser(object):
ParseError: On text parsing problems.
"""
# Tokenize expects native str lines.
try:
str_lines = (
line if isinstance(line, str) else line.decode('utf-8')
for line in lines)
except UnicodeDecodeError as e:
raise self._StringParseError(e)
str_lines = (
line if isinstance(line, str) else line.decode('utf-8')
for line in lines)
tokenizer = Tokenizer(str_lines)
if message:
self.root_type = message.DESCRIPTOR.full_name
while not tokenizer.AtEnd():
self._MergeField(tokenizer, message)
@@ -883,8 +879,6 @@ class _Parser(object):
type_url_prefix, packed_type_name = self._ConsumeAnyTypeUrl(tokenizer)
tokenizer.Consume(']')
tokenizer.TryConsume(':')
self._DetectSilentMarker(tokenizer,
type_url_prefix + '/' + packed_type_name)
if tokenizer.TryConsume('<'):
expanded_any_end_token = '>'
else:
@@ -923,6 +917,8 @@ class _Parser(object):
# pylint: disable=protected-access
field = message.Extensions._FindExtensionByName(name)
# pylint: enable=protected-access
if not field:
if self.allow_unknown_extension:
field = None
@@ -982,11 +978,9 @@ class _Parser(object):
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
tokenizer.TryConsume(':')
self._DetectSilentMarker(tokenizer, field.full_name)
merger = self._MergeMessageField
else:
tokenizer.Consume(':')
self._DetectSilentMarker(tokenizer, field.full_name)
merger = self._MergeScalarField
if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and
@@ -1004,19 +998,13 @@ class _Parser(object):
else: # Proto field is unknown.
assert (self.allow_unknown_extension or self.allow_unknown_field)
self._SkipFieldContents(tokenizer, name)
_SkipFieldContents(tokenizer)
# For historical reasons, fields may optionally be separated by commas or
# semicolons.
if not tokenizer.TryConsume(','):
tokenizer.TryConsume(';')
def _LogSilentMarker(self, field_name):
pass
def _DetectSilentMarker(self, tokenizer, field_name):
if tokenizer.contains_silent_marker_before_current_token:
self._LogSilentMarker(field_name)
def _ConsumeAnyTypeUrl(self, tokenizer):
"""Consumes a google.protobuf.Any type URL and returns the type name."""
@@ -1172,111 +1160,105 @@ class _Parser(object):
else:
setattr(message, field.name, value)
def _SkipFieldContents(self, tokenizer, field_name):
"""Skips over contents (value or message) of a field.
Args:
tokenizer: A tokenizer to parse the field name and values.
field_name: The field name currently being parsed.
"""
# Try to guess the type of this field.
# If this field is not a message, there should be a ":" between the
# field name and the field value and also the field value should not
# start with "{" or "<" which indicates the beginning of a message body.
# If there is no ":" or there is a "{" or "<" after ":", this field has
# to be a message or the input is ill-formed.
if tokenizer.TryConsume(
':') and not tokenizer.LookingAt('{') and not tokenizer.LookingAt('<'):
self._DetectSilentMarker(tokenizer, field_name)
if tokenizer.LookingAt('['):
self._SkipRepeatedFieldValue(tokenizer)
else:
self._SkipFieldValue(tokenizer)
def _SkipFieldContents(tokenizer):
"""Skips over contents (value or message) of a field.
Args:
tokenizer: A tokenizer to parse the field name and values.
"""
# Try to guess the type of this field.
# If this field is not a message, there should be a ":" between the
# field name and the field value and also the field value should not
# start with "{" or "<" which indicates the beginning of a message body.
# If there is no ":" or there is a "{" or "<" after ":", this field has
# to be a message or the input is ill-formed.
if tokenizer.TryConsume(
':') and not tokenizer.LookingAt('{') and not tokenizer.LookingAt('<'):
if tokenizer.LookingAt('['):
_SkipRepeatedFieldValue(tokenizer)
else:
self._DetectSilentMarker(tokenizer, field_name)
self._SkipFieldMessage(tokenizer)
_SkipFieldValue(tokenizer)
else:
_SkipFieldMessage(tokenizer)
def _SkipField(self, tokenizer):
"""Skips over a complete field (name and value/message).
Args:
tokenizer: A tokenizer to parse the field name and values.
"""
field_name = ''
if tokenizer.TryConsume('['):
# Consume extension or google.protobuf.Any type URL
field_name += '[' + tokenizer.ConsumeIdentifier()
num_identifiers = 1
while tokenizer.TryConsume('.'):
field_name += '.' + tokenizer.ConsumeIdentifier()
num_identifiers += 1
# This is possibly a type URL for an Any message.
if num_identifiers == 3 and tokenizer.TryConsume('/'):
field_name += '/' + tokenizer.ConsumeIdentifier()
while tokenizer.TryConsume('.'):
field_name += '.' + tokenizer.ConsumeIdentifier()
tokenizer.Consume(']')
field_name += ']'
else:
field_name += tokenizer.ConsumeIdentifierOrNumber()
def _SkipField(tokenizer):
"""Skips over a complete field (name and value/message).
self._SkipFieldContents(tokenizer, field_name)
# For historical reasons, fields may optionally be separated by commas or
# semicolons.
if not tokenizer.TryConsume(','):
tokenizer.TryConsume(';')
def _SkipFieldMessage(self, tokenizer):
"""Skips over a field message.
Args:
tokenizer: A tokenizer to parse the field name and values.
"""
if tokenizer.TryConsume('<'):
delimiter = '>'
else:
tokenizer.Consume('{')
delimiter = '}'
while not tokenizer.LookingAt('>') and not tokenizer.LookingAt('}'):
self._SkipField(tokenizer)
tokenizer.Consume(delimiter)
def _SkipFieldValue(self, tokenizer):
"""Skips over a field value.
Args:
tokenizer: A tokenizer to parse the field name and values.
Raises:
ParseError: In case an invalid field value is found.
"""
# String/bytes tokens can come in multiple adjacent string literals.
# If we can consume one, consume as many as we can.
if tokenizer.TryConsumeByteString():
while tokenizer.TryConsumeByteString():
pass
return
if (not tokenizer.TryConsumeIdentifier() and
not _TryConsumeInt64(tokenizer) and not _TryConsumeUint64(tokenizer) and
not tokenizer.TryConsumeFloat()):
raise ParseError('Invalid field value: ' + tokenizer.token)
def _SkipRepeatedFieldValue(self, tokenizer):
"""Skips over a repeated field value.
Args:
tokenizer: A tokenizer to parse the field value.
"""
tokenizer.Consume('[')
if not tokenizer.LookingAt(']'):
self._SkipFieldValue(tokenizer)
while tokenizer.TryConsume(','):
self._SkipFieldValue(tokenizer)
Args:
tokenizer: A tokenizer to parse the field name and values.
"""
if tokenizer.TryConsume('['):
# Consume extension name.
tokenizer.ConsumeIdentifier()
while tokenizer.TryConsume('.'):
tokenizer.ConsumeIdentifier()
tokenizer.Consume(']')
else:
tokenizer.ConsumeIdentifierOrNumber()
_SkipFieldContents(tokenizer)
# For historical reasons, fields may optionally be separated by commas or
# semicolons.
if not tokenizer.TryConsume(','):
tokenizer.TryConsume(';')
def _SkipFieldMessage(tokenizer):
"""Skips over a field message.
Args:
tokenizer: A tokenizer to parse the field name and values.
"""
if tokenizer.TryConsume('<'):
delimiter = '>'
else:
tokenizer.Consume('{')
delimiter = '}'
while not tokenizer.LookingAt('>') and not tokenizer.LookingAt('}'):
_SkipField(tokenizer)
tokenizer.Consume(delimiter)
def _SkipFieldValue(tokenizer):
"""Skips over a field value.
Args:
tokenizer: A tokenizer to parse the field name and values.
Raises:
ParseError: In case an invalid field value is found.
"""
# String/bytes tokens can come in multiple adjacent string literals.
# If we can consume one, consume as many as we can.
if tokenizer.TryConsumeByteString():
while tokenizer.TryConsumeByteString():
pass
return
if (not tokenizer.TryConsumeIdentifier() and
not _TryConsumeInt64(tokenizer) and not _TryConsumeUint64(tokenizer) and
not tokenizer.TryConsumeFloat()):
raise ParseError('Invalid field value: ' + tokenizer.token)
def _SkipRepeatedFieldValue(tokenizer):
"""Skips over a repeated field value.
Args:
tokenizer: A tokenizer to parse the field value.
"""
tokenizer.Consume('[')
if not tokenizer.LookingAt(']'):
_SkipFieldValue(tokenizer)
while tokenizer.TryConsume(','):
_SkipFieldValue(tokenizer)
tokenizer.Consume(']')
class Tokenizer(object):
@@ -1317,8 +1299,6 @@ class Tokenizer(object):
self._skip_comments = skip_comments
self._whitespace_pattern = (skip_comments and self._WHITESPACE_OR_COMMENT
or self._WHITESPACE)
self.contains_silent_marker_before_current_token = False
self._SkipWhitespace()
self.NextToken()
@@ -1351,8 +1331,6 @@ class Tokenizer(object):
match = self._whitespace_pattern.match(self._current_line, self._column)
if not match:
break
self.contains_silent_marker_before_current_token = match.group(0) == (
' ' + _DEBUG_STRING_SILENT_MARKER)
length = len(match.group(0))
self._column += length
@@ -1605,7 +1583,6 @@ class Tokenizer(object):
"""Reads the next meaningful token."""
self._previous_line = self._line
self._previous_column = self._column
self.contains_silent_marker_before_current_token = False
self._column += len(self.token)
self._SkipWhitespace()