63 schemaFolderPath =
"./../types/schemas/"
64 schemaFolderPath = os.path.abspath(
65 os.path.join(os.path.dirname(__file__), schemaFolderPath)
68 schemaFilePath = os.path.join(schemaFolderPath, f
"{fname}.fbs")
69 if not os.path.isfile(schemaFilePath):
72 schemaFile = open(schemaFilePath,
"r")
79 for line
in schemaFile:
82 match = re.search(
r'^table [a-zA-Z_]*_fb', line)
84 line = line.strip().split()
85 tableIdx = line.index(
"table")
86 schemaTableName = line[tableIdx + 1]
89 line = line.strip().split()
90 subNameIdx = line.index(
"table")
91 subName = line[subNameIdx + 1]
92 subTables[subName] = []
95 if not inTable
and "{" in line:
109 if (
"deprecated" in line):
113 fieldParts = line.strip().rstrip(
";").split(
":")
115 fieldType = fieldParts[1].split()[0]
117 if curSubTable
is not None:
119 subTables[curSubTable].append((name, fieldType))
121 schemaFieldInfo.append((name, fieldType))
124 if len(subTables) == 0:
125 return schemaTableName, tuple(schemaFieldInfo)
129 newSchemaFieldInfo = []
130 for field
in schemaFieldInfo:
132 if fieldType
in subTables.keys():
133 newSchemaFieldInfo.append({field[0] : subTables[fieldType]})
135 newSchemaFieldInfo.append(field)
137 return schemaTableName, tuple(newSchemaFieldInfo)
141Quick check that the types in .fbs correspond, mainly strings match to strings,
142and vectors to vectors.
143If they do not correspond, the behavior for comparing the fb values in the tests
144is undefined, and action beyond this generator will need to be taken.
183 headerFile = open(hppFname,
"r")
184 headerLines = headerFile.readlines()
187 fNameParts = hppFname.split(
"/")
188 returnInfo[
"name"] = fNameParts[-1].strip().split(
".")[0]
189 CamelCase =
"".join([word.capitalize()
for word
in returnInfo[
"name"].split(
"_")])
190 returnInfo[
"nameCamelCase"] = CamelCase[0].lower() + CamelCase[1:]
192 returnInfo[
"genTestFname"] = f
"{returnInfo['name']}_generated_tests.cpp"
193 returnInfo[
"className"] =
"C" +
"".join([word.capitalize()
for word
in returnInfo[
"name"].split(
"_")])
194 returnInfo[
"classVarName"] =
"".join([word[0].lower()
for word
in returnInfo[
"name"].split(
"_")])
200 if returnInfo[
"name"]
not in baseTypesDict:
201 baseTypesDict[returnInfo[
"name"]] = set()
208 fbMethodName = f
"Create{returnInfo["name
"][0].upper() + returnInfo["name
"][1:]}_fb"
210 messageStructIdxs = []
211 for i
in range(len(headerLines)):
212 if "messageT(" in headerLines[i]:
213 messageStructIdxs.append(i)
214 if fbMethodName
in headerLines[i]:
218 returnInfo[
"schemaTableName"] = schemaTableName
221 if len(messageStructIdxs) == 0:
223 if returnInfo[
"baseType"]
not in baseTypesDict:
224 baseTypesDict[returnInfo[
"baseType"]] = set()
227 baseTypesDict[returnInfo[
"baseType"]].add(returnInfo[
"name"])
233 for line
in headerLines:
234 if re.search(
"^.*Create[a-zA-Z_]*_fb.*$", line)
and returnInfo[
"schemaTableName"] ==
"":
236 startIndex = line.find(
"Create") + len(
"Create")
237 endIndex = line.find(
"_fb")
238 returnInfo[
"schemaTableName"] = f
"{line[startIndex:endIndex]}_fb"
240 returnInfo[
"messageTypes"] =
getMessageFieldInfo(messageStructIdxs, headerLines, schemaFieldInfo)
245Parse out field type and name from string
249 typeIdxStart = 1
if (fieldParts[0] ==
"const")
else 0
250 fieldType = fieldParts[typeIdxStart]
252 if fieldParts[typeIdxStart + 1] ==
"&":
253 nameIdx = (typeIdxStart + 2)
254 elif fieldParts[typeIdxStart + 1] ==
"*":
255 nameIdx = (typeIdxStart + 2)
258 nameIdx = (typeIdxStart + 1)
260 name = fieldParts[nameIdx].rstrip(
")").rstrip(
",")
265 name = name.lstrip(
"&*")
267 return fieldType, name
270Checks if log type has a corresponding generated .h file in ./types/generated
273 generatedFolderPath =
"./../types/generated/"
274 generatedFolderPath = os.path.abspath(
275 os.path.join(os.path.dirname(__file__), generatedFolderPath)
278 generatedFilePath = os.path.join(generatedFolderPath, f
"{logName}_generated.h")
279 if os.path.isfile(generatedFilePath):
341 if "bool" in fieldType
or (schemaFieldType
is not None and "bool" in schemaFieldType):
343 elif "string" in fieldType
or "char *" in fieldType:
344 if gIncrementingVals:
345 gNextVals[
"string"] += 1
346 return f
'"{gNextVals["string"]}"'
347 randString =
''.join(random.choices(string.ascii_lowercase + string.digits, k=10))
348 return f
'"{randString}"'
349 elif "int" in fieldType:
350 if gIncrementingVals:
354 return f
'{str(getRandInt(fieldType))}u' if "uint64_t" in fieldType
else str(
getRandInt(fieldType))
355 elif "float" in fieldType:
356 if gIncrementingVals:
357 gNextVals[
"float"] += 1
358 return str(round( (gNextVals[
"float"] / 100000), 6))
359 return str(round(random.random(), 6))
360 elif "double" in fieldType:
361 if gIncrementingVals:
362 gNextVals[
"double"] += 1
363 return str(round( (gNextVals[
"double"] / 10000000000), 14))
364 return str(round(random.random(), 14))
386 for schemaField
in schemaFieldInfo:
387 if isinstance(schemaField, tuple)
and schemaField[0] == fieldName:
388 return schemaField,
None
389 if isinstance(schemaField, dict):
390 subTableName = next(iter(schemaField))
391 for subField
in schemaField[subTableName]:
392 if len(subField) != 2:
394 if subField[0] == fieldName:
395 return subField, subTableName
439 subTableDictIndex = 0
442 inMultilineComment =
False
443 inDefaultArgDef =
False
444 for i
in range(len(messageStructIdxs)):
445 structIdx = messageStructIdxs[i]
450 while not closed
and structIdx < len(lines):
452 line = lines[structIdx]
456 if (
"//" in line
and line.find(
")") > line.find(
"//")):
459 elif line.strip().strip(
")") ==
"":
462 openParenCount = line.count(
"(")
463 closeParenCount = line.count(
")")
465 if (closeParenCount > openParenCount)
or \
466 (closeParenCount == openParenCount
and "messageT(" in line):
468 line = line[:line.rfind(
")")]
470 if inMultilineComment:
472 inMultilineComment =
False
478 indexStart = (line.find(
"messageT(") + len(
"messageT("))
if "messageT(" in line
else 0
482 if "/*" in line
and line.find(
"/*") < indexEnd:
483 indexEnd = line.find(
"/*")
486 inMultilineComment =
True
488 indexEnd = line.find(
"//")
491 line = line[indexStart:indexEnd].strip()
493 fieldParts = [part.strip().split()
for part
in line.strip().rstrip(
",").split(
",")]
495 for field
in fieldParts:
498 if len(field) > 0
and "//" in field[0]:
504 if inDefaultArgDef
and len(msgsFieldsList) > 1:
506 inDefaultArgDef =
False
518 inDefaultArgDef =
True
524 fieldDict[
"type"] = fieldType
525 fieldDict[
"name"] = name
527 if "std::vector" in fieldDict[
"type"]:
528 typeParts = fieldDict[
"type"].split(
"<")
529 vectorIdx = [i
for i, e
in enumerate(typeParts)
if "std::vector" in e][0]
530 vectorType = typeParts[vectorIdx + 1].strip(
">")
531 fieldDict[
"vectorType"] = vectorType
533 if len(schemaFieldInfo) != 0:
535 if isinstance(schemaFieldInfo[fieldCount], tuple):
536 fieldDict[
"schemaName"] = schemaFieldInfo[fieldCount][0]
537 fieldDict[
"schemaType"] = schemaFieldInfo[fieldCount][1]
542 if matchingSchemaField !=
None and len(matchingSchemaField) == 2:
543 subTableStr = f
"{subTableName}()->" if subTableName
is not None else ""
544 fieldDict[
"schemaName"] = f
"{subTableStr}{matchingSchemaField[0]}"
545 fieldDict[
"schemaType"] = matchingSchemaField[1]
549 subTableName = next(iter(schemaFieldInfo[fieldCount]))
550 schemaFieldName = schemaFieldInfo[fieldCount][subTableName][subTableDictIndex][0]
551 schemaFieldType = schemaFieldInfo[fieldCount][subTableName][subTableDictIndex][1]
552 fieldDict[
"schemaName"] = f
"{subTableName}()->{schemaFieldName}"
553 fieldDict[
"schemaType"] = schemaFieldType
554 subTableDictIndex += 1
555 if (subTableDictIndex >= len(schemaFieldInfo[fieldCount][subTableName])):
557 subTableDictIndex = 0
565 del fieldDict[
"schemaName"]
566 del fieldDict[
"schemaType"]
571 msgsFieldsList.append(fieldDict)
582 msgTypesList.append(msgsFieldsList)
589 baseFilePath = os.path.join(typesFolderPath, f
"{baseName}.hpp")
590 baseHFile = open(baseFilePath,
"r")
594 returnInfo[
"name"] = logName
595 returnInfo[
"genTestFname"] = f
"{returnInfo['name']}_generated_tests.cpp"
596 returnInfo[
"className"] =
"C" +
"".join([word.capitalize()
for word
in returnInfo[
"name"].split(
"_")])
597 CamelCase =
"".join([word.capitalize()
for word
in returnInfo[
"name"].split(
"_")])
598 returnInfo[
"nameCamelCase"] = CamelCase[0].lower() + CamelCase[1:]
599 returnInfo[
"classVarName"] =
"".join([word[0].lower()
for word
in returnInfo[
"name"].split(
"_")])
600 returnInfo[
"baseType"] = baseName
603 baseHLines = baseHFile.readlines()
606 messageStructIdxs = []
607 for i
in range(len(baseHLines)):
608 if "messageT(" in baseHLines[i]:
609 messageStructIdxs.append(i)
613 returnInfo[
"schemaTableName"] = schemaTableName
616 returnInfo[
"messageTypes"] = [[]]
if "empty_log" in baseName
else msgFieldInfo
626 print(
"Error: Python version must be >= 3.9")
630 global gIncrementingVals
631 gIncrementingVals =
False
635 opts, args = getopt.getopt(sys.argv[1:],
"is:")
637 print(
"Error: Only one option allowed. -s <seed> or -i for incrementing values.")
640 except getopt.GetoptError:
641 print(
"Usage: python3 ./generateTemplatedCatch2Tests.py -s <seed> | -i")
643 for opt, arg
in opts:
645 if not arg.isdigit():
646 print(f
"Error: random seed {arg} provided is not an integer.")
649 random.seed(int(arg))
651 gIncrementingVals =
True
654 env = jinja2.Environment(
655 loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__))
657 env.trim_blocks =
True
658 env.lstrip_blocks =
True
660 catchTemplate = env.get_template(
"catch2TestTemplate.jinja2")
663 typesFolderPath =
"./../types"
664 typesFolderPath = os.path.abspath(
665 os.path.join(os.path.dirname(__file__), typesFolderPath)
669 generatedTestsFolderPath =
"./generated_tests/"
670 generatedTestsFolderPath = os.path.abspath(
671 os.path.join(os.path.dirname(__file__), generatedTestsFolderPath)
675 pathlib.Path(generatedTestsFolderPath).mkdir(exist_ok=
True)
676 oldFiles = glob.glob(os.path.join(generatedTestsFolderPath,
"*"))
677 for file
in oldFiles:
680 types = os.listdir(typesFolderPath)
682 baseTypesDict = dict()
683 print(
"generating tests for...")
688 if ".hpp" not in type:
691 typePath = os.path.join(typesFolderPath, type)
700 renderedHeader = catchTemplate.render(info)
703 outPath = os.path.join(generatedTestsFolderPath, info[
"genTestFname"])
704 with open(outPath,
"w")
as outfile:
705 print(renderedHeader,file=outfile)
708 for baseType, inheritedTypes
in baseTypesDict.items():
710 if len(inheritedTypes) == 0:
713 for inheritedType
in inheritedTypes:
720 renderedHeader = catchTemplate.render(info)
723 outPath = os.path.join(generatedTestsFolderPath, info[
"genTestFname"])
724 with open(outPath,
"w")
as outfile:
725 print(renderedHeader,file=outfile)