311 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			311 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Protocol Buffers - Google's data interchange format
 | |
| # Copyright 2008 Google Inc.  All rights reserved.
 | |
| #
 | |
| # Use of this source code is governed by a BSD-style
 | |
| # license that can be found in the LICENSE file or at
 | |
| # https://developers.google.com/open-source/licenses/bsd
 | |
| 
 | |
| """Contains FieldMask class."""
 | |
| 
 | |
| from google.protobuf.descriptor import FieldDescriptor
 | |
| 
 | |
| 
 | |
| class FieldMask(object):
 | |
|   """Class for FieldMask message type."""
 | |
| 
 | |
|   __slots__ = ()
 | |
| 
 | |
|   def ToJsonString(self):
 | |
|     """Converts FieldMask to string according to proto3 JSON spec."""
 | |
|     camelcase_paths = []
 | |
|     for path in self.paths:
 | |
|       camelcase_paths.append(_SnakeCaseToCamelCase(path))
 | |
|     return ','.join(camelcase_paths)
 | |
| 
 | |
|   def FromJsonString(self, value):
 | |
|     """Converts string to FieldMask according to proto3 JSON spec."""
 | |
|     if not isinstance(value, str):
 | |
|       raise ValueError('FieldMask JSON value not a string: {!r}'.format(value))
 | |
|     self.Clear()
 | |
|     if value:
 | |
|       for path in value.split(','):
 | |
|         self.paths.append(_CamelCaseToSnakeCase(path))
 | |
| 
 | |
|   def IsValidForDescriptor(self, message_descriptor):
 | |
|     """Checks whether the FieldMask is valid for Message Descriptor."""
 | |
|     for path in self.paths:
 | |
|       if not _IsValidPath(message_descriptor, path):
 | |
|         return False
 | |
|     return True
 | |
| 
 | |
|   def AllFieldsFromDescriptor(self, message_descriptor):
 | |
|     """Gets all direct fields of Message Descriptor to FieldMask."""
 | |
|     self.Clear()
 | |
|     for field in message_descriptor.fields:
 | |
|       self.paths.append(field.name)
 | |
| 
 | |
|   def CanonicalFormFromMask(self, mask):
 | |
|     """Converts a FieldMask to the canonical form.
 | |
| 
 | |
|     Removes paths that are covered by another path. For example,
 | |
|     "foo.bar" is covered by "foo" and will be removed if "foo"
 | |
|     is also in the FieldMask. Then sorts all paths in alphabetical order.
 | |
| 
 | |
|     Args:
 | |
|       mask: The original FieldMask to be converted.
 | |
|     """
 | |
|     tree = _FieldMaskTree(mask)
 | |
|     tree.ToFieldMask(self)
 | |
| 
 | |
|   def Union(self, mask1, mask2):
 | |
|     """Merges mask1 and mask2 into this FieldMask."""
 | |
|     _CheckFieldMaskMessage(mask1)
 | |
|     _CheckFieldMaskMessage(mask2)
 | |
|     tree = _FieldMaskTree(mask1)
 | |
|     tree.MergeFromFieldMask(mask2)
 | |
|     tree.ToFieldMask(self)
 | |
| 
 | |
|   def Intersect(self, mask1, mask2):
 | |
|     """Intersects mask1 and mask2 into this FieldMask."""
 | |
|     _CheckFieldMaskMessage(mask1)
 | |
|     _CheckFieldMaskMessage(mask2)
 | |
|     tree = _FieldMaskTree(mask1)
 | |
|     intersection = _FieldMaskTree()
 | |
|     for path in mask2.paths:
 | |
|       tree.IntersectPath(path, intersection)
 | |
|     intersection.ToFieldMask(self)
 | |
| 
 | |
|   def MergeMessage(
 | |
|       self, source, destination,
 | |
|       replace_message_field=False, replace_repeated_field=False):
 | |
|     """Merges fields specified in FieldMask from source to destination.
 | |
| 
 | |
|     Args:
 | |
|       source: Source message.
 | |
|       destination: The destination message to be merged into.
 | |
|       replace_message_field: Replace message field if True. Merge message
 | |
|           field if False.
 | |
|       replace_repeated_field: Replace repeated field if True. Append
 | |
|           elements of repeated field if False.
 | |
|     """
 | |
|     tree = _FieldMaskTree(self)
 | |
|     tree.MergeMessage(
 | |
|         source, destination, replace_message_field, replace_repeated_field)
 | |
| 
 | |
| 
 | |
| def _IsValidPath(message_descriptor, path):
 | |
|   """Checks whether the path is valid for Message Descriptor."""
 | |
|   parts = path.split('.')
 | |
|   last = parts.pop()
 | |
|   for name in parts:
 | |
|     field = message_descriptor.fields_by_name.get(name)
 | |
|     if (field is None or
 | |
|         field.label == FieldDescriptor.LABEL_REPEATED or
 | |
|         field.type != FieldDescriptor.TYPE_MESSAGE):
 | |
|       return False
 | |
|     message_descriptor = field.message_type
 | |
|   return last in message_descriptor.fields_by_name
 | |
| 
 | |
| 
 | |
| def _CheckFieldMaskMessage(message):
 | |
|   """Raises ValueError if message is not a FieldMask."""
 | |
|   message_descriptor = message.DESCRIPTOR
 | |
|   if (message_descriptor.name != 'FieldMask' or
 | |
|       message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
 | |
|     raise ValueError('Message {0} is not a FieldMask.'.format(
 | |
|         message_descriptor.full_name))
 | |
| 
 | |
| 
 | |
| def _SnakeCaseToCamelCase(path_name):
 | |
|   """Converts a path name from snake_case to camelCase."""
 | |
|   result = []
 | |
|   after_underscore = False
 | |
|   for c in path_name:
 | |
|     if c.isupper():
 | |
|       raise ValueError(
 | |
|           'Fail to print FieldMask to Json string: Path name '
 | |
|           '{0} must not contain uppercase letters.'.format(path_name))
 | |
|     if after_underscore:
 | |
|       if c.islower():
 | |
|         result.append(c.upper())
 | |
|         after_underscore = False
 | |
|       else:
 | |
|         raise ValueError(
 | |
|             'Fail to print FieldMask to Json string: The '
 | |
|             'character after a "_" must be a lowercase letter '
 | |
|             'in path name {0}.'.format(path_name))
 | |
|     elif c == '_':
 | |
|       after_underscore = True
 | |
|     else:
 | |
|       result += c
 | |
| 
 | |
|   if after_underscore:
 | |
|     raise ValueError('Fail to print FieldMask to Json string: Trailing "_" '
 | |
|                      'in path name {0}.'.format(path_name))
 | |
|   return ''.join(result)
 | |
| 
 | |
| 
 | |
| def _CamelCaseToSnakeCase(path_name):
 | |
|   """Converts a field name from camelCase to snake_case."""
 | |
|   result = []
 | |
|   for c in path_name:
 | |
|     if c == '_':
 | |
|       raise ValueError('Fail to parse FieldMask: Path name '
 | |
|                        '{0} must not contain "_"s.'.format(path_name))
 | |
|     if c.isupper():
 | |
|       result += '_'
 | |
|       result += c.lower()
 | |
|     else:
 | |
|       result += c
 | |
|   return ''.join(result)
 | |
| 
 | |
| 
 | |
| class _FieldMaskTree(object):
 | |
|   """Represents a FieldMask in a tree structure.
 | |
| 
 | |
|   For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
 | |
|   the FieldMaskTree will be:
 | |
|       [_root] -+- foo -+- bar
 | |
|             |       |
 | |
|             |       +- baz
 | |
|             |
 | |
|             +- bar --- baz
 | |
|   In the tree, each leaf node represents a field path.
 | |
|   """
 | |
| 
 | |
|   __slots__ = ('_root',)
 | |
| 
 | |
|   def __init__(self, field_mask=None):
 | |
|     """Initializes the tree by FieldMask."""
 | |
|     self._root = {}
 | |
|     if field_mask:
 | |
|       self.MergeFromFieldMask(field_mask)
 | |
| 
 | |
|   def MergeFromFieldMask(self, field_mask):
 | |
|     """Merges a FieldMask to the tree."""
 | |
|     for path in field_mask.paths:
 | |
|       self.AddPath(path)
 | |
| 
 | |
|   def AddPath(self, path):
 | |
|     """Adds a field path into the tree.
 | |
| 
 | |
|     If the field path to add is a sub-path of an existing field path
 | |
|     in the tree (i.e., a leaf node), it means the tree already matches
 | |
|     the given path so nothing will be added to the tree. If the path
 | |
|     matches an existing non-leaf node in the tree, that non-leaf node
 | |
|     will be turned into a leaf node with all its children removed because
 | |
|     the path matches all the node's children. Otherwise, a new path will
 | |
|     be added.
 | |
| 
 | |
|     Args:
 | |
|       path: The field path to add.
 | |
|     """
 | |
|     node = self._root
 | |
|     for name in path.split('.'):
 | |
|       if name not in node:
 | |
|         node[name] = {}
 | |
|       elif not node[name]:
 | |
|         # Pre-existing empty node implies we already have this entire tree.
 | |
|         return
 | |
|       node = node[name]
 | |
|     # Remove any sub-trees we might have had.
 | |
|     node.clear()
 | |
| 
 | |
|   def ToFieldMask(self, field_mask):
 | |
|     """Converts the tree to a FieldMask."""
 | |
|     field_mask.Clear()
 | |
|     _AddFieldPaths(self._root, '', field_mask)
 | |
| 
 | |
|   def IntersectPath(self, path, intersection):
 | |
|     """Calculates the intersection part of a field path with this tree.
 | |
| 
 | |
|     Args:
 | |
|       path: The field path to calculates.
 | |
|       intersection: The out tree to record the intersection part.
 | |
|     """
 | |
|     node = self._root
 | |
|     for name in path.split('.'):
 | |
|       if name not in node:
 | |
|         return
 | |
|       elif not node[name]:
 | |
|         intersection.AddPath(path)
 | |
|         return
 | |
|       node = node[name]
 | |
|     intersection.AddLeafNodes(path, node)
 | |
| 
 | |
|   def AddLeafNodes(self, prefix, node):
 | |
|     """Adds leaf nodes begin with prefix to this tree."""
 | |
|     if not node:
 | |
|       self.AddPath(prefix)
 | |
|     for name in node:
 | |
|       child_path = prefix + '.' + name
 | |
|       self.AddLeafNodes(child_path, node[name])
 | |
| 
 | |
|   def MergeMessage(
 | |
|       self, source, destination,
 | |
|       replace_message, replace_repeated):
 | |
|     """Merge all fields specified by this tree from source to destination."""
 | |
|     _MergeMessage(
 | |
|         self._root, source, destination, replace_message, replace_repeated)
 | |
| 
 | |
| 
 | |
| def _StrConvert(value):
 | |
|   """Converts value to str if it is not."""
 | |
|   # This file is imported by c extension and some methods like ClearField
 | |
|   # requires string for the field name. py2/py3 has different text
 | |
|   # type and may use unicode.
 | |
|   if not isinstance(value, str):
 | |
|     return value.encode('utf-8')
 | |
|   return value
 | |
| 
 | |
| 
 | |
| def _MergeMessage(
 | |
|     node, source, destination, replace_message, replace_repeated):
 | |
|   """Merge all fields specified by a sub-tree from source to destination."""
 | |
|   source_descriptor = source.DESCRIPTOR
 | |
|   for name in node:
 | |
|     child = node[name]
 | |
|     field = source_descriptor.fields_by_name[name]
 | |
|     if field is None:
 | |
|       raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
 | |
|           name, source_descriptor.full_name))
 | |
|     if child:
 | |
|       # Sub-paths are only allowed for singular message fields.
 | |
|       if (field.label == FieldDescriptor.LABEL_REPEATED or
 | |
|           field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
 | |
|         raise ValueError('Error: Field {0} in message {1} is not a singular '
 | |
|                          'message field and cannot have sub-fields.'.format(
 | |
|                              name, source_descriptor.full_name))
 | |
|       if source.HasField(name):
 | |
|         _MergeMessage(
 | |
|             child, getattr(source, name), getattr(destination, name),
 | |
|             replace_message, replace_repeated)
 | |
|       continue
 | |
|     if field.label == FieldDescriptor.LABEL_REPEATED:
 | |
|       if replace_repeated:
 | |
|         destination.ClearField(_StrConvert(name))
 | |
|       repeated_source = getattr(source, name)
 | |
|       repeated_destination = getattr(destination, name)
 | |
|       repeated_destination.MergeFrom(repeated_source)
 | |
|     else:
 | |
|       if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
 | |
|         if replace_message:
 | |
|           destination.ClearField(_StrConvert(name))
 | |
|         if source.HasField(name):
 | |
|           getattr(destination, name).MergeFrom(getattr(source, name))
 | |
|       else:
 | |
|         setattr(destination, name, getattr(source, name))
 | |
| 
 | |
| 
 | |
| def _AddFieldPaths(node, prefix, field_mask):
 | |
|   """Adds the field paths descended from node to field_mask."""
 | |
|   if not node and prefix:
 | |
|     field_mask.paths.append(prefix)
 | |
|     return
 | |
|   for name in sorted(node):
 | |
|     if prefix:
 | |
|       child_path = prefix + '.' + name
 | |
|     else:
 | |
|       child_path = name
 | |
|     _AddFieldPaths(node[name], child_path, field_mask)
 |