Many languages have support for metaprogramming and code generation, but in Go, it feels like something that requires more effort. For instance, if you Google “How to write a Go code generator” you’ll find very few articles explaining the tools and concepts even for creating a simple code generator. So, in this short post, I’d like to share a small program that generates wrapping functions for the given type methods, and this specific example can then act as a good starting point for your own Go code generator. If you want to write their own Go code generator but don’t know how to start, then read on!
TL;DR:
- Access the metadata of your packages with
golang.org/x/tools/go/packages
- Access the metadata of your types with std
go/types
- Generate Go code with
github.com/dave/jennifer/jen
- Care more about readability than performance
- Start by reading the code from a sample generator
Why did I need to write a generator?
I found myself repeatedly writing boring boilerplate code, so naturally, I decided to try and somehow generate it in order to save time and reduce errors. So, I set out and attempted to search for some simple examples on the net. However, the only useful article I could find was this one: Metaprogramming with Go - or how to build code generators that parse Go code. Many other articles just introduced the basic go generate
abilities with a Stringer
generator example, like this one.
Irina Nazarova CEO at Evil Martians
So, I dug into the docs of the packages I found in that first article and created a “not-a-hello-world” Go code generator. It’s simple, and I want to share it so you can grasp the concepts without the need to read a whole mess of docs.
The project I’m working on is a daemon app that uses the SQLite3 database to store the metadata of downloaded files and folders. As many different places in the code make use of that database connection, so the need to pass it everywhere eventually became a serious pain. To remove this dependency injection I decided to use a private global variable, initialize it once, and simply use public proxy functions to access its methods. This reduced the complexity of the codebase but required a lot of typing:
package db
type DB struct {
// ... some private fields
}
func (d *DB) InsertDocuments(documents []*DocumentParams) ([]*Document, error) {
// ...
}
func (d *DB) DeleteDocument(id string) error {
// ...
}
var globalDB *DB
func SetGlobalDB(value *DB) {
globalDB = value
}
// Boilerplate code: candidates for code generation
func InsertDocuments(documents []*DocumentParams) ([]*Document, error) {
if globalDB == nil {
panic("globalDB is not set")
}
return globalDB.InsertDocuments(documents)
}
func DeleteDocument(id string) error {
if globalDB == nil {
panic("globalDB is not set")
}
return globalDB.DeleteDocument(id)
}
Since there were 50 methods like that, I considered writing a small code generator to be more efficient and less time consuming than maintaining these proxy functions myself.
How to write a code generator in Go?
Let’s take a simple example: using the jen
package to generate code parts.
package main
import (
"log"
"github.com/dave/jennifer/jen"
)
func main() {
f := jen.NewFile("main")
// A comment to mark the file as generated
f.PackageComment("Code generated by generator, DO NOT EDIT.")
f.Func().Id("Hello").Block(
jen.Qual("fmt", "Println").Call(jen.Lit("Hello from generated code!")),
)
targetFilename := "main_gen.go"
if err := f.Save(targetFilename); err != nil {
log.Fatal(err)
}
}
After running go run main.go
, the generated file main_gen.go
will have the following content:
// Code generated by generator, DO NOT EDIT.
package main
import "fmt"
func Hello() {
fmt.Println("Hello from generated code!")
}
However, the code generator is usually called with a go generate ./...
command which finds all the comments in the Go code starting with //go:generate
and runs the command that follows this comment. So, we can use Go’s ENV variables to get at least the filename. Other options like full package name and type name can be passed as command arguments.
The generator will do the following:
- Load the package metadata and get all public methods of a given type.
- Generate the functions; this includes collecting all package names we have to add to imports.
- Put the generated code, including the special
var
declaration andSetGlobal<Type>
function to set it. - Save the resulting code to a file with a
_gen.go
suffix.
package main
import (
"go/types"
"log"
"strings"
"golang.org/x/tools/go/packages"
"github.com/dave/jennifer/jen"
)
func main() {
// 1. Load the package metadata and get all public methods of a given type.
if len(os.Args) < 3 {
log.Fatal("Usage: //go:generate go run gen_functions.go <full package name> <type name>")
}
pkgName := os.Args[1]
typeName := os.Args[2]
// Load package metadata. We need it for accessing typeName's methods.
pkg := loadPackage(pkgName)
typeObject := pkg.Types.Scope().Lookup(typeName)
if typeObject == nil {
log.Fatalf("unable to find type %s", typeName)
}
// We will support only pointer semantics
typeObjectMeta := types.NewPointer(typeObject.Type())
methods := types.NewMethodSet(typeObjectMeta)
for i := 0; i < methods.Len(); i++ {
method := methods.At(i).Obj()
// Skip private methods
if !method.Exported() {
continue
}
// 2. Generate the functions.
// ...
}
// 3. Put the generated code
// ...
// 4. Save the results to a file
// ...
}
func loadPackage(path string) *packages.Package {
cfg := &packages.Config{Mode: packages.NeedTypes}
pkgs, err := packages.Load(cfg, path)
if err != nil {
log.Fatalf("failed to load package '%s' for inspection: %v", path, err)
}
if packages.PrintErrors(pkgs) > 0 {
log.Fatalf("package errors: %v", packages.PrintErrors(pkgs))
}
return pkgs[0]
}
For every public method we want to have the same signature, add a check for the global var, and add a proxy the call. We also have to manually handle the imports if we use the types from other packages:
package main
// ...
func main() {
// 1. Load the package metadata and get all public methods of a given type.
// ...
varName := "global" + typeName
imports := make(map[string]struct{})
funcs := jen.Empty()
for i := 0; i < methods.Len(); i++ {
// ...
// 2. Generate the functions.
signature := method.Type().(*types.Signature)
signatureParams := signature.Params()
params := make([]jen.Code, 0, signatureParams.Len())
paramNames := make([]jen.Code, 0, signatureParams.Len())
for j := 0; j < signatureParams.Len(); j++ {
param := signatureParams.At(j)
paramType := param.Type().String()
// Remove current package prefix
paramType = strings.ReplaceAll(paramType, packageName+".", "")
// Remove prefixes of imported packages
paramType = extractImports(paramType, imports)
params = append(params, jen.Id(param.Name()).Id(paramType))
paramNames = append(paramNames, jen.Id(param.Name()))
}
results := signature.Results()
returnTypes := make([]string, 0, results.Len())
for j := 0; j < results.Len(); j++ {
resultType := results.At(j).Type().String()
// Remove current package prefix
resultType = strings.ReplaceAll(resultType, packageName+".", "")
// Remove prefixes of imported packages
resultType = extractImports(resultType, imports)
returnTypes = append(returnTypes, resultType)
}
var returnType string
var lastStatement jen.Code = jen.Id(varName).Dot(method.Name()).Call(paramNames...)
if len(returnTypes) > 0 {
returnType = "(" + strings.Join(returnTypes, ", ") + ")"
lastStatement = jen.Return(lastStatement)
}
// Generate the proxy function
funcs.Comment(method.Name() + " generated.")
funcs.Line()
funcs.Func().Id(method.Name()).Params(params...).Id(returnType).Block(
jen.If(jen.Id(varName).Op("==").Id("nil")).Block(
jen.Panic(jen.Lit(varName + " is not set.")),
),
jen.Line(),
lastStatement,
)
funcs.Line()
}
// 3. Put the generated code
// ...
// 4. Save the results to a file
// ...
}
var rePkgReferencedType = regexp.MustCompile(`([^*\[\]]+)\.[^.]+`)
// Remove imported packages prefixes and accumulate missing imports.
//
// github.com/somelib/name.Type -> name.Type
// []github.com/somelib/name.Item -> []name.Item
func extractImports(paramType string, imports map[string]struct{}) string {
if strings.Index(paramType, ".") < 0 {
return paramType
}
submatches := rePkgReferencedType.FindStringSubmatch(paramType)
imports[submatches[1]] = struct{}{}
for packageName := range imports {
paramType = strings.ReplaceAll(paramType, filepath.Dir(packageName)+"/", "")
}
return paramType
}
Finally, we want to collect the generated parts and save the code into a file with the same name, but while adding a _gen.go
suffix. To get the filename, we use the GOFILE
environment variable which returns the path to the file where the //go:generate
line was found.
package main
// ...
func main() {
// 1. Load the package metadata and get all public methods of a given type.
// ...
// 2. Generate the functions.
// ...
// 3. Put the generated code
// NOTE: `jen` has a Qual function that automatically adds new imports
// but here we parse imports ourselves so we have to add them manually.
if len(imports) > 0 {
importsList := jen.Empty()
for importPackage := range imports {
importsList.Add(
jen.Lit(importPackage),
jen.Line(),
)
}
f.Comment("Import missing dependencies")
f.Add(jen.Id("import").Parens(importsList))
}
// Declare a global variable for our singleton
f.Id("var").Id(varName).Id(typeName)
// Add a function to set the global variable
f.Func().Id("SetGlobal"+typeName).Params(jen.Id("value").Id(typeName)).Block(
jen.Id(varName).Op("=").Id("value"),
)
// Add generated proxy functions
f.Add(funcs)
// 4. Save the results to a file
filename := os.Getenv("GOFILE")
ext := filepath.Ext(filename)
filename, _ = strings.CutSuffix(filename, ext)
filename = filename + "_gen.go"
if err := f.Save(filename); err != nil {
log.Fatal(err)
}
}
Check it out:
If you need the full code, you can check out the gen-singleton-functions repo. Use it to get inspiration for your own Go code generation tool! By the way, we’ve only used about 20% of the abilities that the jen
and go/types
packages can provide, but I encourage you to explore these libs and find something you find useful.