blob: 868bb14ed0c593468621050ce06e6ad62fc020cb [file] [log] [blame]
James Kuszmaul8e62b022022-03-22 09:33:25 -07001/*
2 * Copyright 2021 Google Inc. All rights reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "bfbs_gen_lua.h"
18
19#include <cstdint>
20#include <map>
21#include <memory>
22#include <string>
23#include <unordered_set>
24#include <vector>
25
26// Ensure no includes to flatc internals. bfbs_gen.h and generator.h are OK.
27#include "bfbs_gen.h"
28#include "flatbuffers/bfbs_generator.h"
29
30// The intermediate representation schema.
31#include "flatbuffers/reflection_generated.h"
32
33namespace flatbuffers {
34namespace {
35
36// To reduce typing
37namespace r = ::reflection;
38
39class LuaBfbsGenerator : public BaseBfbsGenerator {
40 public:
41 explicit LuaBfbsGenerator(const std::string &flatc_version)
42 : BaseBfbsGenerator(),
43 keywords_(),
44 requires_(),
45 current_obj_(nullptr),
46 current_enum_(nullptr),
47 flatc_version_(flatc_version) {
48 static const char *const keywords[] = {
49 "and", "break", "do", "else", "elseif", "end", "false", "for",
50 "function", "goto", "if", "in", "local", "nil", "not", "or",
51 "repeat", "return", "then", "true", "until", "while"
52 };
53 keywords_.insert(std::begin(keywords), std::end(keywords));
54 }
55
56 GeneratorStatus GenerateFromSchema(const r::Schema *schema)
57 FLATBUFFERS_OVERRIDE {
58 if (!GenerateEnums(schema->enums())) { return FAILED; }
59 if (!GenerateObjects(schema->objects(), schema->root_table())) {
60 return FAILED;
61 }
62 return OK;
63 }
64
65 uint64_t SupportedAdvancedFeatures() const FLATBUFFERS_OVERRIDE {
66 return 0xF;
67 }
68
69 protected:
70 bool GenerateEnums(
71 const flatbuffers::Vector<flatbuffers::Offset<r::Enum>> *enums) {
72 ForAllEnums(enums, [&](const r::Enum *enum_def) {
73 std::string code;
74
75 StartCodeBlock(enum_def);
76
77 std::string ns;
78 const std::string enum_name =
79 NormalizeName(Denamespace(enum_def->name(), ns));
80
81 GenerateDocumentation(enum_def->documentation(), "", code);
82 code += "local " + enum_name + " = {\n";
83
84 ForAllEnumValues(enum_def, [&](const reflection::EnumVal *enum_val) {
85 GenerateDocumentation(enum_val->documentation(), " ", code);
86 code += " " + NormalizeName(enum_val->name()) + " = " +
87 NumToString(enum_val->value()) + ",\n";
88 });
89 code += "}\n";
90 code += "\n";
91
92 EmitCodeBlock(code, enum_name, ns, enum_def->declaration_file()->str());
93 });
94 return true;
95 }
96
97 bool GenerateObjects(
98 const flatbuffers::Vector<flatbuffers::Offset<r::Object>> *objects,
99 const r::Object *root_object) {
100 ForAllObjects(objects, [&](const r::Object *object) {
101 std::string code;
102
103 StartCodeBlock(object);
104
105 // Register the main flatbuffers module.
106 RegisterRequires("flatbuffers", "flatbuffers");
107
108 std::string ns;
109 const std::string object_name =
110 NormalizeName(Denamespace(object->name(), ns));
111
112 GenerateDocumentation(object->documentation(), "", code);
113
114 code += "local " + object_name + " = {}\n";
115 code += "local mt = {}\n";
116 code += "\n";
117 code += "function " + object_name + ".New()\n";
118 code += " local o = {}\n";
119 code += " setmetatable(o, {__index = mt})\n";
120 code += " return o\n";
121 code += "end\n";
122 code += "\n";
123
124 if (object == root_object) {
125 code += "function " + object_name + ".GetRootAs" + object_name +
126 "(buf, offset)\n";
127 code += " if type(buf) == \"string\" then\n";
128 code += " buf = flatbuffers.binaryArray.New(buf)\n";
129 code += " end\n";
130 code += "\n";
131 code += " local n = flatbuffers.N.UOffsetT:Unpack(buf, offset)\n";
132 code += " local o = " + object_name + ".New()\n";
133 code += " o:Init(buf, n + offset)\n";
134 code += " return o\n";
135 code += "end\n";
136 code += "\n";
137 }
138
139 // Generates a init method that receives a pre-existing accessor object,
140 // so that objects can be reused.
141
142 code += "function mt:Init(buf, pos)\n";
143 code += " self.view = flatbuffers.view.New(buf, pos)\n";
144 code += "end\n";
145 code += "\n";
146
147 // Create all the field accessors.
148 ForAllFields(object, /*reverse=*/false, [&](const r::Field *field) {
149 // Skip writing deprecated fields altogether.
150 if (field->deprecated()) { return; }
151
152 const std::string field_name = NormalizeName(field->name());
153 const std::string field_name_camel_case =
154 ConvertCase(field_name, Case::kUpperCamel);
155 const r::BaseType base_type = field->type()->base_type();
156
157 // Generate some fixed strings so we don't repeat outselves later.
158 const std::string getter_signature =
159 "function mt:" + field_name_camel_case + "()\n";
160 const std::string offset_prefix = "local o = self.view:Offset(" +
161 NumToString(field->offset()) + ")\n";
162 const std::string offset_prefix_2 = "if o ~= 0 then\n";
163
164 GenerateDocumentation(field->documentation(), "", code);
165
166 if (IsScalar(base_type)) {
167 code += getter_signature;
168
169 if (object->is_struct()) {
170 // TODO(derekbailey): it would be nice to modify the view:Get to
171 // just pass in the offset and not have to add it its own
172 // self.view.pos.
173 code += " return " + GenerateGetter(field->type()) +
174 "self.view.pos + " + NumToString(field->offset()) + ")\n";
175 } else {
176 // Table accessors
177 code += " " + offset_prefix;
178 code += " " + offset_prefix_2;
179
180 std::string getter =
181 GenerateGetter(field->type()) + "self.view.pos + o)";
182 if (IsBool(base_type)) { getter = "(" + getter + " ~=0)"; }
183 code += " return " + getter + "\n";
184 code += " end\n";
185 code += " return " + DefaultValue(field) + "\n";
186 }
187 code += "end\n";
188 code += "\n";
189 } else {
190 switch (base_type) {
191 case r::String: {
192 code += getter_signature;
193 code += " " + offset_prefix;
194 code += " " + offset_prefix_2;
195 code += " return " + GenerateGetter(field->type()) +
196 "self.view.pos + o)\n";
197 code += " end\n";
198 code += "end\n";
199 code += "\n";
200 break;
201 }
202 case r::Obj: {
203 if (object->is_struct()) {
204 code += "function mt:" + field_name_camel_case + "(obj)\n";
205 code += " obj:Init(self.view.bytes, self.view.pos + " +
206 NumToString(field->offset()) + ")\n";
207 code += " return obj\n";
208 code += "end\n";
209 code += "\n";
210 } else {
211 code += getter_signature;
212 code += " " + offset_prefix;
213 code += " " + offset_prefix_2;
214
215 const r::Object *field_object = GetObject(field->type());
216 if (!field_object) {
217 // TODO(derekbailey): this is an error condition. we
218 // should report it better.
219 return;
220 }
221 code += " local x = " +
222 std::string(
223 field_object->is_struct()
224 ? "self.view.pos + o\n"
225 : "self.view:Indirect(self.view.pos + o)\n");
226 const std::string require_name = RegisterRequires(field);
227 code += " local obj = " + require_name + ".New()\n";
228 code += " obj:Init(self.view.bytes, x)\n";
229 code += " return obj\n";
230 code += " end\n";
231 code += "end\n";
232 code += "\n";
233 }
234 break;
235 }
236 case r::Union: {
237 code += getter_signature;
238 code += " " + offset_prefix;
239 code += " " + offset_prefix_2;
240 code +=
241 " local obj = "
242 "flatbuffers.view.New(flatbuffers.binaryArray.New("
243 "0), 0)\n";
244 code += " " + GenerateGetter(field->type()) + "obj, o)\n";
245 code += " return obj\n";
246 code += " end\n";
247 code += "end\n";
248 code += "\n";
249 break;
250 }
251 case r::Array:
252 case r::Vector: {
253 const r::BaseType vector_base_type = field->type()->element();
254 int32_t element_size = field->type()->element_size();
255 code += "function mt:" + field_name_camel_case + "(j)\n";
256 code += " " + offset_prefix;
257 code += " " + offset_prefix_2;
258
259 if (IsStructOrTable(vector_base_type)) {
260 code += " local x = self.view:Vector(o)\n";
261 code +=
262 " x = x + ((j-1) * " + NumToString(element_size) + ")\n";
263 if (IsTable(field->type(), /*use_element=*/true)) {
264 code += " x = self.view:Indirect(x)\n";
265 } else {
266 // Vector of structs are inline, so we need to query the
267 // size of the struct.
268 const reflection::Object *obj =
269 GetObjectByIndex(field->type()->index());
270 element_size = obj->bytesize();
271 }
272
273 // Include the referenced type, thus we need to make sure
274 // we set `use_element` to true.
275 const std::string require_name =
276 RegisterRequires(field, /*use_element=*/true);
277 code += " local obj = " + require_name + ".New()\n";
278 code += " obj:Init(self.view.bytes, x)\n";
279 code += " return obj\n";
280 } else {
281 code += " local a = self.view:Vector(o)\n";
282 code += " return " + GenerateGetter(field->type()) +
283 "a + ((j-1) * " + NumToString(element_size) + "))\n";
284 }
285 code += " end\n";
286 // Only generate a default value for those types that are
287 // supported.
288 if (!IsStructOrTable(vector_base_type)) {
289 code +=
290 " return " +
291 std::string(vector_base_type == r::String ? "''\n" : "0\n");
292 }
293 code += "end\n";
294 code += "\n";
295
296 // If the vector is composed of single byte values, we
297 // generate a helper function to get it as a byte string in
298 // Lua.
299 if (IsSingleByte(vector_base_type)) {
300 code += "function mt:" + field_name_camel_case +
301 "AsString(start, stop)\n";
302 code += " return self.view:VectorAsString(" +
303 NumToString(field->offset()) + ", start, stop)\n";
304 code += "end\n";
305 code += "\n";
306 }
307
308 // We also make a new accessor to query just the length of the
309 // vector.
310 code += "function mt:" + field_name_camel_case + "Length()\n";
311 code += " " + offset_prefix;
312 code += " " + offset_prefix_2;
313 code += " return self.view:VectorLen(o)\n";
314 code += " end\n";
315 code += " return 0\n";
316 code += "end\n";
317 code += "\n";
318 break;
319 }
320 default: {
321 return;
322 }
323 }
324 }
325 return;
326 });
327
328 // Create all the builders
329 if (object->is_struct()) {
330 code += "function " + object_name + ".Create" + object_name +
331 "(builder" + GenerateStructBuilderArgs(object) + ")\n";
332 code += AppendStructBuilderBody(object);
333 code += " return builder:Offset()\n";
334 code += "end\n";
335 code += "\n";
336 } else {
337 // Table builders
338 code += "function " + object_name + ".Start(builder)\n";
339 code += " builder:StartObject(" +
340 NumToString(object->fields()->size()) + ")\n";
341 code += "end\n";
342 code += "\n";
343
344 ForAllFields(object, /*reverse=*/false, [&](const r::Field *field) {
345 if (field->deprecated()) { return; }
346
347 const std::string field_name = NormalizeName(field->name());
348
349 code += "function " + object_name + ".Add" +
350 ConvertCase(field_name, Case::kUpperCamel) + "(builder, " +
351 ConvertCase(field_name, Case::kLowerCamel) + ")\n";
352 code += " builder:Prepend" + GenerateMethod(field) + "Slot(" +
353 NumToString(field->id()) + ", " +
354 ConvertCase(field_name, Case::kLowerCamel) + ", " +
355 DefaultValue(field) + ")\n";
356 code += "end\n";
357 code += "\n";
358
359 if (IsVector(field->type()->base_type())) {
360 code += "function " + object_name + ".Start" +
361 ConvertCase(field_name, Case::kUpperCamel) +
362 "Vector(builder, numElems)\n";
363
364 const int32_t element_size = field->type()->element_size();
365 int32_t alignment = 0;
366 if (IsStruct(field->type(), /*use_element=*/true)) {
367 alignment = GetObjectByIndex(field->type()->index())->minalign();
368 } else {
369 alignment = element_size;
370 }
371
372 code += " return builder:StartVector(" +
373 NumToString(element_size) + ", numElems, " +
374 NumToString(alignment) + ")\n";
375 code += "end\n";
376 code += "\n";
377 }
378 });
379
380 code += "function " + object_name + ".End(builder)\n";
381 code += " return builder:EndObject()\n";
382 code += "end\n";
383 code += "\n";
384 }
385
386 EmitCodeBlock(code, object_name, ns, object->declaration_file()->str());
387 });
388 return true;
389 }
390
391 private:
392 void GenerateDocumentation(
393 const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>
394 *documentation,
395 std::string indent, std::string &code) const {
396 flatbuffers::ForAllDocumentation(
397 documentation, [&](const flatbuffers::String *str) {
398 code += indent + "--" + str->str() + "\n";
399 });
400 }
401
402 std::string GenerateStructBuilderArgs(const r::Object *object,
403 std::string prefix = "") const {
404 std::string signature;
405 ForAllFields(object, /*reverse=*/false, [&](const r::Field *field) {
406 if (IsStructOrTable(field->type()->base_type())) {
407 const r::Object *field_object = GetObject(field->type());
408 signature += GenerateStructBuilderArgs(
409 field_object, prefix + NormalizeName(field->name()) + "_");
410 } else {
411 signature +=
412 ", " + prefix +
413 ConvertCase(NormalizeName(field->name()), Case::kLowerCamel);
414 }
415 });
416 return signature;
417 }
418
419 std::string AppendStructBuilderBody(const r::Object *object,
420 std::string prefix = "") const {
421 std::string code;
422 code += " builder:Prep(" + NumToString(object->minalign()) + ", " +
423 NumToString(object->bytesize()) + ")\n";
424
425 // We need to reverse the order we iterate over, since we build the
426 // buffer backwards.
427 ForAllFields(object, /*reverse=*/true, [&](const r::Field *field) {
428 const int32_t num_padding_bytes = field->padding();
429 if (num_padding_bytes) {
430 code += " builder:Pad(" + NumToString(num_padding_bytes) + ")\n";
431 }
432 if (IsStructOrTable(field->type()->base_type())) {
433 const r::Object *field_object = GetObject(field->type());
434 code += AppendStructBuilderBody(
435 field_object, prefix + NormalizeName(field->name()) + "_");
436 } else {
437 code += " builder:Prepend" + GenerateMethod(field) + "(" + prefix +
438 ConvertCase(NormalizeName(field->name()), Case::kLowerCamel) +
439 ")\n";
440 }
441 });
442
443 return code;
444 }
445
446 std::string GenerateMethod(const r::Field *field) const {
447 const r::BaseType base_type = field->type()->base_type();
448 if (IsScalar(base_type)) {
449 return ConvertCase(GenerateType(base_type), Case::kUpperCamel);
450 }
451 if (IsStructOrTable(base_type)) { return "Struct"; }
452 return "UOffsetTRelative";
453 }
454
455 std::string GenerateGetter(const r::Type *type,
456 bool element_type = false) const {
457 switch (element_type ? type->element() : type->base_type()) {
458 case r::String: return "self.view:String(";
459 case r::Union: return "self.view:Union(";
460 case r::Vector: return GenerateGetter(type, true);
461 default:
462 return "self.view:Get(flatbuffers.N." +
463 ConvertCase(GenerateType(type, element_type),
464 Case::kUpperCamel) +
465 ", ";
466 }
467 }
468
469 std::string GenerateType(const r::Type *type,
470 bool element_type = false) const {
471 const r::BaseType base_type =
472 element_type ? type->element() : type->base_type();
473 if (IsScalar(base_type)) { return GenerateType(base_type); }
474 switch (base_type) {
475 case r::String: return "string";
476 case r::Vector: return GenerateGetter(type, true);
477 case r::Obj: {
478 const r::Object *obj = GetObject(type);
479 return NormalizeName(Denamespace(obj->name()));
480 };
481 default: return "*flatbuffers.Table";
482 }
483 }
484
485 std::string GenerateType(const r::BaseType base_type) const {
486 // Need to override the default naming to match the Lua runtime libraries.
487 // TODO(derekbailey): make overloads in the runtime libraries to avoid this.
488 switch (base_type) {
489 case r::None: return "uint8";
490 case r::UType: return "uint8";
491 case r::Byte: return "int8";
492 case r::UByte: return "uint8";
493 case r::Short: return "int16";
494 case r::UShort: return "uint16";
495 case r::Int: return "int32";
496 case r::UInt: return "uint32";
497 case r::Long: return "int64";
498 case r::ULong: return "uint64";
499 case r::Float: return "Float32";
500 case r::Double: return "Float64";
501 default: return r::EnumNameBaseType(base_type);
502 }
503 }
504
505 std::string DefaultValue(const r::Field *field) const {
506 const r::BaseType base_type = field->type()->base_type();
507 if (IsFloatingPoint(base_type)) {
508 return NumToString(field->default_real());
509 }
510 if (IsBool(base_type)) {
511 return field->default_integer() ? "true" : "false";
512 }
513 if (IsScalar(base_type)) { return NumToString((field->default_integer())); }
514 // represents offsets
515 return "0";
516 }
517
518 std::string NormalizeName(const std::string name) const {
519 return keywords_.find(name) == keywords_.end() ? name : "_" + name;
520 }
521
522 std::string NormalizeName(const flatbuffers::String *name) const {
523 return NormalizeName(name->str());
524 }
525
526 void StartCodeBlock(const reflection::Enum *enum_def) {
527 current_enum_ = enum_def;
528 current_obj_ = nullptr;
529 requires_.clear();
530 }
531
532 void StartCodeBlock(const reflection::Object *object) {
533 current_obj_ = object;
534 current_enum_ = nullptr;
535 requires_.clear();
536 }
537
538 std::string RegisterRequires(const r::Field *field,
539 bool use_element = false) {
540 std::string type_name;
541
542 const r::BaseType type =
543 use_element ? field->type()->element() : field->type()->base_type();
544
545 if (IsStructOrTable(type)) {
546 const r::Object *object = GetObjectByIndex(field->type()->index());
547 if (object == current_obj_) { return Denamespace(object->name()); }
548 type_name = object->name()->str();
549 } else {
550 const r::Enum *enum_def = GetEnumByIndex(field->type()->index());
551 if (enum_def == current_enum_) { return Denamespace(enum_def->name()); }
552 type_name = enum_def->name()->str();
553 }
554
555 // Prefix with double __ to avoid name clashing, since these are defined
556 // at the top of the file and have lexical scoping. Replace '.' with '_'
557 // so it can be a legal identifier.
558 std::string name = "__" + type_name;
559 std::replace(name.begin(), name.end(), '.', '_');
560
561 return RegisterRequires(name, type_name);
562 }
563
564 std::string RegisterRequires(const std::string &local_name,
565 const std::string &requires_name) {
566 requires_[local_name] = requires_name;
567 return local_name;
568 }
569
570 void EmitCodeBlock(const std::string &code_block, const std::string &name,
571 const std::string &ns,
572 const std::string &declaring_file) const {
573 const std::string root_type = schema_->root_table()->name()->str();
574 const std::string root_file =
575 schema_->root_table()->declaration_file()->str();
576 const std::string full_qualified_name = ns.empty() ? name : ns + "." + name;
577
578 std::string code = "--[[ " + full_qualified_name + "\n\n";
579 code +=
580 " Automatically generated by the FlatBuffers compiler, do not "
581 "modify.\n";
582 code += " Or modify. I'm a message, not a cop.\n";
583 code += "\n";
584 code += " flatc version: " + flatc_version_ + "\n";
585 code += "\n";
586 code += " Declared by : " + declaring_file + "\n";
587 code += " Rooting type : " + root_type + " (" + root_file + ")\n";
588 code += "\n--]]\n\n";
589
590 if (!requires_.empty()) {
591 for (auto it = requires_.cbegin(); it != requires_.cend(); ++it) {
592 code += "local " + it->first + " = require('" + it->second + "')\n";
593 }
594 code += "\n";
595 }
596
597 code += code_block;
598 code += "return " + name;
599
600 // Namespaces are '.' deliminted, so replace it with the path separator.
601 std::string path = ns;
602
603 if (path.empty()) {
604 path = ".";
605 } else {
606 std::replace(path.begin(), path.end(), '.', '/');
607 }
608
609 // TODO(derekbailey): figure out a save file without depending on util.h
610 EnsureDirExists(path);
611 const std::string file_name = path + "/" + name + ".lua";
612 SaveFile(file_name.c_str(), code, false);
613 }
614
615 std::unordered_set<std::string> keywords_;
616 std::map<std::string, std::string> requires_;
617 const r::Object *current_obj_;
618 const r::Enum *current_enum_;
619 const std::string flatc_version_;
620};
621} // namespace
622
623std::unique_ptr<BfbsGenerator> NewLuaBfbsGenerator(
624 const std::string &flatc_version) {
625 return std::unique_ptr<LuaBfbsGenerator>(new LuaBfbsGenerator(flatc_version));
626}
627
628} // namespace flatbuffers