blob: 2a0b5553a4e8c557ded22ef442fa4beb53cb787d [file] [log] [blame]
Adam Snaider1c095c92023-07-08 02:09:58 -04001//! A process wrapper for running a Protobuf compiler configured for Prost or Tonic output in a Bazel rule.
2
3use std::collections::BTreeMap;
4use std::collections::BTreeSet;
5use std::fmt::{Display, Formatter, Write};
6use std::fs;
7use std::path::Path;
8use std::path::PathBuf;
9use std::process;
10use std::{env, fmt};
11
12use heck::ToSnakeCase;
13use prost::Message;
14use prost_types::{
15 DescriptorProto, EnumDescriptorProto, FileDescriptorProto, FileDescriptorSet,
16 OneofDescriptorProto,
17};
18
19/// Locate prost outputs in the protoc output directory.
20fn find_generated_rust_files(out_dir: &Path) -> BTreeSet<PathBuf> {
21 let mut all_rs_files: BTreeSet<PathBuf> = BTreeSet::new();
22 for entry in fs::read_dir(out_dir).expect("Failed to read directory") {
23 let entry = entry.expect("Failed to read entry");
24 let path = entry.path();
25 if path.is_dir() {
26 for f in find_generated_rust_files(&path) {
27 all_rs_files.insert(f);
28 }
29 } else if let Some(ext) = path.extension() {
30 if ext == "rs" {
31 all_rs_files.insert(path);
32 }
33 } else if let Some(name) = path.file_name() {
34 // The filename is set to `_` when the package name is empty.
35 if name == "_" {
36 let rs_name = path.parent().expect("Failed to get parent").join("_.rs");
37 fs::rename(&path, &rs_name).unwrap_or_else(|err| {
38 panic!("Failed to rename file: {err:?}: {path:?} -> {rs_name:?}")
39 });
40 all_rs_files.insert(rs_name);
41 }
42 }
43 }
44
45 all_rs_files
46}
47
48fn snake_cased_package_name(package: &str) -> String {
49 if package == "_" {
50 return package.to_owned();
51 }
52
53 package
54 .split('.')
55 .map(|s| s.to_snake_case())
56 .collect::<Vec<_>>()
57 .join(".")
58}
59
60/// Rust module definition.
61#[derive(Debug, Default)]
62struct Module {
63 /// The name of the module.
64 name: String,
65
66 /// The contents of the module.
67 contents: String,
68
69 /// The names of any other modules which are submodules of this module.
70 submodules: BTreeSet<String>,
71}
72
73/// Generate a lib.rs file with all prost/tonic outputs embeeded in modules which
74/// mirror the proto packages. For the example proto file we would expect to see
75/// the Rust output that follows it.
76///
77/// ```proto
78/// syntax = "proto3";
79/// package examples.prost.helloworld;
80///
81/// message HelloRequest {
82/// // Request message contains the name to be greeted
83/// string name = 1;
84/// }
85//
86/// message HelloReply {
87/// // Reply contains the greeting message
88/// string message = 1;
89/// }
90/// ```
91///
92/// This is expected to render out to something like the following. Note that
93/// formatting is not applied so indentation may be missing in the actual output.
94///
95/// ```ignore
96/// pub mod examples {
97/// pub mod prost {
98/// pub mod helloworld {
99/// // @generated
100/// #[allow(clippy::derive_partial_eq_without_eq)]
101/// #[derive(Clone, PartialEq, ::prost::Message)]
102/// pub struct HelloRequest {
103/// /// Request message contains the name to be greeted
104/// #[prost(string, tag = "1")]
105/// pub name: ::prost::alloc::string::String,
106/// }
107/// #[allow(clippy::derive_partial_eq_without_eq)]
108/// #[derive(Clone, PartialEq, ::prost::Message)]
109/// pub struct HelloReply {
110/// /// Reply contains the greeting message
111/// #[prost(string, tag = "1")]
112/// pub message: ::prost::alloc::string::String,
113/// }
114/// // @protoc_insertion_point(module)
115/// }
116/// }
117/// }
118/// ```
119fn generate_lib_rs(prost_outputs: &BTreeSet<PathBuf>, is_tonic: bool) -> String {
120 let mut module_info = BTreeMap::new();
121
122 for path in prost_outputs.iter() {
123 let mut package = path
124 .file_stem()
125 .expect("Failed to get file stem")
126 .to_str()
127 .expect("Failed to convert to str")
128 .to_string();
129
130 if is_tonic {
131 package = package
132 .strip_suffix(".tonic")
133 .expect("Failed to strip suffix")
134 .to_string()
135 };
136
137 if package.is_empty() {
138 continue;
139 }
140
141 let name = if package == "_" {
142 package.clone()
143 } else if package.contains('.') {
144 package
145 .rsplit_once('.')
146 .expect("Failed to split on '.'")
147 .1
148 .to_snake_case()
149 .to_string()
150 } else {
151 package.to_snake_case()
152 };
153
154 // Avoid a stack overflow by skipping a known bad package name
155 let module_name = snake_cased_package_name(&package);
156
157 module_info.insert(
158 module_name.clone(),
159 Module {
160 name,
161 contents: fs::read_to_string(path).expect("Failed to read file"),
162 submodules: BTreeSet::new(),
163 },
164 );
165
166 let module_parts = module_name.split('.').collect::<Vec<_>>();
167 for parent_module_index in 0..module_parts.len() {
168 let child_module_index = parent_module_index + 1;
169 if child_module_index >= module_parts.len() {
170 break;
171 }
172 let full_parent_module_name = module_parts[0..parent_module_index + 1].join(".");
173 let parent_module_name = module_parts[parent_module_index];
174 let child_module_name = module_parts[child_module_index];
175
176 module_info
177 .entry(full_parent_module_name.clone())
178 .and_modify(|parent_module| {
179 parent_module
180 .submodules
181 .insert(child_module_name.to_string());
182 })
183 .or_insert(Module {
184 name: parent_module_name.to_string(),
185 contents: "".to_string(),
186 submodules: [child_module_name.to_string()].iter().cloned().collect(),
187 });
188 }
189 }
190
191 let mut content = "// @generated\n\n".to_string();
192 write_module(&mut content, &module_info, "", 0);
193 content
194}
195
196/// Write out a rust module and all of its submodules.
197fn write_module(
198 content: &mut String,
199 module_info: &BTreeMap<String, Module>,
200 module_name: &str,
201 depth: usize,
202) {
203 if module_name.is_empty() {
204 for submodule_name in module_info.keys() {
205 write_module(content, module_info, submodule_name, depth + 1);
206 }
207 return;
208 }
209 let module = module_info.get(module_name).expect("Failed to get module");
210 let indent = " ".repeat(depth);
211 let is_rust_module = module.name != "_";
212
213 if is_rust_module {
214 let rust_module_name = escape_keyword(module.name.clone());
215 content
216 .write_str(&format!("{}pub mod {} {{\n", indent, rust_module_name))
217 .expect("Failed to write string");
218 }
219
220 content
221 .write_str(&module.contents)
222 .expect("Failed to write string");
223
224 for submodule_name in module.submodules.iter() {
225 write_module(
226 content,
227 module_info,
228 [module_name, submodule_name].join(".").as_str(),
229 depth + 1,
230 );
231 }
232
233 if is_rust_module {
234 content
235 .write_str(&format!("{}}}\n", indent))
236 .expect("Failed to write string");
237 }
238}
239
240/// ProtoPath is a path to a proto message, enum, or oneof.
241///
242/// Example: `helloworld.Greeter.HelloRequest`
243#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
244struct ProtoPath(String);
245
246impl ProtoPath {
247 /// Join a component to the end of the path.
248 fn join(&self, component: &str) -> ProtoPath {
249 if self.0.is_empty() {
250 return ProtoPath(component.to_string());
251 }
252 if component.is_empty() {
253 return self.clone();
254 }
255
256 ProtoPath(format!("{}.{}", self.0, component))
257 }
258}
259
260impl Display for ProtoPath {
261 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
262 write!(f, "{}", self.0)
263 }
264}
265
266impl From<&str> for ProtoPath {
267 fn from(path: &str) -> Self {
268 ProtoPath(path.to_string())
269 }
270}
271
272/// RustModulePath is a path to a rust module.
273///
274/// Example: `helloworld::greeter::HelloRequest`
275#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
276struct RustModulePath(String);
277
278impl RustModulePath {
279 /// Join a path to the end of the module path.
280 fn join(&self, path: &str) -> RustModulePath {
281 if self.0.is_empty() {
282 return RustModulePath(path.to_string());
283 }
284 if path.is_empty() {
285 return self.clone();
286 }
287
288 RustModulePath(format!("{}::{}", self.0, path))
289 }
290}
291
292impl Display for RustModulePath {
293 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
294 write!(f, "{}", self.0)
295 }
296}
297
298impl From<&str> for RustModulePath {
299 fn from(path: &str) -> Self {
300 RustModulePath(path.to_string())
301 }
302}
303
304/// Compute the `--extern_path` flags for a list of proto files. This is
305/// expected to convert proto files into a BTreeMap of
306/// `example.prost.helloworld`: `crate_name::example::prost::helloworld`.
307fn get_extern_paths(
308 descriptor_set: &FileDescriptorSet,
309 crate_name: &str,
310) -> Result<BTreeMap<ProtoPath, RustModulePath>, String> {
311 let mut extern_paths = BTreeMap::new();
312 let rust_path = RustModulePath(crate_name.to_string());
313
314 for file in descriptor_set.file.iter() {
315 descriptor_set_file_to_extern_paths(&mut extern_paths, &rust_path, file);
316 }
317
318 Ok(extern_paths)
319}
320
321/// Add the extern_path pairs for a file descriptor type.
322fn descriptor_set_file_to_extern_paths(
323 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
324 rust_path: &RustModulePath,
325 file: &FileDescriptorProto,
326) {
327 let package = file.package.clone().unwrap_or_default();
328 let rust_path = rust_path.join(&snake_cased_package_name(&package).replace('.', "::"));
329 let proto_path = ProtoPath(package);
330
331 for message_type in file.message_type.iter() {
332 message_type_to_extern_paths(extern_paths, &proto_path, &rust_path, message_type);
333 }
334
335 for enum_type in file.enum_type.iter() {
336 enum_type_to_extern_paths(extern_paths, &proto_path, &rust_path, enum_type);
337 }
338}
339
340/// Add the extern_path pairs for a message descriptor type.
341fn message_type_to_extern_paths(
342 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
343 proto_path: &ProtoPath,
344 rust_path: &RustModulePath,
345 message_type: &DescriptorProto,
346) {
347 let message_type_name = message_type
348 .name
349 .as_ref()
350 .expect("Failed to get message type name");
351
352 extern_paths.insert(
353 proto_path.join(message_type_name),
354 rust_path.join(message_type_name),
355 );
356
357 let name_lower = message_type_name.to_lowercase();
358 let proto_path = proto_path.join(&name_lower);
359 let rust_path = rust_path.join(&name_lower);
360
361 for nested_type in message_type.nested_type.iter() {
362 message_type_to_extern_paths(extern_paths, &proto_path, &rust_path, nested_type)
363 }
364
365 for enum_type in message_type.enum_type.iter() {
366 enum_type_to_extern_paths(extern_paths, &proto_path, &rust_path, enum_type);
367 }
368
369 for oneof_type in message_type.oneof_decl.iter() {
370 oneof_type_to_extern_paths(extern_paths, &proto_path, &rust_path, oneof_type);
371 }
372}
373
374/// Add the extern_path pairs for an enum type.
375fn enum_type_to_extern_paths(
376 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
377 proto_path: &ProtoPath,
378 rust_path: &RustModulePath,
379 enum_type: &EnumDescriptorProto,
380) {
381 let enum_type_name = enum_type
382 .name
383 .as_ref()
384 .expect("Failed to get enum type name");
385 extern_paths.insert(
386 proto_path.join(enum_type_name),
387 rust_path.join(enum_type_name),
388 );
389}
390
391fn oneof_type_to_extern_paths(
392 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
393 proto_path: &ProtoPath,
394 rust_path: &RustModulePath,
395 oneof_type: &OneofDescriptorProto,
396) {
397 let oneof_type_name = oneof_type
398 .name
399 .as_ref()
400 .expect("Failed to get oneof type name");
401 extern_paths.insert(
402 proto_path.join(oneof_type_name),
403 rust_path.join(oneof_type_name),
404 );
405}
406
407/// The parsed command-line arguments.
408struct Args {
409 /// The path to the protoc binary.
410 protoc: PathBuf,
411
412 /// The path to the output directory.
413 out_dir: PathBuf,
414
415 /// The name of the crate.
416 crate_name: String,
417
418 /// The bazel label.
419 label: String,
420
421 /// The path to the package info file.
422 package_info_file: PathBuf,
423
424 /// The proto files to compile.
425 proto_files: Vec<PathBuf>,
426
427 /// The include directories.
428 includes: Vec<String>,
429
430 /// Dependency descriptor sets.
431 descriptor_set: PathBuf,
432
433 /// The path to the generated lib.rs file.
434 out_librs: PathBuf,
435
436 /// The proto include paths.
437 proto_paths: Vec<String>,
438
439 /// The path to the rustfmt binary.
440 rustfmt: Option<PathBuf>,
441
442 /// Whether to generate tonic code.
443 is_tonic: bool,
444
445 /// Extra arguments to pass to protoc.
446 extra_args: Vec<String>,
447}
448
449impl Args {
450 /// Parse the command-line arguments.
451 fn parse() -> Result<Args, String> {
452 let mut protoc: Option<PathBuf> = None;
453 let mut out_dir: Option<PathBuf> = None;
454 let mut crate_name: Option<String> = None;
455 let mut package_info_file: Option<PathBuf> = None;
456 let mut proto_files: Vec<PathBuf> = Vec::new();
457 let mut includes = Vec::new();
458 let mut descriptor_set = None;
459 let mut out_librs: Option<PathBuf> = None;
460 let mut rustfmt: Option<PathBuf> = None;
461 let mut proto_paths = Vec::new();
462 let mut label: Option<String> = None;
463 let mut tonic_or_prost_opts = Vec::new();
464 let mut is_tonic = false;
465
466 let mut extra_args = Vec::new();
467
468 // Iterate over the given command line arguments parsing out arguments
469 // for the process runner and arguments for protoc and potentially spawn
470 // additional arguments needed by prost.
471 for arg in env::args().skip(1) {
472 if !arg.starts_with('-') {
473 proto_files.push(PathBuf::from(arg));
474 continue;
475 }
476
477 if arg.starts_with("-I") {
478 includes.push(
479 arg.strip_prefix("-I")
480 .expect("Failed to strip -I")
481 .to_string(),
482 );
483 continue;
484 }
485
486 if arg == "--is_tonic" {
487 is_tonic = true;
488 continue;
489 }
490
491 if !arg.contains('=') {
492 extra_args.push(arg);
493 continue;
494 }
495
496 let parts = arg.split_once('=').expect("Failed to split argument on =");
497 match parts {
498 ("--protoc", value) => {
499 protoc = Some(PathBuf::from(value));
500 }
501 ("--prost_out", value) => {
502 out_dir = Some(PathBuf::from(value));
503 }
504 ("--package_info_output", value) => {
505 let (key, value) = value
506 .split_once('=')
507 .map(|(a, b)| (a.to_string(), PathBuf::from(b)))
508 .expect("Failed to parse package info output");
509 crate_name = Some(key);
510 package_info_file = Some(value);
511 }
512 ("--deps_info", value) => {
513 for line in fs::read_to_string(value)
514 .expect("Failed to read file")
515 .lines()
516 {
517 let path = PathBuf::from(line.trim());
518 for flag in fs::read_to_string(path)
519 .expect("Failed to read file")
520 .lines()
521 {
522 tonic_or_prost_opts.push(format!("extern_path={}", flag.trim()));
523 }
524 }
525 }
526 ("--descriptor_set", value) => {
527 descriptor_set = Some(PathBuf::from(value));
528 }
529 ("--out_librs", value) => {
530 out_librs = Some(PathBuf::from(value));
531 }
532 ("--rustfmt", value) => {
533 rustfmt = Some(PathBuf::from(value));
534 }
535 ("--proto_path", value) => {
536 proto_paths.push(value.to_string());
537 }
538 ("--label", value) => {
539 label = Some(value.to_string());
540 }
541 (arg, value) => {
542 extra_args.push(format!("{}={}", arg, value));
543 }
544 }
545 }
546
547 for tonic_or_prost_opt in tonic_or_prost_opts {
548 extra_args.push(format!("--prost_opt={}", tonic_or_prost_opt));
549 if is_tonic {
550 extra_args.push(format!("--tonic_opt={}", tonic_or_prost_opt));
551 }
552 }
553
554 if protoc.is_none() {
555 return Err(
556 "No `--protoc` value was found. Unable to parse path to proto compiler."
557 .to_string(),
558 );
559 }
560 if out_dir.is_none() {
561 return Err(
562 "No `--prost_out` value was found. Unable to parse output directory.".to_string(),
563 );
564 }
565 if crate_name.is_none() {
566 return Err(
567 "No `--package_info_output` value was found. Unable to parse target crate name."
568 .to_string(),
569 );
570 }
571 if package_info_file.is_none() {
572 return Err("No `--package_info_output` value was found. Unable to parse package info output file.".to_string());
573 }
574 if out_librs.is_none() {
575 return Err("No `--out_librs` value was found. Unable to parse the output location for all combined prost outputs.".to_string());
576 }
577 if descriptor_set.is_none() {
578 return Err(
579 "No `--descriptor_set` value was found. Unable to parse descriptor set path."
580 .to_string(),
581 );
582 }
583 if label.is_none() {
584 return Err(
585 "No `--label` value was found. Unable to parse the label of the target crate."
586 .to_string(),
587 );
588 }
589
590 Ok(Args {
591 protoc: protoc.unwrap(),
592 out_dir: out_dir.unwrap(),
593 crate_name: crate_name.unwrap(),
594 package_info_file: package_info_file.unwrap(),
595 proto_files,
596 includes,
597 descriptor_set: descriptor_set.unwrap(),
598 out_librs: out_librs.unwrap(),
599 rustfmt,
600 proto_paths,
601 is_tonic,
602 label: label.unwrap(),
603 extra_args,
604 })
605 }
606}
607
608/// Get the output directory with the label suffixed.
609fn get_output_dir(out_dir: &Path, label: &str) -> PathBuf {
610 let label_as_path = label
611 .replace('@', "")
612 .replace("//", "_")
613 .replace(['/', ':'], "_");
614 PathBuf::from(format!(
615 "{}/prost-build-{}",
616 out_dir.display(),
617 label_as_path
618 ))
619}
620
621/// Get the output directory with the label suffixed, and create it if it doesn't exist.
622///
623/// This will remove the directory first if it already exists.
624fn get_and_create_output_dir(out_dir: &Path, label: &str) -> PathBuf {
625 let out_dir = get_output_dir(out_dir, label);
626 if out_dir.exists() {
627 fs::remove_dir_all(&out_dir).expect("Failed to remove old output directory");
628 }
629 fs::create_dir_all(&out_dir).expect("Failed to create output directory");
630 out_dir
631}
632
633/// Parse the descriptor set file into a `FileDescriptorSet`.
634fn parse_descriptor_set_file(descriptor_set_path: &PathBuf) -> FileDescriptorSet {
635 let descriptor_set_bytes =
636 fs::read(descriptor_set_path).expect("Failed to read descriptor set");
637 let descriptor_set = FileDescriptorSet::decode(descriptor_set_bytes.as_slice())
638 .expect("Failed to decode descriptor set");
639
640 descriptor_set
641}
642
643/// Get the package name from the descriptor set.
644fn get_package_name(descriptor_set: &FileDescriptorSet) -> Option<String> {
645 let mut package_name = None;
646
647 for file in &descriptor_set.file {
648 if let Some(package) = &file.package {
649 package_name = Some(package.clone());
650 break;
651 }
652 }
653
654 package_name
655}
656
657/// Whether the proto file should expect to generate a .rs file.
658///
659/// If the proto file contains any messages, enums, or services, then it should generate a rust file.
660/// If the proto file only contains extensions, then it will not generate any rust files.
661fn expect_fs_file_to_be_generated(descriptor_set: &FileDescriptorSet) -> bool {
662 let mut expect_rs = false;
663
664 for file in descriptor_set.file.iter() {
665 let has_messages = !file.message_type.is_empty();
666 let has_enums = !file.enum_type.is_empty();
667 let has_services = !file.service.is_empty();
668 let has_extensions = !file.extension.is_empty();
669
670 let has_definition = has_messages || has_enums || has_services;
671
672 if has_definition {
673 return true;
674 } else if !has_definition && !has_extensions {
675 expect_rs = true;
676 }
677 }
678
679 expect_rs
680}
681
682/// Whether the proto file should expect to generate service definitions.
683fn has_services(descriptor_set: &FileDescriptorSet) -> bool {
684 descriptor_set
685 .file
686 .iter()
687 .any(|file| !file.service.is_empty())
688}
689
690fn main() {
691 // Always enable backtraces for the protoc wrapper.
692 env::set_var("RUST_BACKTRACE", "1");
693
694 let Args {
695 protoc,
696 out_dir,
697 crate_name,
698 label,
699 package_info_file,
700 proto_files,
701 includes,
702 descriptor_set,
703 out_librs,
704 rustfmt,
705 proto_paths,
706 is_tonic,
707 extra_args,
708 } = Args::parse().expect("Failed to parse args");
709
710 let out_dir = get_and_create_output_dir(&out_dir, &label);
711
712 let descriptor_set = parse_descriptor_set_file(&descriptor_set);
713 let package_name = get_package_name(&descriptor_set).unwrap_or_default();
714 let expect_rs = expect_fs_file_to_be_generated(&descriptor_set);
715 let has_services = has_services(&descriptor_set);
716
717 if has_services && !is_tonic {
718 println!("Warning: Service definitions will not be generated because the prost toolchain did not define a tonic plugin.");
719 }
720
721 let mut cmd = process::Command::new(&protoc);
722 cmd.arg(format!("--prost_out={}", out_dir.display()));
723 if is_tonic {
724 cmd.arg(format!("--tonic_out={}", out_dir.display()));
725 }
726 cmd.args(extra_args);
727 cmd.args(
728 proto_paths
729 .iter()
730 .map(|proto_path| format!("--proto_path={}", proto_path)),
731 );
732 cmd.args(includes.iter().map(|include| format!("-I{}", include)));
733 cmd.args(&proto_files);
734
735 let status = cmd.status().expect("Failed to spawn protoc process");
736 if !status.success() {
737 panic!(
738 "protoc failed with status: {}",
739 status.code().expect("failed to get exit code")
740 );
741 }
742
743 // Not all proto files will consistently produce `.rs` or `.tonic.rs` files. This is
744 // caused by the proto file being transpiled not having an RPC service or other protos
745 // defined (a natural and expected situation). To guarantee consistent outputs, all
746 // `.rs` files are either renamed to `.tonic.rs` if there is no `.tonic.rs` or prepended
747 // to the existing `.tonic.rs`.
748 if is_tonic {
749 let tonic_files: BTreeSet<PathBuf> = find_generated_rust_files(&out_dir);
750
751 for tonic_file in tonic_files.iter() {
752 let tonic_path_str = tonic_file.to_str().expect("Failed to convert to str");
753 let filename = tonic_file
754 .file_name()
755 .expect("Failed to get file name")
756 .to_str()
757 .expect("Failed to convert to str");
758
759 let is_tonic_file = filename.ends_with(".tonic.rs");
760
761 if is_tonic_file {
762 let rs_file_str = format!(
763 "{}.rs",
764 tonic_path_str
765 .strip_suffix(".tonic.rs")
766 .expect("Failed to strip suffix.")
767 );
768 let rs_file = PathBuf::from(&rs_file_str);
769
770 if rs_file.exists() {
771 let rs_content = fs::read_to_string(&rs_file).expect("Failed to read file.");
772 let tonic_content =
773 fs::read_to_string(tonic_file).expect("Failed to read file.");
774 fs::write(tonic_file, format!("{}\n{}", rs_content, tonic_content))
775 .expect("Failed to write file.");
776 fs::remove_file(&rs_file).unwrap_or_else(|err| {
777 panic!("Failed to remove file: {err:?}: {rs_file:?}")
778 });
779 }
780 } else {
781 let real_tonic_file = PathBuf::from(format!(
782 "{}.tonic.rs",
783 tonic_path_str
784 .strip_suffix(".rs")
785 .expect("Failed to strip suffix.")
786 ));
787 if real_tonic_file.exists() {
788 continue;
789 }
790 fs::rename(tonic_file, &real_tonic_file).unwrap_or_else(|err| {
791 panic!("Failed to rename file: {err:?}: {tonic_file:?} -> {real_tonic_file:?}");
792 });
793 }
794 }
795 }
796
797 // Locate all prost-generated outputs.
798 let mut rust_files = find_generated_rust_files(&out_dir);
799 if rust_files.is_empty() {
800 if expect_rs {
801 panic!("No .rs files were generated by prost.");
802 } else {
803 let file_stem = if package_name.is_empty() {
804 "_"
805 } else {
806 &package_name
807 };
808 let file_stem = format!("{}{}", file_stem, if is_tonic { ".tonic" } else { "" });
809 let empty_rs_file = out_dir.join(format!("{}.rs", file_stem));
810 fs::write(&empty_rs_file, "").expect("Failed to write file.");
811 rust_files.insert(empty_rs_file);
812 }
813 }
814
815 let extern_paths = get_extern_paths(&descriptor_set, &crate_name)
816 .expect("Failed to compute proto package info");
817
818 // Write outputs
819 fs::write(&out_librs, generate_lib_rs(&rust_files, is_tonic)).expect("Failed to write file.");
820 fs::write(
821 package_info_file,
822 extern_paths
823 .into_iter()
824 .map(|(proto_path, rust_path)| format!(".{}=::{}", proto_path, rust_path))
825 .collect::<Vec<_>>()
826 .join("\n"),
827 )
828 .expect("Failed to write file.");
829
830 // Finally run rustfmt on the output lib.rs file
831 if let Some(rustfmt) = rustfmt {
832 let fmt_status = process::Command::new(rustfmt)
833 .arg("--edition")
834 .arg("2021")
835 .arg("--quiet")
836 .arg(&out_librs)
837 .status()
838 .expect("Failed to spawn rustfmt process");
839 if !fmt_status.success() {
840 panic!(
841 "rustfmt failed with exit code: {}",
842 fmt_status.code().expect("Failed to get exit code")
843 );
844 }
845 }
846}
847
848/// Rust built-in keywords and reserved keywords.
849const RUST_KEYWORDS: [&str; 51] = [
850 "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate",
851 "do", "dyn", "else", "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in",
852 "let", "loop", "macro", "match", "mod", "move", "mut", "override", "priv", "pub", "ref",
853 "return", "self", "Self", "static", "struct", "super", "trait", "true", "try", "type",
854 "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
855];
856
857/// Returns true if the given string is a Rust keyword.
858fn is_keyword(s: &str) -> bool {
859 RUST_KEYWORDS.contains(&s)
860}
861
862/// Escapes a Rust keyword by prefixing it with `r#`.
863fn escape_keyword(s: String) -> String {
864 if is_keyword(&s) {
865 return format!("r#{s}");
866 }
867 s
868}
869
870#[cfg(test)]
871mod test {
872
873 use super::*;
874
875 use prost_types::{FieldDescriptorProto, FileDescriptorProto, ServiceDescriptorProto};
876 use std::collections::BTreeMap;
877
878 #[test]
879 fn oneof_type_to_extern_paths_test() {
880 let oneof_descriptor = OneofDescriptorProto {
881 name: Some("Foo".to_string()),
882 ..OneofDescriptorProto::default()
883 };
884
885 {
886 let mut extern_paths = BTreeMap::new();
887 oneof_type_to_extern_paths(
888 &mut extern_paths,
889 &ProtoPath::from("bar"),
890 &RustModulePath::from("bar"),
891 &oneof_descriptor,
892 );
893
894 assert_eq!(extern_paths.len(), 1);
895 assert_eq!(
896 extern_paths.get(&ProtoPath::from("bar.Foo")),
897 Some(&RustModulePath::from("bar::Foo"))
898 );
899 }
900
901 {
902 let mut extern_paths = BTreeMap::new();
903 oneof_type_to_extern_paths(
904 &mut extern_paths,
905 &ProtoPath::from("bar.baz"),
906 &RustModulePath::from("bar::baz"),
907 &oneof_descriptor,
908 );
909
910 assert_eq!(extern_paths.len(), 1);
911 assert_eq!(
912 extern_paths.get(&ProtoPath::from("bar.baz.Foo")),
913 Some(&RustModulePath::from("bar::baz::Foo"))
914 );
915 }
916 }
917
918 #[test]
919 fn enum_type_to_extern_paths_test() {
920 let enum_descriptor = EnumDescriptorProto {
921 name: Some("Foo".to_string()),
922 ..EnumDescriptorProto::default()
923 };
924
925 {
926 let mut extern_paths = BTreeMap::new();
927 enum_type_to_extern_paths(
928 &mut extern_paths,
929 &ProtoPath::from("bar"),
930 &RustModulePath::from("bar"),
931 &enum_descriptor,
932 );
933
934 assert_eq!(extern_paths.len(), 1);
935 assert_eq!(
936 extern_paths.get(&ProtoPath::from("bar.Foo")),
937 Some(&RustModulePath::from("bar::Foo"))
938 );
939 }
940
941 {
942 let mut extern_paths = BTreeMap::new();
943 enum_type_to_extern_paths(
944 &mut extern_paths,
945 &ProtoPath::from("bar.baz"),
946 &RustModulePath::from("bar::baz"),
947 &enum_descriptor,
948 );
949
950 assert_eq!(extern_paths.len(), 1);
951 assert_eq!(
952 extern_paths.get(&ProtoPath::from("bar.baz.Foo")),
953 Some(&RustModulePath::from("bar::baz::Foo"))
954 );
955 }
956 }
957
958 #[test]
959 fn message_type_to_extern_paths_test() {
960 let message_descriptor = DescriptorProto {
961 name: Some("Foo".to_string()),
962 nested_type: vec![
963 DescriptorProto {
964 name: Some("Bar".to_string()),
965 ..DescriptorProto::default()
966 },
967 DescriptorProto {
968 name: Some("Nested".to_string()),
969 nested_type: vec![DescriptorProto {
970 name: Some("Baz".to_string()),
971 enum_type: vec![EnumDescriptorProto {
972 name: Some("Chuck".to_string()),
973 ..EnumDescriptorProto::default()
974 }],
975 ..DescriptorProto::default()
976 }],
977 ..DescriptorProto::default()
978 },
979 ],
980 enum_type: vec![EnumDescriptorProto {
981 name: Some("Qux".to_string()),
982 ..EnumDescriptorProto::default()
983 }],
984 ..DescriptorProto::default()
985 };
986
987 {
988 let mut extern_paths = BTreeMap::new();
989 message_type_to_extern_paths(
990 &mut extern_paths,
991 &ProtoPath::from("bar"),
992 &RustModulePath::from("bar"),
993 &message_descriptor,
994 );
995 assert_eq!(extern_paths.len(), 6);
996 assert_eq!(
997 extern_paths.get(&ProtoPath::from("bar.Foo")),
998 Some(&RustModulePath::from("bar::Foo"))
999 );
1000 assert_eq!(
1001 extern_paths.get(&ProtoPath::from("bar.foo.Bar")),
1002 Some(&RustModulePath::from("bar::foo::Bar"))
1003 );
1004 assert_eq!(
1005 extern_paths.get(&ProtoPath::from("bar.foo.Nested")),
1006 Some(&RustModulePath::from("bar::foo::Nested"))
1007 );
1008 assert_eq!(
1009 extern_paths.get(&ProtoPath::from("bar.foo.nested.Baz")),
1010 Some(&RustModulePath::from("bar::foo::nested::Baz"))
1011 );
1012 }
1013
1014 {
1015 let mut extern_paths = BTreeMap::new();
1016 message_type_to_extern_paths(
1017 &mut extern_paths,
1018 &ProtoPath::from("bar.bob"),
1019 &RustModulePath::from("bar::bob"),
1020 &message_descriptor,
1021 );
1022 assert_eq!(extern_paths.len(), 6);
1023 assert_eq!(
1024 extern_paths.get(&ProtoPath::from("bar.bob.Foo")),
1025 Some(&RustModulePath::from("bar::bob::Foo"))
1026 );
1027 assert_eq!(
1028 extern_paths.get(&ProtoPath::from("bar.bob.foo.Bar")),
1029 Some(&RustModulePath::from("bar::bob::foo::Bar"))
1030 );
1031 assert_eq!(
1032 extern_paths.get(&ProtoPath::from("bar.bob.foo.Nested")),
1033 Some(&RustModulePath::from("bar::bob::foo::Nested"))
1034 );
1035 assert_eq!(
1036 extern_paths.get(&ProtoPath::from("bar.bob.foo.nested.Baz")),
1037 Some(&RustModulePath::from("bar::bob::foo::nested::Baz"))
1038 );
1039 }
1040 }
1041
1042 #[test]
1043 fn proto_path_test() {
1044 {
1045 let proto_path = ProtoPath::from("");
1046 assert_eq!(proto_path.to_string(), "");
1047 assert_eq!(proto_path.join("foo"), ProtoPath::from("foo"));
1048 }
1049 {
1050 let proto_path = ProtoPath::from("foo");
1051 assert_eq!(proto_path.to_string(), "foo");
1052 assert_eq!(proto_path.join(""), ProtoPath::from("foo"));
1053 }
1054 {
1055 let proto_path = ProtoPath::from("foo");
1056 assert_eq!(proto_path.to_string(), "foo");
1057 assert_eq!(proto_path.join("bar"), ProtoPath::from("foo.bar"));
1058 }
1059 {
1060 let proto_path = ProtoPath::from("foo.bar");
1061 assert_eq!(proto_path.to_string(), "foo.bar");
1062 assert_eq!(proto_path.join("baz"), ProtoPath::from("foo.bar.baz"));
1063 }
1064 {
1065 let proto_path = ProtoPath::from("Foo.baR");
1066 assert_eq!(proto_path.to_string(), "Foo.baR");
1067 assert_eq!(proto_path.join("baz"), ProtoPath::from("Foo.baR.baz"));
1068 }
1069 }
1070
1071 #[test]
1072 fn rust_module_path_test() {
1073 {
1074 let rust_module_path = RustModulePath::from("");
1075 assert_eq!(rust_module_path.to_string(), "");
1076 assert_eq!(rust_module_path.join("foo"), RustModulePath::from("foo"));
1077 }
1078 {
1079 let rust_module_path = RustModulePath::from("foo");
1080 assert_eq!(rust_module_path.to_string(), "foo");
1081 assert_eq!(rust_module_path.join(""), RustModulePath::from("foo"));
1082 }
1083 {
1084 let rust_module_path = RustModulePath::from("foo");
1085 assert_eq!(rust_module_path.to_string(), "foo");
1086 assert_eq!(
1087 rust_module_path.join("bar"),
1088 RustModulePath::from("foo::bar")
1089 );
1090 }
1091 {
1092 let rust_module_path = RustModulePath::from("foo::bar");
1093 assert_eq!(rust_module_path.to_string(), "foo::bar");
1094 assert_eq!(
1095 rust_module_path.join("baz"),
1096 RustModulePath::from("foo::bar::baz")
1097 );
1098 }
1099 }
1100
1101 #[test]
1102 fn expect_fs_file_to_be_generated_test() {
1103 {
1104 // Empty descriptor set should create a file.
1105 let descriptor_set = FileDescriptorSet {
1106 file: vec![FileDescriptorProto {
1107 name: Some("foo.proto".to_string()),
1108 ..FileDescriptorProto::default()
1109 }],
1110 };
1111 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1112 }
1113 {
1114 // Descriptor set with only message should create a file.
1115 let descriptor_set = FileDescriptorSet {
1116 file: vec![FileDescriptorProto {
1117 name: Some("foo.proto".to_string()),
1118 message_type: vec![DescriptorProto {
1119 name: Some("Foo".to_string()),
1120 ..DescriptorProto::default()
1121 }],
1122 ..FileDescriptorProto::default()
1123 }],
1124 };
1125 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1126 }
1127 {
1128 // Descriptor set with only enum should create a file.
1129 let descriptor_set = FileDescriptorSet {
1130 file: vec![FileDescriptorProto {
1131 name: Some("foo.proto".to_string()),
1132 enum_type: vec![EnumDescriptorProto {
1133 name: Some("Foo".to_string()),
1134 ..EnumDescriptorProto::default()
1135 }],
1136 ..FileDescriptorProto::default()
1137 }],
1138 };
1139 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1140 }
1141 {
1142 // Descriptor set with only service should create a file.
1143 let descriptor_set = FileDescriptorSet {
1144 file: vec![FileDescriptorProto {
1145 name: Some("foo.proto".to_string()),
1146 service: vec![ServiceDescriptorProto {
1147 name: Some("Foo".to_string()),
1148 ..ServiceDescriptorProto::default()
1149 }],
1150 ..FileDescriptorProto::default()
1151 }],
1152 };
1153 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1154 }
1155 {
1156 // Descriptor set with only extensions should not create a file.
1157 let descriptor_set = FileDescriptorSet {
1158 file: vec![FileDescriptorProto {
1159 name: Some("foo.proto".to_string()),
1160 extension: vec![FieldDescriptorProto {
1161 name: Some("Foo".to_string()),
1162 ..FieldDescriptorProto::default()
1163 }],
1164 ..FileDescriptorProto::default()
1165 }],
1166 };
1167 assert!(!expect_fs_file_to_be_generated(&descriptor_set));
1168 }
1169 }
1170
1171 #[test]
1172 fn has_services_test() {
1173 {
1174 // Empty file should not have services.
1175 let descriptor_set = FileDescriptorSet {
1176 file: vec![FileDescriptorProto {
1177 name: Some("foo.proto".to_string()),
1178 ..FileDescriptorProto::default()
1179 }],
1180 };
1181 assert!(!has_services(&descriptor_set));
1182 }
1183 {
1184 // File with only message should not have services.
1185 let descriptor_set = FileDescriptorSet {
1186 file: vec![FileDescriptorProto {
1187 name: Some("foo.proto".to_string()),
1188 message_type: vec![DescriptorProto {
1189 name: Some("Foo".to_string()),
1190 ..DescriptorProto::default()
1191 }],
1192 ..FileDescriptorProto::default()
1193 }],
1194 };
1195 assert!(!has_services(&descriptor_set));
1196 }
1197 {
1198 // File with services should have services.
1199 let descriptor_set = FileDescriptorSet {
1200 file: vec![FileDescriptorProto {
1201 name: Some("foo.proto".to_string()),
1202 service: vec![ServiceDescriptorProto {
1203 name: Some("Foo".to_string()),
1204 ..ServiceDescriptorProto::default()
1205 }],
1206 ..FileDescriptorProto::default()
1207 }],
1208 };
1209 assert!(has_services(&descriptor_set));
1210 }
1211 }
1212
1213 #[test]
1214 fn get_package_name_test() {
1215 let descriptor_set = FileDescriptorSet {
1216 file: vec![FileDescriptorProto {
1217 name: Some("foo.proto".to_string()),
1218 package: Some("foo".to_string()),
1219 ..FileDescriptorProto::default()
1220 }],
1221 };
1222
1223 assert_eq!(get_package_name(&descriptor_set), Some("foo".to_string()));
1224 }
1225
1226 #[test]
1227 fn is_keyword_test() {
1228 let non_keywords = [
1229 "foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred",
1230 "plugh", "xyzzy", "thud",
1231 ];
1232 for non_keyword in &non_keywords {
1233 assert!(!is_keyword(non_keyword));
1234 }
1235
1236 for keyword in &RUST_KEYWORDS {
1237 assert!(is_keyword(keyword));
1238 }
1239 }
1240
1241 #[test]
1242 fn escape_keyword_test() {
1243 let non_keywords = [
1244 "foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred",
1245 "plugh", "xyzzy", "thud",
1246 ];
1247 for non_keyword in &non_keywords {
1248 assert_eq!(
1249 escape_keyword(non_keyword.to_string()),
1250 non_keyword.to_owned()
1251 );
1252 }
1253
1254 for keyword in &RUST_KEYWORDS {
1255 assert_eq!(
1256 escape_keyword(keyword.to_string()),
1257 format!("r#{}", keyword)
1258 );
1259 }
1260 }
1261}