Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions Doc/library/argparse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2105,6 +2105,38 @@ Mutual exclusion
exposed through inheritance.


.. method:: ArgumentParser.add_mutually_inclusive_group(required=False)

Create a mutually inclusive group. :mod:`!argparse` will make sure that
either all or none of the arguments in the group are present on the command
line::

>>> parser = argparse.ArgumentParser(prog='PROG')
>>> group = parser.add_mutually_inclusive_group()
>>> group.add_argument('--foo')
>>> group.add_argument('--bar')
>>> parser.parse_args(['--foo', 'X', '--bar', 'Y'])
Namespace(foo='X', bar='Y')
>>> parser.parse_args([])
Namespace(foo=None, bar=None)
>>> parser.parse_args(['--foo', 'X'])
usage: PROG [-h] [--foo FOO & --bar BAR]
PROG: error: the following arguments must be used together: --foo --bar

The *required* argument indicates that the group must be used; providing
none of the arguments is an error::

>>> parser = argparse.ArgumentParser(prog='PROG')
>>> group = parser.add_mutually_inclusive_group(required=True)
>>> group.add_argument('--foo')
>>> group.add_argument('--bar')
>>> parser.parse_args([])
usage: PROG [-h] (--foo FOO & --bar BAR)
PROG: error: the following arguments are required: --foo --bar

.. versionadded:: next


Parser defaults
^^^^^^^^^^^^^^^

Expand Down
9 changes: 9 additions & 0 deletions Doc/whatsnew/3.16.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ New modules
Improved modules
================

argparse
--------

* Add :meth:`~argparse.ArgumentParser.add_mutually_inclusive_group`, the
inverse of :meth:`~argparse.ArgumentParser.add_mutually_exclusive_group`: it
enforces that a set of arguments are either all provided together or not at
all (or, with ``required=True``, that all of them are provided).
(Contributed by Savannah Ostrowski in :gh:`150981`.)

gzip
----

Expand Down
75 changes: 66 additions & 9 deletions Lib/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,9 @@ def _get_actions_usage_parts(self, actions, groups):
if action2.option_strings and
action_groups.pop(action2, None)
] + [action]
positionals.append((group.required, group_actions))
positionals.append((group.required, group_actions, group._separator))
else:
positionals.append((None, [action]))
positionals.append((None, [action], None))
# the remaining optional arguments are sorted by the position of
# the first option in the group
optionals = []
Expand All @@ -482,15 +482,15 @@ def _get_actions_usage_parts(self, actions, groups):
if action2.option_strings and
action_groups.pop(action2, None)
]
optionals.append((group.required, group_actions))
optionals.append((group.required, group_actions, group._separator))
else:
optionals.append((None, [action]))
optionals.append((None, [action], None))

# collect all actions format strings
parts = []
t = self._theme
pos_start = None
for i, (required, group) in enumerate(optionals + positionals):
for i, (required, group, separator) in enumerate(optionals + positionals):
start = len(parts)
if i == len(optionals):
pos_start = start
Expand Down Expand Up @@ -540,7 +540,7 @@ def _get_actions_usage_parts(self, actions, groups):
if in_group:
parts[start] = ('(' if required else '[') + parts[start]
for i in range(start, len(parts) - 1):
parts[i] += ' |'
parts[i] += separator
parts[-1] += ')' if required else ']'

if pos_start is None:
Expand Down Expand Up @@ -1562,6 +1562,7 @@ def __init__(self,
# groups
self._action_groups = []
self._mutually_exclusive_groups = []
self._mutually_inclusive_groups = []

# defaults storage
self._defaults = {}
Expand Down Expand Up @@ -1677,6 +1678,11 @@ def add_mutually_exclusive_group(self, **kwargs):
self._mutually_exclusive_groups.append(group)
return group

def add_mutually_inclusive_group(self, **kwargs):
group = _MutuallyInclusiveGroup(self, **kwargs)
self._mutually_inclusive_groups.append(group)
return group

def _add_action(self, action):
# resolve any conflicts
self._check_conflict(action)
Expand Down Expand Up @@ -1743,6 +1749,19 @@ def _add_container_actions(self, container):
for action in group._group_actions:
group_map[action] = mutex_group

# add container's mutually inclusive groups
for group in container._mutually_inclusive_groups:
if group._container is container:
cont = self
else:
cont = title_group_map[group._container.title]
inc_group = cont.add_mutually_inclusive_group(
required=group.required)

# map the actions to their new inclusive group
for action in group._group_actions:
group_map[action] = inc_group

# add all actions to this container or their group
for action in container._actions:
group_map.get(action, self)._add_action(action)
Expand Down Expand Up @@ -1892,6 +1911,7 @@ def __init__(self, container, title=None, description=None, **kwargs):
self._has_negative_number_optionals = \
container._has_negative_number_optionals
self._mutually_exclusive_groups = container._mutually_exclusive_groups
self._mutually_inclusive_groups = container._mutually_inclusive_groups

def _add_action(self, action):
action = super(_ArgumentGroup, self)._add_action(action)
Expand All @@ -1911,6 +1931,7 @@ def __init__(self, container, required=False):
super(_MutuallyExclusiveGroup, self).__init__(container)
self.required = required
self._container = container
self._separator = ' |'

def _add_action(self, action):
if action.required:
Expand All @@ -1927,6 +1948,26 @@ def _remove_action(self, action):
def add_mutually_exclusive_group(self, **kwargs):
raise ValueError('mutually exclusive groups cannot be nested')

class _MutuallyInclusiveGroup(_ArgumentGroup):

def __init__(self, container, required=False):
super(_MutuallyInclusiveGroup, self).__init__(container)
self.required = required
self._container = container
self._separator = ' &'

def _add_action(self, action):
action = self._container._add_action(action)
self._group_actions.append(action)
return action

def _remove_action(self, action):
self._container._remove_action(action)
self._group_actions.remove(action)

def add_mutually_inclusive_group(self, **kwargs):
raise ValueError('mutually inclusive groups cannot be nested')

def _prog_name(prog=None):
if prog is not None:
return prog
Expand Down Expand Up @@ -2079,7 +2120,7 @@ def add_subparsers(self, **kwargs):
positionals = self._get_positional_actions()
required_optionals = [action for action in self._get_optional_actions()
if action.required]
groups = self._mutually_exclusive_groups
groups = self._mutually_exclusive_groups + self._mutually_inclusive_groups
formatter.add_usage(None, required_optionals + positionals, groups, '')
kwargs['prog'] = formatter.format_help().strip()

Expand Down Expand Up @@ -2469,6 +2510,22 @@ def consume_positionals(start_index):
msg = _('one of the arguments %s is required')
raise ArgumentError(None, msg % ' '.join(names))

# make sure mutually inclusive groups were used together
for group in self._mutually_inclusive_groups:
seen_in_group = seen_non_default_actions.intersection(group._group_actions)
if seen_in_group and len(seen_in_group) != len(group._group_actions):
names = [_get_action_name(action)
for action in group._group_actions
if action.help is not SUPPRESS]
msg = _('the following arguments must be used together: %s')
raise ArgumentError(None, msg % ' '.join(names))
if group.required and not seen_in_group:
names = [_get_action_name(action)
for action in group._group_actions
if action.help is not SUPPRESS]
msg = _('the following arguments are required: %s')
raise ArgumentError(None, msg % ' '.join(names))

# return the updated namespace and the extra arguments
return namespace, extras

Expand Down Expand Up @@ -2813,7 +2870,7 @@ def format_usage(self, formatter=None):
if formatter is None:
formatter = self._get_formatter()
formatter.add_usage(self.usage, self._actions,
self._mutually_exclusive_groups)
self._mutually_exclusive_groups + self._mutually_inclusive_groups)
return formatter.format_help()

def format_help(self, formatter=None):
Expand All @@ -2822,7 +2879,7 @@ def format_help(self, formatter=None):

# usage
formatter.add_usage(self.usage, self._actions,
self._mutually_exclusive_groups)
self._mutually_exclusive_groups + self._mutually_inclusive_groups)

# description
formatter.add_text(self.description)
Expand Down
148 changes: 148 additions & 0 deletions Lib/test/test_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4072,6 +4072,154 @@ class TestMutuallyExclusiveOptionalsAndPositionalsMixedParent(
MEPBase, TestMutuallyExclusiveOptionalsAndPositionalsMixed):
pass

# =============================
# Mutually inclusive group tests
# =============================

@force_not_colorized_test_class
class TestMutuallyInclusiveGroupErrors(TestCase):

def test_nested_inclusive_groups(self):
parser = argparse.ArgumentParser(prog='PROG')
g = parser.add_mutually_inclusive_group()
g.add_argument('--spam')
self.assertRaisesRegex(ValueError,
'mutually inclusive groups cannot be nested',
g.add_mutually_inclusive_group)

@force_not_colorized
def test_help(self):
parser = ErrorRaisingArgumentParser(prog='PROG')
group1 = parser.add_mutually_inclusive_group()
group1.add_argument('--foo', action='store_true')
group1.add_argument('--bar', action='store_false')
group2 = parser.add_mutually_inclusive_group()
group2.add_argument('--soup', action='store_true')
group2.add_argument('--nuts', action='store_false')
expected = '''\
usage: PROG [-h] [--foo & --bar] [--soup & --nuts]

options:
-h, --help show this help message and exit
--foo
--bar
--soup
--nuts
'''
self.assertEqual(parser.format_help(), textwrap.dedent(expected))

def test_usage_empty_group(self):
parser = ErrorRaisingArgumentParser(prog='PROG')
group = parser.add_mutually_inclusive_group()
self.assertEqual(parser.format_usage(), 'usage: PROG [-h]\n')


class MIMixin(object):

def test_failures_when_not_required(self):
parse_args = self.get_parser(required=False).parse_args
error = ArgumentParserError
for args_string in self.failures:
with self.subTest(args=args_string):
self.assertRaises(error, parse_args, args_string.split())

def test_failures_when_required(self):
parse_args = self.get_parser(required=True).parse_args
error = ArgumentParserError
for args_string in self.failures + ['']:
with self.subTest(args=args_string):
self.assertRaises(error, parse_args, args_string.split())

def test_successes_when_not_required(self):
parse_args = self.get_parser(required=False).parse_args
successes = self.successes + self.successes_when_not_required
for args_string, expected_ns in successes:
with self.subTest(args=args_string):
actual_ns = parse_args(args_string.split())
self.assertEqual(actual_ns, expected_ns)

def test_successes_when_required(self):
parse_args = self.get_parser(required=True).parse_args
for args_string, expected_ns in self.successes:
with self.subTest(args=args_string):
actual_ns = parse_args(args_string.split())
self.assertEqual(actual_ns, expected_ns)

@force_not_colorized
def test_usage_when_not_required(self):
format_usage = self.get_parser(required=False).format_usage
self.assertEqual(format_usage(), textwrap.dedent(self.usage_when_not_required))

@force_not_colorized
def test_usage_when_required(self):
format_usage = self.get_parser(required=True).format_usage
self.assertEqual(format_usage(), textwrap.dedent(self.usage_when_required))

@force_not_colorized
def test_help_when_not_required(self):
format_help = self.get_parser(required=False).format_help
self.assertEqual(format_help(), textwrap.dedent(self.usage_when_not_required + self.help))

@force_not_colorized
def test_help_when_required(self):
format_help = self.get_parser(required=True).format_help
self.assertEqual(format_help(), textwrap.dedent(self.usage_when_required + self.help))


class TestMutuallyInclusiveSimple(MIMixin, TestCase):

def get_parser(self, required=None):
parser = ErrorRaisingArgumentParser(prog='PROG')
group = parser.add_mutually_inclusive_group(required=required)
group.add_argument('--bar', help='bar help')
group.add_argument('--baz', help='baz help')
return parser

failures = ['--bar X', '--baz Y']
successes = [
('--bar X --baz Y', NS(bar='X', baz='Y')),
]
successes_when_not_required = [
('', NS(bar=None, baz=None)),
]

usage_when_not_required = '''\
usage: PROG [-h] [--bar BAR & --baz BAZ]
'''
usage_when_required = '''\
usage: PROG [-h] (--bar BAR & --baz BAZ)
'''
help = '''\

options:
-h, --help show this help message and exit
--bar BAR bar help
--baz BAZ baz help
'''


# =====================================================
# Mutually inclusive group in parent parser tests
# =====================================================

class MIPBase(object):

def get_parser(self, required=None):
parent = super(MIPBase, self).get_parser(required=required)
parser = ErrorRaisingArgumentParser(
prog=parent.prog, add_help=False, parents=[parent])
return parser


class TestMutuallyInclusiveGroupErrorsParent(
MIPBase, TestMutuallyInclusiveGroupErrors):
pass


class TestMutuallyInclusiveSimpleParent(
MIPBase, TestMutuallyInclusiveSimple):
pass

# =================
# Set default tests
# =================
Expand Down
1 change: 1 addition & 0 deletions Lib/test/translationdata/argparse/msgids.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ show program's version number and exit
show this help message and exit
subcommands
the following arguments are required: %s
the following arguments must be used together: %s
unexpected option string: %s
unknown parser %(parser_name)r (choices: %(choices)s)
unrecognized arguments: %s
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add :meth:`argparse.ArgumentParser.add_mutually_inclusive_group` to support
argument groups where all arguments must be used together or not at all.
Loading