pytree_visitor_test.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright 2015 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tests for yapf.pytree_visitor."""
  15. import unittest
  16. from io import StringIO
  17. from yapf.pytree import pytree_utils
  18. from yapf.pytree import pytree_visitor
  19. class _NodeNameCollector(pytree_visitor.PyTreeVisitor):
  20. """A tree visitor that collects the names of all tree nodes into a list.
  21. Attributes:
  22. all_node_names: collected list of the names, available when the traversal
  23. is over.
  24. name_node_values: collects a list of NAME leaves (in addition to those going
  25. into all_node_names).
  26. """
  27. def __init__(self):
  28. self.all_node_names = []
  29. self.name_node_values = []
  30. def DefaultNodeVisit(self, node):
  31. self.all_node_names.append(pytree_utils.NodeName(node))
  32. super(_NodeNameCollector, self).DefaultNodeVisit(node)
  33. def DefaultLeafVisit(self, leaf):
  34. self.all_node_names.append(pytree_utils.NodeName(leaf))
  35. def Visit_NAME(self, leaf):
  36. self.name_node_values.append(leaf.value)
  37. self.DefaultLeafVisit(leaf)
  38. _VISITOR_TEST_SIMPLE_CODE = r"""
  39. foo = bar
  40. baz = x
  41. """
  42. _VISITOR_TEST_NESTED_CODE = r"""
  43. if x:
  44. if y:
  45. return z
  46. """
  47. class PytreeVisitorTest(unittest.TestCase):
  48. def testCollectAllNodeNamesSimpleCode(self):
  49. tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
  50. collector = _NodeNameCollector()
  51. collector.Visit(tree)
  52. expected_names = [
  53. 'file_input',
  54. 'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE',
  55. 'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE',
  56. 'ENDMARKER',
  57. ] # yapf: disable
  58. self.assertEqual(expected_names, collector.all_node_names)
  59. expected_name_node_values = ['foo', 'bar', 'baz', 'x']
  60. self.assertEqual(expected_name_node_values, collector.name_node_values)
  61. def testCollectAllNodeNamesNestedCode(self):
  62. tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_NESTED_CODE)
  63. collector = _NodeNameCollector()
  64. collector.Visit(tree)
  65. expected_names = [
  66. 'file_input',
  67. 'if_stmt', 'NAME', 'NAME', 'COLON',
  68. 'suite', 'NEWLINE',
  69. 'INDENT', 'if_stmt', 'NAME', 'NAME', 'COLON', 'suite', 'NEWLINE',
  70. 'INDENT', 'simple_stmt', 'return_stmt', 'NAME', 'NAME', 'NEWLINE',
  71. 'DEDENT', 'DEDENT', 'ENDMARKER',
  72. ] # yapf: disable
  73. self.assertEqual(expected_names, collector.all_node_names)
  74. expected_name_node_values = ['if', 'x', 'if', 'y', 'return', 'z']
  75. self.assertEqual(expected_name_node_values, collector.name_node_values)
  76. def testDumper(self):
  77. # PyTreeDumper is mainly a debugging utility, so only do basic sanity
  78. # checking.
  79. tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
  80. stream = StringIO()
  81. pytree_visitor.PyTreeDumper(target_stream=stream).Visit(tree)
  82. dump_output = stream.getvalue()
  83. self.assertIn('file_input [3 children]', dump_output)
  84. self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output)
  85. self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output)
  86. def testDumpPyTree(self):
  87. # Similar sanity checking for the convenience wrapper DumpPyTree
  88. tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
  89. stream = StringIO()
  90. pytree_visitor.DumpPyTree(tree, target_stream=stream)
  91. dump_output = stream.getvalue()
  92. self.assertIn('file_input [3 children]', dump_output)
  93. self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output)
  94. self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output)
  95. if __name__ == '__main__':
  96. unittest.main()