blob: d3d6fe5bf56bc4e6dd041c21951297cd4c20545f [file] [log] [blame]
Brian Silvermancc09f182022-03-09 15:40:20 -08001// Copyright 2020 The Bazel Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashSet};
16use std::error::Error;
17use std::fmt;
18use std::fmt::Write;
19use std::iter::Peekable;
20use std::mem::take;
21
22#[derive(Debug, Clone)]
23pub(crate) enum FlagParseError {
24 UnknownFlag(String),
25 ValueMissing(String),
26 ProvidedMultipleTimes(String),
27 ProgramNameMissing,
28}
29
30impl fmt::Display for FlagParseError {
31 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
32 match self {
33 Self::UnknownFlag(ref flag) => write!(f, "unknown flag \"{}\"", flag),
34 Self::ValueMissing(ref flag) => write!(f, "flag \"{}\" missing parameter(s)", flag),
35 Self::ProvidedMultipleTimes(ref flag) => {
36 write!(f, "flag \"{}\" can only appear once", flag)
37 }
38 Self::ProgramNameMissing => {
39 write!(f, "program name (argv[0]) missing")
40 }
41 }
42 }
43}
44impl Error for FlagParseError {}
45
46struct FlagDef<'a, T> {
47 name: String,
48 help: String,
49 output_storage: &'a mut Option<T>,
50}
51
52impl<'a, T> fmt::Display for FlagDef<'a, T> {
53 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54 write!(f, "{}\t{}", self.name, self.help)
55 }
56}
57
58impl<'a, T> fmt::Debug for FlagDef<'a, T> {
59 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60 f.debug_struct("FlagDef")
61 .field("name", &self.name)
62 .field("help", &self.help)
63 .finish()
64 }
65}
66
67#[derive(Debug)]
68pub(crate) struct Flags<'a> {
69 single: BTreeMap<String, FlagDef<'a, String>>,
70 repeated: BTreeMap<String, FlagDef<'a, Vec<String>>>,
71}
72
73#[derive(Debug)]
74pub(crate) enum ParseOutcome {
75 Help(String),
76 Parsed(Vec<String>),
77}
78
79impl<'a> Flags<'a> {
80 pub(crate) fn new() -> Flags<'a> {
81 Flags {
82 single: BTreeMap::new(),
83 repeated: BTreeMap::new(),
84 }
85 }
86
87 pub(crate) fn define_flag(
88 &mut self,
89 name: impl Into<String>,
90 help: impl Into<String>,
91 output_storage: &'a mut Option<String>,
92 ) {
93 let name = name.into();
94 if self.repeated.contains_key(&name) {
95 panic!("argument \"{}\" already defined as repeated flag", name)
96 }
97 self.single.insert(
98 name.clone(),
99 FlagDef::<'a, String> {
100 name,
101 help: help.into(),
102 output_storage,
103 },
104 );
105 }
106
107 pub(crate) fn define_repeated_flag(
108 &mut self,
109 name: impl Into<String>,
110 help: impl Into<String>,
111 output_storage: &'a mut Option<Vec<String>>,
112 ) {
113 let name = name.into();
114 if self.single.contains_key(&name) {
115 panic!("argument \"{}\" already defined as flag", name)
116 }
117 self.repeated.insert(
118 name.clone(),
119 FlagDef::<'a, Vec<String>> {
120 name,
121 help: help.into(),
122 output_storage,
123 },
124 );
125 }
126
127 fn help(&self, program_name: String) -> String {
128 let single = self.single.values().map(|fd| fd.to_string());
129 let repeated = self.repeated.values().map(|fd| fd.to_string());
130 let mut all: Vec<String> = single.chain(repeated).collect();
131 all.sort();
132
133 let mut help_text = String::new();
134 writeln!(
135 &mut help_text,
136 "Help for {}: [options] -- [extra arguments]",
137 program_name
138 )
139 .unwrap();
140 for line in all {
141 writeln!(&mut help_text, "\t{}", line).unwrap();
142 }
143 help_text
144 }
145
146 pub(crate) fn parse(mut self, argv: Vec<String>) -> Result<ParseOutcome, FlagParseError> {
147 let mut argv_iter = argv.into_iter().peekable();
148 let program_name = argv_iter.next().ok_or(FlagParseError::ProgramNameMissing)?;
149
150 // To check if a non-repeated flag has been set already.
151 let mut seen_single_flags = HashSet::<String>::new();
152
153 while let Some(flag) = argv_iter.next() {
154 if flag == "--help" {
155 return Ok(ParseOutcome::Help(self.help(program_name)));
156 }
157 if !flag.starts_with("--") {
158 return Err(FlagParseError::UnknownFlag(flag));
159 }
160 let mut args = consume_args(&flag, &mut argv_iter);
161 if flag == "--" {
162 return Ok(ParseOutcome::Parsed(args));
163 }
164 if args.is_empty() {
165 return Err(FlagParseError::ValueMissing(flag.clone()));
166 }
167 if let Some(flag_def) = self.single.get_mut(&flag) {
168 if args.len() > 1 || seen_single_flags.contains(&flag) {
169 return Err(FlagParseError::ProvidedMultipleTimes(flag.clone()));
170 }
171 let arg = args.first_mut().unwrap();
172 seen_single_flags.insert(flag);
173 *flag_def.output_storage = Some(take(arg));
174 continue;
175 }
176 if let Some(flag_def) = self.repeated.get_mut(&flag) {
177 flag_def
178 .output_storage
179 .get_or_insert_with(Vec::new)
180 .append(&mut args);
181 continue;
182 }
183 return Err(FlagParseError::UnknownFlag(flag));
184 }
185 Ok(ParseOutcome::Parsed(vec![]))
186 }
187}
188
189fn consume_args<I: Iterator<Item = String>>(
190 flag: &str,
191 argv_iter: &mut Peekable<I>,
192) -> Vec<String> {
193 if flag == "--" {
194 // If we have found --, the rest of the iterator is just returned as-is.
195 argv_iter.collect()
196 } else {
197 let mut args = vec![];
198 while let Some(arg) = argv_iter.next_if(|s| !s.starts_with("--")) {
199 args.push(arg);
200 }
201 args
202 }
203}
204
205#[cfg(test)]
206mod test {
207 use super::*;
208
209 fn args(args: &[&str]) -> Vec<String> {
210 ["foo"].iter().chain(args).map(|&s| s.to_owned()).collect()
211 }
212
213 #[test]
214 fn test_flag_help() {
215 let mut bar = None;
216 let mut parser = Flags::new();
217 parser.define_flag("--bar", "bar help", &mut bar);
218 let result = parser.parse(args(&["--help"])).unwrap();
219 if let ParseOutcome::Help(h) = result {
220 assert!(h.contains("Help for foo"));
221 assert!(h.contains("--bar\tbar help"));
222 } else {
223 panic!("expected that --help would invoke help, instead parsed arguments")
224 }
225 }
226
227 #[test]
228 fn test_flag_single_repeated() {
229 let mut bar = None;
230 let mut parser = Flags::new();
231 parser.define_flag("--bar", "bar help", &mut bar);
232 let result = parser.parse(args(&["--bar", "aa", "bb"]));
233 if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result {
234 assert_eq!(f, "--bar");
235 } else {
236 panic!("expected error, got {:?}", result)
237 }
238 let mut parser = Flags::new();
239 parser.define_flag("--bar", "bar help", &mut bar);
240 let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"]));
241 if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result {
242 assert_eq!(f, "--bar");
243 } else {
244 panic!("expected error, got {:?}", result)
245 }
246 }
247
248 #[test]
249 fn test_repeated_flags() {
250 // Test case 1) --bar something something_else should work as a repeated flag.
251 let mut bar = None;
252 let mut parser = Flags::new();
253 parser.define_repeated_flag("--bar", "bar help", &mut bar);
254 let result = parser.parse(args(&["--bar", "aa", "bb"])).unwrap();
255 assert!(matches!(result, ParseOutcome::Parsed(_)));
256 assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()]));
257 // Test case 2) --bar something --bar something_else should also work as a repeated flag.
258 bar = None;
259 let mut parser = Flags::new();
260 parser.define_repeated_flag("--bar", "bar help", &mut bar);
261 let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"])).unwrap();
262 assert!(matches!(result, ParseOutcome::Parsed(_)));
263 assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()]));
264 }
265
266 #[test]
267 fn test_extra_args() {
268 let parser = Flags::new();
269 let result = parser.parse(args(&["--", "bb"])).unwrap();
270 if let ParseOutcome::Parsed(got) = result {
271 assert_eq!(got, vec!["bb".to_owned()])
272 } else {
273 panic!("expected correct parsing, got {:?}", result)
274 }
275 }
276}