Brian Silverman | cc09f18 | 2022-03-09 15:40:20 -0800 | [diff] [blame^] | 1 | // Copyright 2020 The Bazel Authors. All rights reserved. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | use std::collections::{BTreeMap, HashSet}; |
| 16 | use std::error::Error; |
| 17 | use std::fmt; |
| 18 | use std::fmt::Write; |
| 19 | use std::iter::Peekable; |
| 20 | use std::mem::take; |
| 21 | |
| 22 | #[derive(Debug, Clone)] |
| 23 | pub(crate) enum FlagParseError { |
| 24 | UnknownFlag(String), |
| 25 | ValueMissing(String), |
| 26 | ProvidedMultipleTimes(String), |
| 27 | ProgramNameMissing, |
| 28 | } |
| 29 | |
| 30 | impl fmt::Display for FlagParseError { |
| 31 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
| 32 | match self { |
| 33 | Self::UnknownFlag(ref flag) => write!(f, "unknown flag \"{}\"", flag), |
| 34 | Self::ValueMissing(ref flag) => write!(f, "flag \"{}\" missing parameter(s)", flag), |
| 35 | Self::ProvidedMultipleTimes(ref flag) => { |
| 36 | write!(f, "flag \"{}\" can only appear once", flag) |
| 37 | } |
| 38 | Self::ProgramNameMissing => { |
| 39 | write!(f, "program name (argv[0]) missing") |
| 40 | } |
| 41 | } |
| 42 | } |
| 43 | } |
| 44 | impl Error for FlagParseError {} |
| 45 | |
| 46 | struct FlagDef<'a, T> { |
| 47 | name: String, |
| 48 | help: String, |
| 49 | output_storage: &'a mut Option<T>, |
| 50 | } |
| 51 | |
| 52 | impl<'a, T> fmt::Display for FlagDef<'a, T> { |
| 53 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
| 54 | write!(f, "{}\t{}", self.name, self.help) |
| 55 | } |
| 56 | } |
| 57 | |
| 58 | impl<'a, T> fmt::Debug for FlagDef<'a, T> { |
| 59 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
| 60 | f.debug_struct("FlagDef") |
| 61 | .field("name", &self.name) |
| 62 | .field("help", &self.help) |
| 63 | .finish() |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | #[derive(Debug)] |
| 68 | pub(crate) struct Flags<'a> { |
| 69 | single: BTreeMap<String, FlagDef<'a, String>>, |
| 70 | repeated: BTreeMap<String, FlagDef<'a, Vec<String>>>, |
| 71 | } |
| 72 | |
| 73 | #[derive(Debug)] |
| 74 | pub(crate) enum ParseOutcome { |
| 75 | Help(String), |
| 76 | Parsed(Vec<String>), |
| 77 | } |
| 78 | |
| 79 | impl<'a> Flags<'a> { |
| 80 | pub(crate) fn new() -> Flags<'a> { |
| 81 | Flags { |
| 82 | single: BTreeMap::new(), |
| 83 | repeated: BTreeMap::new(), |
| 84 | } |
| 85 | } |
| 86 | |
| 87 | pub(crate) fn define_flag( |
| 88 | &mut self, |
| 89 | name: impl Into<String>, |
| 90 | help: impl Into<String>, |
| 91 | output_storage: &'a mut Option<String>, |
| 92 | ) { |
| 93 | let name = name.into(); |
| 94 | if self.repeated.contains_key(&name) { |
| 95 | panic!("argument \"{}\" already defined as repeated flag", name) |
| 96 | } |
| 97 | self.single.insert( |
| 98 | name.clone(), |
| 99 | FlagDef::<'a, String> { |
| 100 | name, |
| 101 | help: help.into(), |
| 102 | output_storage, |
| 103 | }, |
| 104 | ); |
| 105 | } |
| 106 | |
| 107 | pub(crate) fn define_repeated_flag( |
| 108 | &mut self, |
| 109 | name: impl Into<String>, |
| 110 | help: impl Into<String>, |
| 111 | output_storage: &'a mut Option<Vec<String>>, |
| 112 | ) { |
| 113 | let name = name.into(); |
| 114 | if self.single.contains_key(&name) { |
| 115 | panic!("argument \"{}\" already defined as flag", name) |
| 116 | } |
| 117 | self.repeated.insert( |
| 118 | name.clone(), |
| 119 | FlagDef::<'a, Vec<String>> { |
| 120 | name, |
| 121 | help: help.into(), |
| 122 | output_storage, |
| 123 | }, |
| 124 | ); |
| 125 | } |
| 126 | |
| 127 | fn help(&self, program_name: String) -> String { |
| 128 | let single = self.single.values().map(|fd| fd.to_string()); |
| 129 | let repeated = self.repeated.values().map(|fd| fd.to_string()); |
| 130 | let mut all: Vec<String> = single.chain(repeated).collect(); |
| 131 | all.sort(); |
| 132 | |
| 133 | let mut help_text = String::new(); |
| 134 | writeln!( |
| 135 | &mut help_text, |
| 136 | "Help for {}: [options] -- [extra arguments]", |
| 137 | program_name |
| 138 | ) |
| 139 | .unwrap(); |
| 140 | for line in all { |
| 141 | writeln!(&mut help_text, "\t{}", line).unwrap(); |
| 142 | } |
| 143 | help_text |
| 144 | } |
| 145 | |
| 146 | pub(crate) fn parse(mut self, argv: Vec<String>) -> Result<ParseOutcome, FlagParseError> { |
| 147 | let mut argv_iter = argv.into_iter().peekable(); |
| 148 | let program_name = argv_iter.next().ok_or(FlagParseError::ProgramNameMissing)?; |
| 149 | |
| 150 | // To check if a non-repeated flag has been set already. |
| 151 | let mut seen_single_flags = HashSet::<String>::new(); |
| 152 | |
| 153 | while let Some(flag) = argv_iter.next() { |
| 154 | if flag == "--help" { |
| 155 | return Ok(ParseOutcome::Help(self.help(program_name))); |
| 156 | } |
| 157 | if !flag.starts_with("--") { |
| 158 | return Err(FlagParseError::UnknownFlag(flag)); |
| 159 | } |
| 160 | let mut args = consume_args(&flag, &mut argv_iter); |
| 161 | if flag == "--" { |
| 162 | return Ok(ParseOutcome::Parsed(args)); |
| 163 | } |
| 164 | if args.is_empty() { |
| 165 | return Err(FlagParseError::ValueMissing(flag.clone())); |
| 166 | } |
| 167 | if let Some(flag_def) = self.single.get_mut(&flag) { |
| 168 | if args.len() > 1 || seen_single_flags.contains(&flag) { |
| 169 | return Err(FlagParseError::ProvidedMultipleTimes(flag.clone())); |
| 170 | } |
| 171 | let arg = args.first_mut().unwrap(); |
| 172 | seen_single_flags.insert(flag); |
| 173 | *flag_def.output_storage = Some(take(arg)); |
| 174 | continue; |
| 175 | } |
| 176 | if let Some(flag_def) = self.repeated.get_mut(&flag) { |
| 177 | flag_def |
| 178 | .output_storage |
| 179 | .get_or_insert_with(Vec::new) |
| 180 | .append(&mut args); |
| 181 | continue; |
| 182 | } |
| 183 | return Err(FlagParseError::UnknownFlag(flag)); |
| 184 | } |
| 185 | Ok(ParseOutcome::Parsed(vec![])) |
| 186 | } |
| 187 | } |
| 188 | |
| 189 | fn consume_args<I: Iterator<Item = String>>( |
| 190 | flag: &str, |
| 191 | argv_iter: &mut Peekable<I>, |
| 192 | ) -> Vec<String> { |
| 193 | if flag == "--" { |
| 194 | // If we have found --, the rest of the iterator is just returned as-is. |
| 195 | argv_iter.collect() |
| 196 | } else { |
| 197 | let mut args = vec![]; |
| 198 | while let Some(arg) = argv_iter.next_if(|s| !s.starts_with("--")) { |
| 199 | args.push(arg); |
| 200 | } |
| 201 | args |
| 202 | } |
| 203 | } |
| 204 | |
| 205 | #[cfg(test)] |
| 206 | mod test { |
| 207 | use super::*; |
| 208 | |
| 209 | fn args(args: &[&str]) -> Vec<String> { |
| 210 | ["foo"].iter().chain(args).map(|&s| s.to_owned()).collect() |
| 211 | } |
| 212 | |
| 213 | #[test] |
| 214 | fn test_flag_help() { |
| 215 | let mut bar = None; |
| 216 | let mut parser = Flags::new(); |
| 217 | parser.define_flag("--bar", "bar help", &mut bar); |
| 218 | let result = parser.parse(args(&["--help"])).unwrap(); |
| 219 | if let ParseOutcome::Help(h) = result { |
| 220 | assert!(h.contains("Help for foo")); |
| 221 | assert!(h.contains("--bar\tbar help")); |
| 222 | } else { |
| 223 | panic!("expected that --help would invoke help, instead parsed arguments") |
| 224 | } |
| 225 | } |
| 226 | |
| 227 | #[test] |
| 228 | fn test_flag_single_repeated() { |
| 229 | let mut bar = None; |
| 230 | let mut parser = Flags::new(); |
| 231 | parser.define_flag("--bar", "bar help", &mut bar); |
| 232 | let result = parser.parse(args(&["--bar", "aa", "bb"])); |
| 233 | if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result { |
| 234 | assert_eq!(f, "--bar"); |
| 235 | } else { |
| 236 | panic!("expected error, got {:?}", result) |
| 237 | } |
| 238 | let mut parser = Flags::new(); |
| 239 | parser.define_flag("--bar", "bar help", &mut bar); |
| 240 | let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"])); |
| 241 | if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result { |
| 242 | assert_eq!(f, "--bar"); |
| 243 | } else { |
| 244 | panic!("expected error, got {:?}", result) |
| 245 | } |
| 246 | } |
| 247 | |
| 248 | #[test] |
| 249 | fn test_repeated_flags() { |
| 250 | // Test case 1) --bar something something_else should work as a repeated flag. |
| 251 | let mut bar = None; |
| 252 | let mut parser = Flags::new(); |
| 253 | parser.define_repeated_flag("--bar", "bar help", &mut bar); |
| 254 | let result = parser.parse(args(&["--bar", "aa", "bb"])).unwrap(); |
| 255 | assert!(matches!(result, ParseOutcome::Parsed(_))); |
| 256 | assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()])); |
| 257 | // Test case 2) --bar something --bar something_else should also work as a repeated flag. |
| 258 | bar = None; |
| 259 | let mut parser = Flags::new(); |
| 260 | parser.define_repeated_flag("--bar", "bar help", &mut bar); |
| 261 | let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"])).unwrap(); |
| 262 | assert!(matches!(result, ParseOutcome::Parsed(_))); |
| 263 | assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()])); |
| 264 | } |
| 265 | |
| 266 | #[test] |
| 267 | fn test_extra_args() { |
| 268 | let parser = Flags::new(); |
| 269 | let result = parser.parse(args(&["--", "bb"])).unwrap(); |
| 270 | if let ParseOutcome::Parsed(got) = result { |
| 271 | assert_eq!(got, vec!["bb".to_owned()]) |
| 272 | } else { |
| 273 | panic!("expected correct parsing, got {:?}", result) |
| 274 | } |
| 275 | } |
| 276 | } |