| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- # Copyright 2015 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Tests for yapf.pytree_visitor."""
- import unittest
- from io import StringIO
- from yapf.pytree import pytree_utils
- from yapf.pytree import pytree_visitor
- class _NodeNameCollector(pytree_visitor.PyTreeVisitor):
- """A tree visitor that collects the names of all tree nodes into a list.
- Attributes:
- all_node_names: collected list of the names, available when the traversal
- is over.
- name_node_values: collects a list of NAME leaves (in addition to those going
- into all_node_names).
- """
- def __init__(self):
- self.all_node_names = []
- self.name_node_values = []
- def DefaultNodeVisit(self, node):
- self.all_node_names.append(pytree_utils.NodeName(node))
- super(_NodeNameCollector, self).DefaultNodeVisit(node)
- def DefaultLeafVisit(self, leaf):
- self.all_node_names.append(pytree_utils.NodeName(leaf))
- def Visit_NAME(self, leaf):
- self.name_node_values.append(leaf.value)
- self.DefaultLeafVisit(leaf)
- _VISITOR_TEST_SIMPLE_CODE = r"""
- foo = bar
- baz = x
- """
- _VISITOR_TEST_NESTED_CODE = r"""
- if x:
- if y:
- return z
- """
- class PytreeVisitorTest(unittest.TestCase):
- def testCollectAllNodeNamesSimpleCode(self):
- tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
- collector = _NodeNameCollector()
- collector.Visit(tree)
- expected_names = [
- 'file_input',
- 'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE',
- 'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE',
- 'ENDMARKER',
- ] # yapf: disable
- self.assertEqual(expected_names, collector.all_node_names)
- expected_name_node_values = ['foo', 'bar', 'baz', 'x']
- self.assertEqual(expected_name_node_values, collector.name_node_values)
- def testCollectAllNodeNamesNestedCode(self):
- tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_NESTED_CODE)
- collector = _NodeNameCollector()
- collector.Visit(tree)
- expected_names = [
- 'file_input',
- 'if_stmt', 'NAME', 'NAME', 'COLON',
- 'suite', 'NEWLINE',
- 'INDENT', 'if_stmt', 'NAME', 'NAME', 'COLON', 'suite', 'NEWLINE',
- 'INDENT', 'simple_stmt', 'return_stmt', 'NAME', 'NAME', 'NEWLINE',
- 'DEDENT', 'DEDENT', 'ENDMARKER',
- ] # yapf: disable
- self.assertEqual(expected_names, collector.all_node_names)
- expected_name_node_values = ['if', 'x', 'if', 'y', 'return', 'z']
- self.assertEqual(expected_name_node_values, collector.name_node_values)
- def testDumper(self):
- # PyTreeDumper is mainly a debugging utility, so only do basic sanity
- # checking.
- tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
- stream = StringIO()
- pytree_visitor.PyTreeDumper(target_stream=stream).Visit(tree)
- dump_output = stream.getvalue()
- self.assertIn('file_input [3 children]', dump_output)
- self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output)
- self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output)
- def testDumpPyTree(self):
- # Similar sanity checking for the convenience wrapper DumpPyTree
- tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
- stream = StringIO()
- pytree_visitor.DumpPyTree(tree, target_stream=stream)
- dump_output = stream.getvalue()
- self.assertIn('file_input [3 children]', dump_output)
- self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output)
- self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output)
- if __name__ == '__main__':
- unittest.main()
|