Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 1 | /* |
| 2 | * Copyright 2020 Google Inc. All rights reserved. |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | /* |
| 18 | * NOTE: The following implementation is a translation for the Swift-grpc |
| 19 | * generator since flatbuffers doesnt allow plugins for now. if an issue arises |
| 20 | * please open an issue in the flatbuffers repository. This file should always |
| 21 | * be maintained according to the Swift-grpc repository |
| 22 | */ |
| 23 | #include <map> |
| 24 | #include <sstream> |
| 25 | |
| 26 | #include "flatbuffers/util.h" |
| 27 | #include "src/compiler/schema_interface.h" |
| 28 | #include "src/compiler/swift_generator.h" |
| 29 | |
| 30 | namespace grpc_swift_generator { |
| 31 | |
| 32 | std::string WrapInNameSpace(const std::vector<std::string> &components, const grpc::string &name) { |
| 33 | std::string qualified_name; |
| 34 | for (auto it = components.begin(); it != components.end(); ++it) |
| 35 | qualified_name += *it + "_"; |
| 36 | return qualified_name + name; |
| 37 | } |
| 38 | |
| 39 | grpc::string GenerateMessage(const std::vector<std::string> &components, const grpc::string &name) { |
| 40 | return "Message<" + WrapInNameSpace(components, name) + ">"; |
| 41 | } |
| 42 | |
| 43 | // MARK: - Client |
| 44 | |
| 45 | grpc::string GenerateClientFuncName(const grpc_generator::Method *method) { |
| 46 | if (method->NoStreaming()) { |
| 47 | return "$GenAccess$ func $MethodName$(_ request: $Input$" |
| 48 | ", callOptions: CallOptions?$isNil$) -> UnaryCall<$Input$,$Output$>"; |
| 49 | } |
| 50 | |
| 51 | if (method->ClientStreaming()) { |
| 52 | return "$GenAccess$ func $MethodName$" |
| 53 | "(callOptions: CallOptions?$isNil$) -> " |
| 54 | "ClientStreamingCall<$Input$,$Output$>"; |
| 55 | } |
| 56 | |
| 57 | if (method->ServerStreaming()) { |
| 58 | return "$GenAccess$ func $MethodName$(_ request: $Input$" |
| 59 | ", callOptions: CallOptions?$isNil$, handler: @escaping ($Output$" |
| 60 | ") -> Void) -> ServerStreamingCall<$Input$, $Output$>"; |
| 61 | } |
| 62 | return "$GenAccess$ func $MethodName$" |
| 63 | "(callOptions: CallOptions?$isNil$, handler: @escaping ($Output$" |
| 64 | ") -> Void) -> BidirectionalStreamingCall<$Input$, $Output$>"; |
| 65 | } |
| 66 | |
| 67 | grpc::string GenerateClientFuncBody(const grpc_generator::Method *method) { |
| 68 | if (method->NoStreaming()) { |
| 69 | return "return self.makeUnaryCall(path: " |
| 70 | "\"/$PATH$$ServiceName$/$MethodName$\", request: request, " |
| 71 | "callOptions: callOptions ?? self.defaultCallOptions)"; |
| 72 | } |
| 73 | |
| 74 | if (method->ClientStreaming()) { |
| 75 | return "return self.makeClientStreamingCall(path: " |
| 76 | "\"/$PATH$$ServiceName$/$MethodName$\", callOptions: callOptions ?? " |
| 77 | "self.defaultCallOptions)"; |
| 78 | } |
| 79 | |
| 80 | if (method->ServerStreaming()) { |
| 81 | return "return self.makeServerStreamingCall(path: " |
| 82 | "\"/$PATH$$ServiceName$/$MethodName$\", request: request, " |
| 83 | "callOptions: callOptions ?? self.defaultCallOptions, handler: " |
| 84 | "handler)"; |
| 85 | } |
| 86 | return "return self.makeBidirectionalStreamingCall(path: " |
| 87 | "\"/$PATH$$ServiceName$/$MethodName$\", callOptions: callOptions ?? " |
| 88 | "self.defaultCallOptions, handler: handler)"; |
| 89 | } |
| 90 | |
| 91 | void GenerateClientProtocol(const grpc_generator::Service *service, |
| 92 | grpc_generator::Printer *printer, |
| 93 | std::map<grpc::string, grpc::string> *dictonary) { |
| 94 | auto vars = *dictonary; |
| 95 | printer->Print(vars, "$ACCESS$ protocol $ServiceQualifiedName$Service {\n"); |
| 96 | vars["GenAccess"] = ""; |
| 97 | for (auto it = 0; it < service->method_count(); it++) { |
| 98 | auto method = service->method(it); |
| 99 | vars["Input"] = GenerateMessage(method->get_input_namespace_parts(), method->get_input_type_name()); |
| 100 | vars["Output"] = GenerateMessage(method->get_output_namespace_parts(), method->get_output_type_name()); |
| 101 | vars["MethodName"] = method->name(); |
| 102 | vars["isNil"] = ""; |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 103 | printer->Print(" "); |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 104 | auto func = GenerateClientFuncName(method.get()); |
| 105 | printer->Print(vars, func.c_str()); |
| 106 | printer->Print("\n"); |
| 107 | } |
| 108 | printer->Print("}\n\n"); |
| 109 | } |
| 110 | |
| 111 | void GenerateClientClass(const grpc_generator::Service *service, |
| 112 | grpc_generator::Printer *printer, |
| 113 | std::map<grpc::string, grpc::string> *dictonary) { |
| 114 | auto vars = *dictonary; |
| 115 | printer->Print(vars, |
| 116 | "$ACCESS$ final class $ServiceQualifiedName$ServiceClient: GRPCClient, " |
| 117 | "$ServiceQualifiedName$Service {\n"); |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 118 | printer->Print(vars, " $ACCESS$ let channel: GRPCChannel\n"); |
| 119 | printer->Print(vars, " $ACCESS$ var defaultCallOptions: CallOptions\n"); |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 120 | printer->Print("\n"); |
| 121 | printer->Print(vars, |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 122 | " $ACCESS$ init(channel: GRPCChannel, " |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 123 | "defaultCallOptions: CallOptions = CallOptions()) {\n"); |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 124 | printer->Print(" self.channel = channel\n"); |
| 125 | printer->Print(" self.defaultCallOptions = defaultCallOptions\n"); |
| 126 | printer->Print(" }"); |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 127 | printer->Print("\n"); |
| 128 | vars["GenAccess"] = service->is_internal() ? "internal" : "public"; |
| 129 | for (auto it = 0; it < service->method_count(); it++) { |
| 130 | auto method = service->method(it); |
| 131 | vars["Input"] = GenerateMessage(method->get_input_namespace_parts(), method->get_input_type_name()); |
| 132 | vars["Output"] = GenerateMessage(method->get_output_namespace_parts(), method->get_output_type_name()); |
| 133 | vars["MethodName"] = method->name(); |
| 134 | vars["isNil"] = " = nil"; |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 135 | printer->Print("\n "); |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 136 | auto func = GenerateClientFuncName(method.get()); |
| 137 | printer->Print(vars, func.c_str()); |
| 138 | printer->Print(" {\n"); |
| 139 | auto body = GenerateClientFuncBody(method.get()); |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 140 | printer->Print(" "); |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 141 | printer->Print(vars, body.c_str()); |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 142 | printer->Print("\n }\n"); |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 143 | } |
| 144 | printer->Print("}\n"); |
| 145 | } |
| 146 | |
| 147 | // MARK: - Server |
| 148 | |
| 149 | grpc::string GenerateServerFuncName(const grpc_generator::Method *method) { |
| 150 | if (method->NoStreaming()) { |
| 151 | return "func $MethodName$(_ request: $Input$" |
| 152 | ", context: StatusOnlyCallContext) -> EventLoopFuture<$Output$>"; |
| 153 | } |
| 154 | |
| 155 | if (method->ClientStreaming()) { |
| 156 | return "func $MethodName$(context: UnaryResponseCallContext<$Output$>) -> " |
| 157 | "EventLoopFuture<(StreamEvent<$Input$" |
| 158 | ">) -> Void>"; |
| 159 | } |
| 160 | |
| 161 | if (method->ServerStreaming()) { |
| 162 | return "func $MethodName$(request: $Input$" |
| 163 | ", context: StreamingResponseCallContext<$Output$>) -> " |
| 164 | "EventLoopFuture<GRPCStatus>"; |
| 165 | } |
| 166 | return "func $MethodName$(context: StreamingResponseCallContext<$Output$>) " |
| 167 | "-> EventLoopFuture<(StreamEvent<$Input$>) -> Void>"; |
| 168 | } |
| 169 | |
| 170 | grpc::string GenerateServerExtensionBody(const grpc_generator::Method *method) { |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 171 | grpc::string start = " case \"$MethodName$\":\n "; |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 172 | if (method->NoStreaming()) { |
| 173 | return start + |
| 174 | "return CallHandlerFactory.makeUnary(callHandlerContext: callHandlerContext) { " |
| 175 | "context in" |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 176 | "\n " |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 177 | "return { request in" |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 178 | "\n " |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 179 | "self.$MethodName$(request, context: context)" |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 180 | "\n }" |
| 181 | "\n }"; |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 182 | } |
| 183 | if (method->ClientStreaming()) { |
| 184 | return start + |
| 185 | "return CallHandlerFactory.makeClientStreaming(callHandlerContext: " |
| 186 | "callHandlerContext) { context in" |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 187 | "\n " |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 188 | "self.$MethodName$(context: context)" |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 189 | "\n }"; |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 190 | } |
| 191 | if (method->ServerStreaming()) { |
| 192 | return start + |
| 193 | "return CallHandlerFactory.makeServerStreaming(callHandlerContext: " |
| 194 | "callHandlerContext) { context in" |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 195 | "\n " |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 196 | "return { request in" |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 197 | "\n " |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 198 | "self.$MethodName$(request: request, context: context)" |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 199 | "\n }" |
| 200 | "\n }"; |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 201 | } |
| 202 | if (method->BidiStreaming()) { |
| 203 | return start + |
| 204 | "return CallHandlerFactory.makeBidirectionalStreaming(callHandlerContext: " |
| 205 | "callHandlerContext) { context in" |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 206 | "\n " |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 207 | "self.$MethodName$(context: context)" |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 208 | "\n }"; |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 209 | } |
| 210 | return ""; |
| 211 | } |
| 212 | |
| 213 | void GenerateServerProtocol(const grpc_generator::Service *service, |
| 214 | grpc_generator::Printer *printer, |
| 215 | std::map<grpc::string, grpc::string> *dictonary) { |
| 216 | auto vars = *dictonary; |
| 217 | printer->Print( |
| 218 | vars, "$ACCESS$ protocol $ServiceQualifiedName$Provider: CallHandlerProvider {\n"); |
| 219 | for (auto it = 0; it < service->method_count(); it++) { |
| 220 | auto method = service->method(it); |
| 221 | vars["Input"] = GenerateMessage(method->get_input_namespace_parts(), method->get_input_type_name()); |
| 222 | vars["Output"] = GenerateMessage(method->get_output_namespace_parts(), method->get_output_type_name()); |
| 223 | vars["MethodName"] = method->name(); |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 224 | printer->Print(" "); |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 225 | auto func = GenerateServerFuncName(method.get()); |
| 226 | printer->Print(vars, func.c_str()); |
| 227 | printer->Print("\n"); |
| 228 | } |
| 229 | printer->Print("}\n\n"); |
| 230 | |
| 231 | printer->Print(vars, "$ACCESS$ extension $ServiceQualifiedName$Provider {\n"); |
| 232 | printer->Print("\n"); |
| 233 | printer->Print(vars, |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 234 | " var serviceName: Substring { return " |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 235 | "\"$PATH$$ServiceName$\" }\n"); |
| 236 | printer->Print("\n"); |
| 237 | printer->Print( |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 238 | " func handleMethod(_ methodName: Substring, callHandlerContext: " |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 239 | "CallHandlerContext) -> GRPCCallHandler? {\n"); |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 240 | printer->Print(" switch methodName {\n"); |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 241 | for (auto it = 0; it < service->method_count(); it++) { |
| 242 | auto method = service->method(it); |
| 243 | vars["Input"] = GenerateMessage(method->get_input_namespace_parts(), method->get_input_type_name()); |
| 244 | vars["Output"] = GenerateMessage(method->get_output_namespace_parts(), method->get_output_type_name()); |
| 245 | vars["MethodName"] = method->name(); |
| 246 | auto body = GenerateServerExtensionBody(method.get()); |
| 247 | printer->Print(vars, body.c_str()); |
| 248 | printer->Print("\n"); |
| 249 | } |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 250 | printer->Print(" default: return nil;\n"); |
| 251 | printer->Print(" }\n"); |
| 252 | printer->Print(" }\n\n"); |
| 253 | printer->Print("}"); |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 254 | } |
| 255 | |
| 256 | grpc::string Generate(grpc_generator::File *file, |
| 257 | const grpc_generator::Service *service) { |
| 258 | grpc::string output; |
| 259 | std::map<grpc::string, grpc::string> vars; |
| 260 | vars["PATH"] = file->package(); |
| 261 | if (!file->package().empty()) { vars["PATH"].append("."); } |
| 262 | vars["ServiceQualifiedName"] = WrapInNameSpace(service->namespace_parts(), service->name()); |
| 263 | vars["ServiceName"] = service->name(); |
| 264 | vars["ACCESS"] = service->is_internal() ? "internal" : "public"; |
| 265 | auto printer = file->CreatePrinter(&output); |
| 266 | printer->Print(vars, |
| 267 | "/// Usage: instantiate $ServiceQualifiedName$ServiceClient, then call " |
| 268 | "methods of this protocol to make API calls.\n"); |
| 269 | GenerateClientProtocol(service, &*printer, &vars); |
| 270 | GenerateClientClass(service, &*printer, &vars); |
| 271 | printer->Print("\n"); |
| 272 | GenerateServerProtocol(service, &*printer, &vars); |
| 273 | return output; |
| 274 | } |
| 275 | |
| 276 | grpc::string GenerateHeader() { |
| 277 | grpc::string code; |
| 278 | code += |
| 279 | "/// The following code is generated by the Flatbuffers library which " |
| 280 | "might not be in sync with grpc-swift\n"; |
| 281 | code += |
| 282 | "/// in case of an issue please open github issue, though it would be " |
| 283 | "maintained\n"; |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 284 | code += "\n"; |
| 285 | code += "// swiftlint:disable all\n"; |
| 286 | code += "// swiftformat:disable all\n"; |
| 287 | code += "\n"; |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 288 | code += "import Foundation\n"; |
| 289 | code += "import GRPC\n"; |
| 290 | code += "import NIO\n"; |
| 291 | code += "import NIOHTTP1\n"; |
| 292 | code += "import FlatBuffers\n"; |
| 293 | code += "\n"; |
| 294 | code += |
| 295 | "public protocol GRPCFlatBufPayload: GRPCPayload, FlatBufferGRPCMessage " |
| 296 | "{}\n"; |
| 297 | |
| 298 | code += "public extension GRPCFlatBufPayload {\n"; |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 299 | code += " init(serializedByteBuffer: inout NIO.ByteBuffer) throws {\n"; |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 300 | code += |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 301 | " self.init(byteBuffer: FlatBuffers.ByteBuffer(contiguousBytes: " |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 302 | "serializedByteBuffer.readableBytesView, count: " |
| 303 | "serializedByteBuffer.readableBytes))\n"; |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 304 | code += " }\n"; |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 305 | |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 306 | code += " func serialize(into buffer: inout NIO.ByteBuffer) throws {\n"; |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 307 | code += |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 308 | " let buf = UnsafeRawBufferPointer(start: self.rawPointer, count: " |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 309 | "Int(self.size))\n"; |
Austin Schuh | 58b9b47 | 2020-11-25 19:12:44 -0800 | [diff] [blame^] | 310 | code += " buffer.writeBytes(buf)\n"; |
| 311 | code += " }\n"; |
Austin Schuh | 272c613 | 2020-11-14 16:37:52 -0800 | [diff] [blame] | 312 | code += "}\n"; |
| 313 | code += "extension Message: GRPCFlatBufPayload {}\n"; |
| 314 | return code; |
| 315 | } |
| 316 | } // namespace grpc_swift_generator |