API
 
Loading...
Searching...
No Matches
generateTemplatedCatch2Tests.py
Go to the documentation of this file.
1#!/bin/env python3
2
3'''
4Generate Catch2 tests from template.
5See README.md for more details.
6'''
7
8import os
9import sys
10import subprocess
11import glob
12import re
13import pathlib
14import string
15import random
16import getopt
17
18
19gNextVals = {
20 "string" : 0,
21 "int64" : 0,
22 "uint64" : 0,
23 "int32" : 0,
24 "uint32" : 0,
25 "int16" : 0,
26 "uint16" : 0,
27 "int8" : 0,
28 "uint8" : 0,
29 "float" : 0,
30 "double" : 0
31}
32gIncrementingVals = False
33
34# check jinja2 is installed. install it if not
35try:
36 import jinja2
37except ModuleNotFoundError:
38 print("module 'Jinja2' is not installed. Installing Jinja2...")
39 subprocess.check_call([sys.executable, "-m", "pip", "install", 'Jinja2'])
40 import jinja2
41
42
43'''
44Get base type of log. This is needed for log types that inherit from a base type
45that specfies the messageT(...)
46'''
47def getBaseType(lines : list) -> str:
48 # use regex to find #include "<baseType>.hpp"
49 baseType = ""
50 for line in lines:
51 match = re.search(r'^struct [a-z_]* : public [a-z_]*', line)
52 if match != None:
53 baseType = line.strip().split()[-1]
54 baseType = baseType.split("<")[0]
55
56 return baseType
57
58
59'''
60NOTE: This relies on name order in .fbs schema and .hpp files to be the same.
61'''
62def getSchemaFieldInfo(fname : str) -> tuple[str, tuple] :
63 schemaFolderPath = "./../types/schemas/"
64 schemaFolderPath = os.path.abspath(
65 os.path.join(os.path.dirname(__file__), schemaFolderPath)
66 )
67
68 schemaFilePath = os.path.join(schemaFolderPath, f"{fname}.fbs")
69 if not os.path.isfile(schemaFilePath):
70 return "", tuple()
71
72 schemaFile = open(schemaFilePath, "r")
73
74 schemaFieldInfo = []
75 subTables = dict() # dict where key is sub-table name, value is [(fieldname, type)...]
76 curSubTable = None
77 inTable = False
78 schemaTableName = ""
79 for line in schemaFile:
80 if "table" in line:
81 # check if `table <log_type>_fb`
82 match = re.search(r'^table [a-zA-Z_]*_fb', line)
83 if match != None:
84 line = line.strip().split()
85 tableIdx = line.index("table")
86 schemaTableName = line[tableIdx + 1]
87 # otherwise it is a sub-table
88 else:
89 line = line.strip().split()
90 subNameIdx = line.index("table")
91 subName = line[subNameIdx + 1]
92 subTables[subName] = []
93 curSubTable = subName # we are in a sub-table of the schema
94
95 if not inTable and "{" in line:
96 inTable = True
97 continue
98
99 if inTable:
100 line = line.strip()
101 if ("//" in line):
102 continue
103
104 if ("}" in line):
105 inTable = False
106 curSubTable = None
107 continue
108
109 if ("deprecated" in line):
110 continue
111
112 if (line != ""):
113 fieldParts = line.strip().rstrip(";").split(":")
114 name = fieldParts[0]
115 fieldType = fieldParts[1].split()[0]
116
117 if curSubTable is not None:
118 # add to subtable dict for now, will be added in later
119 subTables[curSubTable].append((name, fieldType))
120 else:
121 schemaFieldInfo.append((name, fieldType))
122 continue
123
124 if len(subTables) == 0:
125 return schemaTableName, tuple(schemaFieldInfo)
126
127
128 # go through sub tables and add them in
129 newSchemaFieldInfo = []
130 for field in schemaFieldInfo:
131 fieldType = field[1]
132 if fieldType in subTables.keys():
133 newSchemaFieldInfo.append({field[0] : subTables[fieldType]})
134 else:
135 newSchemaFieldInfo.append(field)
136 # print(newSchemaFieldInfo)
137 return schemaTableName, tuple(newSchemaFieldInfo)
138
139
140'''
141Quick check that the types in .fbs correspond, mainly strings match to strings,
142and vectors to vectors.
143If they do not correspond, the behavior for comparing the fb values in the tests
144is undefined, and action beyond this generator will need to be taken.
145'''
146def typesCorrespond(fbsType : str, cType : str) -> bool:
147 if ("[" in fbsType) or ("vector" in cType):
148 return ("[" in fbsType) and ("vector" in cType)
149
150 if ("string" in fbsType) or ("string" in cType or "char *" in cType):
151 return (("string" in fbsType) and ("string" in cType or "char *" in cType))
152
153 return True
154
155
156'''
157Check it is not a base log type.
158Must have eventCode and defaultLevel
159'''
160def isValidLogType(lines : list) -> bool:
161 hasEventCode = False
162 hasDefaultLevel = False
163 for line in lines:
164
165 # check event code
166 eventCode = re.search("flatlogs::eventCodeT eventCode = eventCodes::[A-Za-z_0-9]*;", line)
167 if eventCode != None:
168 hasEventCode = True
169
170 # check default level
171 defaultLevel = re.search("flatlogs::logPrioT defaultLevel = flatlogs::logPrio::[A-Za-z_0-9]*;", line)
172 if defaultLevel != None:
173 hasDefaultLevel = True
174
175 # if we have both already, return
176 if hasEventCode and hasDefaultLevel:
177 return True
178
179 return (hasEventCode and hasDefaultLevel)
180
181def makeTestInfoDict(hppFname : str, baseTypesDict : dict) -> dict:
182 returnInfo = dict()
183 headerFile = open(hppFname,"r")
184 headerLines = headerFile.readlines()
185
186 # add name of test/file/type to be generated
187 fNameParts = hppFname.split("/")
188 returnInfo["name"] = fNameParts[-1].strip().split(".")[0]
189 CamelCase = "".join([word.capitalize() for word in returnInfo["name"].split("_")])
190 returnInfo["nameCamelCase"] = CamelCase[0].lower() + CamelCase[1:]
191 # print(f"LOGNAME: {returnInfo["name"]}")
192 returnInfo["genTestFname"] = f"{returnInfo['name']}_generated_tests.cpp"
193 returnInfo["className"] = "C" + "".join([word.capitalize() for word in returnInfo["name"].split("_")])
194 returnInfo["classVarName"] = "".join([word[0].lower() for word in returnInfo["name"].split("_")])
195 returnInfo["baseType"] = getBaseType(headerLines)
196 returnInfo["hasGeneratedHfile"] = hasGeneratedHFile(returnInfo["name"])
197
198 # cannot generate tests from this file alone, need base type
199 if not isValidLogType(headerLines):
200 if returnInfo["name"] not in baseTypesDict:
201 baseTypesDict[returnInfo["name"]] = set()
202
203 return None # don't render anything from this file
204
205 # iterate through all lines in header to:
206 # 1. find where messageT structs are being made -> describes fields
207 # 2. check that is has its own <Get|Create|Verify><name>_fb methods
208 fbMethodName = f"Create{returnInfo["name"][0].upper() + returnInfo["name"][1:]}_fb"
209 hasFbMethods = False
210 messageStructIdxs = []
211 for i in range(len(headerLines)):
212 if "messageT(" in headerLines[i]:
213 messageStructIdxs.append(i)
214 if fbMethodName in headerLines[i]:
215 hasFbMethods = True
216
217 schemaTableName, schemaFieldInfo = getSchemaFieldInfo(returnInfo["name"])
218 returnInfo["schemaTableName"] = schemaTableName
219
220 # handle log types that inherit from base types
221 if len(messageStructIdxs) == 0:
222
223 if returnInfo["baseType"] not in baseTypesDict:
224 baseTypesDict[returnInfo["baseType"]] = set()
225
226 # add inhertied type to dict where val is the base type it inherits from
227 baseTypesDict[returnInfo["baseType"]].add(returnInfo["name"])
228
229 return None # don't render me yet!
230
231 # if it does not have its own fb method, find name of class its using
232 if not hasFbMethods:
233 for line in headerLines:
234 if re.search("^.*Create[a-zA-Z_]*_fb.*$", line) and returnInfo["schemaTableName"] == "":
235 # figure out name of fb methods this type is re-using, e.g. ao_observer -> observer
236 startIndex = line.find("Create") + len("Create")
237 endIndex = line.find("_fb")
238 returnInfo["schemaTableName"] = f"{line[startIndex:endIndex]}_fb"
239
240 returnInfo["messageTypes"] = getMessageFieldInfo(messageStructIdxs, headerLines, schemaFieldInfo)
241
242 return returnInfo
243
244'''
245Parse out field type and name from string
246'''
247def getTypeAndName(fieldParts : list) -> tuple[str, str]:
248
249 typeIdxStart = 1 if (fieldParts[0] == "const") else 0
250 fieldType = fieldParts[typeIdxStart]
251
252 if fieldParts[typeIdxStart + 1] == "&":
253 nameIdx = (typeIdxStart + 2)
254 elif fieldParts[typeIdxStart + 1] == "*":
255 nameIdx = (typeIdxStart + 2)
256 fieldType += " *"
257 else:
258 nameIdx = (typeIdxStart + 1)
259
260 name = fieldParts[nameIdx].rstrip(")").rstrip(",")
261
262 if name[0] == "*":
263 fieldType += " *"
264
265 name = name.lstrip("&*")
266
267 return fieldType, name
268
269'''
270Checks if log type has a corresponding generated .h file in ./types/generated
271'''
272def hasGeneratedHFile(logName : str) -> bool:
273 generatedFolderPath = "./../types/generated/"
274 generatedFolderPath = os.path.abspath(
275 os.path.join(os.path.dirname(__file__), generatedFolderPath)
276 )
277
278 generatedFilePath = os.path.join(generatedFolderPath, f"{logName}_generated.h")
279 if os.path.isfile(generatedFilePath):
280 return True
281
282 return False
283
284def getIntSize(type : str) -> int:
285 intSizeBits = 32 # default size 32 bits
286 if "_t" in type:
287 typeParts = type.split("_t")
288 intSizeBits = int(typeParts[0][-1]) if (int(typeParts[0][-1]) == 8) \
289 else int(typeParts[0][-2:])
290
291 return intSizeBits
292
293
294def getRandInt(type : str) -> int:
295 unsigned = True if "uint" in type else False
296
297 intSizeBits = getIntSize(type)
298
299 if not unsigned:
300 intSizeBits -= 1
301
302 max = (2 ** intSizeBits) - 1
303 min = 0 if unsigned else (0 - max - 1)
304
305 return random.randint(min, max)
306
307def getIncrementingInt(type : str) -> int:
308 intSizeBits = getIntSize(type)
309
310 max = (2 ** intSizeBits) - 1
311
312 if "int8_t" in type:
313 gNextVals["int8"] = (gNextVals["int8"] + 1) % max
314 return gNextVals["int8"]
315 elif "uint8_t" in type:
316 gNextVals["uint8"] = (gNextVals["uint8"] + 1) % max
317 return gNextVals["uint8"]
318 elif "int16_t" in type:
319 gNextVals["int16"] = (gNextVals["int16"] + 1) % max
320 return gNextVals["int16"]
321 elif "uint16_t" in type:
322 gNextVals["uint16"] = (gNextVals["uint16"] + 1) % max
323 return gNextVals["uint16"]
324 elif "int32_t" in type:
325 gNextVals["int32"] = (gNextVals["int32"] + 1) % max
326 return gNextVals["int32"]
327 elif "uint32_t" in type:
328 gNextVals["uint32"] = (gNextVals["uint32"] + 1) % max
329 return gNextVals["uint32"]
330 elif "int64_t" in type:
331 gNextVals["int64"] = (gNextVals["int64"] + 1) % max
332 return gNextVals["int64"]
333 elif "uint64_t" in type:
334 gNextVals["uint64"] = (gNextVals["uint64"] + 1) % max
335 return gNextVals["uint64"]
336 else:
337 gNextVals["int32"] = (gNextVals["int32"] + 1) % max
338 return gNextVals["int32"]
339
340def getTestValFromType(fieldType : str, schemaFieldType = None) -> str:
341 if "bool" in fieldType or (schemaFieldType is not None and "bool" in schemaFieldType):
342 return "1"
343 elif "string" in fieldType or "char *" in fieldType:
344 if gIncrementingVals:
345 gNextVals["string"] += 1
346 return f'"{gNextVals["string"]}"'
347 randString = ''.join(random.choices(string.ascii_lowercase + string.digits, k=10))
348 return f'"{randString}"'
349 elif "int" in fieldType:
350 if gIncrementingVals:
351 return str(getIncrementingInt(fieldType))
352 # need 'u' suffix for randomly generated uint64_t to avoid:
353 # "warning: integer constant is so large that it is unsigned"
354 return f'{str(getRandInt(fieldType))}u' if "uint64_t" in fieldType else str(getRandInt(fieldType))
355 elif "float" in fieldType:
356 if gIncrementingVals:
357 gNextVals["float"] += 1
358 return str(round( (gNextVals["float"] / 100000), 6))
359 return str(round(random.random(), 6))
360 elif "double" in fieldType:
361 if gIncrementingVals:
362 gNextVals["double"] += 1
363 return str(round( (gNextVals["double"] / 10000000000), 14))
364 return str(round(random.random(), 14))
365 else:
366 return "{}"
367
368
369def makeTestVal(fieldDict : dict) -> str:
370 if "vector" in fieldDict["type"]:
371 vals = [ getTestValFromType(fieldDict["vectorType"]) for i in range(10)]
372
373 # special case telem_pokecenter because vector follows specific format
374 if fieldDict["name"] == "pokes" and "vector<float" in fieldDict["type"]:
375 catchAssertVals = [vals[i] for i in range(0, len(vals), 2)]
376 fieldDict["specialAssertVal"] = f"{{ {",".join(catchAssertVals)} }}"
377 return f"{{ {",".join(vals)} }}"
378
379 if "schemaType" in fieldDict:
380 return getTestValFromType(fieldDict["type"], fieldDict["schemaType"])
381
382 return getTestValFromType(fieldDict["type"])
383
384# returns tuple of schema field info, subtable name (none if no subtable)
385def findMatchingSchemaField(schemaFieldInfo, fieldName):
386 for schemaField in schemaFieldInfo:
387 if isinstance(schemaField, tuple) and schemaField[0] == fieldName:
388 return schemaField, None
389 if isinstance(schemaField, dict):
390 subTableName = next(iter(schemaField))
391 for subField in schemaField[subTableName]:
392 if len(subField) != 2:
393 continue
394 if subField[0] == fieldName:
395 return subField, subTableName
396 return None, None # no matching field in schema for given fieldName
397
398def setDefaultArgOfLastField(fieldsList, fieldParts):
399 # this is the default arg value for msgsFieldList[-1]
400 fieldsList[-1]["defaultArg"] = " ".join(fieldParts).strip("=").strip()
401
402 # std::source_location aliases separate fields file and line for software_log
403 if "std::source_location" in fieldsList[-1]["type"]:
404 if fieldsList[-1]["name"] == "loc": # can remove this line if we don't want "loc" strictly associated with alias
405 fieldsList.pop()
406
407 # special case software log loc alias for file and line fields.
408 fieldsList.append(
409 {
410 "type": "char *",
411 "name": "file",
412 "schemaName": "file",
413 "schemaType": "string",
414 "testVal": None,
415 "defaultArg": True,
416 "defaultTestVal": "__FILE__"
417 }
418 )
419 fieldsList.append(
420 {
421 "type": "uint32_t",
422 "name": "line",
423 "schemaName": "line",
424 "schemaType": "uint32",
425 "testVal": None,
426 "defaultArg": True
427 }
428 )
429
430
431
432'''
433make 2d array. each inner array contains dictionaries corresponding to
434the type(s) and name(s) of field(s) in a message:
435[ [ {type : x, name: y ...}, {name: type, ...} ], ... ]
436'''
437def getMessageFieldInfo(messageStructIdxs: list, lines : list, schemaFieldInfo : tuple):
438 msgTypesList = []
439 subTableDictIndex = 0
440
441 # extract log field types and names
442 inMultilineComment = False
443 inDefaultArgDef = False
444 for i in range(len(messageStructIdxs)):
445 structIdx = messageStructIdxs[i]
446 msgsFieldsList = []
447
448 closed = False
449 fieldCount = 0
450 while not closed and structIdx < len(lines):
451
452 line = lines[structIdx]
453
454 # check if this is a closing line
455 if ")" in line:
456 if ("//" in line and line.find(")") > line.find("//")):
457 # parenthesis is in comment
458 pass
459 elif line.strip().strip(")") == "":
460 break # field is done
461 else:
462 openParenCount = line.count("(")
463 closeParenCount = line.count(")")
464 # check if truly closed or not
465 if (closeParenCount > openParenCount) or \
466 (closeParenCount == openParenCount and "messageT(" in line):
467 closed = True # parse the field, don't leave loop yet
468 line = line[:line.rfind(")")]
469
470 if inMultilineComment:
471 if "*/" in line:
472 inMultilineComment = False
473 structIdx += 1
474 continue
475
476
477 # trim line to just get field info
478 indexStart = (line.find("messageT(") + len("messageT(")) if "messageT(" in line else 0
479
480 indexEnd = len(line)
481 # adjust for comments. Note /* takes precedence over //
482 if "/*" in line and line.find("/*") < indexEnd:
483 indexEnd = line.find("/*")
484 # handle multiline comments
485 if "*/" not in line:
486 inMultilineComment = True
487 elif "//" in line:
488 indexEnd = line.find("//")
489
490
491 line = line[indexStart:indexEnd].strip()
492
493 fieldParts = [part.strip().split() for part in line.strip().rstrip(",").split(",")]
494
495 for field in fieldParts:
496 fieldDict = {}
497
498 if len(field) > 0 and "//" in field[0]:
499 break
500 if len(field) == 0:
501 break
502
503 # check if this is a default arg value
504 if inDefaultArgDef and len(msgsFieldsList) > 1:
505 setDefaultArgOfLastField(msgsFieldsList, field)
506 inDefaultArgDef = False
507 field.pop(0)
508 if len(field) == 0:
509 break
510
511 # handle default arguments that expand across two lines
512 if field[0] == "=":
513 setDefaultArgOfLastField(msgsFieldsList, field)
514 continue
515 if field[-1] == '=':
516 # set flag but still need to parse this fields info.
517 # don't leave this loop iteration yet.
518 inDefaultArgDef = True
519
520
521 # find type and name
522 fieldType, name = getTypeAndName(field)
523
524 fieldDict["type"] = fieldType
525 fieldDict["name"] = name
526 # get vector type if necessary
527 if "std::vector" in fieldDict["type"]:
528 typeParts = fieldDict["type"].split("<")
529 vectorIdx = [i for i, e in enumerate(typeParts) if "std::vector" in e][0]
530 vectorType = typeParts[vectorIdx + 1].strip(">")
531 fieldDict["vectorType"] = vectorType
532
533 if len(schemaFieldInfo) != 0:
534
535 if isinstance(schemaFieldInfo[fieldCount], tuple):
536 fieldDict["schemaName"] = schemaFieldInfo[fieldCount][0]
537 fieldDict["schemaType"] = schemaFieldInfo[fieldCount][1]
538 fieldCount += 1
539
540 # check if matching name in schema file exists, let this overwrite
541 matchingSchemaField, subTableName = findMatchingSchemaField(schemaFieldInfo, fieldDict["name"])
542 if matchingSchemaField != None and len(matchingSchemaField) == 2:
543 subTableStr = f"{subTableName}()->" if subTableName is not None else ""
544 fieldDict["schemaName"] = f"{subTableStr}{matchingSchemaField[0]}"
545 fieldDict["schemaType"] = matchingSchemaField[1]
546
547 else:
548 # go into dictionary..
549 subTableName = next(iter(schemaFieldInfo[fieldCount]))
550 schemaFieldName = schemaFieldInfo[fieldCount][subTableName][subTableDictIndex][0]
551 schemaFieldType = schemaFieldInfo[fieldCount][subTableName][subTableDictIndex][1]
552 fieldDict["schemaName"] = f"{subTableName}()->{schemaFieldName}"
553 fieldDict["schemaType"] = schemaFieldType
554 subTableDictIndex += 1
555 if (subTableDictIndex >= len(schemaFieldInfo[fieldCount][subTableName])):
556 # reset dictionary index if we need to
557 subTableDictIndex = 0
558 fieldCount += 1
559
560 # check schemaType correlates to type in .hpp file
561 if not typesCorrespond(fieldDict["schemaType"], fieldDict["type"]):
562 # if types don't correspond, then use name in messageT and hope for best.
563 # this is why if types are different, then names MUST correspond between
564 # .fbs and .hpp file
565 del fieldDict["schemaName"]
566 del fieldDict["schemaType"]
567
568 fieldDict["testVal"] = makeTestVal(fieldDict)
569
570 # add field dict to list of fields
571 msgsFieldsList.append(fieldDict)
572
573 # note we do this after the fieldDict has been appended to msgsFieldList.
574 # This is because the function setDefaultArgOfLastField() does some
575 # special casing to replace std::current_location with field and line
576 # within the msgsFieldsList
577 if "=" in field:
578 setDefaultArgOfLastField(msgsFieldsList, field)
579
580 structIdx += 1
581
582 msgTypesList.append(msgsFieldsList)
583
584 return msgTypesList
585
586def makeInheritedTypeInfoDict(typesFolderPath : str, baseName : str, logName : str) -> dict:
587 returnInfo = dict()
588
589 baseFilePath = os.path.join(typesFolderPath, f"{baseName}.hpp")
590 baseHFile = open(baseFilePath,"r")
591
592 # add name of test/file/type to be generated
593 # print(f"LOGNAME: {logName}")
594 returnInfo["name"] = logName
595 returnInfo["genTestFname"] = f"{returnInfo['name']}_generated_tests.cpp"
596 returnInfo["className"] = "C" + "".join([word.capitalize() for word in returnInfo["name"].split("_")])
597 CamelCase = "".join([word.capitalize() for word in returnInfo["name"].split("_")])
598 returnInfo["nameCamelCase"] = CamelCase[0].lower() + CamelCase[1:]
599 returnInfo["classVarName"] = "".join([word[0].lower() for word in returnInfo["name"].split("_")])
600 returnInfo["baseType"] = baseName
601 returnInfo["hasGeneratedHfile"] = hasGeneratedHFile(logName)
602
603 baseHLines = baseHFile.readlines()
604
605 # find where messageT structs are being made in base log file -> describes fields
606 messageStructIdxs = []
607 for i in range(len(baseHLines)):
608 if "messageT(" in baseHLines[i]:
609 messageStructIdxs.append(i)
610
611 schemaTableName, schemaFieldInfo = getSchemaFieldInfo(baseName)
612
613 returnInfo["schemaTableName"] = schemaTableName
614 msgFieldInfo = getMessageFieldInfo(messageStructIdxs, baseHLines, schemaFieldInfo)
615
616 returnInfo["messageTypes"] = [[]] if "empty_log" in baseName else msgFieldInfo
617
618 return returnInfo
619
620def versionAsNumber(major, minor):
621 return (major * 1000 + minor)
622
623def main():
624 # check python version >= 3.9
625 if (versionAsNumber(sys.version_info[0], sys.version_info[1]) < versionAsNumber(3,9)):
626 print("Error: Python version must be >= 3.9")
627 exit(0)
628
629
630 global gIncrementingVals
631 gIncrementingVals = False
632
633 # getopt for random seed or incrementing vals
634 try:
635 opts, args = getopt.getopt(sys.argv[1:], "is:")
636 if len(opts) > 1:
637 print("Error: Only one option allowed. -s <seed> or -i for incrementing values.")
638 exit(0)
639
640 except getopt.GetoptError:
641 print("Usage: python3 ./generateTemplatedCatch2Tests.py -s <seed> | -i")
642 exit(0)
643 for opt, arg in opts:
644 if opt in ["-s"]:
645 if not arg.isdigit():
646 print(f"Error: random seed {arg} provided is not an integer.")
647 exit(0)
648 # use random seed if provided with -s
649 random.seed(int(arg))
650 if opt in ["-i"]:
651 gIncrementingVals = True
652
653 # load template
654 env = jinja2.Environment(
655 loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__))
656 )
657 env.trim_blocks = True
658 env.lstrip_blocks = True
659
660 catchTemplate = env.get_template("catch2TestTemplate.jinja2")
661
662 # path to .hpp files here
663 typesFolderPath = "./../types"
664 typesFolderPath = os.path.abspath(
665 os.path.join(os.path.dirname(__file__), typesFolderPath)
666 )
667
668 # generated tests output path
669 generatedTestsFolderPath = "./generated_tests/"
670 generatedTestsFolderPath = os.path.abspath(
671 os.path.join(os.path.dirname(__file__), generatedTestsFolderPath)
672 )
673
674 # make directory if it doesn't exist
675 pathlib.Path(generatedTestsFolderPath).mkdir(exist_ok=True)
676 oldFiles = glob.glob(os.path.join(generatedTestsFolderPath, "*"))
677 for file in oldFiles:
678 os.remove(file)
679
680 types = os.listdir(typesFolderPath)
681 types.sort()
682 baseTypesDict = dict() # map baseTypes to the types that inherit from them
683 print("generating tests for...")
684 for type in types:
685
686 print(type)
687 # check valid type to generate tests for
688 if ".hpp" not in type:
689 continue
690
691 typePath = os.path.join(typesFolderPath, type)
692
693 # make dictionary with info for template
694 info = makeTestInfoDict(typePath, baseTypesDict)
695 if (info is None):
696 # empty dictionary, no tests to make
697 continue
698
699 # render
700 renderedHeader = catchTemplate.render(info)
701
702 # write generated file
703 outPath = os.path.join(generatedTestsFolderPath, info["genTestFname"])
704 with open(outPath,"w") as outfile:
705 print(renderedHeader,file=outfile)
706
707 # handle types that inherit from baseTypes
708 for baseType, inheritedTypes in baseTypesDict.items():
709
710 if len(inheritedTypes) == 0:
711 continue
712
713 for inheritedType in inheritedTypes:
714 info = makeInheritedTypeInfoDict(typesFolderPath, baseType, inheritedType)
715 if (info is None):
716 # empty dictionary, no tests to make
717 continue
718
719 # render
720 renderedHeader = catchTemplate.render(info)
721
722 # write generated file
723 outPath = os.path.join(generatedTestsFolderPath, info["genTestFname"])
724 with open(outPath,"w") as outfile:
725 print(renderedHeader,file=outfile)
726
727
728if (__name__ == "__main__"):
729 main()
dict makeInheritedTypeInfoDict(str typesFolderPath, str baseName, str logName)
dict makeTestInfoDict(str hppFname, dict baseTypesDict)
findMatchingSchemaField(schemaFieldInfo, fieldName)
str getTestValFromType(str fieldType, schemaFieldType=None)
tuple[str, tuple] getSchemaFieldInfo(str fname)
tuple[str, str] getTypeAndName(list fieldParts)
bool typesCorrespond(str fbsType, str cType)
getMessageFieldInfo(list messageStructIdxs, list lines, tuple schemaFieldInfo)
setDefaultArgOfLastField(fieldsList, fieldParts)