Notes on

Let's Go Further!

by Alex Edwards

| 67 min read


Like the previous book, Let’s Go!, this has been a great learning experience. The author, Alex Edwards, is great at explaining and teaching Go. I particularly enjoyed this book for the more advanced topics and deep dives.

If you find these notes useful, please Go and support Alex by buying the book.

You can find the code I’ve written while reading the book here: lets-go-further-book.

2. Getting Started

MethodUsage
GETUse for actions that retrieve information only and don’t change the state of your application or any data.
POSTUse for non-idempotent actions that modify state. In the context of a REST API, POST is generally used for actions that create a new resource.
PUTUse for idempotent actions that modify the state of a resource at a specific URL. In the context of a REST API, PUT is generally used for actions that replace or update an existing resource.
PATCHUse for actions that partially update a resource at a specific URL. It’s OK for the action to be either idempotent or non-idempotent.
DELETEUse for actions that delete a resource at a specific URL.

“Idempotent”:

A request method is considered “idempotent” if the intended effect on the server of multiple identical requests with that method is the same as the effect for a single such request. Of the request methods defined by this specification, PUT, DELETE, and safe request methods are idempotent.

3. Sending JSON Responses

There are various structures for JSON responses you can use. For example:

In this project, we envelop the responses, basically wrapping them in an object before sending them to the user.

Enveloping is not strictly necessary, but there are some benefits to using envelopes. Enveloping a response means to wrap the response in a JSON object that contains the response data and any metadata. This is useful for returning a response to the client with a consistent structure.

  1. The response is more self-documenting.
  2. Explicit access: The client has to explicitly access the response via, for example, the movie key, making it clear what the response is for.
  3. Security Mitigation: Addresses a potential security issue. More details can be found in this article on JSON security vulnerability.

For sending JSON responses, using json.Marshal is generally better.
While very slightly less efficient, the API is much more ergonomic.

4. Parsing JSON Requests

For parsing JSON requests, using json.Decoder is generally better.
json.Unmarshal requires more (verbose) code and is less efficient.

5. Database Setup and Configuration

Default Postgres config may not be optimal. Can optimize:

How the sql.DB connection pool works

  • Has two types of connections:
    • In-use: when using it to perform DB tasks (query, executing SQL, etc.)
    • Idle: after a task completes, the connection becomes idle
  • When you have Go do a database task, it checks for available idle connections. If there is one, it reuses that, marking it as in-use for the duration of the task.
    • If no idle connections, it creates one.
  • When Go refuses an idle connection, problems with the connection are handled gracefully.
  • Automatically retry bad connections twice. If still unsuccessful, remove it and create a new connection.

There are 4 methods to configure the database pool:

  • SetMaxOpenConns() - set max on open (in-use + idle) connections. Default is unlimited. But by default PostgreSQL has a (configured) hard limit of 100 open connections. So leaving it unlimited isn’t necessarily the best option (+ for other reasons I imagine).
    • A consequence of lowering max open connections is that tasks have to wait if there are no available connections. This could leave them hanging forever. To mitigate that, we use timeouts on database tasks with context.Context.
  • SetMaxIdleConns() - by default is 2. The trade-off here is time vs. memory. Increasing means likely better performance, as it’s less likely a new connection has to be established from scratch. But idle connections take up memory. And if idle for too long, they can become unusable (e.g. auto-closed by DBMS if unused for X hours).
    • Guideline: only keep a connection idle if you’re likely to use it again soon.
  • SetConnMaxLifetime()
  • SetConnMaxIdleTime()

Tips:

  • Explicitly set a MaxOpenConns. Should be comfortably below the hard limit of connections imposed by your database & infrastructure. You can also consider setting it low, so it acts as a throttle.
    • 25 is reasonable for small-to-medium web applications and APIs.
  • Generally, higher MaxOpenConns and MaxIdleConns lead to better performance. Results are diminishing, though.
  • Generally, you should set a ConnMaxIdleTime to remove idle connections that haven’t been used for a long time. E.g. 15 minutes.
  • It can be OK to leave ConnMaxLifetime as unlimited, unless there’s a hard limit on connection lifetime, or you need to e.g. gracefully swap databases.

6. SQL Migrations

How SQL migrations work at a high level:

  1. For every change you want to make to your database schema, you create a pair of migration files. One is the ‘up’ migration, containing the SQL statements to implement the change, and the other is the ‘down’ migration, which contains the SQL statements to reverse the change.
  2. Each pair is numbered sequentially (0001, 0002, …) or with a Unix timestamp to indicate order.
  3. You’d use some kind of script or tool to execute/rollback the SQL statements. This tool also keeps track of the migrations that have been applied, so only necessary SQL statements are executed.

This is obviously better than manually running the migrations by writing SQL statements yourself. The database schema is described by the migration files, which can be put under version control & lives with your other code. It’s much easier to get set up on another machine by simply running the ‘up’ migrations. And it’s also easy to roll-back.

We’re using golang-migrate/migrate here. I’ve heard good things about goose as well.

$ migrate create -seq -ext=.sql -dir=./migrations create_movies_table
/home/christian/projects/golang/lets-go-further-book/migrations/000001_create_movies_table.up.sql
/home/christian/projects/golang/lets-go-further-book/migrations/000001_create_movies_table.down.sq

-seq for sequential numbering (instead of Unix timestamp), -ext to specify extension, -dir to indicate where we want to store migration files, and create_movies_table is the migration label.

Let’s create the ‘up’ migration:

CREATE TABLE IF NOT EXISTS movies (
    id bigserial PRIMARY KEY,
    created_at timestamp(0) with time zone NOT NULL DEFAULT NOW()
    title text NOT NULL,
    year integer NOT NULL,
    runtime integer NOT NULL,
    genres text[] NOT NULL
    version integer NOT NULL DEFAULT 1
);

bigserial is a 64-bit auto-incrementing integer, starting at 1.
Genres have the type text[], which is an array of 0 or more text values. Arrays in PostgreSQL are queryable and indexable.
Working with NULL values in Go is awkward, so we just set NOT NULL constraints and default values.
We store strings with text instead of varchar or varchar(n) types. Text seems generally better.

The ‘down’ migration is quite simple:

DROP TABLE IF EXISTS movies;

And let’s add another migration that sets some constraints to enforce business logic:

$ migrate create -seq -ext=.sql -dir=./migrations add_movies_check_constraints
-- === UP ===
ALTER TABLE movies ADD CONSTRAINT movies_runtime_check CHECK (runtime >= 0);

ALTER TABLE movies ADD CONSTRAINT movies_year_check CHECK (year BETWEEN 1888 AND date_part('year', now()));

ALTER TABLE movies ADD CONSTRAINT genres_length_check CHECK (array_length(genres, 1) BETWEEN 1 AND 5);

-- === DOWN ===
ALTER TABLE movies DROP CONSTRAINT IF EXISTS movies_runtime_check;

ALTER TABLE movies DROP CONSTRAINT IF EXISTS movies_year_check;

ALTER TABLE movies DROP CONSTRAINT IF EXISTS genres_length_check;

Now we can’t insert or update data in the movies table that fails these checks.

Run the migrations with:

$ migrate -path=./migrations -database="DSN_GOES_HERE" up

If you are using PostgreSQL v15, you may see error: pq: permission denied for schema public... when running the above command. This is because v15 revokes CREATE from all users except a database owner. You can fix it with:

-- Set owner:
ALTER DATABASE greenlight OWNER TO greenlight;

-- And if that didn't work:
GRANT CREATE ON DATABASE greenlight TO greenlight;

We now have two tables in the database. movies and schema_migrations.
schema_migrations was made by migrate to keep track of which migrations it has applied.

Here are a few useful migrate commands:

# Show current migration version
$ migrate -path=./migrations -database=$EXAMPLE_DSN version
2

# Migrate up or down to specific verison
$ migrate -path=./migrations -database=$EXAMPLE_DSN goto 1

# Execute down migrations BY A SPECIFIC NUMBER OF MIGRATIONS
$ migrate -path=./migrations -database =$EXAMPLE_DSN down 1

Fixing errors in SQL migrations
When you run a migration that contains an error, all SQL statements up to the one with the error will be applied, and then migrate exits with the error.
This can mean your migration file was partially applied, leaving your database in an ‘unknown state’ (to migrate, at least).
If you try to run another migration (even down), you’ll get an error saying you have a dirty database version.
You’ll have to investigate the original error and figure out if the migration file was partially applied. If so, you need to manually roll-back the partially applied migration.
Then you need to force the version number in schema_migrations to the correct value:

# Force version to 1
$ migrate -path=./migrations -database=$EXAMPLE_DSN force 1

Now it’ll be clean, and you can run the migrations again.

You could have migrate run your migrations when your application starts, but it can be problematic to tightly couple the execution of migrations with your application source code in the long term.

7. CRUD Operations

Just basic CRUD stuff.

Step 1: define data layer.
Step 2: consume data layer in routes.
Step 3: add routes with proper REST semantics.

Data Layer

package data

import (
	"database/sql"
	"errors"
	"time"

	"github.com/lib/pq"
	"greenlight.bagerbach.com/internal/validator"
)

type Movie struct {
	ID        int64     `json:"id"`                // Unique identifier for the movie
	CreatedAt time.Time `json:"-"`                 // Time when the movie was added to our db
	Title     string    `json:"title"`             // The title of the movie
	Year      int32     `json:"year,omitempty"`    // The release year of the movie
	Runtime   Runtime   `json:"runtime,omitempty"` // The runtime of the movie in minutes
	Genres    []string  `json:"genres,omitempty"`  // The genres of the movie
	Version   int32     `json:"version"`           // The version of the movie: starts at 1 and increments each time the movie is updated
}

func ValidateMovie(v *validator.Validator, movie *Movie) {
	v.Check(movie.Title != "", "title", "must be provided")
	v.Check(len(movie.Title) <= 500, "title", "must not be more than 500 bytes long")

	v.Check(movie.Year != 0, "year", "must be provided")
	v.Check(movie.Year >= 1888, "year", "must be greater than 1888")
	v.Check(movie.Year <= int32(time.Now().Year()), "year", "must not be in the future")

	v.Check(movie.Runtime != 0, "runtime", "must be provided")
	v.Check(movie.Runtime > 0, "runtime", "must be a positive integer")

	v.Check(movie.Genres != nil, "genres", "must be provided")
	v.Check(len(movie.Genres) >= 1, "genres", "must contain at least 1 genre")
	v.Check(len(movie.Genres) <= 5, "genres", "must not contain more than 5 genres")
	v.Check(validator.Unique(movie.Genres), "genres", "must not contain duplicate values")
}

type MovieModel struct {
	DB *sql.DB
}

func (m MovieModel) Insert(movie *Movie) error {
	query := `
		INSERT INTO movies (title, year, runtime, genres)
		VALUES ($1, $2, $3, $4)
		RETURNING id, created_at, version`

	args := []interface{}{movie.Title, movie.Year, movie.Runtime, pq.Array(movie.Genres)}
	// Need to use QueryRow because of the RETURNING clause (which returns the id, created_at and version)
	// RETURNING is a Postgres feature that is not part of SQL standard
	return m.DB.QueryRow(query, args...).Scan(&movie.ID, &movie.CreatedAt, &movie.Version)
}

func (m MovieModel) Get(id int64) (*Movie, error) {
	if id < 1 {
		return nil, errors.New("invalid id")
	}

	query := `
		SELECT id, created_at, title, year, runtime, genres, version
		FROM movies
		WHERE id = $1`

	var movie Movie

	err := m.DB.QueryRow(query, id).Scan(
		&movie.ID,
		&movie.CreatedAt,
		&movie.Title,
		&movie.Year,
		&movie.Runtime,
		pq.Array(&movie.Genres),
		&movie.Version,
	)

	if err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return nil, ErrRecordNotFound
		} else {
			return nil, err
		}
	}

	return &movie, nil
}

func (m MovieModel) Update(movie *Movie) error {
	query := `
		UPDATE movies
		SET title = $1, year = $2, runtime = $3, genres = $4, version = version + 1
		WHERE id = $5
		RETURNING version`

	args := []interface{}{movie.Title, movie.Year, movie.Runtime, pq.Array(movie.Genres), movie.ID}

	return m.DB.QueryRow(query, args...).Scan(&movie.Version)
}

func (m MovieModel) Delete(id int64) error {
	if id < 1 {
		return ErrRecordNotFound
	}

	query := `
		DELETE FROM movies
		WHERE id = $1`

	result, err := m.DB.Exec(query, id)
	if err != nil {
		return err
	}

	rowsAffected, err := result.RowsAffected()
	if err != nil {
		return err
	}

	if rowsAffected == 0 {
		return ErrRecordNotFound
	}

	return nil
}

Route Handlers

package main

import (
	"errors"
	"fmt"
	"net/http"

	"greenlight.bagerbach.com/internal/data"
	"greenlight.bagerbach.com/internal/validator"
)

func (app *application) createMovieHandler(w http.ResponseWriter, r *http.Request) {
	var input struct {
		Title   string       `json:"title"`
		Year    int32        `json:"year"`
		Runtime data.Runtime `json:"runtime"`
		Genres  []string     `json:"genres"`
	}

	if err := app.readJSON(w, r, &input); err != nil {
		app.badRequestResponse(w, r, err)
		return
	}

	movie := &data.Movie{
		Title:   input.Title,
		Year:    input.Year,
		Runtime: input.Runtime,
		Genres:  input.Genres,
	}

	v := validator.New()

	if data.ValidateMovie(v, movie); !v.Valid() {
		app.failedValidationResponse(w, r, v.Errors)
		return
	}

	if err := app.models.Movies.Insert(movie); err != nil {
		app.serverErrorResponse(w, r, err)
		return
	}

	headers := make(http.Header)
	headers.Set("Location", fmt.Sprintf("/v1/movies/%d", movie.ID))

	if err := app.writeJSON(w, http.StatusCreated, envelope{"movie": movie}, headers); err != nil {
		app.serverErrorResponse(w, r, err)
	}
}

func (app *application) showMovieHandler(w http.ResponseWriter, r *http.Request) {
	id, err := app.readIDParam(r)
	if err != nil {
		app.notFoundResponse(w, r)
		return
	}

	movie, err := app.models.Movies.Get(id)
	if err != nil {
		switch {
		case errors.Is(err, data.ErrRecordNotFound):
			app.notFoundResponse(w, r)
		default:
			app.serverErrorResponse(w, r, err)
		}
		return
	}

	if err := app.writeJSON(w, http.StatusOK, envelope{"movie": movie}, nil); err != nil {
		app.serverErrorResponse(w, r, err)
	}
}

func (app *application) updateMovieHandler(w http.ResponseWriter, r *http.Request) {
	id, err := app.readIDParam(r)
	if err != nil {
		app.notFoundResponse(w, r)
		return
	}

	movie, err := app.models.Movies.Get(id)
	if err != nil {
		switch {
		case errors.Is(err, data.ErrRecordNotFound):
			app.notFoundResponse(w, r)
		default:
			app.serverErrorResponse(w, r, err)
		}
		return
	}

	var input struct {
		Title   string       `json:"title"`
		Year    int32        `json:"year"`
		Runtime data.Runtime `json:"runtime"`
		Genres  []string     `json:"genres"`
	}

	if err := app.readJSON(w, r, &input); err != nil {
		app.badRequestResponse(w, r, err)
		return
	}

	movie.Title = input.Title
	movie.Year = input.Year
	movie.Runtime = input.Runtime
	movie.Genres = input.Genres

	v := validator.New()

	if data.ValidateMovie(v, movie); !v.Valid() {
		app.failedValidationResponse(w, r, v.Errors)
		return
	}

	err = app.models.Movies.Update(movie)
	if err != nil {
		app.serverErrorResponse(w, r, err)
		return
	}

	if err := app.writeJSON(w, http.StatusOK, envelope{"movie": movie}, nil); err != nil {
		app.serverErrorResponse(w, r, err)
	}
}

func (app *application) deleteMovieHandler(w http.ResponseWriter, r *http.Request) {
	id, err := app.readIDParam(r)
	if err != nil {
		app.notFoundResponse(w, r)
		return
	}

	err = app.models.Movies.Delete(id)
	if err != nil {
		switch {
		case errors.Is(err, data.ErrRecordNotFound):
			app.notFoundResponse(w, r)
		default:
			app.serverErrorResponse(w, r, err)
		}
		return
	}

	if err := app.writeJSON(w, http.StatusOK, envelope{"message": "movie successfully deleted"}, nil); err != nil {
		app.serverErrorResponse(w, r, err)
	}
}

Routes

package main

import (
	"net/http"

	"github.com/julienschmidt/httprouter"
)

func (app *application) routes() http.Handler {
	router := httprouter.New()

	router.NotFound = http.HandlerFunc(app.notFoundResponse)
	router.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse)

	router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler)
	router.HandlerFunc(http.MethodPost, "/v1/movies", app.createMovieHandler)
	router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.showMovieHandler)
	router.HandlerFunc(http.MethodPut, "/v1/movies/:id", app.updateMovieHandler)
	router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler)

	return app.recoverPanic(router)
}

8. Advanced CRUD Operations

Some advanced patterns:

  • Partial updates to a resource, so clients only need to send the data they want to change.
  • Optimistic concurrency control to avoid race conditions when two clients try to update the same resource at the same time.
  • Context timeouts to terminate long-running database queries and prevent unnecessary resource use.

Partial updates

We couldn’t support this in the previous version because we couldn’t detect which fields were supplied and which weren’t. This is because of the zero-values of the various types we were using. If they weren’t supplied, strings would be "", integers would be 0, and so on. How would we know that it wasn’t the user who wrote that?
So now we use pointers, whose zero-values are nil.

func (app *application) updateMovieHandler(w http.ResponseWriter, r *http.Request) {
	id, err := app.readIDParam(r)
	if err != nil {
		app.notFoundResponse(w, r)
		return
	}

	movie, err := app.models.Movies.Get(id)
	if err != nil {
		switch {
		case errors.Is(err, data.ErrRecordNotFound):
			app.notFoundResponse(w, r)
		default:
			app.serverErrorResponse(w, r, err)
		}
		return
	}

	// Pointers' zero-value is nil, so turning these into pointers lets us do partial updates
	// (whereas e.g. the string zero-value is "" - you wouldn't know if it was or wasn't supplied!)
	var input struct {
		Title   *string       `json:"title"`
		Year    *int32        `json:"year"`
		Runtime *data.Runtime `json:"runtime"`
		Genres  []string      `json:"genres"`
	}

	if err := app.readJSON(w, r, &input); err != nil {
		app.badRequestResponse(w, r, err)
		return
	}

	if input.Title != nil {
		movie.Title = *input.Title
	}
	if input.Year != nil {
		movie.Year = *input.Year
	}
	if input.Runtime != nil {
		movie.Runtime = *input.Runtime
	}
	if input.Genres != nil {
		movie.Genres = input.Genres
	}

	v := validator.New()

	if data.ValidateMovie(v, movie); !v.Valid() {
		app.failedValidationResponse(w, r, v.Errors)
		return
	}

	err = app.models.Movies.Update(movie)
	if err != nil {
		app.serverErrorResponse(w, r, err)
		return
	}

	if err := app.writeJSON(w, http.StatusOK, envelope{"movie": movie}, nil); err != nil {
		app.serverErrorResponse(w, r, err)
	}
}

And because we now support partial updates, we update the route method to PATCH from PUT:

func (app *application) routes() http.Handler {
	router := httprouter.New()

	router.NotFound = http.HandlerFunc(app.notFoundResponse)
	router.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse)

	router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler)
	router.HandlerFunc(http.MethodPost, "/v1/movies", app.createMovieHandler)
	router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.showMovieHandler)
-   router.HandlerFunc(http.MethodPut, "/v1/movies/:id", app.updateMovieHandler)
+   router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.updateMovieHandler)
	router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler)

	return app.recoverPanic(router)
}

Optimistic Concurrency Control

Optimistic locking is when you read a record, note the version number, and check it hasn’t changed before writing the record back. You can also use dates, timestamps, or checksums/hashes. The update should filter on the version to ensure it is atomic and updates the version in one go.

Pessimistic locking requires you to lock the record for exclusive use until you’re finished with it. This has better integrity than optimistic locking, but you need to be careful to avoid deadlocks. (src)

We implemented optimistic locking to prevent race conditions on updates by following these steps:

  1. Edit Conflict Response: We added a new function, editConflictResponse, to handle edit conflicts by sending a 409 Conflict status code to the client.
  2. Expected Version Header: In the updateMovieHandler function, we checked if the client provided an “X-Expected-Version” header. If the provided version did not match the current version of the record, we returned an edit conflict response.
  3. Version Check in Update: We modified the SQL update query in the Update method of the MovieModel to include a check on the version number. The query now only updates the record if the version matches the expected version.
  4. Error Handling: If the update query did not affect any rows, indicating a version mismatch, we returned an ErrEditConflict error.
  5. Conflict Handling: We adjusted the error handling in the updateMovieHandler to send an edit conflict response if an ErrEditConflict error occurred.

This approach ensures that if multiple users attempt to update the same record simultaneously, only the user with the correct version can proceed, thereby preventing race conditions.

diff --git a/lets-go-further-book/cmd/api/errors.go b/lets-go-further-book/cmd/api/errors.go
index bc5c1ae..78b7b7a 100644
--- a/lets-go-further-book/cmd/api/errors.go
+++ b/lets-go-further-book/cmd/api/errors.go
@@ -53,3 +53,7 @@ func (app *application) failedValidationResponse(w http.ResponseWriter, r *http.
 	app.errorResponse(w, r, http.StatusUnprocessableEntity, errors)
 }
 
+func (app *application) editConflictResponse(w http.ResponseWriter, r *http.Request) {
+	message := "unable to update the record due to an edit conflict, please try again"
+	app.errorResponse(w, r, http.StatusConflict, message)
+}
diff --git a/lets-go-further-book/cmd/api/movies.go b/lets-go-further-book/cmd/api/movies.go
index ae6abf5..ff8f302 100644
--- a/lets-go-further-book/cmd/api/movies.go
+++ b/lets-go-further-book/cmd/api/movies.go
@@ -4,6 +4,7 @@ import (
 	"errors"
 	"fmt"
 	"net/http"
+	"strconv"
 
 	"greenlight.bagerbach.com/internal/data"
 	"greenlight.bagerbach.com/internal/validator"
@@ -90,6 +91,16 @@ func (app *application) updateMovieHandler(w http.ResponseWriter, r *http.Reques
 		return
 	}
 
+	// If the client provided an "X-Expected-Version" header, check that the version
+	// matches the version of the record being updated. If not, return a 409 Conflict
+	// status code.
+	if r.Header.Get("X-Expected-Version") != "" {
+		if strconv.Itoa(int(movie.Version)) != r.Header.Get("X-Expected-Version") {
+			app.editConflictResponse(w, r)
+			return
+		}
+	}
+
 	// Pointers' zero-value is nil, so turning these into pointers lets us do partial updates
 	// (whereas e.g. the string zero-value is "" - you wouldn't know if it was or wasn't supplied!)
 	var input struct {
@@ -126,7 +137,12 @@ func (app *application) updateMovieHandler(w http.ResponseWriter, r *http.Reques
 
 	err = app.models.Movies.Update(movie)
 	if err != nil {
-		app.serverErrorResponse(w, r, err)
+		switch {
+		case errors.Is(err, data.ErrEditConflict):
+			app.editConflictResponse(w, r)
+		default:
+			app.serverErrorResponse(w, r, err)
+		}
 		return
 	}
 
diff --git a/lets-go-further-book/internal/data/models.go b/lets-go-further-book/internal/data/models.go
index c93d17b..8fe842d 100644
--- a/lets-go-further-book/internal/data/models.go
+++ b/lets-go-further-book/internal/data/models.go
@@ -7,6 +7,7 @@ import (
 
 var (
 	ErrRecordNotFound = errors.New("record not found")
+	ErrEditConflict   = errors.New("edit conflict")
 )
 
 type Models struct {
diff --git a/lets-go-further-book/internal/data/movies.go b/lets-go-further-book/internal/data/movies.go
index 1822ccc..50e55f0 100644
--- a/lets-go-further-book/internal/data/movies.go
+++ b/lets-go-further-book/internal/data/movies.go
@@ -89,12 +89,24 @@ func (m MovieModel) Update(movie *Movie) error {
 	query := `
 		UPDATE movies
 		SET title = $1, year = $2, runtime = $3, genres = $4, version = version + 1
-		WHERE id = $5
+		WHERE id = $5 AND version = $6
 		RETURNING version`
 
-	args := []interface{}{movie.Title, movie.Year, movie.Runtime, pq.Array(movie.Genres), movie.ID}
+	args := []interface{}{movie.Title, movie.Year, movie.Runtime, pq.Array(movie.Genres), movie.ID, movie.Version}
 
-	return m.DB.QueryRow(query, args...).Scan(&movie.Version)
+	err := m.DB.QueryRow(query, args...).Scan(&movie.Version)
+	if err != nil {
+		switch {
+		// If no rows were found, we know that the version has changed since we last read it
+		// If that's the case, we return an ErrEditConflict error
+		case errors.Is(err, sql.ErrNoRows):
+			return ErrEditConflict
+		default:
+			return err
+		}
+	}
+
+	return nil
 }
 
 func (m MovieModel) Delete(id int64) error {

If you want to avoid the version identifier being guessable, you could use a high-entropy random string (like a UUID) in the version field.

Timeouts

Handling potential delays or hangs is important when working with database operations.
One way to do this is by using contexts for timeouts.
This ensures the app doesn’t wait indefinitely for a response from the database.

Using Context for Timeouts in Go

When working with database operations in Go, it’s crucial to handle potential delays or hangs efficiently. One way to do this is by using context for timeouts. This ensures that your application doesn’t wait indefinitely for a response from the database. Below, we illustrate how to implement context with timeouts using code snippets.

In the following example, we modify our database operations to include context with a timeout. This will help in controlling the maximum time allowed for each database operation to complete.

Here’s how we modify the Insert method to use a context with a timeout:

package data

import (
	"context"
	"database/sql"
	"errors"
	"time"
)

func (m MovieModel) Insert(movie *Movie) error {
	query := `
		INSERT INTO movies (title, year, runtime, genres)
		VALUES ($1, $2, $3, $4)
		RETURNING id, created_at, version`

	args := []interface{}{movie.Title, movie.Year, movie.Runtime, pq.Array(movie.Genres)}

	// Create a context with a 3-second timeout
	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	// Use QueryRowContext with the context
	return m.DB.QueryRowContext(ctx, query, args...).Scan(&movie.ID, &movie.CreatedAt, &movie.Version)
}

9. Filtering, Sorting, and Pagination

You can add full text search (support partial matches) by using a SQL query like this:

func (m MovieModel) GetAll(title string, genres []string, filters Filters) ([]*Movie, error) {
	query := `
		SELECT id, created_at, title, year, runtime, genres, version
		FROM movies
		WHERE (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '')
		AND (genres && $2 OR $2 = '{}')
		ORDER BY id
	`

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	rows, err := m.DB.QueryContext(ctx, query, title, pq.Array(genres))
	if err != nil {
		return nil, err
	}
	// ...
}

Here, we use PostgreSQL’s full-text search feature to support partial matches on the title.

Here’s what happens:

  • to_tsvector('simple', title) takes a movie title and splits it into lexemes. We use the simple configuration, so lexemes are just lowercase words in the title. Other configurations may apply other rules to the lexemes, like removing common words for a given language or applying language-specific stemming.
  • plainto_tsquery('simple', $1) takes a search value, turns it into a formatted query term, normalizes the search value (with simple config), strips special characters, and inserts the & operator between words. For example, Iron Man becomes 'iron' & 'man'.
  • The @@ operator is the matches operator, which we use to check whether the generated query term matches the lexemes. 'iron' & 'man' would match rows containing both iron and man.

To keep SQL queries performant as the dataset grows, we add indexes to avoid full table scans and to avoid generating lexemes for the title field every time the query is run.

$ migrate create -seq -ext .sql -dir ./migrations add_movies_indexes

Up migration:

CREATE INDEX IF NOT EXISTS movies_title_idx ON movies using GIN (to_tsvector('simple', title));
CREATE INDEX IF NOT EXISTS movies_genres_idx ON movies using GIN (genres);

Down migration:

DROP INDEX IF EXISTS movies_title_idx;
DROP INDEX IF EXISTS movies_genres_idx;

We use Generalized Inverted Index (GIN) indexes.

GIN indexes are specialized data structures designed for efficient handling of complex data types, such as arrays, full-text search, and JSONB. They’re optimized for scenarios where each indexed record contains multiple keys and are particularly adept at search operations involving many-to-many relationships. GIN indexes use an inverted index structure where each distinct key points to a set of rows containing that key, making query lookups with operations like @>, <@, and % highly efficient. This index type supports fast insertion and retrieval but can be slower on updates and deletions due to the complexity of maintaining the multi-key mappings.

If you’d prefer not to use full-text search, there’s also STRPOS() and ILIKE.

  • STRPOS(): Checks for the existence of a substring in a database field. Not ideal, as it could lead to a full-table scan each time it’s run.
  • ILIKE: Finds rows matching a specific, case-insensitive pattern. You can create an index on your target field using the pg_trgm extension and a GIN index, so it’s better for performance than STRPOS(). It also lets the user control matching behavior by prefixing/suffixing with a % wildcard character.

Sorting

If you don’t specify an order with ORDER BY, PostgreSQL may return the movies in any order, which may or may not change each time the query is run. And if you were to order by, e.g., year, but you have multiple items with the same year, you will get an ordering by year, except the items with the same year may be ordered differently on each query.

For endpoints that provide pagination, you need to ensure the order is perfectly consistent between requests to prevent items from jumping between pages.
To do so is simple, just ensure the ORDER BY clause always includes a primary key column, or one with a unique constraint on it.

Here’s how you can implement sorting for your endpoints.

We’re using a filters.go that looks like this:

package data

import (
	"strings"

	"greenlight.bagerbach.com/internal/validator"
)

type Filters struct {
	Page         int
	PageSize     int
	Sort         string
	SortSafelist []string
}

func (f Filters) sortColumn() string {
	for _, safeValue := range f.SortSafelist {
		if f.Sort == safeValue {
			return strings.TrimPrefix(f.Sort, "-")
		}
	}

	// If the sort parameter is not in the safelist, we panic
	// This shouldn't happen, as you'd have checked it in the validator
	// But is a fine failsafe against SQL injection attacks.
	panic("unsafe sort parameter: " + f.Sort)
}

func (f Filters) sortDirection() string {
	if strings.HasPrefix(f.Sort, "-") {
		return "DESC"
	}

	return "ASC"
}

func ValidateFilters(v *validator.Validator, f Filters) {
	v.Check(f.Page > 0, "page", "must be greater than zero")
	v.Check(f.Page <= 10_000_000, "page", "must be less than 10 million")
	v.Check(f.PageSize > 0, "page_size", "must be greater than zero")
	v.Check(f.PageSize <= 100, "page_size", "must be less than 100")

	v.Check(validator.PermittedValue(f.Sort, f.SortSafelist...), "sort", "invalid sort value")
}

The handler looks like this, defining a safelist:

func (app *application) listMoviesHandler(w http.ResponseWriter, r *http.Request) {
	var input struct {
		Title  string
		Genres []string
		data.Filters
	}

	v := validator.New()
	qs := r.URL.Query()

	input.Title = app.readString(qs, "title", "")
	input.Genres = app.readCSV(qs, "genres", []string{})

	input.Filters.Page = app.readInt(qs, "page", 1, v)
	input.Filters.PageSize = app.readInt(qs, "page_size", 20, v)

	input.Filters.Sort = app.readString(qs, "sort", "id")
	input.Filters.SortSafelist = []string{
		"id", "title", "year", "runtime", "-id", "-title", "-year", "-runtime",
	}

	if data.ValidateFilters(v, input.Filters); !v.Valid() {
		app.failedValidationResponse(w, r, v.Errors)
		return
	}

	movies, err := app.models.Movies.GetAll(input.Title, input.Genres, input.Filters)
	if err != nil {
		app.serverErrorResponse(w, r, err)
		return
	}

	if err := app.writeJSON(w, http.StatusOK, envelope{"movies": movies}, nil); err != nil {
		app.serverErrorResponse(w, r, err)
	}
}

And we have the GetAll function to get all movies, sorted in the given order.

func (m MovieModel) GetAll(title string, genres []string, filters Filters) ([]*Movie, error) {
	// Add secondary sorting by id to ensure consistent sorting
	query := fmt.Sprintf(`
		SELECT id, created_at, title, year, runtime, genres, version
		FROM movies
		WHERE (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '')
		AND (genres && $2 OR $2 = '{}')
		ORDER BY %s %s, id ASC`, filters.sortColumn(), filters.sortDirection())

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	rows, err := m.DB.QueryContext(ctx, query, title, pq.Array(genres))
	if err != nil {
		return nil, err
	}

	defer rows.Close()

	movies := []*Movie{}

	for rows.Next() {
		var movie Movie
		err := rows.Scan(
			&movie.ID,
			&movie.CreatedAt,
			&movie.Title,
			&movie.Year,
			&movie.Runtime,
			pq.Array(&movie.Genres),
			&movie.Version,
		)

		if err != nil {
			return nil, err
		}

		movies = append(movies, &movie)
	}

	if err := rows.Err(); err != nil {
		return nil, err
	}

	return movies, nil
}

Pagination

We’ll be using LIMIT and OFFSET to implement pagination:

LIMIT = page_size
OFFSET = (page - 1) * page_size

We’ll need some helper methods for this:

type Filters struct {
	Page         int
	PageSize     int
	Sort         string
	SortSafelist []string
}

func (f Filters) limit() int {
	return f.PageSize
}

// Risks integer overflow, but that should be prevented by input validation
func (f Filters) offset() int {
	return (f.Page - 1) * f.PageSize
}

And now we implement these in our GetAll function. Notice the LIMIT and OFFSET:

func (m MovieModel) GetAll(title string, genres []string, filters Filters) ([]*Movie, error) {
	query := fmt.Sprintf(`
		SELECT id, created_at, title, year, runtime, genres, version
		FROM movies
		WHERE (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '')
		AND (genres && $2 OR $2 = '{}')
		ORDER BY %s %s, id ASC
		LIMIT $3 OFFSET $4`, filters.sortColumn(), filters.sortDirection())

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	args := []interface{}{title, pq.Array(genres), filters.limit(), filters.offset()}

	rows, err := m.DB.QueryContext(ctx, query, args...)
	if err != nil {
		return nil, err
	}

	defer rows.Close()

	movies := []*Movie{}

	for rows.Next() {
		var movie Movie
		err := rows.Scan(
			&movie.ID,
			&movie.CreatedAt,
			&movie.Title,
			&movie.Year,
			&movie.Runtime,
			pq.Array(&movie.Genres),
			&movie.Version,
		)

		if err != nil {
			return nil, err
		}

		movies = append(movies, &movie)
	}

	if err := rows.Err(); err != nil {
		return nil, err
	}

	return movies, nil
}

Adding pagination metadata

We’ll add some pagination metadata, so our responses will look like this:

{
    "metadata": {
        "current_page": 1,
        "page_size": 20,
        "first_page": 1,
        "last_page": 42,
        "total_records": 832
    },
    "movies": [
        ...
    ]
}

The total_records should match the total number of records given the title and genres filters that are applied, not just the absolute total records in the table.

We can do this by adding a window function to the SQL query, which counts the total number of filtered rows.

Let’s start.

In internal/data/filters.go, we add:

type Metadata struct {
	CurrentPage  int `json:"current_page,omitempty"`
	PageSize     int `json:"page_size,omitempty"`
	FirstPage    int `json:"first_page,omitempty"`
	LastPage     int `json:"last_page,omitempty"`
	TotalRecords int `json:"total_records,omitempty"`
}

func calculateMetadata(totalRecords, page, pageSize int) Metadata {
	if totalRecords == 0 {
		return Metadata{}
	}

	return Metadata{
		CurrentPage:  page,
		PageSize:     pageSize,
		FirstPage:    1,
		LastPage:     (totalRecords + pageSize - 1) / pageSize,
		TotalRecords: totalRecords,
	}
}

We need to ensure we return the metadata in internal/data/movies.go:

func (m MovieModel) GetAll(title string, genres []string, filters Filters) ([]*Movie, Metadata, error) {
	// Add `count(*) OVER()`
	query := fmt.Sprintf(`
		SELECT count(*) OVER(), id, created_at, title, year, runtime, genres, version
		FROM movies
		WHERE (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '')
		AND (genres && $2 OR $2 = '{}')
		ORDER BY %s %s, id ASC
		LIMIT $3 OFFSET $4`, filters.sortColumn(), filters.sortDirection())

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	args := []interface{}{title, pq.Array(genres), filters.limit(), filters.offset()}

	rows, err := m.DB.QueryContext(ctx, query, args...)
	if err != nil {
		return nil, Metadata{}, err // return metadata as well
	}

	defer rows.Close()

	totalRecords := 0 // define this
	movies := []*Movie{}

	for rows.Next() {
		var movie Movie
		err := rows.Scan(
			&totalRecords, // scan count
			&movie.ID,
			&movie.CreatedAt,
			&movie.Title,
			&movie.Year,
			&movie.Runtime,
			pq.Array(&movie.Genres),
			&movie.Version,
		)

		if err != nil {
			return nil, Metadata{}, err // return metadata as well
		}

		movies = append(movies, &movie)
	}

	if err := rows.Err(); err != nil {
		return nil, Metadata{}, err // return metadata as well
	}

	// Calculate and return metadata with the records
	metadata := calculateMetadata(totalRecords, filters.Page, filters.PageSize)

	return movies, metadata, nil
}

And we need to consume the return metadata in the API endpoint, in cmd/api/movies.go:

func (app *application) listMoviesHandler(w http.ResponseWriter, r *http.Request) {
	var input struct {
		Title  string
		Genres []string
		data.Filters
	}

	v := validator.New()
	qs := r.URL.Query()

	input.Title = app.readString(qs, "title", "")
	input.Genres = app.readCSV(qs, "genres", []string{})

	input.Filters.Page = app.readInt(qs, "page", 1, v)
	input.Filters.PageSize = app.readInt(qs, "page_size", 20, v)

	input.Filters.Sort = app.readString(qs, "sort", "id")
	input.Filters.SortSafelist = []string{
		"id", "title", "year", "runtime", "-id", "-title", "-year", "-runtime",
	}

	if data.ValidateFilters(v, input.Filters); !v.Valid() {
		app.failedValidationResponse(w, r, v.Errors)
		return
	}

	// Get movies, metadata, err
	movies, metadata, err := app.models.Movies.GetAll(input.Title, input.Genres, input.Filters)
	if err != nil {
		app.serverErrorResponse(w, r, err)
		return
	}

	// Send both movies and metadata to client
	if err := app.writeJSON(w, http.StatusOK, envelope{"movies": movies, "metadata": metadata}, nil); err != nil {
		app.serverErrorResponse(w, r, err)
	}
}

10. Rate Limiting

We’ll use the x/time/rate package: go get golang.org/x/time/rate@latest.

We’ll be using a token-bucket rate limiter.
This kind of limiter controls how frequently events are allowed to occur by implementing a token bucket of size b, which starts out full and is refilled at a rate of r tokens per second.
In our case, we’d start with a bucket with b tokens. Each time we get an HTTP request, we remove one token from the bucket.
Every 1/r seconds, a token is added back, up to a maximum of b tokens.
If we get too many HTTP requests and the bucket is empty, we return an HTTP 429 Too Many Requests response.

So it allows for a maximum ‘burst’ of b requests, but over time it allows for an average of r requests per second.

Start by adding an error for when the rate limit is reached

// cmd/api/errors.go
func (app *application) rateLimitExceededResponse(w http.ResponseWriter, r *http.Request) {
	message := "rate limit exceeded"
	app.errorResponse(w, r, http.StatusTooManyRequests, message)
}

Add configuration for the limiter to your main entry

// cmd/api/main.go
package main

import (
	"context"
	"database/sql"
	"flag"
	"fmt"
	"log/slog"
	"net/http"
	"os"
	"time"

	"greenlight.bagerbach.com/internal/data"

	// Import the pq driver - it needs to register itself with the database/sql package
	"github.com/joho/godotenv"
	_ "github.com/lib/pq"
)

// Will be generated automatically later.
const version = "1.0.0"

type config struct {
	port int
	env  string
	db   struct {
		dsn          string
		maxOpenConns int
		maxIdleConns int
		maxIdleTime  time.Duration
	}
	limiter struct {
		rps     float64
		burst   int
		enabled bool
	}
}

type application struct {
	config config
	logger *slog.Logger
	models data.Models
}

func main() {
	err := godotenv.Load()
	if err != nil {
		slog.Error("error loading .env file", "error", err)
		os.Exit(1)
	}

	var cfg config

	flag.IntVar(&cfg.port, "port", 4000, "Server port to listen on")
	flag.StringVar(&cfg.env, "env", "development", "Application environment (development|staging|production)")
	flag.StringVar(&cfg.db.dsn, "db-dsn", os.Getenv("DATABASE_URL"), "PostgreSQL DSN")
	flag.IntVar(&cfg.db.maxOpenConns, "db-max-open-conns", 25, "Maximum number of open connections to the database")
	flag.IntVar(&cfg.db.maxIdleConns, "db-max-idle-conns", 25, "Maximum number of idle connections to the database")
	// DurationVar lets us pass in any value acceptable to time.ParseDuration(), e.g. 300ms, 5s, 2h45m.
	flag.DurationVar(&cfg.db.maxIdleTime, "db-max-idle-time", 15*time.Minute, "Maximum idle time for a connection to the database")
	flag.Float64Var(&cfg.limiter.rps, "limiter-rps", 2, "Rate limit to apply to requests per second")
	flag.IntVar(&cfg.limiter.burst, "limiter-burst", 4, "Burst limit to apply to requests")
	flag.BoolVar(&cfg.limiter.enabled, "limiter-enabled", true, "Enable rate limiting")

	flag.Parse()

	logger := slog.New(slog.NewTextHandler(os.Stdout, nil))

	db, err := openDB(cfg)
	if err != nil {
		logger.Error("error opening db", "error", err)
		os.Exit(1)
	}
	defer db.Close()

	logger.Info("database connection pool established")

	app := &application{
		config: cfg,
		logger: logger,
		models: data.NewModels(db),
	}

	srv := &http.Server{
		Addr:         fmt.Sprintf(":%d", cfg.port),
		Handler:      app.routes(),
		IdleTimeout:  time.Minute,
		ReadTimeout:  5 * time.Second,
		WriteTimeout: 10 * time.Second,
		ErrorLog:     slog.NewLogLogger(logger.Handler(), slog.LevelError),
	}

	logger.Info("starting server", "addr", srv.Addr, "env", cfg.env)
	if err := srv.ListenAndServe(); err != nil {
		logger.Error("error starting server", "error", err)
		os.Exit(1)
	}
}

func openDB(cfg config) (*sql.DB, error) {
	db, err := sql.Open("postgres", cfg.db.dsn)
	if err != nil {
		return nil, err
	}

	db.SetMaxOpenConns(cfg.db.maxOpenConns)
	db.SetMaxIdleConns(cfg.db.maxIdleConns)
	db.SetConnMaxIdleTime(cfg.db.maxIdleTime)

	// Create context with 5s timeout. If we can't connect in 5s, we cancel the context and return an error.
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	if err = db.PingContext(ctx); err != nil {
		db.Close()
		return nil, err
	}

	return db, nil
}

Implement the rate-limiting middleware
This solution only works when the server is running on a single machine.
If your infrastructure is distributed, and you’re running multiple servers behind a load balancer, you’ll need something else. If that’s the case, and you’re running HAProxy or Nginx as a load balancer or reverse proxy, you can use their built-in rate-limiting features.
Alternatively, you could use a fast database like Redis to maintain a request count for clients, running on a server with which all your application servers can communicate.

This uses a map that stores the IP addresses of clients containing a limiter and a lastSeen time for each client. It periodically cleans up this map.
We have to use a mutex here because Go maps are not concurrency safe.

// cmd/api/middleware.go
func (app *application) rateLimit(next http.Handler) http.Handler {
	type client struct {
		limiter  *rate.Limiter
		lastSeen time.Time
	}
	var (
		// Maps are not thread-safe. We need to lock the mutex before reading from the map.
		mu      sync.Mutex
		clients = make(map[string]*client)
	)

	// Background goroutine to clean up old clients.
	go func() {
		for {
			time.Sleep(time.Minute)

			mu.Lock()
			for ip, client := range clients {
				if time.Since(client.lastSeen) > 3*time.Minute {
					delete(clients, ip)
				}
			}
			mu.Unlock()
		}
	}()

	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if !app.config.limiter.enabled {
			next.ServeHTTP(w, r)
			return
		}

		ip, _, err := net.SplitHostPort(r.RemoteAddr)
		if err != nil {
			app.serverErrorResponse(w, r, err)
			return
		}

		mu.Lock()
		if _, found := clients[ip]; !found {
			clients[ip] = &client{
				limiter: rate.NewLimiter(rate.Limit(app.config.limiter.rps), app.config.limiter.burst),
			}
		}

		clients[ip].lastSeen = time.Now()

		// Check if the request is allowed by the rate limiter.
		// `limiter.Allow()` returns `true` if the request is allowed, and `false` if the request is not allowed.
		// It consumes a token from the bucket.
		if !clients[ip].limiter.Allow() {
			mu.Unlock()
			app.rateLimitExceededResponse(w, r)
			return
		}
		// We aren't using defer here because we want to unlock the mutex as soon as possible.
		// If we deferred the unlock, it only would be executed after all the downstream handlers have returned.
		mu.Unlock()

		next.ServeHTTP(w, r)
	})
}

Use the middleware in your routes

// cmd/api/routes.go
package main

import (
	"net/http"

	"github.com/julienschmidt/httprouter"
)

func (app *application) routes() http.Handler {
	router := httprouter.New()

	router.NotFound = http.HandlerFunc(app.notFoundResponse)
	router.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse)

	router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler)

	router.HandlerFunc(http.MethodGet, "/v1/movies", app.listMoviesHandler)
	router.HandlerFunc(http.MethodPost, "/v1/movies", app.createMovieHandler)
	router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.showMovieHandler)
	router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.updateMovieHandler)
	router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler)

	return app.recoverPanic(app.rateLimit(router))
}

11. Graceful Shutdown

Applications can be closed with certain (POSIX) signals. For example, using CTRL+C sends an interrupt signal called SIGINT.

Some of these signals are catchable, while others are not. Those that are catchable can be intercepted by our application and used to trigger certain actions, like graceful shutdowns. SIGKILL is not catchable.

Go provides the tools for intercepting signals in os/signal.

We can create a serve function to handle setting up the server and tearing it down:

// cmd/api/server.go
package main

import (
	"context"
	"errors"
	"fmt"
	"log/slog"
	"net/http"
	"os"
	"os/signal"
	"syscall"
	"time"
)

func (app *application) serve() error {
	srv := &http.Server{
		Addr:         fmt.Sprintf(":%d", app.config.port),
		Handler:      app.routes(),
		IdleTimeout:  time.Minute,
		ReadTimeout:  5 * time.Second,
		WriteTimeout: 10 * time.Second,
		ErrorLog:     slog.NewLogLogger(app.logger.Handler(), slog.LevelError),
	}

	// To receive any error returned by the graceful Shutdown() function
	shutdownError := make(chan error)

	go func() {
		quit := make(chan os.Signal, 1)
		signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
		s := <-quit

		app.logger.Info("shutting down server", "signal", s.String())

		ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
		defer cancel()

		shutdownError <- srv.Shutdown(ctx)
	}()

	app.logger.Info("starting server", "addr", srv.Addr, "port", app.config.port)

	if err := srv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
		return err
	}

	app.logger.Info("server closed", "addr", srv.Addr)
	return nil
}

Which we should use in our main():

// cmd/api/main.go
// ...
func main() {
	// ...
	if err := app.serve(); err != nil {
		logger.Error("error starting server", "error", err)
		os.Exit(1)
	}
}
// ...

12. User Model Setup and Registration

Create users table migration:

migrate create -seq -ext .sql -dir ./migrations create_users_table

And write the migrations:

-- UP
CREATE EXTENSION IF NOT EXISTS citext; -- I didn't have this

CREATE TABLE IF NOT EXISTS users (
    id bigserial PRIMARY KEY,
    created_at timestamp(0) with time zone NOT NULL DEFAULT NOW(),
    name text NOT NULL,
    email citext UNIQUE NOT NULL,
    password_hash bytea NOT NULL,
    activated bool NOT NULL,
    version integer NOT NULL DEFAULT 1
)

-- DOWN
DROP TABLE IF EXISTS users;
DROP EXTENSION IF EXISTS citext;

bytea is a binary string. This is what we use to store a one-way hash of the user’s password.

While the email we store is in a case-insensitive format, emails are not actually guaranteed to be case-insensitive. The domain part is (RFC 2821), but the username part isn’t guaranteed to be (up to the provider — most providers have it case-insensitive).
So, you should always store emails using the exact casing provided during user registration, and you should only send emails using that exact casing.
However, since they are very likely to be the same user, we generally treat email addresses as case-insensitive for comparison purposes. For example, if a user mistypes and accidentally capitalizes one of the characters in their email, they can’t create another account with the ‘same’ email.

First, we implement the User model:

// internal/data/users.go
package data

import (
	"context"
	"database/sql"
	"errors"
	"time"

	"golang.org/x/crypto/bcrypt"
	"greenlight.bagerbach.com/internal/validator"
)

var (
	ErrDuplicateEmail = errors.New("duplicate email")
)

type User struct {
	ID        int64     `json:"id"`
	CreatedAt time.Time `json:"created_at"`
	Name      string    `json:"name"`
	Email     string    `json:"email"`
	Password  password  `json:"-"`
	Activated bool      `json:"activated"`
	Version   int       `json:"-"`
}

type password struct {
	plaintext *string
	hash      []byte
}

// Set sets the password hash and plaintext
func (p *password) Set(plaintextPassword string) error {
	// Generates a bcrypt hash of a password using specific cost parameters (12 here)
	// The higher the cost, the slower (more computationally expensive) the hash generation will be
	// Need to strike a balance between security and performance
	// Inputs are truncated to 72 bytes (bcrypt max) when creating the hash, so it's a good idea to
	// e.g. enforce a hard max on the length of the password
	hash, err := bcrypt.GenerateFromPassword([]byte(plaintextPassword), 12)
	if err != nil {
		return err
	}

	p.plaintext = &plaintextPassword
	p.hash = hash

	return nil
}

// Checks if the provided plaintext password matches the hashed password
func (p *password) Matches(plaintextPassword string) (bool, error) {
	// Use the bcrypt package to compare the hashed password with the plaintext password
	// Works by re-hashing the provided plaintext password (using same salt and cost) and comparing the result
	// to the hashed password.
	// Compares using subtle.ConstantTimeCompare to avoid timing attacks, which are side-channel
	// attacks that allow an attacker to determine if the hashed password is correct by
	// measuring the time taken to compare the hashed password.
	if err := bcrypt.CompareHashAndPassword(p.hash, []byte(plaintextPassword)); err != nil {
		switch {
		case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword):
			return false, nil
		default:
			return false, err
		}
	}

	return true, nil
}

func ValidateEmail(v *validator.Validator, email string) {
	v.Check(email != "", "email", "must be provided")
	v.Check(validator.Matches(email, validator.EmailRX), "email", "must be a valid email address")
}

func ValidatePasswordPlaintext(v *validator.Validator, password string) {
	v.Check(password != "", "password", "must be provided")
	v.Check(len(password) <= 72, "password", "must not be more than 72 bytes long")
	v.Check(len(password) >= 8, "password", "must be at least 8 bytes long")
}

func ValidateUser(v *validator.Validator, user *User) {
	v.Check(user.Name != "", "name", "must be provided")
	v.Check(len(user.Name) <= 500, "name", "must not be more than 500 bytes long")

	ValidateEmail(v, user.Email)

	if user.Password.plaintext != nil {
		ValidatePasswordPlaintext(v, *user.Password.plaintext)
	}

	// Will only be nil due to logic error, e.g. forgetting to set a password for the user.
	// This is a sanity check, but isn't a problem with the data itself.
	if user.Password.hash == nil {
		panic("missing password hash for user")
	}
}

type UserModel struct {
	DB *sql.DB
}

func (m UserModel) Insert(user *User) error {
	query := `
		INSERT INTO users (name, email, password_hash, activated)
		VALUES ($1, $2, $3, $4)
		RETURNING id, created_at, version`

	args := []interface{}{user.Name, user.Email, user.Password.hash, user.Activated}

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	if err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.CreatedAt, &user.Version); err != nil {
		switch {
		case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
			return ErrDuplicateEmail
		default:
			return err
		}
	}

	return nil
}

func (m UserModel) GetByEmail(email string) (*User, error) {
	query := `
		SELECT id, created_at, name, email, password_hash, activated, version
		FROM users
		WHERE email = $1`

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	var user User
	if err := m.DB.QueryRowContext(ctx, query, email).Scan(
		&user.ID,
		&user.CreatedAt,
		&user.Name,
		&user.Email,
		&user.Password.hash,
		&user.Activated,
		&user.Version,
	); err != nil {
		switch {
		case err == sql.ErrNoRows:
			return nil, ErrRecordNotFound
		default:
			return nil, err
		}
	}

	return &user, nil
}

func (m UserModel) Update(user *User) error {
	query := `
		UPDATE users
		SET name = $1, email = $2, password_hash = $3, activated = $4, version = version + 1
		WHERE id = $5 AND version = $6
		RETURNING version`

	args := []interface{}{user.Name, user.Email, user.Password.hash, user.Activated, user.ID, user.Version}

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	if err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version); err != nil {
		switch {
		case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
			return ErrDuplicateEmail
		case errors.Is(err, sql.ErrNoRows):
			return ErrRecordNotFound
		default:
			return err
		}
	}

	return nil
}

And we include that in our Models struct:

// internal/data/models.go
package data

import (
	"database/sql"
	"errors"
)

var (
	ErrRecordNotFound = errors.New("record not found")
	ErrEditConflict   = errors.New("edit conflict")
)

type Models struct {
	Movies MovieModel
	Users  UserModel // Added
}

func NewModels(db *sql.DB) Models {
	return Models{
		Movies: MovieModel{DB: db},
		Users:  UserModel{DB: db}, // Added
	}
}

Then we create the route handler:

package main

import (
	"errors"
	"net/http"

	"greenlight.bagerbach.com/internal/data"
	"greenlight.bagerbach.com/internal/validator"
)

func (app *application) registerUserHandler(w http.ResponseWriter, r *http.Request) {
	var input struct {
		Name     string `json:"name"`
		Email    string `json:"email"`
		Password string `json:"password"`
	}

	if err := app.readJSON(w, r, &input); err != nil {
		app.badRequestResponse(w, r, err)
		return
	}

	user := &data.User{
		Name:      input.Name,
		Email:     input.Email,
		Activated: false,
	}

	if err := user.Password.Set(input.Password); err != nil {
		app.badRequestResponse(w, r, err)
		return
	}

	v := validator.New()

	if data.ValidateUser(v, user); !v.Valid() {
		app.failedValidationResponse(w, r, v.Errors)
		return
	}

	if err := app.models.Users.Insert(user); err != nil {
		switch {
		case errors.Is(err, data.ErrDuplicateEmail):
			v.AddError("email", "a user with this email address already exists")
			app.failedValidationResponse(w, r, v.Errors)
		default:
			app.serverErrorResponse(w, r, err)
		}
		return
	}

	if err := app.writeJSON(w, http.StatusCreated, envelope{"user": user}, nil); err != nil {
		app.serverErrorResponse(w, r, err)
	}
}

And we use it in routes:

// cmd/api/routes.go
package main

import (
	"net/http"

	"github.com/julienschmidt/httprouter"
)

func (app *application) routes() http.Handler {
	router := httprouter.New()

	router.NotFound = http.HandlerFunc(app.notFoundResponse)
	router.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse)

	router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler)

	router.HandlerFunc(http.MethodGet, "/v1/movies", app.listMoviesHandler)
	router.HandlerFunc(http.MethodPost, "/v1/movies", app.createMovieHandler)
	router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.showMovieHandler)
	router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.updateMovieHandler)
	router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler)

	// Added this
	router.HandlerFunc(http.MethodPost, "/v1/users", app.registerUserHandler)

	return app.recoverPanic(app.rateLimit(router))
}

13. Sending Emails

We’re using Mailtrap because it has a nice testing suite for emails. We’re also using go-mail/mail, as the standard Go SMTP package is not receiving updates.

$ go get github.com/go-mail/mail/v2@v2

Add SMTP configuration fields to the config struct in cmd/api/main.go:

type config struct {
    // ...
    smtp struct {
        host     string
        port     int
        username string
        password string
        sender   string
    }
}

Update the main.go file to include the new SMTP configuration:

package main

import (
    ...
    "greenlight.bagerbach.com/internal/mailer"
    ...
    "sync"
    "strconv"
)

func main() {
    var cfg config
    ...
    flag.StringVar(&cfg.smtp.host, "smtp-host", os.Getenv("SMTP_HOST"), "SMTP host")
    port, err := strconv.Atoi(os.Getenv("SMTP_PORT"))
    if err != nil {
        slog.Error("invalid SMTP_PORT", "error", err)
        os.Exit(1)
    }
    flag.IntVar(&cfg.smtp.port, "smtp-port", port, "SMTP port")
    flag.StringVar(&cfg.smtp.username, "smtp-username", os.Getenv("SMTP_USERNAME"), "SMTP username")
    flag.StringVar(&cfg.smtp.password, "smtp-password", os.Getenv("SMTP_PASSWORD"), "SMTP password")
    flag.StringVar(&cfg.smtp.sender, "smtp-sender", os.Getenv("SMTP_SENDER"), "SMTP sender")
    ...
    app := &application{
        config: cfg,
        ...
        mailer: mailer.New(cfg.smtp.host, cfg.smtp.port, cfg.smtp.username, cfg.smtp.password, cfg.smtp.sender),
        wg:     sync.WaitGroup{},  // we'll use this to ensure graceful shutdown for background jobs
    }
    ...
}

Create a new package for handling emails in internal/mailer/mailer.go:

package mailer

import (
    "bytes"
    "embed"
    "text/template"
    "time"
    "github.com/go-mail/mail/v2"
)

//go:embed "templates"
var templateFS embed.FS

type Mailer struct {
    dialer *mail.Dialer
    sender string
}

func New(host string, port int, username, password, sender string) Mailer {
    dialer := mail.NewDialer(host, port, username, password)
    dialer.Timeout = 5 * time.Second

    return Mailer{
        dialer: dialer,
        sender: sender,
    }
}

func (m Mailer) Send(recipient, templateFile string, data any) error {
    tmpl, err := template.ParseFS(templateFS, "templates/"+templateFile)
    if err != nil {
        return err
    }

    subject := new(bytes.Buffer)
    if err := tmpl.ExecuteTemplate(subject, "subject", data); err != nil {
        return err
    }

    plainBody := new(bytes.Buffer)
    if err := tmpl.ExecuteTemplate(plainBody, "plainBody", data); err != nil {
        return err
    }

    htmlBody := new(bytes.Buffer)
    if err := tmpl.ExecuteTemplate(htmlBody, "htmlBody", data); err != nil {
        return err
    }

    msg := mail.NewMessage()
    msg.SetHeader("To", recipient)
    msg.SetHeader("From", m.sender)
    msg.SetHeader("Subject", subject.String())
    msg.SetBody("text/plain", plainBody.String())
    msg.AddAlternative("text/html", htmlBody.String())

    const maxRetries = 3
    for i := 0; i < maxRetries; i++ {
        if err := m.dialer.DialAndSend(msg); err != nil {
            if i == maxRetries-1 {
                return err
            }
            time.Sleep(2 * time.Second)
            continue
        }
        break
    }

    return nil
}

Add the email template user_welcome.tmpl in internal/mailer/templates:

{{define "subject"}}Welcome to Greenlight!{{end}}

{{define "plainBody"}}
Hi {{.Name}},

Thanks for signing up for a Greenlight account. We're excited to have you on board!

Thanks,

The Greenlight Team
{{end}}

{{define "htmlBody"}}
<!doctype html>
<html>

<head>
    <meta name="viewport" content="width=device-width" />
    <meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
</head>

<body>
    <p>Hi {{.Name}},</p>
    <p>Thanks for signing up for a Greenlight account. We're excited to have you on board!</p>
    <p>Thanks,</p>
    <p>The Greenlight Team</p>
</body>

</html>
{{end}}

Modify the registerUserHandler function in cmd/api/users.go to send a welcome email:

func (app *application) registerUserHandler(w http.ResponseWriter, r *http.Request) {
    ...
    app.background(func() {
        if err := app.mailer.Send(user.Email, "user_welcome.tmpl", user); err != nil {
            app.logger.Error("failed to send email", "error", err)
        }
    })

    if err := app.writeJSON(w, http.StatusAccepted, envelope{"user": user}, nil); err != nil {
        app.serverErrorResponse(w, r, err)
    }
}

Add a method to handle background tasks in cmd/api/helpers.go:

func (app *application) background(fn func()) {
    app.wg.Add(1)

    go func() {
        defer app.wg.Done()

        defer func() {
            if err := recover(); err != nil {
                app.logger.Error(fmt.Sprintf("%v", err))
            }
        }()

        fn()
    }()
}

Ensure that the server waits for all background tasks to complete before shutting down, in cmd/api/server.go:

func (app *application) serve() error {
    ...
    go func() {
        ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
        defer cancel()

        err := srv.Shutdown(ctx)
        if err != nil {
            shutdownError <- err
        }

        app.logger.Info("completing background tasks", "addr", srv.Addr)
        app.wg.Wait()        
        shutdownError <- nil 
    }()

    ...
    err := <-shutdownError
    if err != nil {
        return err
    }

    app.logger.Info("server closed", "addr", srv.Addr)
    return nil
}

14. User Activation

Create migration:

$ migrate create -seq -ext .sql -dir ./migrations create_tokens_table

Create up & down migrations:

-- UP
CREATE TABLE IF NOT EXISTS tokens (
    hash bytea PRIMARY KEY,
    user_id bigint NOT NULL REFERENCES users ON DELETE CASCADE,
    expires timestamp(0) with time zone NOT NULL,
    scope text NOT NULL
);

-- DOWN
DROP TABLE IF EXISTS tokens;

We use a SHA-256 hash for the activation token. A fast algorithm like SHA-256 is fine as we use a high-entropy random string for our activation token.

We reference users, creating a foreign key constraint against the primary key of the users table. This means the user_id will have a corresponding id entry in the users table. The ON DELETE CASCADE ensures that if a user is deleted, the corresponding tokens are deleted as well.

The tokens table will store more than just activation tokens, which is why we also create the scope column. We’ll use it to restrict the scope of what a token can be used for.

Creating tokens
Tokens should be ‘unguessable’ – not easy to guess, not possible to brute-force.
So we want the token to be generated by a cryptographically secure random number generator (CSPRNG) and have enough entropy (randomness) that it’s impossible to guess.

The activation tokens will be created with crypto/rand and 128 bits (16 bytes) of entropy.
Never use math/rand, which provides a deterministic pseudo-random number generator (PRNG), for any purpose requiring cryptographic security.

// internal/data/tokens.go
package data

import (
	"context"
	"crypto/rand"
	"crypto/sha256"
	"database/sql"
	"encoding/base32"
	"time"

	"greenlight.bagerbach.com/internal/validator"
)

const (
	ScopeActivation = "activation"
)

type Token struct {
	Plaintext string
	Hash      []byte
	UserID    int64
	Expiry    time.Time
	Scope     string
}

func generateToken(userID int64, ttl time.Duration, scope string) (*Token, error) {
	token := &Token{
		UserID: userID,
		Expiry: time.Now().Add(ttl),
		Scope:  scope,
	}

	// 16 doesn't mean the plaintext tokens are 16 characters long, but that they have an underlying entropy of 16 bytes of randomness
	// The length of the plaintext token depends on the 16 random bytes are encoded to create a string.
	// Since we'll encode them to a base-32 string, it'll be 26 characters long
	randomBytes := make([]byte, 16)
	// Fill randomBytes with random data using the OS' CSPRNG
	if _, err := rand.Read(randomBytes); err != nil {
		return nil, err
	}

	// We encode the random bytes to a base32-encoded string
	// By default, base-32 strings may be padded with '=' characters, so we use WithPadding to remove them, we don't need them
	token.Plaintext = base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(randomBytes)

	// Now we generate the hash of the plaintext token.
	hash := sha256.Sum256([]byte(token.Plaintext))
	token.Hash = hash[:] // It returned an array, so we convert it to a slice to make it easier to work with

	return token, nil
}

func ValidateTokenPlaintext(v *validator.Validator, tokenPlaintext string) {
	v.Check(tokenPlaintext != "", "token", "must be provided")
	v.Check(len(tokenPlaintext) == 26, "token", "must be 26 characters long")
}

type TokenModel struct {
	DB *sql.DB
}

func (m TokenModel) New(userID int64, ttl time.Duration, scope string) (*Token, error) {
	token, err := generateToken(userID, ttl, scope)
	if err != nil {
		return nil, err
	}

	err = m.Insert(token)
	return token, err
}

func (m TokenModel) Insert(token *Token) error {
	query := `
		INSERT INTO tokens (hash, user_id, expiry, scope)
		VALUES ($1, $2, $3, $4)
	`
	args := []interface{}{token.Hash, token.UserID, token.Expiry, token.Scope}

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	_, err := m.DB.ExecContext(ctx, query, args...)
	return err
}

func (m TokenModel) DeleteAllForUser(userID int64, scope string) error {
	query := `
		DELETE FROM tokens
		WHERE user_id = $1 AND scope = $2
	`
	args := []interface{}{userID, scope}

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	_, err := m.DB.ExecContext(ctx, query, args...)
	return err
}

And add it to models:

// internal/data/models.go
package data

import (
	"database/sql"
	"errors"
)

var (
	ErrRecordNotFound = errors.New("record not found")
	ErrEditConflict   = errors.New("edit conflict")
)

type Models struct {
	Movies MovieModel
	Users  UserModel
	Tokens TokenModel // here
}

func NewModels(db *sql.DB) Models {
	return Models{
		Movies: MovieModel{DB: db},
		Users:  UserModel{DB: db},
		Tokens: TokenModel{DB: db}, // here
	}
}

Send the activation token to users
We’ll put the token in the email, asking the user to send a PUT request.
The book argues against letting users click a link to activate. That would mean we activate with a GET request, which is more convenient, but has some drawbacks. First, it violates the HTTP principle that GET should only be used for ‘safe’ requests which retrieve resources, not modify something.
And it’s possible that the link would get pre-fetched, inadvertently activating the account.

Generally, ensure any actions which change state are only ever executed via POST, PUT, PATCH, or DELETE. Not GET.

If the API is a backend for a website, you could ask users to click a link, which takes them to a page on the website. Then they can click a button which sends the PUT request to confirm their activation. Alternatively, you could ask them to paste in the activation code.
If you go with the former, the URL you send them to could be, e.g.: https://example.com/users/activate?token=Y3QMGX3PJ3WLRL2YRTQGQ6KRHU
But make sure you avoid the token being leaked in a referrer header if you go with that option. You could use either the Referrer-Policy: Origin header or the meta tag for it.
You also don’t want to rely on the Host header from r.header to construct the URL, as you’d be vulnerable to a host header injection attack. So hard-code the URL or pass it in as a CLI flag.

We start by updating the welcome email template:

{{define "subject"}}Welcome to Greenlight!{{end}}

{{define "plainBody"}}
Hi {{.name}},

Thanks for signing up for a Greenlight account. We're excited to have you on board!

Please send a request to the `PUT /v1/users/activated` endpoint with the following JSON
body to activate your account:

{"token": "{{.activationToken}}"}

Please note that this is a one-time use token, and it will expire in 3 days.

Thanks,

The Greenlight Team
{{end}}

{{define "htmlBody"}}
<!doctype html>
<html>

<head>
    <meta name="viewport" content="width=device-width" />
    <meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
</head>

<body>
    <p>Hi {{.name}},</p>
    <p>Thanks for signing up for a Greenlight account. We're excited to have you on board!</p>
    <p>Please send a request to the <code>PUT /v1/users/activated</code> endpoint with the 
    following JSON body to activate your account:</p>
    <pre><code>
    {"token": "{{.activationToken}}"}
    </code></pre>
    <p>Please note that this is a one-time use token, and it will expire in 3 days.</p>
    <p>Thanks,</p>
    <p>The Greenlight Team</p>
</body>

</html>
{{end}}

And then we update the registerUserHandler we made earlier:

// cmd/api/users.go
// ...
func (app *application) registerUserHandler(w http.ResponseWriter, r *http.Request) {
	// ...

	token, err := app.models.Tokens.New(user.ID, 3*24*time.Hour, "activation")
	if err != nil {
		app.serverErrorResponse(w, r, err)
		return
	}

	app.background(func() {
		data := map[string]interface{}{
			"activationToken": token.Plaintext,
			"name":            user.Name,
		}

		if err := app.mailer.Send(user.Email, "user_welcome.tmpl", data); err != nil {
			app.logger.Error("failed to send email", "error", err)
		}
	})

	if err := app.writeJSON(w, http.StatusAccepted, envelope{"user": user}, nil); err != nil {
		app.serverErrorResponse(w, r, err)
	}
}

A nice addition to this would be to provide a standalone endpoint for generating tokens.
For example, if a user doesn’t activate within the 3-day limit, or if they never receive their welcome email, then you’d be able to resend it.

Activating users
Since we have a one-to-many relationship between our users and tokens, it can be useful to execute queries against the relationship from both sides.
We could update the database models like so:

UserModel.GetForToken(token)   // -> Get user associated with token
TokenModel.GetAllForUser(user) // -> Get all tokens associated with a user

This also aligns with the responsibility of the models. The UserModel method returns a user, and the TokenModel returns tokens.

We create the handler for activating users:

// cmd/api/users.go
func (app *application) activateUserHandler(w http.ResponseWriter, r *http.Request) {
	var input struct {
		TokenPlaintext string `json:"token"`
	}

	if err := app.readJSON(w, r, &input); err != nil {
		app.badRequestResponse(w, r, err)
		return
	}

	v := validator.New()

	if data.ValidateTokenPlaintext(v, input.TokenPlaintext); !v.Valid() {
		app.failedValidationResponse(w, r, v.Errors)
		return
	}

	user, err := app.models.Users.GetForToken(data.ScopeActivation, input.TokenPlaintext)
	if err != nil {
		switch {
		case errors.Is(err, data.ErrRecordNotFound):
			v.AddError("token", "invalid or expired activation token")
			app.failedValidationResponse(w, r, v.Errors)
		default:
			app.serverErrorResponse(w, r, err)
		}
		return
	}

	user.Activated = true

	if err := app.models.Users.Update(user); err != nil {
		switch {
		case errors.Is(err, data.ErrEditConflict):
			app.editConflictResponse(w, r)
		default:
			app.serverErrorResponse(w, r, err)

		}
		return
	}

	if err := app.models.Tokens.DeleteAllForUser(user.ID, data.ScopeActivation); err != nil {
		app.serverErrorResponse(w, r, err)
		return
	}

	if err := app.writeJSON(w, http.StatusOK, envelope{"user": user}, nil); err != nil {
		app.serverErrorResponse(w, r, err)
	}
}

Which needs the UserModel.GetForToken method:

// internal/data/users.go
// ...
func (m UserModel) GetForToken(tokenScope, tokenPlaintext string) (*User, error) {
	tokenHash := sha256.Sum256([]byte(tokenPlaintext))

	query := `
		SELECT users.id, users.created_at, users.name, users.email, users.password_hash, users.activated, users.version
		FROM users
		INNER JOIN tokens ON users.id = tokens.user_id
		WHERE tokens.hash = $1
		AND tokens.scope = $2
		AND tokens.expiry > $3`

	// use [:] to turn the byte array into a slice
	args := []interface{}{tokenHash[:], tokenScope, time.Now()}

	var user User

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	if err := m.DB.QueryRowContext(ctx, query, args...).Scan(
		&user.ID,
		&user.CreatedAt,
		&user.Name,
		&user.Email,	
		&user.Password.hash,
		&user.Activated,
		&user.Version,
	); err != nil {
		switch {
		case errors.Is(err, sql.ErrNoRows):
			return nil, ErrRecordNotFound
		default:
			return nil, err
		}
	}

	return &user, nil
}

We also add the route:

// cmd/api/routes.go
func (app *application) routes() http.Handler {
	// ...
	router.HandlerFunc(http.MethodPut, "/v1/users/activated", app.activateUserHandler)
	// ...

We use PUT and not POST because the endpoint is idempotent. This means if the same request is sent multiple times, the first will succeed—if the token is valid—and any subsequent requests will result in an error for the user, as the token is consumed and deleted. However, nothing in the application state changes after the first request.

15. Authentication

There are many approaches to authenticating requests to APIs.
Here are some common ones:

  • Basic authentication
  • Stateful token authentication
  • Stateless token authentication
  • API key authentication
  • OAuth 2.0 / OpenID Connect

HTTP Basic Authentication

The client includes an Authorization header with every request containing their credentials.
They’ll be in the format username:password and are base-64 encoded.

Like:

Authorization: Basic YWxpY2VAZXhhbXBsZS5jb206cGE1NXdvcmQ=

You can extract these with Go’s Request.BasicAuth() method.

Comparing the password with a hashed password is a slow operation (deliberately), and when you use basic auth, you’ll have to do it on each request.

Token Authentication

Also known as bearer token authentication.

  1. Client sends a request to the API with credentials (usually username/email and password).
  2. API verifies credentials, generates a bearer token representing the user, and sends it back. The token expires after some time, after which users will have to resubmit their credentials to get a new token.
  3. For subsequent requests, the client includes the token in an Authorization header as Authorization: Bearer <token>.
  4. When the API receives that, it will check that the token hasn’t expired and examine it to determine who the user is.

This avoids the issue of having to do the expensive checking of password vs. hashed password on every request. However, managing tokens can be complicated for clients (they need to handle caching tokens, monitoring and managing expiry, and periodically generating new ones).

There are two subtypes: stateful and stateless.

Stateful token authentication

The token value is a high-entropy, cryptographically secure random string.
It is usually stored in a server-side database with the user ID and expiry time.

The client sends the token in subsequent requests, and the server looks up the token, checks its expiry, and retrieves the user ID.

The API maintains control over the tokens. It’s easy to remove them by deleting or marking them as expired. This approach is also simple and robust, as we get security from an unguessable high-entropy token.

There aren’t many downsides besides those inherent with token authentication.
It requires a database lookup, though that’s necessary for user status checks anyway.

Stateless token authentication

Encodes the user ID and expiry time in the token itself. The token is cryptographically signed to prevent tampering and sometimes encrypted to prevent reading the contents.

Encoding the information in a JSON Web Token (JWT) is common. However, PASETO, Branca, and nacl/secretbox are viable alternatives.

The work required to encode/decode the token can be done in memory, and all necessary information to identify the user is in the token itself – no database lookups are needed to figure out who made the request.

But the tokens aren’t easy to revoke once they’re issued.

You could revoke all tokens by changing the secret used for signing the tokens, which forces all users to re-authenticate. Or maintain a blocklist of revoked tokens, but that defeats the stateless aspect.

Avoid storing additional information in these tokens, like activation status or permissions. Save that for authorization checks.

JWTs are highly configurable, meaning lots can go wrong.

Because of these downsides, stateless tokens (and especially JWTs) aren’t the best choice for managing authentication in most API applications. But they can be useful for delegated authentication, which is where the application creating the token is different from the one consuming it, and those apps don’t share state (so stateful isn’t an option). This is common in a microservice-style architecture.

API-Key Authentication

The user gets a non-expiring secret ‘key’ associated with their account.
This key should be a high-entropy, cryptographically secure random string. You store a fast hash of the key (SHA256 or SHA512) alongside the corresponding user ID in your database.

The user passes their key with each request to the API in the authorization header:

Authorization: Key <key>

When received, your API can regenerate the fast hash of the key and use it to look up the corresponding user ID.

This approach is similar to the stateful token approach, except keys are permanent, not temporary tokens.

This is also nice for the client, as they don’t have to manage tokens or expiry.
But now they have both their password and their API key to manage.
You need to build ways for users to regenerate their API keys if they are lost or compromised. Users may want multiple keys for the same account.
They should only be communicated to users over a secure channel – you should treat them with the same level of care as a password.

OAuth 2.0 / OpenID Connect

With this approach, information about your users is stored by a third-party identity provider, like Google, rather than by yourself.

OAuth 2.0 isn’t really an authentication protocol, and you shouldn’t use it for authenticating users.
If you want authentication checks against a third-party provider, you should use OpenID Connect.

Generating Authentication Tokens

Create an error for invalid credentials:

// cmd/api/errors.go
// ...
func (app *application) invalidCredentialsResponse(w http.ResponseWriter, r *http.Request) {
    message := "invalid authentication credentials"
    app.errorResponse(w, r, http.StatusUnauthorized, message)
}

Update Token struct to include JSON struct tags & add new constant for authentication scope:

// internal/data/tokens.go

const (
    ScopeActivation     = "activation"
    ScopeAuthentication = "authentication"
)


type Token struct {
    Plaintext string    `json:"token"`
    Hash      []byte    `json:"-"`
    UserID    int64     `json:"-"`
    Expiry    time.Time `json:"expiry"`
    Scope     string    `json:"-"`
}

Create a new route handler for the token creation endpoint:

// cmd/api/tokens.go
package main

import (
	"errors"
	"net/http"
	"time"

	"greenlight.bagerbach.com/internal/data"
	"greenlight.bagerbach.com/internal/validator"
)

func (app *application) createAuthenticationTokenHandler(w http.ResponseWriter, r *http.Request) {
	var input struct {
		Email    string `json:"email"`
		Password string `json:"password"`
	}

	err := app.readJSON(w, r, &input)
	if err != nil {
		app.badRequestResponse(w, r, err)
		return
	}

	v := validator.New()

	data.ValidateEmail(v, input.Email)
	data.ValidatePasswordPlaintext(v, input.Password)

	if !v.Valid() {
		app.failedValidationResponse(w, r, v.Errors)
		return
	}

	user, err := app.models.Users.GetByEmail(input.Email)
	if err != nil {
		switch {
		case errors.Is(err, data.ErrRecordNotFound):
			app.invalidCredentialsResponse(w, r)
		default:
			app.serverErrorResponse(w, r, err)
		}
		return
	}

	match, err := user.Password.Matches(input.Password)
	if err != nil {
		app.serverErrorResponse(w, r, err)
		return
	}

	if !match {
		app.invalidCredentialsResponse(w, r)
		return
	}

	token, err := app.models.Tokens.New(user.ID, 24*time.Hour, data.ScopeAuthentication)
	if err != nil {
		app.serverErrorResponse(w, r, err)
		return
	}

	err = app.writeJSON(w, http.StatusCreated, envelope{"authentication_token": token}, nil)
	if err != nil {
		app.serverErrorResponse(w, r, err)
	}
}

And then include it in the routes:

func (app *application) routes() http.Handler {
    // ...
    router.HandlerFunc(http.MethodPost, "/v1/tokens/authentication", app.createAuthenticationTokenHandler)

    return app.recoverPanic(app.rateLimit(router))
}

Authenticating Requests

We’ll expect clients to send their authentication tokens, once they have them, in all requests in an Authorization header (as a bearer token).
We’ll use an authenticate() middleware to handle validation and adding user info to the request context. If the token is valid, we’ll populate the request context with user info. If not, we’ll send a 401 Unauthorized. And if no Authorization header is given, we add details for an anonymous user.

Add context handling for users:

// cmd/api/context.go

package main

import (
	"context"
	"net/http"

	"greenlight.bagerbach.com/internal/data"
)

type contextKey string

const userContextKey = contextKey("user")

func (app *application) contextSetUser(r *http.Request, user *data.User) *http.Request {
	ctx := context.WithValue(r.Context(), userContextKey, user)
	return r.WithContext(ctx)
}

func (app *application) contextGetUser(r *http.Request) *data.User {
	user, ok := r.Context().Value(userContextKey).(*data.User)
	if !ok {
		panic("missing user value in request context")
	}

	return user
}

Add error response for invalid authentication tokens:

// cmd/api/errors.go
func (app *application) invalidAuthenticationTokenResponse(w http.ResponseWriter, r *http.Request) {
    // Remind client that we expect them to authenticate using a bearer token.
    w.Header().Set("WWW-Authenticate", "Bearer")

    message := "invalid or missing authentication token"
    app.errorResponse(w, r, http.StatusUnauthorized, message)
}

Implement middleware to handle authentication:

// cmd/api/middleware.go
func (app *application) authenticate(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // This header indicates that the response may vary based on the value of the Authorization header in the request.
        w.Header().Add("Vary", "Authorization")

        authorizationHeader := r.Header.Get("Authorization")
        if authorizationHeader == "" {
            app.contextSetUser(r, data.AnonymousUser)
            next.ServeHTTP(w, r)
            return
        }

        // We expect the value in the Authorization header to be in the format of "Bearer <token>".
        headerParts := strings.Split(authorizationHeader, " ")
        if len(headerParts) != 2 || headerParts[0] != "Bearer" {
            app.invalidAuthenticationTokenResponse(w, r)
            return
        }

        token := headerParts[1]

        // Validate token
        v := validator.New()
        if data.ValidateTokenPlaintext(v, token); !v.Valid() {
            app.invalidAuthenticationTokenResponse(w, r)
            return
        }

        // Get user associated with authentication token
        user, err := app.models.Users.GetForToken(data.ScopeAuthentication, token)
        if err != nil {
            switch {
            case errors.Is(err, data.ErrRecordNotFound):
                app.invalidAuthenticationTokenResponse(w, r)
            default:
                app.serverErrorResponse(w, r, err)
            }
            return
        }

        r = app.contextSetUser(r, user)
        next.ServeHTTP(w, r)
    })
}

And we use it in routes.go:

// cmd/api/routes.go
func (app *application) routes() http.Handler {
	// ...
	return app.recoverPanic(app.rateLimit(app.authenticate(router)))
}

We also add an anonymous user and a method to check if a user is anonymous:

// internal/data/users.go
var AnonymousUser = &User{}

type User struct {
	// ...
}

func (u *User) IsAnonymous() bool {
    return u == AnonymousUser
}

16. Permission-based Authorization

Ensure only activated users can access some endpoints and implement a permission-based authorization pattern for fine-grained control over who can access which endpoints.

Requiring User Activation

We’ll add some errors to handle the new cases:

// cmd/api/errors.go
func (app *application) authenticationRequiredResponse(w http.ResponseWriter, r *http.Request) {
    message := "you must be authenticated to access this resource"
    app.errorResponse(w, r, http.StatusUnauthorized, message)
}

func (app *application) inactiveAccountResponse(w http.ResponseWriter, r *http.Request) {
    message := "your account must be activated to access this resource"
    app.errorResponse(w, r, http.StatusForbidden, message)
}

Implement middleware for requiring authenticated & activated user accounts:

// cmd/api/middleware.go
func (app *application) requireAuthenticatedUser(next http.HandlerFunc) http.HandlerFunc {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        user := app.contextGetUser(r)

        if user.IsAnonymous() {
            app.authenticationRequiredResponse(w, r)
            return
        }

        next.ServeHTTP(w, r)
    })
}

func (app *application) requireActivatedUser(next http.HandlerFunc) http.HandlerFunc {
    fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        user := app.contextGetUser(r)

        if !user.Activated {
            app.inactiveAccountResponse(w, r)
            return
        }

        next.ServeHTTP(w, r)
    })

    return app.requireAuthenticatedUser(fn)
}

Update routes to use middleware:

// cmd/api/routes.go
func (app *application) routes() http.Handler {
	// ...
	router.HandlerFunc(http.MethodGet, "/v1/movies", app.requireActivatedUser(app.listMoviesHandler))
    router.HandlerFunc(http.MethodPost, "/v1/movies", app.requireActivatedUser(app.createMovieHandler))
    router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.requireActivatedUser(app.showMovieHandler))
    router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.requireActivatedUser(app.updateMovieHandler))
    router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.requireActivatedUser(app.deleteMovieHandler))

	// ...
    return app.recoverPanic(app.rateLimit(app.authenticate(router)))
}

Permissions Database Table

We’ll use this to restrict routes based on permissions like movies:read and movies:write.

The relationship between users and permissions will be many-to-many. A user can have many permissions, and the same permission can belong to many users.
This kind of relationship can be managed in a relational database by creating a joining table between two entities.

So we may store all users in one table (id, email, etc.) and all permissions in another table (id, code). Then our joined table will be user_permissions, storing information about which users have which permissions, by their respective ids (user_id, permission_id).

We start by creating the migration:

migrate create -seq -ext .sql -dir ./migrations add_permissions

Up migration

CREATE TABLE IF NOT EXISTS permissions (
    id bigserial PRIMARY KEY,
    code text NOT NULL
);

CREATE TABLE IF NOT EXISTS users_permissions (
    user_id bigint NOT NULL REFERENCES users ON DELETE CASCADE,
    permission_id bigint NOT NULL REFERENCES permissions ON DELETE CASCADE,
    -- This is a composite primary key
    -- It means that the combination of user_id and permission_id must be unique
    PRIMARY KEY (user_id, permission_id)
);

INSERT INTO permissions (code) VALUES ('movies:read'), ('movies:write');

Down migration

DROP TABLE IF EXISTS users_permissions;
DROP TABLE IF EXISTS permissions;

And now we’ll create the permissions model.

// internal/data/permissions.go
package data

import (
	"context"
	"database/sql"
	"time"
)

type Permissions []string

func (p Permissions) Include(code string) bool {
	for _, permission := range p {
		if permission == code {
			return true
		}
	}
	return false
}

type PermissionModel struct {
	DB *sql.DB
}

func (m PermissionModel) GetAllForUser(userID int64) (Permissions, error) {
	query := `
		SELECT permissions.code FROM permissions
		INNER JOIN users_permissions ON users_permissions.permission_id = permissions.id
		INNER JOIN users ON users.id = users_permissions.user_id
		WHERE users.id = $1
	`

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	rows, err := m.DB.QueryContext(ctx, query, userID)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var permissions Permissions
	for rows.Next() {
		var permission string
		err := rows.Scan(&permission)
		if err != nil {
			return nil, err
		}
		permissions = append(permissions, permission)
	}

	if err := rows.Err(); err != nil {
		return nil, err
	}

	return permissions, nil
}

And include it in the parent Model struct:

// internal/data/models.go
type Models struct {
	Movies      MovieModel
	Users       UserModel
	Tokens      TokenModel
	Permissions PermissionModel
}

func NewModels(db *sql.DB) Models {
	return Models{
		Movies:      MovieModel{DB: db},
		Users:       UserModel{DB: db},
		Tokens:      TokenModel{DB: db},
		Permissions: PermissionModel{DB: db},
	}
}

Checking Permissions

Add an error for insufficient permissions:

// cmd/api/errors.go
func (app *application) nonPermittedResponse(w http.ResponseWriter, r *http.Request) {
	message := "your account does not have the necessary permissions to access this resource"
	app.errorResponse(w, r, http.StatusForbidden, message)
}

Add middleware to require a permission:

// cmd/api/middleware.go

func (app *application) requirePermission(code string, next http.HandlerFunc) http.HandlerFunc {
	fn := func(w http.ResponseWriter, r *http.Request) {
		user := app.contextGetUser(r)

		permissions, err := app.models.Permissions.GetAllForUser(user.ID)
		if err != nil {
			app.serverErrorResponse(w, r, err)
			return
		}

		if !permissions.Include(code) {
			app.nonPermittedResponse(w, r)
			return
		}

		next.ServeHTTP(w, r)
	}

	return app.requireActivatedUser(fn)
}

Use the middleware:

// cmd/api/routes.go
func (app *application) routes() http.Handler {
	// ...

	router.HandlerFunc(http.MethodGet, "/v1/movies", app.requirePermission("movies:read", app.listMoviesHandler))
	router.HandlerFunc(http.MethodPost, "/v1/movies", app.requirePermission("movies:write", app.createMovieHandler))
	router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.requirePermission("movies:read", app.showMovieHandler))
	router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.requirePermission("movies:write", app.updateMovieHandler))
	router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.requirePermission("movies:write", app.deleteMovieHandler))

	// ...
	return app.recoverPanic(app.rateLimit(app.authenticate(router)))
}

Granting Permissions

We’ll need a method for adding permissions to a user.

// internal/data/permissions.go
func (m PermissionModel) AddForUser(userID int64, codes ...string) error {
	query := `
		INSERT INTO users_permissions
		SELECT $1, permissions.id FROM permissions WHERE permissions.code = ANY($2)
	`

	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	_, err := m.DB.ExecContext(ctx, query, userID, pq.Array(codes))
	return err
}

You’ll need to import "github.com/lib/pq" for pq.Array.

And we’ll add movies:read as a default permission for new users:

// cmd/api/users.go
func (app *application) registerUserHandler(w http.ResponseWriter, r *http.Request) {
	// ...
	if err := app.models.Users.Insert(user); err != nil {
		switch {
		case errors.Is(err, data.ErrDuplicateEmail):
			v.AddError("email", "a user with this email address already exists")
			app.failedValidationResponse(w, r, v.Errors)
		default:
			app.serverErrorResponse(w, r, err)
		}
		return
	}

	// Add this:
	if err := app.models.Permissions.AddForUser(user.ID, "movies:read"); err != nil {
		app.serverErrorResponse(w, r, err)
		return
	}
	// ...
}

17. Cross Origin Requests

First, what is meant by origin?
If two URLs have the same scheme, host, and port (if specified), they are said to share the same origin. For example, https://foo.com/a and http://foo.com/a do not have the same origin, as they have different schemes. Likewise, http://foo.com/a and http://www.foo.com/a have different hosts.

This is important because browsers implement a security mechanism called the same-origin policy.
In practice, this means a webpage on one origin can embed certain resources from another in its HTML (like images, CSS, and JavaScript files). A webpage on one origin can send data to a different origin – like an HTML form sending to another origin. But a webpage on one origin is not allowed to receive data from a different origin.

This prevents potentially malicious websites on another origin from reading potentially confidential information on your website.
Sending data to other origins is not prevented by the policy. However, it is still dangerous, which is why CSRF attacks are possible.

Say you have a webpage at https://foo.com. If the JavaScript on that site tries to make an HTTP request to https://bar.com/data.json (a different origin), then the request is sent and processed by the bar.com server, but the user’s browser will block the response so the JavaScript code from https://foo.com cannot see it.

The same-origin policy is a useful safeguard. However, sometimes you may want to relax it.
Say you have an API at api.example.com, and a trusted JavaScript front-end running on www.example.com. Then you might want to allow cross-origin requests from the trusted www.example.com domain to your API.
Or you have a completely open public API, and you want to allow cross-origin requests from anywhere.

Modern web browsers let you allow or disallow specific cross-origin requests to your API by setting Access-Control headers on your API responses.

Notably, the same-origin policy is a web browser thing only. Outside browsers, anyone can make requests to the API from anywhere.

Cross-origin requests can be classified as simple if the following conditions are met:

  • The HTTP method is either HEAD, GET, or POST (CORS-safe methods).
  • Request headers are all either forbidden headers or one of four CORS-safe headers: Accept, Accept-Language, Content-Language, Content-Type.
  • The value for the Content-Type header, if set, is either application/x-www-form-urlencoded, multipart/form-data, or text/plain.

If a cross-origin request doesn’t meet these conditions, the browser will trigger an initial ‘preflight’ request before the real request to check if the real request will be permitted or not.

Preflight requests are sent with the HTTP method OPTIONS. They also always have an Origin header and an Access-Control-Request-Method header.

Supporting CORS

If you just want to enable access from any origin, add this middleware:

// cmd/api/middleware.go
func (app *application) enableCORS(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// Don't use `*` if you want to use auth / credentials
		w.Header().Set("Access-Control-Allow-Origin", "*")

		next.ServeHTTP(w, r)
	})
}

And then use it early in your middleware chain:

func (app *application) routes() http.Handler {
	// ...
	return app.recoverPanic(
		app.enableCORS(
			app.rateLimit(
				app.authenticate(router)
			)
		)
	)
}

However, we want to support a set of trusted domains. These will be passed via a string argument, separated by spaces.

// cmd/api/main.go

type config struct {
	// ...
	cors struct {
		trustedOrigins []string
	}
}

func main() {
	// ...
	flag.Func("cors-trusted-origins", "Trusted CORS origins (space separated)", func(s string) error {
		cfg.cors.trustedOrigins = strings.Fields(s)
		return nil
	})
	// ...
}

And now the middleware:

// cmd/api/middleware.go
// ...
func (app *application) enableCORS(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Add("Vary", "Origin")
		w.Header().Add("Vary", "Access-Control-Request-Method")

		origin := r.Header.Get("Origin")
		if origin != "" && slices.Contains(app.config.cors.trustedOrigins, origin) {
			w.Header().Set("Access-Control-Allow-Origin", origin)

			// If the request is a preflight OPTIONS request, we need to set the
			// Access-Control-Allow-Methods and Access-Control-Allow-Headers headers.
			if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
				w.Header().Set("Access-Control-Allow-Methods", "OPTIONS, PUT, PATCH, DELETE")
				// Since we're allowing Authorization, Allow-Origin should be checked against a
				// list of trusted origins. Never use `*` in this case.
				w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
				
				// Write headers along with 200 OK status and return from the middleware with no further action
				w.WriteHeader(http.StatusOK)
				return
			}
		}

		next.ServeHTTP(w, r)
	})
}

Just make sure to never include null in the list, as it can be forged by an attacker by sending a request from a sandboxed iframe.

If your API endpoints require credentials (cookies or HTTP basic authentication), set an Access-Control-Allow-Credentials: true header in your responses. If you don’t, the web browser will prevent cross-origin responses with credentials from being read by JavaScript.
If you use the credentials header, never use the wildcard Access-Control-Allow-Origin: *, as it would allow websites to make credentialed cross-origin requests to your API.

18. Metrics

To get insights into how your application is performing and what resources it’s using, we’ll use Go’s standard library expvar package.

// cmd/api/routes.go
func (app *application) routes() http.Handler {
	// ...
	// Using /debug/vars, which is conventional for expvar, to display the metrics
	// and debug information.
	router.Handler(http.MethodGet, "/debug/vars", expvar.Handler())
	// ...
}

You’ll get a JSON object with two top-level items: cmdline and memstats.

  • cmdline contains an array of command-line arguments used to run the application.
  • memstats contains a ‘moment-in-time’ snapshot of memory usage.

You may also want custom metrics.

// cmd/api/main.go

func main() {
	// ...
	expvar.NewString("version").Set(version)
	expvar.Publish("goroutines", expvar.Func(func() any {
		return runtime.NumGoroutine()
	}))
	expvar.Publish("database", expvar.Func(func() any {
		return db.Stats()
	}))
	expvar.Publish("timestamp", expvar.Func(func() any {
		return time.Now().Unix()
	}))
	// ...
}

You absolutely need to protect the metrics endpoint. The information it provides is very useful to attackers and may even expose sensitive information.
One way to do so is by using authentication, e.g., with a metrics:view permission or HTTP Basic Authentication.
We’ll run the application behind Caddy as a reverse proxy, where we’ll restrict access to GET /debug/vars so it can only be accessed via connections from the local machine.

More Metrics

// cmd/api/middleware.go
type metricsResponseWriter struct {
	wrapped       http.ResponseWriter
	statusCode    int
	headerWritten bool
}

func newMetricsResponseWriter(w http.ResponseWriter) *metricsResponseWriter {
	return &metricsResponseWriter{wrapped: w, statusCode: http.StatusOK}
}

func (mrw *metricsResponseWriter) Header() http.Header {
	return mrw.wrapped.Header()
}

func (mrw *metricsResponseWriter) WriteHeader(statusCode int) {
	mrw.wrapped.WriteHeader(statusCode)

	if !mrw.headerWritten {
		mrw.statusCode = statusCode
		mrw.headerWritten = true
	}
}

func (mrw *metricsResponseWriter) Write(b []byte) (int, error) {
	mrw.headerWritten = true
	return mrw.wrapped.Write(b)
}

func (mrw *metricsResponseWriter) Unwrap() http.ResponseWriter {
	return mrw.wrapped
}

func (app *application) metrics(next http.Handler) http.Handler {
	var (
		totalRequestsReceived           = expvar.NewInt("total_requests_received")
		totalResponsesSent              = expvar.NewInt("total_responses_sent")
		totalProcessingTimeMicroseconds = expvar.NewInt("total_processing_time_μs")
		totalResponsesSentByStatus      = expvar.NewMap("total_responses_sent_by_status")
	)

	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		start := time.Now()

		totalRequestsReceived.Add(1)

		mrw := newMetricsResponseWriter(w)

		next.ServeHTTP(mrw, r)

		totalResponsesSent.Add(1)
		totalResponsesSentByStatus.Add(strconv.Itoa(mrw.statusCode), 1)

		duration := time.Since(start).Microseconds()
		totalProcessingTimeMicroseconds.Add(duration)
	})
}

Now, use that middleware in the routes. Put it at the beginning of the chain.

You can use the metrics endpoint in conjunction with something like Prometheus.

19. Building, Versioning and Quality Control

Makefiles and environment variables

Here’s an example Makefile. You will need a .envrc to load environment variables.
audit also requires

go install honnef.co/go/tools/cmd/staticcheck@latest
include .envrc

## help: print this help message
.PHONY: help
help:
	@echo 'Usage:'
	@sed -n 's/^##//p' ${MAKEFILE_LIST} | column -t -s ':' |  sed -e 's/^/ /'

.PHONY: confirm
confirm:
	@echo -n 'Are you sure? [y/N] ' && read ans && [ $${ans:-N} = y ]

## run/api: run the cmd/api application
.PHONY: run/api
run/api:
	go run ./cmd/api -db-dsn="${DATABASE_URL}"

## db/psql: connect to the database using psql
.PHONY: db/psql
db/psql:
	@echo 'Connecting to database...'
	psql "${DATABASE_URL}"

## db/migrations/up: apply all up database migrations
.PHONY: db/migrations/up
db/migrations/up: confirm
	@echo 'Running up migrations...'
	migrate -path ./migrations -database "${DATABASE_URL}" up

## db/migrations/new: create a new database migration
.PHONY: db/migrations/new
db/migrations/new:
	@echo 'Creating migration files for ${name}...'
	migrate create -seq -ext=.sql -dir=./migrations ${name}
	
## audit: tidy dependencies and format, vet and test all code
.PHONY: audit
audit:
	@echo 'Tidying and verifying module dependencies...'
	go mod tidy
	go mod verify
	@echo 'Formatting code...'
	go fmt ./...
	@echo 'Vetting code...'
	go vet ./...
	staticcheck ./...
	@echo 'Testing code...'
	go test -race -vet=off ./...

## build/api: build the cmd/api application
.PHONY: build/api
build/api:
	@echo 'Building cmd/api...'
	go build -ldflags='-s -w' -o=./bin/api ./cmd/api
	GOOS=linux GOARCH=amd64 go build -ldflags='-s -w' -o=./bin/linux_amd64/api ./cmd/api

This setup also means we don’t have to use os.Getenv or godotenv.

By using vendor, you’ll be including all the packages in your repository.
This is great because you won’t be depending on them being available online later.

We’re building for two targets: the current machine and AMD64 Linux (where our app will run when deployed).

Realip

We’re getting ready for deploying with Caddy by using realip:

go get github.com/tomasen/realip@latest

And then editing the rate limiter:

--- a/lets-go-further-book/cmd/api/middleware.go
+++ b/lets-go-further-book/cmd/api/middleware.go
@@ -4,7 +4,6 @@ import (
 	"errors"
 	"expvar"
 	"fmt"
-	"net"
 	"net/http"
 	"slices"
 	"strconv"
@@ -12,6 +11,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/tomasen/realip"
 	"golang.org/x/time/rate"
 	"greenlight.bagerbach.com/internal/data"
 	"greenlight.bagerbach.com/internal/validator"
@@ -72,11 +72,7 @@ func (app *application) rateLimit(next http.Handler) http.Handler {
 			return
 		}
 
-		ip, _, err := net.SplitHostPort(r.RemoteAddr)
-		if err != nil {
-			app.serverErrorResponse(w, r, err)
-			return
-		}
+		ip := realip.FromRequest(r)
 
 		mu.Lock()
 		if _, found := clients[ip]; !found {

Liked these notes? Join the newsletter.

Get notified whenever I post new notes.