xref: /aosp_15_r20/external/flatbuffers/src/bfbs_gen_lua.cpp (revision 890232f25432b36107d06881e0a25aaa6b473652)
1 /*
2  * Copyright 2021 Google Inc. All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "bfbs_gen_lua.h"
18 
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_set>
24 #include <vector>
25 
26 // Ensure no includes to flatc internals. bfbs_gen.h and generator.h are OK.
27 #include "bfbs_gen.h"
28 #include "bfbs_namer.h"
29 #include "flatbuffers/bfbs_generator.h"
30 
31 // The intermediate representation schema.
32 #include "flatbuffers/reflection.h"
33 #include "flatbuffers/reflection_generated.h"
34 
35 namespace flatbuffers {
36 namespace {
37 
38 // To reduce typing
39 namespace r = ::reflection;
40 
LuaKeywords()41 std::set<std::string> LuaKeywords() {
42   return { "and",   "break", "do",       "else", "elseif", "end",
43            "false", "for",   "function", "goto", "if",     "in",
44            "local", "nil",   "not",      "or",   "repeat", "return",
45            "then",  "true",  "until",    "while" };
46 }
47 
LuaDefaultConfig()48 Namer::Config LuaDefaultConfig() {
49   return { /*types=*/Case::kUpperCamel,
50            /*constants=*/Case::kUnknown,
51            /*methods=*/Case::kUpperCamel,
52            /*functions=*/Case::kUpperCamel,
53            /*fields=*/Case::kUpperCamel,
54            /*variables=*/Case::kLowerCamel,
55            /*variants=*/Case::kKeep,
56            /*enum_variant_seperator=*/"",
57            /*escape_keywords=*/Namer::Config::Escape::AfterConvertingCase,
58            /*namespaces=*/Case::kKeep,
59            /*namespace_seperator=*/"__",
60            /*object_prefix=*/"",
61            /*object_suffix=*/"",
62            /*keyword_prefix=*/"",
63            /*keyword_suffix=*/"_",
64            /*filenames=*/Case::kKeep,
65            /*directories=*/Case::kKeep,
66            /*output_path=*/"",
67            /*filename_suffix=*/"",
68            /*filename_extension=*/".lua" };
69 }
70 
71 class LuaBfbsGenerator : public BaseBfbsGenerator {
72  public:
LuaBfbsGenerator(const std::string & flatc_version)73   explicit LuaBfbsGenerator(const std::string &flatc_version)
74       : BaseBfbsGenerator(),
75         keywords_(),
76         requires_(),
77         current_obj_(nullptr),
78         current_enum_(nullptr),
79         flatc_version_(flatc_version),
80         namer_(LuaDefaultConfig(), LuaKeywords()) {}
81 
GenerateFromSchema(const r::Schema * schema)82   GeneratorStatus GenerateFromSchema(const r::Schema *schema)
83       FLATBUFFERS_OVERRIDE {
84     if (!GenerateEnums(schema->enums())) { return FAILED; }
85     if (!GenerateObjects(schema->objects(), schema->root_table())) {
86       return FAILED;
87     }
88     return OK;
89   }
90 
SupportedAdvancedFeatures() const91   uint64_t SupportedAdvancedFeatures() const FLATBUFFERS_OVERRIDE {
92     return 0xF;
93   }
94 
95  protected:
GenerateEnums(const flatbuffers::Vector<flatbuffers::Offset<r::Enum>> * enums)96   bool GenerateEnums(
97       const flatbuffers::Vector<flatbuffers::Offset<r::Enum>> *enums) {
98     ForAllEnums(enums, [&](const r::Enum *enum_def) {
99       std::string code;
100 
101       StartCodeBlock(enum_def);
102 
103       std::string ns;
104       const std::string enum_name =
105           namer_.Type(namer_.Denamespace(enum_def, ns));
106 
107       GenerateDocumentation(enum_def->documentation(), "", code);
108       code += "local " + enum_name + " = {\n";
109 
110       ForAllEnumValues(enum_def, [&](const reflection::EnumVal *enum_val) {
111         GenerateDocumentation(enum_val->documentation(), "  ", code);
112         code += "  " + namer_.Variant(enum_val->name()->str()) + " = " +
113                 NumToString(enum_val->value()) + ",\n";
114       });
115       code += "}\n";
116       code += "\n";
117 
118       EmitCodeBlock(code, enum_name, ns, enum_def->declaration_file()->str());
119     });
120     return true;
121   }
122 
GenerateObjects(const flatbuffers::Vector<flatbuffers::Offset<r::Object>> * objects,const r::Object * root_object)123   bool GenerateObjects(
124       const flatbuffers::Vector<flatbuffers::Offset<r::Object>> *objects,
125       const r::Object *root_object) {
126     ForAllObjects(objects, [&](const r::Object *object) {
127       std::string code;
128 
129       StartCodeBlock(object);
130 
131       // Register the main flatbuffers module.
132       RegisterRequires("flatbuffers", "flatbuffers");
133 
134       std::string ns;
135       const std::string object_name =
136           namer_.Type(namer_.Denamespace(object, ns));
137 
138       GenerateDocumentation(object->documentation(), "", code);
139 
140       code += "local " + object_name + " = {}\n";
141       code += "local mt = {}\n";
142       code += "\n";
143       code += "function " + object_name + ".New()\n";
144       code += "  local o = {}\n";
145       code += "  setmetatable(o, {__index = mt})\n";
146       code += "  return o\n";
147       code += "end\n";
148       code += "\n";
149 
150       if (object == root_object) {
151         code += "function " + object_name + ".GetRootAs" + object_name +
152                 "(buf, offset)\n";
153         code += "  if type(buf) == \"string\" then\n";
154         code += "    buf = flatbuffers.binaryArray.New(buf)\n";
155         code += "  end\n";
156         code += "\n";
157         code += "  local n = flatbuffers.N.UOffsetT:Unpack(buf, offset)\n";
158         code += "  local o = " + object_name + ".New()\n";
159         code += "  o:Init(buf, n + offset)\n";
160         code += "  return o\n";
161         code += "end\n";
162         code += "\n";
163       }
164 
165       // Generates a init method that receives a pre-existing accessor object,
166       // so that objects can be reused.
167 
168       code += "function mt:Init(buf, pos)\n";
169       code += "  self.view = flatbuffers.view.New(buf, pos)\n";
170       code += "end\n";
171       code += "\n";
172 
173       // Create all the field accessors.
174       ForAllFields(object, /*reverse=*/false, [&](const r::Field *field) {
175         // Skip writing deprecated fields altogether.
176         if (field->deprecated()) { return; }
177 
178         const std::string field_name = namer_.Field(field->name()->str());
179         const r::BaseType base_type = field->type()->base_type();
180 
181         // Generate some fixed strings so we don't repeat outselves later.
182         const std::string getter_signature =
183             "function mt:" + field_name + "()\n";
184         const std::string offset_prefix = "local o = self.view:Offset(" +
185                                           NumToString(field->offset()) + ")\n";
186         const std::string offset_prefix_2 = "if o ~= 0 then\n";
187 
188         GenerateDocumentation(field->documentation(), "", code);
189 
190         if (IsScalar(base_type)) {
191           code += getter_signature;
192 
193           if (object->is_struct()) {
194             // TODO(derekbailey): it would be nice to modify the view:Get to
195             // just pass in the offset and not have to add it its own
196             // self.view.pos.
197             code += "  return " + GenerateGetter(field->type()) +
198                     "self.view.pos + " + NumToString(field->offset()) + ")\n";
199           } else {
200             // Table accessors
201             code += "  " + offset_prefix;
202             code += "  " + offset_prefix_2;
203 
204             std::string getter =
205                 GenerateGetter(field->type()) + "self.view.pos + o)";
206             if (IsBool(base_type)) { getter = "(" + getter + " ~=0)"; }
207             code += "    return " + getter + "\n";
208             code += "  end\n";
209             code += "  return " + DefaultValue(field) + "\n";
210           }
211           code += "end\n";
212           code += "\n";
213         } else {
214           switch (base_type) {
215             case r::String: {
216               code += getter_signature;
217               code += "  " + offset_prefix;
218               code += "  " + offset_prefix_2;
219               code += "    return " + GenerateGetter(field->type()) +
220                       "self.view.pos + o)\n";
221               code += "  end\n";
222               code += "end\n";
223               code += "\n";
224               break;
225             }
226             case r::Obj: {
227               if (object->is_struct()) {
228                 code += "function mt:" + field_name + "(obj)\n";
229                 code += "  obj:Init(self.view.bytes, self.view.pos + " +
230                         NumToString(field->offset()) + ")\n";
231                 code += "  return obj\n";
232                 code += "end\n";
233                 code += "\n";
234               } else {
235                 code += getter_signature;
236                 code += "  " + offset_prefix;
237                 code += "  " + offset_prefix_2;
238 
239                 const r::Object *field_object = GetObject(field->type());
240                 if (!field_object) {
241                   // TODO(derekbailey): this is an error condition. we
242                   // should report it better.
243                   return;
244                 }
245                 code += "    local x = " +
246                         std::string(
247                             field_object->is_struct()
248                                 ? "self.view.pos + o\n"
249                                 : "self.view:Indirect(self.view.pos + o)\n");
250                 const std::string require_name = RegisterRequires(field);
251                 code += "    local obj = " + require_name + ".New()\n";
252                 code += "    obj:Init(self.view.bytes, x)\n";
253                 code += "    return obj\n";
254                 code += "  end\n";
255                 code += "end\n";
256                 code += "\n";
257               }
258               break;
259             }
260             case r::Union: {
261               code += getter_signature;
262               code += "  " + offset_prefix;
263               code += "  " + offset_prefix_2;
264               code +=
265                   "   local obj = "
266                   "flatbuffers.view.New(flatbuffers.binaryArray.New("
267                   "0), 0)\n";
268               code += "    " + GenerateGetter(field->type()) + "obj, o)\n";
269               code += "    return obj\n";
270               code += "  end\n";
271               code += "end\n";
272               code += "\n";
273               break;
274             }
275             case r::Array:
276             case r::Vector: {
277               const r::BaseType vector_base_type = field->type()->element();
278               int32_t element_size = field->type()->element_size();
279               code += "function mt:" + field_name + "(j)\n";
280               code += "  " + offset_prefix;
281               code += "  " + offset_prefix_2;
282 
283               if (IsStructOrTable(vector_base_type)) {
284                 code += "    local x = self.view:Vector(o)\n";
285                 code +=
286                     "    x = x + ((j-1) * " + NumToString(element_size) + ")\n";
287                 if (IsTable(field->type(), /*use_element=*/true)) {
288                   code += "    x = self.view:Indirect(x)\n";
289                 } else {
290                   // Vector of structs are inline, so we need to query the
291                   // size of the struct.
292                   const reflection::Object *obj =
293                       GetObjectByIndex(field->type()->index());
294                   element_size = obj->bytesize();
295                 }
296 
297                 // Include the referenced type, thus we need to make sure
298                 // we set `use_element` to true.
299                 const std::string require_name =
300                     RegisterRequires(field, /*use_element=*/true);
301                 code += "    local obj = " + require_name + ".New()\n";
302                 code += "    obj:Init(self.view.bytes, x)\n";
303                 code += "    return obj\n";
304               } else {
305                 code += "    local a = self.view:Vector(o)\n";
306                 code += "    return " + GenerateGetter(field->type()) +
307                         "a + ((j-1) * " + NumToString(element_size) + "))\n";
308               }
309               code += "  end\n";
310               // Only generate a default value for those types that are
311               // supported.
312               if (!IsStructOrTable(vector_base_type)) {
313                 code +=
314                     "  return " +
315                     std::string(vector_base_type == r::String ? "''\n" : "0\n");
316               }
317               code += "end\n";
318               code += "\n";
319 
320               // If the vector is composed of single byte values, we
321               // generate a helper function to get it as a byte string in
322               // Lua.
323               if (IsSingleByte(vector_base_type)) {
324                 code += "function mt:" + field_name + "AsString(start, stop)\n";
325                 code += "  return self.view:VectorAsString(" +
326                         NumToString(field->offset()) + ", start, stop)\n";
327                 code += "end\n";
328                 code += "\n";
329               }
330 
331               // We also make a new accessor to query just the length of the
332               // vector.
333               code += "function mt:" + field_name + "Length()\n";
334               code += "  " + offset_prefix;
335               code += "  " + offset_prefix_2;
336               code += "    return self.view:VectorLen(o)\n";
337               code += "  end\n";
338               code += "  return 0\n";
339               code += "end\n";
340               code += "\n";
341               break;
342             }
343             default: {
344               return;
345             }
346           }
347         }
348         return;
349       });
350 
351       // Create all the builders
352       if (object->is_struct()) {
353         code += "function " + object_name + ".Create" + object_name +
354                 "(builder" + GenerateStructBuilderArgs(object) + ")\n";
355         code += AppendStructBuilderBody(object);
356         code += "  return builder:Offset()\n";
357         code += "end\n";
358         code += "\n";
359       } else {
360         // Table builders
361         code += "function " + object_name + ".Start(builder)\n";
362         code += "  builder:StartObject(" +
363                 NumToString(object->fields()->size()) + ")\n";
364         code += "end\n";
365         code += "\n";
366 
367         ForAllFields(object, /*reverse=*/false, [&](const r::Field *field) {
368           if (field->deprecated()) { return; }
369 
370           const std::string field_name = namer_.Field(field->name()->str());
371           const std::string variable_name =
372               namer_.Variable(field->name()->str());
373 
374           code += "function " + object_name + ".Add" + field_name +
375                   "(builder, " + variable_name + ")\n";
376           code += "  builder:Prepend" + GenerateMethod(field) + "Slot(" +
377                   NumToString(field->id()) + ", " + variable_name + ", " +
378                   DefaultValue(field) + ")\n";
379           code += "end\n";
380           code += "\n";
381 
382           if (IsVector(field->type()->base_type())) {
383             code += "function " + object_name + ".Start" + field_name +
384                     "Vector(builder, numElems)\n";
385 
386             const int32_t element_size = field->type()->element_size();
387             int32_t alignment = 0;
388             if (IsStruct(field->type(), /*use_element=*/true)) {
389               alignment = GetObjectByIndex(field->type()->index())->minalign();
390             } else {
391               alignment = element_size;
392             }
393 
394             code += "  return builder:StartVector(" +
395                     NumToString(element_size) + ", numElems, " +
396                     NumToString(alignment) + ")\n";
397             code += "end\n";
398             code += "\n";
399           }
400         });
401 
402         code += "function " + object_name + ".End(builder)\n";
403         code += "  return builder:EndObject()\n";
404         code += "end\n";
405         code += "\n";
406       }
407 
408       EmitCodeBlock(code, object_name, ns, object->declaration_file()->str());
409     });
410     return true;
411   }
412 
413  private:
GenerateDocumentation(const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> * documentation,std::string indent,std::string & code) const414   void GenerateDocumentation(
415       const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>
416           *documentation,
417       std::string indent, std::string &code) const {
418     flatbuffers::ForAllDocumentation(
419         documentation, [&](const flatbuffers::String *str) {
420           code += indent + "--" + str->str() + "\n";
421         });
422   }
423 
GenerateStructBuilderArgs(const r::Object * object,std::string prefix="") const424   std::string GenerateStructBuilderArgs(const r::Object *object,
425                                         std::string prefix = "") const {
426     std::string signature;
427     ForAllFields(object, /*reverse=*/false, [&](const r::Field *field) {
428       if (IsStructOrTable(field->type()->base_type())) {
429         const r::Object *field_object = GetObject(field->type());
430         signature += GenerateStructBuilderArgs(
431             field_object, prefix + namer_.Variable(field->name()->str()) + "_");
432       } else {
433         signature += ", " + prefix + namer_.Variable(field->name()->str());
434       }
435     });
436     return signature;
437   }
438 
AppendStructBuilderBody(const r::Object * object,std::string prefix="") const439   std::string AppendStructBuilderBody(const r::Object *object,
440                                       std::string prefix = "") const {
441     std::string code;
442     code += "  builder:Prep(" + NumToString(object->minalign()) + ", " +
443             NumToString(object->bytesize()) + ")\n";
444 
445     // We need to reverse the order we iterate over, since we build the
446     // buffer backwards.
447     ForAllFields(object, /*reverse=*/true, [&](const r::Field *field) {
448       const int32_t num_padding_bytes = field->padding();
449       if (num_padding_bytes) {
450         code += "  builder:Pad(" + NumToString(num_padding_bytes) + ")\n";
451       }
452       if (IsStructOrTable(field->type()->base_type())) {
453         const r::Object *field_object = GetObject(field->type());
454         code += AppendStructBuilderBody(
455             field_object, prefix + namer_.Variable(field->name()->str()) + "_");
456       } else {
457         code += "  builder:Prepend" + GenerateMethod(field) + "(" + prefix +
458                 namer_.Variable(field->name()->str()) + ")\n";
459       }
460     });
461 
462     return code;
463   }
464 
GenerateMethod(const r::Field * field) const465   std::string GenerateMethod(const r::Field *field) const {
466     const r::BaseType base_type = field->type()->base_type();
467     if (IsScalar(base_type)) { return namer_.Type(GenerateType(base_type)); }
468     if (IsStructOrTable(base_type)) { return "Struct"; }
469     return "UOffsetTRelative";
470   }
471 
GenerateGetter(const r::Type * type,bool element_type=false) const472   std::string GenerateGetter(const r::Type *type,
473                              bool element_type = false) const {
474     switch (element_type ? type->element() : type->base_type()) {
475       case r::String: return "self.view:String(";
476       case r::Union: return "self.view:Union(";
477       case r::Vector: return GenerateGetter(type, true);
478       default:
479         return "self.view:Get(flatbuffers.N." +
480                namer_.Type(GenerateType(type, element_type)) + ", ";
481     }
482   }
483 
GenerateType(const r::Type * type,bool element_type=false) const484   std::string GenerateType(const r::Type *type,
485                            bool element_type = false) const {
486     const r::BaseType base_type =
487         element_type ? type->element() : type->base_type();
488     if (IsScalar(base_type)) { return GenerateType(base_type); }
489     switch (base_type) {
490       case r::String: return "string";
491       case r::Vector: return GenerateGetter(type, true);
492       case r::Obj: return namer_.Type(namer_.Denamespace(GetObject(type)));
493 
494       default: return "*flatbuffers.Table";
495     }
496   }
497 
GenerateType(const r::BaseType base_type) const498   std::string GenerateType(const r::BaseType base_type) const {
499     // Need to override the default naming to match the Lua runtime libraries.
500     // TODO(derekbailey): make overloads in the runtime libraries to avoid this.
501     switch (base_type) {
502       case r::None: return "uint8";
503       case r::UType: return "uint8";
504       case r::Byte: return "int8";
505       case r::UByte: return "uint8";
506       case r::Short: return "int16";
507       case r::UShort: return "uint16";
508       case r::Int: return "int32";
509       case r::UInt: return "uint32";
510       case r::Long: return "int64";
511       case r::ULong: return "uint64";
512       case r::Float: return "Float32";
513       case r::Double: return "Float64";
514       default: return r::EnumNameBaseType(base_type);
515     }
516   }
517 
DefaultValue(const r::Field * field) const518   std::string DefaultValue(const r::Field *field) const {
519     const r::BaseType base_type = field->type()->base_type();
520     if (IsFloatingPoint(base_type)) {
521       return NumToString(field->default_real());
522     }
523     if (IsBool(base_type)) {
524       return field->default_integer() ? "true" : "false";
525     }
526     if (IsScalar(base_type)) { return NumToString((field->default_integer())); }
527     // represents offsets
528     return "0";
529   }
530 
StartCodeBlock(const reflection::Enum * enum_def)531   void StartCodeBlock(const reflection::Enum *enum_def) {
532     current_enum_ = enum_def;
533     current_obj_ = nullptr;
534     requires_.clear();
535   }
536 
StartCodeBlock(const reflection::Object * object)537   void StartCodeBlock(const reflection::Object *object) {
538     current_obj_ = object;
539     current_enum_ = nullptr;
540     requires_.clear();
541   }
542 
RegisterRequires(const r::Field * field,bool use_element=false)543   std::string RegisterRequires(const r::Field *field,
544                                bool use_element = false) {
545     std::string type_name;
546 
547     const r::BaseType type =
548         use_element ? field->type()->element() : field->type()->base_type();
549 
550     if (IsStructOrTable(type)) {
551       const r::Object *object = GetObjectByIndex(field->type()->index());
552       if (object == current_obj_) { return namer_.Denamespace(object); }
553       type_name = object->name()->str();
554     } else {
555       const r::Enum *enum_def = GetEnumByIndex(field->type()->index());
556       if (enum_def == current_enum_) { return namer_.Denamespace(enum_def); }
557       type_name = enum_def->name()->str();
558     }
559 
560     // Prefix with double __ to avoid name clashing, since these are defined
561     // at the top of the file and have lexical scoping. Replace '.' with '_'
562     // so it can be a legal identifier.
563     std::string name = "__" + type_name;
564     std::replace(name.begin(), name.end(), '.', '_');
565 
566     return RegisterRequires(name, type_name);
567   }
568 
RegisterRequires(const std::string & local_name,const std::string & requires_name)569   std::string RegisterRequires(const std::string &local_name,
570                                const std::string &requires_name) {
571     requires_[local_name] = requires_name;
572     return local_name;
573   }
574 
EmitCodeBlock(const std::string & code_block,const std::string & name,const std::string & ns,const std::string & declaring_file) const575   void EmitCodeBlock(const std::string &code_block, const std::string &name,
576                      const std::string &ns,
577                      const std::string &declaring_file) const {
578     const std::string root_type = schema_->root_table()->name()->str();
579     const std::string root_file =
580         schema_->root_table()->declaration_file()->str();
581     const std::string full_qualified_name = ns.empty() ? name : ns + "." + name;
582 
583     std::string code = "--[[ " + full_qualified_name + "\n\n";
584     code +=
585         "  Automatically generated by the FlatBuffers compiler, do not "
586         "modify.\n";
587     code += "  Or modify. I'm a message, not a cop.\n";
588     code += "\n";
589     code += "  flatc version: " + flatc_version_ + "\n";
590     code += "\n";
591     code += "  Declared by  : " + declaring_file + "\n";
592     code += "  Rooting type : " + root_type + " (" + root_file + ")\n";
593     code += "\n--]]\n\n";
594 
595     if (!requires_.empty()) {
596       for (auto it = requires_.cbegin(); it != requires_.cend(); ++it) {
597         code += "local " + it->first + " = require('" + it->second + "')\n";
598       }
599       code += "\n";
600     }
601 
602     code += code_block;
603     code += "return " + name;
604 
605     // Namespaces are '.' deliminted, so replace it with the path separator.
606     std::string path = ns;
607 
608     if (ns.empty()) {
609       path = ".";
610     } else {
611       std::replace(path.begin(), path.end(), '.', '/');
612     }
613 
614     // TODO(derekbailey): figure out a save file without depending on util.h
615     EnsureDirExists(path);
616     const std::string file_name = path + "/" + namer_.File(name);
617     SaveFile(file_name.c_str(), code, false);
618   }
619 
620   std::unordered_set<std::string> keywords_;
621   std::map<std::string, std::string> requires_;
622   const r::Object *current_obj_;
623   const r::Enum *current_enum_;
624   const std::string flatc_version_;
625   const BfbsNamer namer_;
626 };
627 }  // namespace
628 
NewLuaBfbsGenerator(const std::string & flatc_version)629 std::unique_ptr<BfbsGenerator> NewLuaBfbsGenerator(
630     const std::string &flatc_version) {
631   return std::unique_ptr<LuaBfbsGenerator>(new LuaBfbsGenerator(flatc_version));
632 }
633 
634 }  // namespace flatbuffers