API
 
Loading...
Searching...
No Matches
generateTemplatedCatch2Tests.py
Go to the documentation of this file.
1#!/bin/env python3
2
3'''
4Generate Catch2 tests from template.
5See README.md for more details.
6'''
7
8import os
9import sys
10import subprocess
11import glob
12import re
13import pathlib
14import string
15import random
16import getopt
17
18
19gNextVals = {
20 "string" : 0,
21 "int64" : 0,
22 "uint64" : 0,
23 "int32" : 0,
24 "uint32" : 0,
25 "int16" : 0,
26 "uint16" : 0,
27 "int8" : 0,
28 "uint8" : 0,
29 "float" : 0,
30 "double" : 0
31}
32gIncrementingVals = False
33
34# check jinja2 is installed. install it if not
35try:
36 import jinja2
37except ModuleNotFoundError:
38 print("module 'Jinja2' is not installed. Installing Jinja2...")
39 subprocess.check_call([sys.executable, "-m", "pip", "install", 'Jinja2'])
40 import jinja2
41
42
43'''
44Get base type of log. This is needed for log types that inherit from a base type
45that specfies the messageT(...)
46'''
47def getBaseType(lines : list) -> str:
48 # use regex to find #include "<baseType>.hpp"
49 baseType = ""
50 for line in lines:
51 match = re.search(r'^struct [a-z_]* : public [a-z_]*', line)
52 if match != None:
53 baseType = line.strip().split()[-1]
54 baseType = baseType.split("<")[0]
55
56 return baseType
57
58
59'''
60NOTE: This relies on name order in .fbs schema and .hpp files to be the same.
61'''
62def getSchemaFieldInfo(fname : str) -> tuple[str, tuple] :
63 schemaFolderPath = "./../types/schemas/"
64 schemaFolderPath = os.path.abspath(
65 os.path.join(os.path.dirname(__file__), schemaFolderPath)
66 )
67
68 schemaFilePath = os.path.join(schemaFolderPath, f"{fname}.fbs")
69 if not os.path.isfile(schemaFilePath):
70 return "", tuple()
71
72 schemaFile = open(schemaFilePath, "r")
73
74 schemaFieldInfo = []
75 subTables = dict() # dict where key is sub-table name, value is [(fieldname, type)...]
76 curSubTable = None
77 inTable = False
78 schemaTableName = ""
79 for line in schemaFile:
80 if "table" in line:
81 # check if `table <log_type>_fb`
82 match = re.search(r'^table [a-zA-Z_]*_fb', line)
83 if match != None:
84 line = line.strip().split()
85 tableIdx = line.index("table")
86 schemaTableName = line[tableIdx + 1]
87 # otherwise it is a sub-table
88 else:
89 line = line.strip().split()
90 subNameIdx = line.index("table")
91 subName = line[subNameIdx + 1]
92 subTables[subName] = []
93 curSubTable = subName # we are in a sub-table of the schema
94
95 if not inTable and "{" in line:
96 inTable = True
97 continue
98
99 if inTable:
100 line = line.strip()
101 if ("//" in line):
102 continue
103
104 if ("}" in line):
105 inTable = False
106 curSubTable = None
107 continue
108
109 if (line != ""):
110 lineParts = line.strip().rstrip(";").split(":")
111 name = lineParts[0]
112 type = lineParts[1].split()[0]
113
114 if curSubTable is not None:
115 # add to subtable dict for now, will be added in later
116 subTables[curSubTable].append((name, type))
117 else:
118 schemaFieldInfo.append((name, type))
119 continue
120
121 if len(subTables) == 0:
122 return schemaTableName, tuple(schemaFieldInfo)
123
124
125 # go through sub tables and add them in
126 newSchemaFieldInfo = []
127 for field in schemaFieldInfo:
128 fieldType = field[1]
129 if fieldType in subTables.keys():
130 newSchemaFieldInfo.append({field[0] : subTables[fieldType]})
131 else:
132 newSchemaFieldInfo.append(field)
133 # print(newSchemaFieldInfo)
134 return schemaTableName, tuple(newSchemaFieldInfo)
135
136
137'''
138Quick check that the types in .fbs correspond, mainly strings match to strings,
139and vectors to vectors.
140If they do not correspond, the behavior for comparing the fb values in the tests
141is undefined, and action beyond this generator will need to be taken.
142'''
143def typesCorrespond(fbsType : str, cType : str) -> bool:
144 if ("[" in fbsType) or ("vector" in cType):
145 return ("[" in fbsType) and ("vector" in cType)
146
147 if ("string" in fbsType) or ("string" in cType or "char *" in cType):
148 return (("string" in fbsType) and ("string" in cType or "char *" in cType))
149
150 return True
151
152
153'''
154Check it is not a base log type.
155Must have eventCode and defaultLevel
156'''
157def isValidLogType(lines : list) -> bool:
158 hasEventCode = False
159 hasDefaultLevel = False
160 for line in lines:
161
162 # check event code
163 eventCode = re.search("flatlogs::eventCodeT eventCode = eventCodes::[A-Za-z_0-9]*;", line)
164 if eventCode != None:
165 hasEventCode = True
166
167 # check default level
168 defaultLevel = re.search("flatlogs::logPrioT defaultLevel = flatlogs::logPrio::[A-Za-z_0-9]*;", line)
169 if defaultLevel != None:
170 hasDefaultLevel = True
171
172 # if we have both already, return
173 if hasEventCode and hasDefaultLevel:
174 return True
175
176 return (hasEventCode and hasDefaultLevel)
177
178
179
180def makeTestInfoDict(hppFname : str, baseTypesDict : dict) -> dict:
181 returnInfo = dict()
182 headerFile = open(hppFname,"r")
183 headerLines = headerFile.readlines()
184
185 # add name of test/file/type to be generated
186 fNameParts = hppFname.split("/")
187 returnInfo["name"] = fNameParts[-1].strip().split(".")[0]
188 CamelCase = "".join([word.capitalize() for word in returnInfo["name"].split("_")])
189 returnInfo["nameCamelCase"] = CamelCase[0].lower() + CamelCase[1:]
190 # print(f"LOGNAME: {returnInfo["name"]}")
191 returnInfo["genTestFname"] = f"{returnInfo['name']}_generated_tests.cpp"
192 returnInfo["className"] = "C" + "".join([word.capitalize() for word in returnInfo["name"].split("_")])
193 returnInfo["classVarName"] = "".join([word[0].lower() for word in returnInfo["name"].split("_")])
194 returnInfo["baseType"] = getBaseType(headerLines)
195 returnInfo["hasGeneratedHfile"] = hasGeneratedHFile(returnInfo["name"])
196
197
198 # cannot generate tests from this file alone, need base type
199 if not isValidLogType(headerLines):
200 if returnInfo["name"] not in baseTypesDict:
201 baseTypesDict[returnInfo["name"]] = set()
202 return None # don't render anything from this file
203
204 # find where messageT structs are being made -> describes fields
205 messageStructIdxs = []
206 for i in range(len(headerLines)):
207 if "messageT(" in headerLines[i]:
208 messageStructIdxs.append(i)
209
210 schemaTableName, schemaFieldInfo = getSchemaFieldInfo(returnInfo["name"])
211 returnInfo["schemaTableName"] = schemaTableName
212
213 # handle log types that inherit from base types
214 if len(messageStructIdxs) == 0:
215
216 if returnInfo["baseType"] not in baseTypesDict:
217 baseTypesDict[returnInfo["baseType"]] = set()
218
219 # add inhertied type to dict where val is the base type it inherits from
220 baseTypesDict[returnInfo["baseType"]].add(returnInfo["name"])
221 return None # don't render me yet!
222
223
224 returnInfo["messageTypes"] = getMessageFieldInfo(messageStructIdxs, headerLines, schemaFieldInfo)
225
226 return returnInfo
227
228'''
229Parse out field type and name from string
230'''
231def getTypeAndName(lineParts : list) -> tuple[str, str]:
232
233 typeIdxStart = 1 if (lineParts[0] == "const") else 0
234 type = lineParts[typeIdxStart]
235
236 if lineParts[typeIdxStart + 1] == "&":
237 nameIdx = (typeIdxStart + 2)
238 elif lineParts[typeIdxStart + 1] == "*":
239 nameIdx = (typeIdxStart + 2)
240 type += " " + lineParts[typeIdxStart + 1]
241 else:
242 nameIdx = (typeIdxStart + 1)
243
244 name = lineParts[nameIdx].rstrip(")").rstrip(",")
245
246 return type, name
247
248'''
249Checks if log type has a corresponding generated .h file in ./types/generated
250'''
251def hasGeneratedHFile(logName : str) -> bool:
252 generatedFolderPath = "./../types/generated/"
253 generatedFolderPath = os.path.abspath(
254 os.path.join(os.path.dirname(__file__), generatedFolderPath)
255 )
256
257 generatedFilePath = os.path.join(generatedFolderPath, f"{logName}_generated.h")
258 if os.path.isfile(generatedFilePath):
259 return True
260
261 return False
262
263def getIntSize(type : str) -> int:
264 intSizeBits = 32 # default size 32 bits
265 if "_t" in type:
266 typeParts = type.split("_t")
267 intSizeBits = int(typeParts[0][-1]) if (int(typeParts[0][-1]) == 8) \
268 else int(typeParts[0][-2:])
269
270 return intSizeBits
271
272
273def getRandInt(type : str) -> int:
274 unsigned = True if "uint" in type else False
275
276 intSizeBits = getIntSize(type)
277
278 if not unsigned:
279 intSizeBits -= 1
280
281 max = (2 ** intSizeBits) - 1
282 min = 0 if unsigned else (0 - max - 1)
283
284 return random.randint(min, max)
285
286def getIncrementingInt(type : str) -> int:
287 intSizeBits = getIntSize(type)
288
289 max = (2 ** intSizeBits) - 1
290
291 if "int8_t" in type:
292 gNextVals["int8"] = (gNextVals["int8"] + 1) % max
293 return gNextVals["int8"]
294 elif "uint8_t" in type:
295 gNextVals["uint8"] = (gNextVals["uint8"] + 1) % max
296 return gNextVals["uint8"]
297 elif "int16_t" in type:
298 gNextVals["int16"] = (gNextVals["int16"] + 1) % max
299 return gNextVals["int16"]
300 elif "uint16_t" in type:
301 gNextVals["uint16"] = (gNextVals["uint16"] + 1) % max
302 return gNextVals["uint16"]
303 elif "int32_t" in type:
304 gNextVals["int32"] = (gNextVals["int32"] + 1) % max
305 return gNextVals["int32"]
306 elif "uint32_t" in type:
307 gNextVals["uint32"] = (gNextVals["uint32"] + 1) % max
308 return gNextVals["uint32"]
309 elif "int64_t" in type:
310 gNextVals["int64"] = (gNextVals["int64"] + 1) % max
311 return gNextVals["int64"]
312 elif "uint64_t" in type:
313 gNextVals["uint64"] = (gNextVals["uint64"] + 1) % max
314 return gNextVals["uint64"]
315 else:
316 gNextVals["int32"] = (gNextVals["int32"] + 1) % max
317 return gNextVals["int32"]
318
319def getTestValFromType(fieldType : str, schemaFieldType = None) -> str:
320 if "bool" in fieldType or (schemaFieldType is not None and "bool" in schemaFieldType):
321 return "1"
322 elif "string" in fieldType or "char *" in fieldType:
323 if gIncrementingVals:
324 gNextVals["string"] += 1
325 return f'"{gNextVals["string"]}"'
326 randString = ''.join(random.choices(string.ascii_lowercase + string.digits, k=10))
327 return f'"{randString}"'
328 elif "int" in fieldType:
329 if gIncrementingVals:
330 return str(getIncrementingInt(fieldType))
331 # need 'u' suffix for randomly generated uint64_t to avoid:
332 # "warning: integer constant is so large that it is unsigned"
333 return f'{str(getRandInt(fieldType))}u' if "uint64_t" in fieldType else str(getRandInt(fieldType))
334 elif "float" in fieldType:
335 if gIncrementingVals:
336 gNextVals["float"] += 1
337 return str(round( (gNextVals["float"] / 100000), 6))
338 return str(round(random.random(), 6))
339 elif "double" in fieldType:
340 if gIncrementingVals:
341 gNextVals["double"] += 1
342 return str(round( (gNextVals["double"] / 10000000000), 14))
343 return str(round(random.random(), 14))
344 else:
345 return "{}"
346
347
348def makeTestVal(fieldDict : dict) -> str:
349 if "vector" in fieldDict["type"]:
350 vals = [ getTestValFromType(fieldDict["vectorType"]) for i in range(10)]
351
352 # special case telem_pokecenter because vector follows specific format
353 if fieldDict["name"] == "pokes" and "vector<float" in fieldDict["type"]:
354 catchAssertVals = [vals[i] for i in range(0, len(vals), 2)]
355 fieldDict["specialAssertVal"] = f"{{ {",".join(catchAssertVals)} }}"
356 return f"{{ {",".join(vals)} }}"
357
358 if "schemaType" in fieldDict:
359 return getTestValFromType(fieldDict["type"], fieldDict["schemaType"])
360
361 return getTestValFromType(fieldDict["type"])
362
363
364
365'''
366make 2d array. each inner array contains dictionaries corresponding to
367the type(s) and name(s) of field(s) in a message:
368[ [ {type : x, name: y ...}, {name: type, ...} ], ... ]
369'''
370def getMessageFieldInfo(messageStructIdxs: list, lines : list, schemaFieldInfo : tuple):
371 msgTypesList = []
372 subTableDictIndex = 0
373
374 # extract log field types and names
375 for i in range(len(messageStructIdxs)):
376 structIdx = messageStructIdxs[i]
377 msgsFieldsList = []
378
379 closed = False
380 fieldCount = 0
381 while not closed and structIdx < len(lines):
382
383 line = lines[structIdx]
384
385 # check if this is a closing line
386 if ")" in line:
387 if "//" in line and line.find(")") > line.find("//"):
388 # parenthesis is in comment
389 pass
390 elif line.strip().strip(")") == "":
391 break
392 else:
393 closed = True # parse the field, don't leave loop yet
394
395
396 # trim line to just get field info
397 indexStart = (line.find("messageT(") + len("messageT(")) if "messageT(" in line else 0
398 indexEnd = line.find("//") if "//" in line else len(line)
399 line = line[indexStart:indexEnd]
400
401 lineParts = [part.strip().split() for part in line.strip().rstrip(",").split(",")]
402
403 for field in lineParts:
404 fieldDict = {}
405 if len(field) > 0 and "//" in field[0]:
406 break
407
408 # find type and name
409 type, name = getTypeAndName(field)
410
411 fieldDict["type"] = type
412 fieldDict["name"] = name
413 # get vector type if necessary
414 if "std::vector" in fieldDict["type"]:
415 typeParts = fieldDict["type"].split("<")
416 vectorIdx = [i for i, e in enumerate(typeParts) if "std::vector" in e][0]
417 vectorType = typeParts[vectorIdx + 1].strip(">")
418 fieldDict["vectorType"] = vectorType
419
420 if len(schemaFieldInfo) != 0:
421 if isinstance(schemaFieldInfo[fieldCount], tuple):
422 fieldDict["schemaName"] = schemaFieldInfo[fieldCount][0]
423 fieldDict["schemaType"] = schemaFieldInfo[fieldCount][1]
424 fieldCount += 1
425 else:
426 # go into dictionary..
427 subTableName = next(iter(schemaFieldInfo[fieldCount]))
428 schemaFieldName = schemaFieldInfo[fieldCount][subTableName][subTableDictIndex][0]
429 schemaFieldType = schemaFieldInfo[fieldCount][subTableName][subTableDictIndex][1]
430 fieldDict["schemaName"] = f"{subTableName}()->{schemaFieldName}"
431 fieldDict["schemaType"] = schemaFieldType
432 subTableDictIndex += 1
433 if (subTableDictIndex >= len(schemaFieldInfo[fieldCount][subTableName])):
434 # reset dictionary index if we need to
435 subTableDictIndex = 0
436 fieldCount += 1
437
438 # check schemaType correlates to type in .hpp file
439 if not typesCorrespond(fieldDict["schemaType"], fieldDict["type"]):
440 # if types don't correspond, then use name in messageT and hope for best.
441 # this is why if types are different, then names MUST correspond between
442 # .fbs and .hpp file
443 del fieldDict["schemaName"]
444
445 fieldDict["testVal"] = makeTestVal(fieldDict)
446
447 # add field dict to list of fields
448 msgsFieldsList.append(fieldDict)
449
450 structIdx += 1
451
452 msgTypesList.append(msgsFieldsList)
453
454 return msgTypesList
455
456def makeInheritedTypeInfoDict(typesFolderPath : str, baseName : str, logName : str) -> dict:
457 returnInfo = dict()
458
459 baseFilePath = os.path.join(typesFolderPath, f"{baseName}.hpp")
460 baseHFile = open(baseFilePath,"r")
461
462 # add name of test/file/type to be generated
463 # print(f"LOGNAME: {logName}")
464 returnInfo["name"] = logName
465 returnInfo["genTestFname"] = f"{returnInfo['name']}_generated_tests.cpp"
466 returnInfo["className"] = "C" + "".join([word.capitalize() for word in returnInfo["name"].split("_")])
467 CamelCase = "".join([word.capitalize() for word in returnInfo["name"].split("_")])
468 returnInfo["nameCamelCase"] = CamelCase[0].lower() + CamelCase[1:]
469 returnInfo["classVarName"] = "".join([word[0].lower() for word in returnInfo["name"].split("_")])
470 returnInfo["baseType"] = baseName
471 returnInfo["hasGeneratedHfile"] = hasGeneratedHFile(logName)
472
473
474 baseHLines = baseHFile.readlines()
475
476 # find where messageT structs are being made in base log file -> describes fields
477 messageStructIdxs = []
478 for i in range(len(baseHLines)):
479 if "messageT(" in baseHLines[i]:
480 messageStructIdxs.append(i)
481
482 schemaTableName, schemaFieldInfo = getSchemaFieldInfo(baseName)
483
484 returnInfo["schemaTableName"] = schemaTableName
485 msgFieldInfo = getMessageFieldInfo(messageStructIdxs, baseHLines, schemaFieldInfo)
486
487 returnInfo["messageTypes"] = [[]] if "empty_log" in baseName else msgFieldInfo
488
489 return returnInfo
490
491def versionAsNumber(major, minor):
492 return (major * 1000 + minor)
493
494def main():
495 # check python version >= 3.9
496 if (versionAsNumber(sys.version_info[0], sys.version_info[1]) < versionAsNumber(3,9)):
497 print("Error: Python version must be >= 3.9")
498 exit(0)
499
500
501 global gIncrementingVals
502 gIncrementingVals = False
503
504 # getopt for random seed or incrementing vals
505 try:
506 opts, args = getopt.getopt(sys.argv[1:], "is:")
507 if len(opts) > 1:
508 print("Error: Only one option allowed. -s <seed> or -i for incrementing values.")
509 exit(0)
510
511 except getopt.GetoptError:
512 print("Usage: python3 ./generateTemplatedCatch2Tests.py -s <seed> | -i")
513 exit(0)
514 for opt, arg in opts:
515 if opt in ["-s"]:
516 if not arg.isdigit():
517 print(f"Error: random seed {arg} provided is not an integer.")
518 exit(0)
519 # use random seed if provided with -s
520 random.seed(int(arg))
521 if opt in ["-i"]:
522 gIncrementingVals = True
523
524 # load template
525 env = jinja2.Environment(
526 loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__))
527 )
528 env.trim_blocks = True
529 env.lstrip_blocks = True
530
531 catchTemplate = env.get_template("catch2TestTemplate.jinja2")
532
533 # path to .hpp files here
534 typesFolderPath = "./../types"
535 typesFolderPath = os.path.abspath(
536 os.path.join(os.path.dirname(__file__), typesFolderPath)
537 )
538
539 # generated tests output path
540 generatedTestsFolderPath = "./generated_tests/"
541 generatedTestsFolderPath = os.path.abspath(
542 os.path.join(os.path.dirname(__file__), generatedTestsFolderPath)
543 )
544
545 # make directory if it doesn't exist
546 pathlib.Path(generatedTestsFolderPath).mkdir(exist_ok=True)
547 oldFiles = glob.glob(os.path.join(generatedTestsFolderPath, "*"))
548 for file in oldFiles:
549 os.remove(file)
550
551 types = os.listdir(typesFolderPath)
552 types.sort()
553 baseTypesDict = dict() # map baseTypes to the types that inherit from them
554 for type in types:
555
556 # check valid type to generate tests for
557 if ".hpp" not in type:
558 continue
559
560 typePath = os.path.join(typesFolderPath, type)
561
562 # make dictionary with info for template
563 info = makeTestInfoDict(typePath, baseTypesDict)
564 if (info is None):
565 # empty dictionary, no tests to make
566 continue
567
568 # render
569 renderedHeader = catchTemplate.render(info)
570
571 # write generated file
572 outPath = os.path.join(generatedTestsFolderPath, info["genTestFname"])
573 with open(outPath,"w") as outfile:
574 print(renderedHeader,file=outfile)
575
576 # handle types that inherit from baseTypes
577 for baseType, inheritedTypes in baseTypesDict.items():
578
579 if len(inheritedTypes) == 0:
580 continue
581
582 for inheritedType in inheritedTypes:
583 info = makeInheritedTypeInfoDict(typesFolderPath, baseType, inheritedType)
584 if (info is None):
585 # empty dictionary, no tests to make
586 continue
587
588 # render
589 renderedHeader = catchTemplate.render(info)
590
591 # write generated file
592 outPath = os.path.join(generatedTestsFolderPath, info["genTestFname"])
593 with open(outPath,"w") as outfile:
594 print(renderedHeader,file=outfile)
595
596
597if (__name__ == "__main__"):
598 main()
dict makeInheritedTypeInfoDict(str typesFolderPath, str baseName, str logName)
dict makeTestInfoDict(str hppFname, dict baseTypesDict)
tuple[str, str] getTypeAndName(list lineParts)
str getTestValFromType(str fieldType, schemaFieldType=None)
tuple[str, tuple] getSchemaFieldInfo(str fname)
bool typesCorrespond(str fbsType, str cType)
getMessageFieldInfo(list messageStructIdxs, list lines, tuple schemaFieldInfo)