A taste of Go code generator magic: a quick guide to getting started

Cover for A taste of Go code generator magic: a quick guide to getting started

Many languages have support for metaprogramming and code generation, but in Go, it feels like something that requires more effort. For instance, if you Google “How to write a Go code generator” you’ll find very few articles explaining the tools and concepts even for creating a simple code generator. So, in this short post, I’d like to share a small program that generates wrapping functions for the given type methods, and this specific example can then act as a good starting point for your own Go code generator. If you want to write their own Go code generator but don’t know how to start, then read on!

TL;DR:

  1. Access the metadata of your packages with golang.org/x/tools/go/packages
  2. Access the metadata of your types with std go/types
  3. Generate Go code with github.com/dave/jennifer/jen
  4. Care more about readability than performance
  5. Start by reading the code from a sample generator

Why did I need to write a generator?

I found myself repeatedly writing boring boilerplate code, so naturally, I decided to try and somehow generate it in order to save time and reduce errors. So, I set out and attempted to search for some simple examples on the net. However, the only useful article I could find was this one: Metaprogramming with Go - or how to build code generators that parse Go code. Many other articles just introduced the basic go generate abilities with a Stringer generator example, like this one.

Schedule call

Irina Nazarova CEO at Evil Martians

Schedule call

So, I dug into the docs of the packages I found in that first article and created a “not-a-hello-world” Go code generator. It’s simple, and I want to share it so you can grasp the concepts without the need to read a whole mess of docs.

The project I’m working on is a daemon app that uses the SQLite3 database to store the metadata of downloaded files and folders. As many different places in the code make use of that database connection, so the need to pass it everywhere eventually became a serious pain. To remove this dependency injection I decided to use a private global variable, initialize it once, and simply use public proxy functions to access its methods. This reduced the complexity of the codebase but required a lot of typing:

package db

type DB struct {
    // ... some private fields
}

func (d *DB) InsertDocuments(documents []*DocumentParams) ([]*Document, error) {
    // ...
}

func (d *DB) DeleteDocument(id string) error {
    // ...
}

var globalDB *DB

func SetGlobalDB(value *DB) {
    globalDB = value
}

// Boilerplate code: candidates for code generation

func InsertDocuments(documents []*DocumentParams) ([]*Document, error) {
    if globalDB == nil {
        panic("globalDB is not set")
    }

    return globalDB.InsertDocuments(documents)
}

func DeleteDocument(id string) error {
    if globalDB == nil {
        panic("globalDB is not set")
    }

    return globalDB.DeleteDocument(id)
}

Since there were 50 methods like that, I considered writing a small code generator to be more efficient and less time consuming than maintaining these proxy functions myself.

How to write a code generator in Go?

Let’s take a simple example: using the jen package to generate code parts.

package main

import (
	"log"

	"github.com/dave/jennifer/jen"
)

func main() {
	f := jen.NewFile("main")

    // A comment to mark the file as generated
	f.PackageComment("Code generated by generator, DO NOT EDIT.")

    f.Func().Id("Hello").Block(
        jen.Qual("fmt", "Println").Call(jen.Lit("Hello from generated code!")),
    )

	targetFilename := "main_gen.go"

	if err := f.Save(targetFilename); err != nil {
		log.Fatal(err)
	}
}

After running go run main.go, the generated file main_gen.go will have the following content:

// Code generated by generator, DO NOT EDIT.
package main

import "fmt"

func Hello() {
        fmt.Println("Hello from generated code!")
}

However, the code generator is usually called with a go generate ./... command which finds all the comments in the Go code starting with //go:generate and runs the command that follows this comment. So, we can use Go’s ENV variables to get at least the filename. Other options like full package name and type name can be passed as command arguments.

The generator will do the following:

  1. Load the package metadata and get all public methods of a given type.
  2. Generate the functions; this includes collecting all package names we have to add to imports.
  3. Put the generated code, including the special var declaration and SetGlobal<Type> function to set it.
  4. Save the resulting code to a file with a _gen.go suffix.
package main

import (
	"go/types"
    "log"
    "strings"

	"golang.org/x/tools/go/packages"
	"github.com/dave/jennifer/jen"
)

func main() {
    // 1. Load the package metadata and get all public methods of a given type.

    if len(os.Args) < 3 {
        log.Fatal("Usage: //go:generate go run gen_functions.go <full package name> <type name>")
    }

    pkgName := os.Args[1]
    typeName := os.Args[2]

    // Load package metadata. We need it for accessing typeName's methods.
    pkg := loadPackage(pkgName)

	typeObject := pkg.Types.Scope().Lookup(typeName)
	if typeObject == nil {
		log.Fatalf("unable to find type %s", typeName)
	}

    // We will support only pointer semantics
	typeObjectMeta := types.NewPointer(typeObject.Type())

	methods := types.NewMethodSet(typeObjectMeta)
	for i := 0; i < methods.Len(); i++ {
		method := methods.At(i).Obj()

		// Skip private methods
		if !method.Exported() {
			continue
		}

        // 2. Generate the functions.
        // ...
    }

    // 3. Put the generated code
    // ...

    // 4. Save the results to a file
    // ...
}

func loadPackage(path string) *packages.Package {
	cfg := &packages.Config{Mode: packages.NeedTypes}
	pkgs, err := packages.Load(cfg, path)
	if err != nil {
		log.Fatalf("failed to load package '%s' for inspection: %v", path, err)
	}
	if packages.PrintErrors(pkgs) > 0 {
		log.Fatalf("package errors: %v", packages.PrintErrors(pkgs))
	}

	return pkgs[0]
}

For every public method we want to have the same signature, add a check for the global var, and add a proxy the call. We also have to manually handle the imports if we use the types from other packages:

package main

// ...

func main() {
    // 1. Load the package metadata and get all public methods of a given type.
    // ...

    varName := "global" + typeName
    imports := make(map[string]struct{})
    funcs := jen.Empty()

	for i := 0; i < methods.Len(); i++ {
        // ...

        // 2. Generate the functions.
		signature := method.Type().(*types.Signature)
		signatureParams := signature.Params()

		params := make([]jen.Code, 0, signatureParams.Len())
		paramNames := make([]jen.Code, 0, signatureParams.Len())

		for j := 0; j < signatureParams.Len(); j++ {
			param := signatureParams.At(j)
			paramType := param.Type().String()

			// Remove current package prefix
			paramType = strings.ReplaceAll(paramType, packageName+".", "")

			// Remove prefixes of imported packages
			paramType = extractImports(paramType, imports)

			params = append(params, jen.Id(param.Name()).Id(paramType))
			paramNames = append(paramNames, jen.Id(param.Name()))
		}

		results := signature.Results()
		returnTypes := make([]string, 0, results.Len())

		for j := 0; j < results.Len(); j++ {
			resultType := results.At(j).Type().String()

			// Remove current package prefix
			resultType = strings.ReplaceAll(resultType, packageName+".", "")

			// Remove prefixes of imported packages
			resultType = extractImports(resultType, imports)

			returnTypes = append(returnTypes, resultType)
		}

        var returnType string
		var lastStatement jen.Code = jen.Id(varName).Dot(method.Name()).Call(paramNames...)

		if len(returnTypes) > 0 {
			returnType = "(" + strings.Join(returnTypes, ", ") + ")"
			lastStatement = jen.Return(lastStatement)
		}

        // Generate the proxy function
		funcs.Comment(method.Name() + " generated.")
		funcs.Line()
		funcs.Func().Id(method.Name()).Params(params...).Id(returnType).Block(
            jen.If(jen.Id(varName).Op("==").Id("nil")).Block(
				jen.Panic(jen.Lit(varName + "  is not set.")),
			),
			jen.Line(),
			lastStatement,
		)
		funcs.Line()
    }

    // 3. Put the generated code
    // ...

    // 4. Save the results to a file
    // ...
}


var rePkgReferencedType = regexp.MustCompile(`([^*\[\]]+)\.[^.]+`)

// Remove imported packages prefixes and accumulate missing imports.
//
//	  github.com/somelib/name.Type   -> name.Type
//    []github.com/somelib/name.Item -> []name.Item
func extractImports(paramType string, imports map[string]struct{}) string {
	if strings.Index(paramType, ".") < 0 {
		return paramType
	}

	submatches := rePkgReferencedType.FindStringSubmatch(paramType)

	imports[submatches[1]] = struct{}{}

	for packageName := range imports {
		paramType = strings.ReplaceAll(paramType, filepath.Dir(packageName)+"/", "")
	}

	return paramType
}

Finally, we want to collect the generated parts and save the code into a file with the same name, but while adding a _gen.go suffix. To get the filename, we use the GOFILE environment variable which returns the path to the file where the //go:generate line was found.

package main

// ...

func main() {
    // 1. Load the package metadata and get all public methods of a given type.
    // ...

    // 2. Generate the functions.
    // ...

    // 3. Put the generated code

    // NOTE: `jen` has a Qual function that automatically adds new imports
    // but here we parse imports ourselves so we have to add them manually.
	if len(imports) > 0 {
		importsList := jen.Empty()
		for importPackage := range imports {
			importsList.Add(
				jen.Lit(importPackage),
				jen.Line(),
			)
		}

		f.Comment("Import missing dependencies")
		f.Add(jen.Id("import").Parens(importsList))
	}

    // Declare a global variable for our singleton
	f.Id("var").Id(varName).Id(typeName)

    // Add a function to set the global variable
	f.Func().Id("SetGlobal"+typeName).Params(jen.Id("value").Id(typeName)).Block(
		jen.Id(varName).Op("=").Id("value"),
	)

    // Add generated proxy functions
	f.Add(funcs)

    // 4. Save the results to a file

	filename := os.Getenv("GOFILE")
	ext := filepath.Ext(filename)
	filename, _ = strings.CutSuffix(filename, ext)
	filename = filename + "_gen.go"

	if err := f.Save(filename); err != nil {
		log.Fatal(err)
	}
}

Check it out:

Let’s run it!

If you need the full code, you can check out the gen-singleton-functions repo. Use it to get inspiration for your own Go code generation tool! By the way, we’ve only used about 20% of the abilities that the jen and go/types packages can provide, but I encourage you to explore these libs and find something you find useful.

Schedule call

Irina Nazarova CEO at Evil Martians

Ready to add some magic to your Go development process? Hire us to create elegant, efficient solutions, and let’s write less boilerplate and more awesome Go code together!