Notes on
Let's Go Further!
by Alex Edwards
| 67 min read
Like the previous book, Let’s Go!, this has been a great learning experience. The author, Alex Edwards, is great at explaining and teaching Go. I particularly enjoyed this book for the more advanced topics and deep dives.
If you find these notes useful, please Go and support Alex by buying the book.
You can find the code I’ve written while reading the book here: lets-go-further-book.
2. Getting Started
Method | Usage |
---|---|
GET | Use for actions that retrieve information only and don’t change the state of your application or any data. |
POST | Use for non-idempotent actions that modify state. In the context of a REST API, POST is generally used for actions that create a new resource. |
PUT | Use for idempotent actions that modify the state of a resource at a specific URL. In the context of a REST API, PUT is generally used for actions that replace or update an existing resource. |
PATCH | Use for actions that partially update a resource at a specific URL. It’s OK for the action to be either idempotent or non-idempotent. |
DELETE | Use for actions that delete a resource at a specific URL. |
“Idempotent”:
A request method is considered “idempotent” if the intended effect on the server of multiple identical requests with that method is the same as the effect for a single such request. Of the request methods defined by this specification, PUT, DELETE, and safe request methods are idempotent.
3. Sending JSON Responses
There are various structures for JSON responses you can use. For example:
- JSON:API — A specification for building APIs in JSON
- GitHub - omniti-labs/jsend: JSend is a specification for a simple, no-frills, JSON based format for application-level communication.
In this project, we envelop the responses, basically wrapping them in an object before sending them to the user.
Enveloping is not strictly necessary, but there are some benefits to using envelopes. Enveloping a response means to wrap the response in a JSON object that contains the response data and any metadata. This is useful for returning a response to the client with a consistent structure.
- The response is more self-documenting.
- Explicit access: The client has to explicitly access the response via, for example, the
movie
key, making it clear what the response is for. - Security Mitigation: Addresses a potential security issue. More details can be found in this article on JSON security vulnerability.
For sending JSON responses, using json.Marshal
is generally better.
While very slightly less efficient, the API is much more ergonomic.
4. Parsing JSON Requests
For parsing JSON requests, using json.Decoder
is generally better.
json.Unmarshal
requires more (verbose) code and is less efficient.
5. Database Setup and Configuration
Default Postgres config may not be optimal. Can optimize:
- How to tune PostgreSQL for memory | EDB
- PGTune - calculate configuration for PostgreSQL based on the maximum performance for a given hardware configuration
How the sql.DB
connection pool works
- Has two types of connections:
- In-use: when using it to perform DB tasks (query, executing SQL, etc.)
- Idle: after a task completes, the connection becomes idle
- When you have Go do a database task, it checks for available idle connections. If there is one, it reuses that, marking it as in-use for the duration of the task.
- If no idle connections, it creates one.
- When Go refuses an idle connection, problems with the connection are handled gracefully.
- Automatically retry bad connections twice. If still unsuccessful, remove it and create a new connection.
There are 4 methods to configure the database pool:
SetMaxOpenConns()
- set max on open (in-use + idle) connections. Default is unlimited. But by default PostgreSQL has a (configured) hard limit of 100 open connections. So leaving it unlimited isn’t necessarily the best option (+ for other reasons I imagine).- A consequence of lowering max open connections is that tasks have to wait if there are no available connections. This could leave them hanging forever. To mitigate that, we use timeouts on database tasks with
context.Context
.
- A consequence of lowering max open connections is that tasks have to wait if there are no available connections. This could leave them hanging forever. To mitigate that, we use timeouts on database tasks with
SetMaxIdleConns()
- by default is 2. The trade-off here is time vs. memory. Increasing means likely better performance, as it’s less likely a new connection has to be established from scratch. But idle connections take up memory. And if idle for too long, they can become unusable (e.g. auto-closed by DBMS if unused for X hours).- Guideline: only keep a connection idle if you’re likely to use it again soon.
SetConnMaxLifetime()
SetConnMaxIdleTime()
Tips:
- Explicitly set a
MaxOpenConns
. Should be comfortably below the hard limit of connections imposed by your database & infrastructure. You can also consider setting it low, so it acts as a throttle.- 25 is reasonable for small-to-medium web applications and APIs.
- Generally, higher
MaxOpenConns
andMaxIdleConns
lead to better performance. Results are diminishing, though. - Generally, you should set a
ConnMaxIdleTime
to remove idle connections that haven’t been used for a long time. E.g. 15 minutes. - It can be OK to leave
ConnMaxLifetime
as unlimited, unless there’s a hard limit on connection lifetime, or you need to e.g. gracefully swap databases.
6. SQL Migrations
How SQL migrations work at a high level:
- For every change you want to make to your database schema, you create a pair of migration files. One is the ‘up’ migration, containing the SQL statements to implement the change, and the other is the ‘down’ migration, which contains the SQL statements to reverse the change.
- Each pair is numbered sequentially (
0001
,0002
, …) or with a Unix timestamp to indicate order. - You’d use some kind of script or tool to execute/rollback the SQL statements. This tool also keeps track of the migrations that have been applied, so only necessary SQL statements are executed.
This is obviously better than manually running the migrations by writing SQL statements yourself. The database schema is described by the migration files, which can be put under version control & lives with your other code. It’s much easier to get set up on another machine by simply running the ‘up’ migrations. And it’s also easy to roll-back.
We’re using golang-migrate/migrate
here. I’ve heard good things about goose as well.
$ migrate create -seq -ext=.sql -dir=./migrations create_movies_table
/home/christian/projects/golang/lets-go-further-book/migrations/000001_create_movies_table.up.sql
/home/christian/projects/golang/lets-go-further-book/migrations/000001_create_movies_table.down.sq
-seq
for sequential numbering (instead of Unix timestamp), -ext
to specify extension, -dir
to indicate where we want to store migration files, and create_movies_table
is the migration label.
Let’s create the ‘up’ migration:
CREATE TABLE IF NOT EXISTS movies (
id bigserial PRIMARY KEY,
created_at timestamp(0) with time zone NOT NULL DEFAULT NOW()
title text NOT NULL,
year integer NOT NULL,
runtime integer NOT NULL,
genres text[] NOT NULL
version integer NOT NULL DEFAULT 1
);
bigserial
is a 64-bit auto-incrementing integer, starting at 1.
Genres have the type text[]
, which is an array of 0 or more text values. Arrays in PostgreSQL are queryable and indexable.
Working with NULL
values in Go is awkward, so we just set NOT NULL
constraints and default values.
We store strings with text
instead of varchar
or varchar(n)
types. Text seems generally better.
The ‘down’ migration is quite simple:
DROP TABLE IF EXISTS movies;
And let’s add another migration that sets some constraints to enforce business logic:
$ migrate create -seq -ext=.sql -dir=./migrations add_movies_check_constraints
-- === UP ===
ALTER TABLE movies ADD CONSTRAINT movies_runtime_check CHECK (runtime >= 0);
ALTER TABLE movies ADD CONSTRAINT movies_year_check CHECK (year BETWEEN 1888 AND date_part('year', now()));
ALTER TABLE movies ADD CONSTRAINT genres_length_check CHECK (array_length(genres, 1) BETWEEN 1 AND 5);
-- === DOWN ===
ALTER TABLE movies DROP CONSTRAINT IF EXISTS movies_runtime_check;
ALTER TABLE movies DROP CONSTRAINT IF EXISTS movies_year_check;
ALTER TABLE movies DROP CONSTRAINT IF EXISTS genres_length_check;
Now we can’t insert or update data in the movies
table that fails these checks.
Run the migrations with:
$ migrate -path=./migrations -database="DSN_GOES_HERE" up
If you are using PostgreSQL v15, you may see error: pq: permission denied for schema public...
when running the above command. This is because v15 revokes CREATE
from all users except a database owner. You can fix it with:
-- Set owner:
ALTER DATABASE greenlight OWNER TO greenlight;
-- And if that didn't work:
GRANT CREATE ON DATABASE greenlight TO greenlight;
We now have two tables in the database. movies
and schema_migrations
.
schema_migrations
was made by migrate
to keep track of which migrations it has applied.
Here are a few useful migrate
commands:
# Show current migration version
$ migrate -path=./migrations -database=$EXAMPLE_DSN version
2
# Migrate up or down to specific verison
$ migrate -path=./migrations -database=$EXAMPLE_DSN goto 1
# Execute down migrations BY A SPECIFIC NUMBER OF MIGRATIONS
$ migrate -path=./migrations -database =$EXAMPLE_DSN down 1
Fixing errors in SQL migrations
When you run a migration that contains an error, all SQL statements up to the one with the error will be applied, and then migrate
exits with the error.
This can mean your migration file was partially applied, leaving your database in an ‘unknown state’ (to migrate
, at least).
If you try to run another migration (even down), you’ll get an error saying you have a dirty database version.
You’ll have to investigate the original error and figure out if the migration file was partially applied. If so, you need to manually roll-back the partially applied migration.
Then you need to force the version number in schema_migrations
to the correct value:
# Force version to 1
$ migrate -path=./migrations -database=$EXAMPLE_DSN force 1
Now it’ll be clean, and you can run the migrations again.
You could have migrate
run your migrations when your application starts, but it can be problematic to tightly couple the execution of migrations with your application source code in the long term.
7. CRUD Operations
Just basic CRUD stuff.
Step 1: define data layer.
Step 2: consume data layer in routes.
Step 3: add routes with proper REST semantics.
Data Layer
package data
import (
"database/sql"
"errors"
"time"
"github.com/lib/pq"
"greenlight.bagerbach.com/internal/validator"
)
type Movie struct {
ID int64 `json:"id"` // Unique identifier for the movie
CreatedAt time.Time `json:"-"` // Time when the movie was added to our db
Title string `json:"title"` // The title of the movie
Year int32 `json:"year,omitempty"` // The release year of the movie
Runtime Runtime `json:"runtime,omitempty"` // The runtime of the movie in minutes
Genres []string `json:"genres,omitempty"` // The genres of the movie
Version int32 `json:"version"` // The version of the movie: starts at 1 and increments each time the movie is updated
}
func ValidateMovie(v *validator.Validator, movie *Movie) {
v.Check(movie.Title != "", "title", "must be provided")
v.Check(len(movie.Title) <= 500, "title", "must not be more than 500 bytes long")
v.Check(movie.Year != 0, "year", "must be provided")
v.Check(movie.Year >= 1888, "year", "must be greater than 1888")
v.Check(movie.Year <= int32(time.Now().Year()), "year", "must not be in the future")
v.Check(movie.Runtime != 0, "runtime", "must be provided")
v.Check(movie.Runtime > 0, "runtime", "must be a positive integer")
v.Check(movie.Genres != nil, "genres", "must be provided")
v.Check(len(movie.Genres) >= 1, "genres", "must contain at least 1 genre")
v.Check(len(movie.Genres) <= 5, "genres", "must not contain more than 5 genres")
v.Check(validator.Unique(movie.Genres), "genres", "must not contain duplicate values")
}
type MovieModel struct {
DB *sql.DB
}
func (m MovieModel) Insert(movie *Movie) error {
query := `
INSERT INTO movies (title, year, runtime, genres)
VALUES ($1, $2, $3, $4)
RETURNING id, created_at, version`
args := []interface{}{movie.Title, movie.Year, movie.Runtime, pq.Array(movie.Genres)}
// Need to use QueryRow because of the RETURNING clause (which returns the id, created_at and version)
// RETURNING is a Postgres feature that is not part of SQL standard
return m.DB.QueryRow(query, args...).Scan(&movie.ID, &movie.CreatedAt, &movie.Version)
}
func (m MovieModel) Get(id int64) (*Movie, error) {
if id < 1 {
return nil, errors.New("invalid id")
}
query := `
SELECT id, created_at, title, year, runtime, genres, version
FROM movies
WHERE id = $1`
var movie Movie
err := m.DB.QueryRow(query, id).Scan(
&movie.ID,
&movie.CreatedAt,
&movie.Title,
&movie.Year,
&movie.Runtime,
pq.Array(&movie.Genres),
&movie.Version,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrRecordNotFound
} else {
return nil, err
}
}
return &movie, nil
}
func (m MovieModel) Update(movie *Movie) error {
query := `
UPDATE movies
SET title = $1, year = $2, runtime = $3, genres = $4, version = version + 1
WHERE id = $5
RETURNING version`
args := []interface{}{movie.Title, movie.Year, movie.Runtime, pq.Array(movie.Genres), movie.ID}
return m.DB.QueryRow(query, args...).Scan(&movie.Version)
}
func (m MovieModel) Delete(id int64) error {
if id < 1 {
return ErrRecordNotFound
}
query := `
DELETE FROM movies
WHERE id = $1`
result, err := m.DB.Exec(query, id)
if err != nil {
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return ErrRecordNotFound
}
return nil
}
Route Handlers
package main
import (
"errors"
"fmt"
"net/http"
"greenlight.bagerbach.com/internal/data"
"greenlight.bagerbach.com/internal/validator"
)
func (app *application) createMovieHandler(w http.ResponseWriter, r *http.Request) {
var input struct {
Title string `json:"title"`
Year int32 `json:"year"`
Runtime data.Runtime `json:"runtime"`
Genres []string `json:"genres"`
}
if err := app.readJSON(w, r, &input); err != nil {
app.badRequestResponse(w, r, err)
return
}
movie := &data.Movie{
Title: input.Title,
Year: input.Year,
Runtime: input.Runtime,
Genres: input.Genres,
}
v := validator.New()
if data.ValidateMovie(v, movie); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
if err := app.models.Movies.Insert(movie); err != nil {
app.serverErrorResponse(w, r, err)
return
}
headers := make(http.Header)
headers.Set("Location", fmt.Sprintf("/v1/movies/%d", movie.ID))
if err := app.writeJSON(w, http.StatusCreated, envelope{"movie": movie}, headers); err != nil {
app.serverErrorResponse(w, r, err)
}
}
func (app *application) showMovieHandler(w http.ResponseWriter, r *http.Request) {
id, err := app.readIDParam(r)
if err != nil {
app.notFoundResponse(w, r)
return
}
movie, err := app.models.Movies.Get(id)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.notFoundResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
if err := app.writeJSON(w, http.StatusOK, envelope{"movie": movie}, nil); err != nil {
app.serverErrorResponse(w, r, err)
}
}
func (app *application) updateMovieHandler(w http.ResponseWriter, r *http.Request) {
id, err := app.readIDParam(r)
if err != nil {
app.notFoundResponse(w, r)
return
}
movie, err := app.models.Movies.Get(id)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.notFoundResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
var input struct {
Title string `json:"title"`
Year int32 `json:"year"`
Runtime data.Runtime `json:"runtime"`
Genres []string `json:"genres"`
}
if err := app.readJSON(w, r, &input); err != nil {
app.badRequestResponse(w, r, err)
return
}
movie.Title = input.Title
movie.Year = input.Year
movie.Runtime = input.Runtime
movie.Genres = input.Genres
v := validator.New()
if data.ValidateMovie(v, movie); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
err = app.models.Movies.Update(movie)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
if err := app.writeJSON(w, http.StatusOK, envelope{"movie": movie}, nil); err != nil {
app.serverErrorResponse(w, r, err)
}
}
func (app *application) deleteMovieHandler(w http.ResponseWriter, r *http.Request) {
id, err := app.readIDParam(r)
if err != nil {
app.notFoundResponse(w, r)
return
}
err = app.models.Movies.Delete(id)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.notFoundResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
if err := app.writeJSON(w, http.StatusOK, envelope{"message": "movie successfully deleted"}, nil); err != nil {
app.serverErrorResponse(w, r, err)
}
}
Routes
package main
import (
"net/http"
"github.com/julienschmidt/httprouter"
)
func (app *application) routes() http.Handler {
router := httprouter.New()
router.NotFound = http.HandlerFunc(app.notFoundResponse)
router.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse)
router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler)
router.HandlerFunc(http.MethodPost, "/v1/movies", app.createMovieHandler)
router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.showMovieHandler)
router.HandlerFunc(http.MethodPut, "/v1/movies/:id", app.updateMovieHandler)
router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler)
return app.recoverPanic(router)
}
8. Advanced CRUD Operations
Some advanced patterns:
- Partial updates to a resource, so clients only need to send the data they want to change.
- Optimistic concurrency control to avoid race conditions when two clients try to update the same resource at the same time.
- Context timeouts to terminate long-running database queries and prevent unnecessary resource use.
Partial updates
We couldn’t support this in the previous version because we couldn’t detect which fields were supplied and which weren’t. This is because of the zero-values of the various types we were using. If they weren’t supplied, strings would be ""
, integers would be 0
, and so on. How would we know that it wasn’t the user who wrote that?
So now we use pointers, whose zero-values are nil
.
func (app *application) updateMovieHandler(w http.ResponseWriter, r *http.Request) {
id, err := app.readIDParam(r)
if err != nil {
app.notFoundResponse(w, r)
return
}
movie, err := app.models.Movies.Get(id)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.notFoundResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
// Pointers' zero-value is nil, so turning these into pointers lets us do partial updates
// (whereas e.g. the string zero-value is "" - you wouldn't know if it was or wasn't supplied!)
var input struct {
Title *string `json:"title"`
Year *int32 `json:"year"`
Runtime *data.Runtime `json:"runtime"`
Genres []string `json:"genres"`
}
if err := app.readJSON(w, r, &input); err != nil {
app.badRequestResponse(w, r, err)
return
}
if input.Title != nil {
movie.Title = *input.Title
}
if input.Year != nil {
movie.Year = *input.Year
}
if input.Runtime != nil {
movie.Runtime = *input.Runtime
}
if input.Genres != nil {
movie.Genres = input.Genres
}
v := validator.New()
if data.ValidateMovie(v, movie); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
err = app.models.Movies.Update(movie)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
if err := app.writeJSON(w, http.StatusOK, envelope{"movie": movie}, nil); err != nil {
app.serverErrorResponse(w, r, err)
}
}
And because we now support partial updates, we update the route method to PATCH
from PUT
:
func (app *application) routes() http.Handler {
router := httprouter.New()
router.NotFound = http.HandlerFunc(app.notFoundResponse)
router.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse)
router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler)
router.HandlerFunc(http.MethodPost, "/v1/movies", app.createMovieHandler)
router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.showMovieHandler)
- router.HandlerFunc(http.MethodPut, "/v1/movies/:id", app.updateMovieHandler)
+ router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.updateMovieHandler)
router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler)
return app.recoverPanic(router)
}
Optimistic Concurrency Control
Optimistic locking is when you read a record, note the version number, and check it hasn’t changed before writing the record back. You can also use dates, timestamps, or checksums/hashes. The update should filter on the version to ensure it is atomic and updates the version in one go.
Pessimistic locking requires you to lock the record for exclusive use until you’re finished with it. This has better integrity than optimistic locking, but you need to be careful to avoid deadlocks. (src)
We implemented optimistic locking to prevent race conditions on updates by following these steps:
- Edit Conflict Response: We added a new function,
editConflictResponse
, to handle edit conflicts by sending a 409 Conflict status code to the client. - Expected Version Header: In the
updateMovieHandler
function, we checked if the client provided an “X-Expected-Version” header. If the provided version did not match the current version of the record, we returned an edit conflict response. - Version Check in Update: We modified the SQL update query in the
Update
method of theMovieModel
to include a check on the version number. The query now only updates the record if the version matches the expected version. - Error Handling: If the update query did not affect any rows, indicating a version mismatch, we returned an
ErrEditConflict
error. - Conflict Handling: We adjusted the error handling in the
updateMovieHandler
to send an edit conflict response if anErrEditConflict
error occurred.
This approach ensures that if multiple users attempt to update the same record simultaneously, only the user with the correct version can proceed, thereby preventing race conditions.
diff --git a/lets-go-further-book/cmd/api/errors.go b/lets-go-further-book/cmd/api/errors.go
index bc5c1ae..78b7b7a 100644
--- a/lets-go-further-book/cmd/api/errors.go
+++ b/lets-go-further-book/cmd/api/errors.go
@@ -53,3 +53,7 @@ func (app *application) failedValidationResponse(w http.ResponseWriter, r *http.
app.errorResponse(w, r, http.StatusUnprocessableEntity, errors)
}
+func (app *application) editConflictResponse(w http.ResponseWriter, r *http.Request) {
+ message := "unable to update the record due to an edit conflict, please try again"
+ app.errorResponse(w, r, http.StatusConflict, message)
+}
diff --git a/lets-go-further-book/cmd/api/movies.go b/lets-go-further-book/cmd/api/movies.go
index ae6abf5..ff8f302 100644
--- a/lets-go-further-book/cmd/api/movies.go
+++ b/lets-go-further-book/cmd/api/movies.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/http"
+ "strconv"
"greenlight.bagerbach.com/internal/data"
"greenlight.bagerbach.com/internal/validator"
@@ -90,6 +91,16 @@ func (app *application) updateMovieHandler(w http.ResponseWriter, r *http.Reques
return
}
+ // If the client provided an "X-Expected-Version" header, check that the version
+ // matches the version of the record being updated. If not, return a 409 Conflict
+ // status code.
+ if r.Header.Get("X-Expected-Version") != "" {
+ if strconv.Itoa(int(movie.Version)) != r.Header.Get("X-Expected-Version") {
+ app.editConflictResponse(w, r)
+ return
+ }
+ }
+
// Pointers' zero-value is nil, so turning these into pointers lets us do partial updates
// (whereas e.g. the string zero-value is "" - you wouldn't know if it was or wasn't supplied!)
var input struct {
@@ -126,7 +137,12 @@ func (app *application) updateMovieHandler(w http.ResponseWriter, r *http.Reques
err = app.models.Movies.Update(movie)
if err != nil {
- app.serverErrorResponse(w, r, err)
+ switch {
+ case errors.Is(err, data.ErrEditConflict):
+ app.editConflictResponse(w, r)
+ default:
+ app.serverErrorResponse(w, r, err)
+ }
return
}
diff --git a/lets-go-further-book/internal/data/models.go b/lets-go-further-book/internal/data/models.go
index c93d17b..8fe842d 100644
--- a/lets-go-further-book/internal/data/models.go
+++ b/lets-go-further-book/internal/data/models.go
@@ -7,6 +7,7 @@ import (
var (
ErrRecordNotFound = errors.New("record not found")
+ ErrEditConflict = errors.New("edit conflict")
)
type Models struct {
diff --git a/lets-go-further-book/internal/data/movies.go b/lets-go-further-book/internal/data/movies.go
index 1822ccc..50e55f0 100644
--- a/lets-go-further-book/internal/data/movies.go
+++ b/lets-go-further-book/internal/data/movies.go
@@ -89,12 +89,24 @@ func (m MovieModel) Update(movie *Movie) error {
query := `
UPDATE movies
SET title = $1, year = $2, runtime = $3, genres = $4, version = version + 1
- WHERE id = $5
+ WHERE id = $5 AND version = $6
RETURNING version`
- args := []interface{}{movie.Title, movie.Year, movie.Runtime, pq.Array(movie.Genres), movie.ID}
+ args := []interface{}{movie.Title, movie.Year, movie.Runtime, pq.Array(movie.Genres), movie.ID, movie.Version}
- return m.DB.QueryRow(query, args...).Scan(&movie.Version)
+ err := m.DB.QueryRow(query, args...).Scan(&movie.Version)
+ if err != nil {
+ switch {
+ // If no rows were found, we know that the version has changed since we last read it
+ // If that's the case, we return an ErrEditConflict error
+ case errors.Is(err, sql.ErrNoRows):
+ return ErrEditConflict
+ default:
+ return err
+ }
+ }
+
+ return nil
}
func (m MovieModel) Delete(id int64) error {
If you want to avoid the version identifier being guessable, you could use a high-entropy random string (like a UUID) in the version
field.
Timeouts
Handling potential delays or hangs is important when working with database operations.
One way to do this is by using contexts for timeouts.
This ensures the app doesn’t wait indefinitely for a response from the database.
Using Context for Timeouts in Go
When working with database operations in Go, it’s crucial to handle potential delays or hangs efficiently. One way to do this is by using context for timeouts. This ensures that your application doesn’t wait indefinitely for a response from the database. Below, we illustrate how to implement context with timeouts using code snippets.
In the following example, we modify our database operations to include context with a timeout. This will help in controlling the maximum time allowed for each database operation to complete.
Here’s how we modify the Insert
method to use a context with a timeout:
package data
import (
"context"
"database/sql"
"errors"
"time"
)
func (m MovieModel) Insert(movie *Movie) error {
query := `
INSERT INTO movies (title, year, runtime, genres)
VALUES ($1, $2, $3, $4)
RETURNING id, created_at, version`
args := []interface{}{movie.Title, movie.Year, movie.Runtime, pq.Array(movie.Genres)}
// Create a context with a 3-second timeout
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
// Use QueryRowContext with the context
return m.DB.QueryRowContext(ctx, query, args...).Scan(&movie.ID, &movie.CreatedAt, &movie.Version)
}
9. Filtering, Sorting, and Pagination
Full text search
You can add full text search (support partial matches) by using a SQL query like this:
func (m MovieModel) GetAll(title string, genres []string, filters Filters) ([]*Movie, error) {
query := `
SELECT id, created_at, title, year, runtime, genres, version
FROM movies
WHERE (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '')
AND (genres && $2 OR $2 = '{}')
ORDER BY id
`
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
rows, err := m.DB.QueryContext(ctx, query, title, pq.Array(genres))
if err != nil {
return nil, err
}
// ...
}
Here, we use PostgreSQL’s full-text search feature to support partial matches on the title.
Here’s what happens:
to_tsvector('simple', title)
takes a movie title and splits it into lexemes. We use thesimple
configuration, so lexemes are just lowercase words in the title. Other configurations may apply other rules to the lexemes, like removing common words for a given language or applying language-specific stemming.plainto_tsquery('simple', $1)
takes a search value, turns it into a formatted query term, normalizes the search value (withsimple
config), strips special characters, and inserts the&
operator between words. For example,Iron Man
becomes'iron' & 'man'
.- The
@@
operator is the matches operator, which we use to check whether the generated query term matches the lexemes.'iron' & 'man'
would match rows containing bothiron
andman
.
To keep SQL queries performant as the dataset grows, we add indexes to avoid full table scans and to avoid generating lexemes for the title
field every time the query is run.
$ migrate create -seq -ext .sql -dir ./migrations add_movies_indexes
Up migration:
CREATE INDEX IF NOT EXISTS movies_title_idx ON movies using GIN (to_tsvector('simple', title));
CREATE INDEX IF NOT EXISTS movies_genres_idx ON movies using GIN (genres);
Down migration:
DROP INDEX IF EXISTS movies_title_idx;
DROP INDEX IF EXISTS movies_genres_idx;
We use Generalized Inverted Index (GIN) indexes.
GIN indexes are specialized data structures designed for efficient handling of complex data types, such as arrays, full-text search, and JSONB. They’re optimized for scenarios where each indexed record contains multiple keys and are particularly adept at search operations involving many-to-many relationships. GIN indexes use an inverted index structure where each distinct key points to a set of rows containing that key, making query lookups with operations like @>
, <@
, and %
highly efficient. This index type supports fast insertion and retrieval but can be slower on updates and deletions due to the complexity of maintaining the multi-key mappings.
If you’d prefer not to use full-text search, there’s also STRPOS()
and ILIKE
.
STRPOS()
: Checks for the existence of a substring in a database field. Not ideal, as it could lead to a full-table scan each time it’s run.ILIKE
: Finds rows matching a specific, case-insensitive pattern. You can create an index on your target field using thepg_trgm
extension and a GIN index, so it’s better for performance thanSTRPOS()
. It also lets the user control matching behavior by prefixing/suffixing with a%
wildcard character.
Sorting
If you don’t specify an order with ORDER BY
, PostgreSQL may return the movies in any order, which may or may not change each time the query is run. And if you were to order by, e.g., year, but you have multiple items with the same year, you will get an ordering by year, except the items with the same year may be ordered differently on each query.
For endpoints that provide pagination, you need to ensure the order is perfectly consistent between requests to prevent items from jumping between pages.
To do so is simple, just ensure the ORDER BY
clause always includes a primary key column, or one with a unique constraint on it.
Here’s how you can implement sorting for your endpoints.
We’re using a filters.go
that looks like this:
package data
import (
"strings"
"greenlight.bagerbach.com/internal/validator"
)
type Filters struct {
Page int
PageSize int
Sort string
SortSafelist []string
}
func (f Filters) sortColumn() string {
for _, safeValue := range f.SortSafelist {
if f.Sort == safeValue {
return strings.TrimPrefix(f.Sort, "-")
}
}
// If the sort parameter is not in the safelist, we panic
// This shouldn't happen, as you'd have checked it in the validator
// But is a fine failsafe against SQL injection attacks.
panic("unsafe sort parameter: " + f.Sort)
}
func (f Filters) sortDirection() string {
if strings.HasPrefix(f.Sort, "-") {
return "DESC"
}
return "ASC"
}
func ValidateFilters(v *validator.Validator, f Filters) {
v.Check(f.Page > 0, "page", "must be greater than zero")
v.Check(f.Page <= 10_000_000, "page", "must be less than 10 million")
v.Check(f.PageSize > 0, "page_size", "must be greater than zero")
v.Check(f.PageSize <= 100, "page_size", "must be less than 100")
v.Check(validator.PermittedValue(f.Sort, f.SortSafelist...), "sort", "invalid sort value")
}
The handler looks like this, defining a safelist:
func (app *application) listMoviesHandler(w http.ResponseWriter, r *http.Request) {
var input struct {
Title string
Genres []string
data.Filters
}
v := validator.New()
qs := r.URL.Query()
input.Title = app.readString(qs, "title", "")
input.Genres = app.readCSV(qs, "genres", []string{})
input.Filters.Page = app.readInt(qs, "page", 1, v)
input.Filters.PageSize = app.readInt(qs, "page_size", 20, v)
input.Filters.Sort = app.readString(qs, "sort", "id")
input.Filters.SortSafelist = []string{
"id", "title", "year", "runtime", "-id", "-title", "-year", "-runtime",
}
if data.ValidateFilters(v, input.Filters); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
movies, err := app.models.Movies.GetAll(input.Title, input.Genres, input.Filters)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
if err := app.writeJSON(w, http.StatusOK, envelope{"movies": movies}, nil); err != nil {
app.serverErrorResponse(w, r, err)
}
}
And we have the GetAll
function to get all movies, sorted in the given order.
func (m MovieModel) GetAll(title string, genres []string, filters Filters) ([]*Movie, error) {
// Add secondary sorting by id to ensure consistent sorting
query := fmt.Sprintf(`
SELECT id, created_at, title, year, runtime, genres, version
FROM movies
WHERE (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '')
AND (genres && $2 OR $2 = '{}')
ORDER BY %s %s, id ASC`, filters.sortColumn(), filters.sortDirection())
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
rows, err := m.DB.QueryContext(ctx, query, title, pq.Array(genres))
if err != nil {
return nil, err
}
defer rows.Close()
movies := []*Movie{}
for rows.Next() {
var movie Movie
err := rows.Scan(
&movie.ID,
&movie.CreatedAt,
&movie.Title,
&movie.Year,
&movie.Runtime,
pq.Array(&movie.Genres),
&movie.Version,
)
if err != nil {
return nil, err
}
movies = append(movies, &movie)
}
if err := rows.Err(); err != nil {
return nil, err
}
return movies, nil
}
Pagination
We’ll be using LIMIT
and OFFSET
to implement pagination:
LIMIT = page_size
OFFSET = (page - 1) * page_size
We’ll need some helper methods for this:
type Filters struct {
Page int
PageSize int
Sort string
SortSafelist []string
}
func (f Filters) limit() int {
return f.PageSize
}
// Risks integer overflow, but that should be prevented by input validation
func (f Filters) offset() int {
return (f.Page - 1) * f.PageSize
}
And now we implement these in our GetAll
function. Notice the LIMIT
and OFFSET
:
func (m MovieModel) GetAll(title string, genres []string, filters Filters) ([]*Movie, error) {
query := fmt.Sprintf(`
SELECT id, created_at, title, year, runtime, genres, version
FROM movies
WHERE (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '')
AND (genres && $2 OR $2 = '{}')
ORDER BY %s %s, id ASC
LIMIT $3 OFFSET $4`, filters.sortColumn(), filters.sortDirection())
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
args := []interface{}{title, pq.Array(genres), filters.limit(), filters.offset()}
rows, err := m.DB.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
movies := []*Movie{}
for rows.Next() {
var movie Movie
err := rows.Scan(
&movie.ID,
&movie.CreatedAt,
&movie.Title,
&movie.Year,
&movie.Runtime,
pq.Array(&movie.Genres),
&movie.Version,
)
if err != nil {
return nil, err
}
movies = append(movies, &movie)
}
if err := rows.Err(); err != nil {
return nil, err
}
return movies, nil
}
Adding pagination metadata
We’ll add some pagination metadata, so our responses will look like this:
{
"metadata": {
"current_page": 1,
"page_size": 20,
"first_page": 1,
"last_page": 42,
"total_records": 832
},
"movies": [
...
]
}
The total_records
should match the total number of records given the title
and genres
filters that are applied, not just the absolute total records in the table.
We can do this by adding a window function to the SQL query, which counts the total number of filtered rows.
Let’s start.
In internal/data/filters.go
, we add:
type Metadata struct {
CurrentPage int `json:"current_page,omitempty"`
PageSize int `json:"page_size,omitempty"`
FirstPage int `json:"first_page,omitempty"`
LastPage int `json:"last_page,omitempty"`
TotalRecords int `json:"total_records,omitempty"`
}
func calculateMetadata(totalRecords, page, pageSize int) Metadata {
if totalRecords == 0 {
return Metadata{}
}
return Metadata{
CurrentPage: page,
PageSize: pageSize,
FirstPage: 1,
LastPage: (totalRecords + pageSize - 1) / pageSize,
TotalRecords: totalRecords,
}
}
We need to ensure we return the metadata in internal/data/movies.go
:
func (m MovieModel) GetAll(title string, genres []string, filters Filters) ([]*Movie, Metadata, error) {
// Add `count(*) OVER()`
query := fmt.Sprintf(`
SELECT count(*) OVER(), id, created_at, title, year, runtime, genres, version
FROM movies
WHERE (to_tsvector('simple', title) @@ plainto_tsquery('simple', $1) OR $1 = '')
AND (genres && $2 OR $2 = '{}')
ORDER BY %s %s, id ASC
LIMIT $3 OFFSET $4`, filters.sortColumn(), filters.sortDirection())
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
args := []interface{}{title, pq.Array(genres), filters.limit(), filters.offset()}
rows, err := m.DB.QueryContext(ctx, query, args...)
if err != nil {
return nil, Metadata{}, err // return metadata as well
}
defer rows.Close()
totalRecords := 0 // define this
movies := []*Movie{}
for rows.Next() {
var movie Movie
err := rows.Scan(
&totalRecords, // scan count
&movie.ID,
&movie.CreatedAt,
&movie.Title,
&movie.Year,
&movie.Runtime,
pq.Array(&movie.Genres),
&movie.Version,
)
if err != nil {
return nil, Metadata{}, err // return metadata as well
}
movies = append(movies, &movie)
}
if err := rows.Err(); err != nil {
return nil, Metadata{}, err // return metadata as well
}
// Calculate and return metadata with the records
metadata := calculateMetadata(totalRecords, filters.Page, filters.PageSize)
return movies, metadata, nil
}
And we need to consume the return metadata in the API endpoint, in cmd/api/movies.go
:
func (app *application) listMoviesHandler(w http.ResponseWriter, r *http.Request) {
var input struct {
Title string
Genres []string
data.Filters
}
v := validator.New()
qs := r.URL.Query()
input.Title = app.readString(qs, "title", "")
input.Genres = app.readCSV(qs, "genres", []string{})
input.Filters.Page = app.readInt(qs, "page", 1, v)
input.Filters.PageSize = app.readInt(qs, "page_size", 20, v)
input.Filters.Sort = app.readString(qs, "sort", "id")
input.Filters.SortSafelist = []string{
"id", "title", "year", "runtime", "-id", "-title", "-year", "-runtime",
}
if data.ValidateFilters(v, input.Filters); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
// Get movies, metadata, err
movies, metadata, err := app.models.Movies.GetAll(input.Title, input.Genres, input.Filters)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
// Send both movies and metadata to client
if err := app.writeJSON(w, http.StatusOK, envelope{"movies": movies, "metadata": metadata}, nil); err != nil {
app.serverErrorResponse(w, r, err)
}
}
10. Rate Limiting
We’ll use the x/time/rate
package: go get golang.org/x/time/rate@latest
.
We’ll be using a token-bucket rate limiter.
This kind of limiter controls how frequently events are allowed to occur by implementing a token bucket of size b
, which starts out full and is refilled at a rate of r
tokens per second.
In our case, we’d start with a bucket with b
tokens. Each time we get an HTTP request, we remove one token from the bucket.
Every 1/r
seconds, a token is added back, up to a maximum of b
tokens.
If we get too many HTTP requests and the bucket is empty, we return an HTTP 429 Too Many Requests
response.
So it allows for a maximum ‘burst’ of b
requests, but over time it allows for an average of r
requests per second.
Start by adding an error for when the rate limit is reached
// cmd/api/errors.go
func (app *application) rateLimitExceededResponse(w http.ResponseWriter, r *http.Request) {
message := "rate limit exceeded"
app.errorResponse(w, r, http.StatusTooManyRequests, message)
}
Add configuration for the limiter to your main
entry
// cmd/api/main.go
package main
import (
"context"
"database/sql"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"time"
"greenlight.bagerbach.com/internal/data"
// Import the pq driver - it needs to register itself with the database/sql package
"github.com/joho/godotenv"
_ "github.com/lib/pq"
)
// Will be generated automatically later.
const version = "1.0.0"
type config struct {
port int
env string
db struct {
dsn string
maxOpenConns int
maxIdleConns int
maxIdleTime time.Duration
}
limiter struct {
rps float64
burst int
enabled bool
}
}
type application struct {
config config
logger *slog.Logger
models data.Models
}
func main() {
err := godotenv.Load()
if err != nil {
slog.Error("error loading .env file", "error", err)
os.Exit(1)
}
var cfg config
flag.IntVar(&cfg.port, "port", 4000, "Server port to listen on")
flag.StringVar(&cfg.env, "env", "development", "Application environment (development|staging|production)")
flag.StringVar(&cfg.db.dsn, "db-dsn", os.Getenv("DATABASE_URL"), "PostgreSQL DSN")
flag.IntVar(&cfg.db.maxOpenConns, "db-max-open-conns", 25, "Maximum number of open connections to the database")
flag.IntVar(&cfg.db.maxIdleConns, "db-max-idle-conns", 25, "Maximum number of idle connections to the database")
// DurationVar lets us pass in any value acceptable to time.ParseDuration(), e.g. 300ms, 5s, 2h45m.
flag.DurationVar(&cfg.db.maxIdleTime, "db-max-idle-time", 15*time.Minute, "Maximum idle time for a connection to the database")
flag.Float64Var(&cfg.limiter.rps, "limiter-rps", 2, "Rate limit to apply to requests per second")
flag.IntVar(&cfg.limiter.burst, "limiter-burst", 4, "Burst limit to apply to requests")
flag.BoolVar(&cfg.limiter.enabled, "limiter-enabled", true, "Enable rate limiting")
flag.Parse()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
db, err := openDB(cfg)
if err != nil {
logger.Error("error opening db", "error", err)
os.Exit(1)
}
defer db.Close()
logger.Info("database connection pool established")
app := &application{
config: cfg,
logger: logger,
models: data.NewModels(db),
}
srv := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.port),
Handler: app.routes(),
IdleTimeout: time.Minute,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
}
logger.Info("starting server", "addr", srv.Addr, "env", cfg.env)
if err := srv.ListenAndServe(); err != nil {
logger.Error("error starting server", "error", err)
os.Exit(1)
}
}
func openDB(cfg config) (*sql.DB, error) {
db, err := sql.Open("postgres", cfg.db.dsn)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(cfg.db.maxOpenConns)
db.SetMaxIdleConns(cfg.db.maxIdleConns)
db.SetConnMaxIdleTime(cfg.db.maxIdleTime)
// Create context with 5s timeout. If we can't connect in 5s, we cancel the context and return an error.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err = db.PingContext(ctx); err != nil {
db.Close()
return nil, err
}
return db, nil
}
Implement the rate-limiting middleware
This solution only works when the server is running on a single machine.
If your infrastructure is distributed, and you’re running multiple servers behind a load balancer, you’ll need something else. If that’s the case, and you’re running HAProxy or Nginx as a load balancer or reverse proxy, you can use their built-in rate-limiting features.
Alternatively, you could use a fast database like Redis to maintain a request count for clients, running on a server with which all your application servers can communicate.
This uses a map that stores the IP addresses of clients containing a limiter and a lastSeen
time for each client. It periodically cleans up this map.
We have to use a mutex here because Go maps are not concurrency safe.
// cmd/api/middleware.go
func (app *application) rateLimit(next http.Handler) http.Handler {
type client struct {
limiter *rate.Limiter
lastSeen time.Time
}
var (
// Maps are not thread-safe. We need to lock the mutex before reading from the map.
mu sync.Mutex
clients = make(map[string]*client)
)
// Background goroutine to clean up old clients.
go func() {
for {
time.Sleep(time.Minute)
mu.Lock()
for ip, client := range clients {
if time.Since(client.lastSeen) > 3*time.Minute {
delete(clients, ip)
}
}
mu.Unlock()
}
}()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !app.config.limiter.enabled {
next.ServeHTTP(w, r)
return
}
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
mu.Lock()
if _, found := clients[ip]; !found {
clients[ip] = &client{
limiter: rate.NewLimiter(rate.Limit(app.config.limiter.rps), app.config.limiter.burst),
}
}
clients[ip].lastSeen = time.Now()
// Check if the request is allowed by the rate limiter.
// `limiter.Allow()` returns `true` if the request is allowed, and `false` if the request is not allowed.
// It consumes a token from the bucket.
if !clients[ip].limiter.Allow() {
mu.Unlock()
app.rateLimitExceededResponse(w, r)
return
}
// We aren't using defer here because we want to unlock the mutex as soon as possible.
// If we deferred the unlock, it only would be executed after all the downstream handlers have returned.
mu.Unlock()
next.ServeHTTP(w, r)
})
}
Use the middleware in your routes
// cmd/api/routes.go
package main
import (
"net/http"
"github.com/julienschmidt/httprouter"
)
func (app *application) routes() http.Handler {
router := httprouter.New()
router.NotFound = http.HandlerFunc(app.notFoundResponse)
router.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse)
router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler)
router.HandlerFunc(http.MethodGet, "/v1/movies", app.listMoviesHandler)
router.HandlerFunc(http.MethodPost, "/v1/movies", app.createMovieHandler)
router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.showMovieHandler)
router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.updateMovieHandler)
router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler)
return app.recoverPanic(app.rateLimit(router))
}
11. Graceful Shutdown
Applications can be closed with certain (POSIX) signals. For example, using CTRL+C
sends an interrupt signal called SIGINT
.
Some of these signals are catchable, while others are not. Those that are catchable can be intercepted by our application and used to trigger certain actions, like graceful shutdowns. SIGKILL
is not catchable.
Go provides the tools for intercepting signals in os/signal
.
We can create a serve
function to handle setting up the server and tearing it down:
// cmd/api/server.go
package main
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
func (app *application) serve() error {
srv := &http.Server{
Addr: fmt.Sprintf(":%d", app.config.port),
Handler: app.routes(),
IdleTimeout: time.Minute,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
ErrorLog: slog.NewLogLogger(app.logger.Handler(), slog.LevelError),
}
// To receive any error returned by the graceful Shutdown() function
shutdownError := make(chan error)
go func() {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
s := <-quit
app.logger.Info("shutting down server", "signal", s.String())
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
shutdownError <- srv.Shutdown(ctx)
}()
app.logger.Info("starting server", "addr", srv.Addr, "port", app.config.port)
if err := srv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
return err
}
app.logger.Info("server closed", "addr", srv.Addr)
return nil
}
Which we should use in our main()
:
// cmd/api/main.go
// ...
func main() {
// ...
if err := app.serve(); err != nil {
logger.Error("error starting server", "error", err)
os.Exit(1)
}
}
// ...
12. User Model Setup and Registration
Create users
table migration:
migrate create -seq -ext .sql -dir ./migrations create_users_table
And write the migrations:
-- UP
CREATE EXTENSION IF NOT EXISTS citext; -- I didn't have this
CREATE TABLE IF NOT EXISTS users (
id bigserial PRIMARY KEY,
created_at timestamp(0) with time zone NOT NULL DEFAULT NOW(),
name text NOT NULL,
email citext UNIQUE NOT NULL,
password_hash bytea NOT NULL,
activated bool NOT NULL,
version integer NOT NULL DEFAULT 1
)
-- DOWN
DROP TABLE IF EXISTS users;
DROP EXTENSION IF EXISTS citext;
bytea
is a binary string. This is what we use to store a one-way hash of the user’s password.
While the email we store is in a case-insensitive format, emails are not actually guaranteed to be case-insensitive. The domain part is (RFC 2821), but the username part isn’t guaranteed to be (up to the provider — most providers have it case-insensitive).
So, you should always store emails using the exact casing provided during user registration, and you should only send emails using that exact casing.
However, since they are very likely to be the same user, we generally treat email addresses as case-insensitive for comparison purposes. For example, if a user mistypes and accidentally capitalizes one of the characters in their email, they can’t create another account with the ‘same’ email.
First, we implement the User
model:
// internal/data/users.go
package data
import (
"context"
"database/sql"
"errors"
"time"
"golang.org/x/crypto/bcrypt"
"greenlight.bagerbach.com/internal/validator"
)
var (
ErrDuplicateEmail = errors.New("duplicate email")
)
type User struct {
ID int64 `json:"id"`
CreatedAt time.Time `json:"created_at"`
Name string `json:"name"`
Email string `json:"email"`
Password password `json:"-"`
Activated bool `json:"activated"`
Version int `json:"-"`
}
type password struct {
plaintext *string
hash []byte
}
// Set sets the password hash and plaintext
func (p *password) Set(plaintextPassword string) error {
// Generates a bcrypt hash of a password using specific cost parameters (12 here)
// The higher the cost, the slower (more computationally expensive) the hash generation will be
// Need to strike a balance between security and performance
// Inputs are truncated to 72 bytes (bcrypt max) when creating the hash, so it's a good idea to
// e.g. enforce a hard max on the length of the password
hash, err := bcrypt.GenerateFromPassword([]byte(plaintextPassword), 12)
if err != nil {
return err
}
p.plaintext = &plaintextPassword
p.hash = hash
return nil
}
// Checks if the provided plaintext password matches the hashed password
func (p *password) Matches(plaintextPassword string) (bool, error) {
// Use the bcrypt package to compare the hashed password with the plaintext password
// Works by re-hashing the provided plaintext password (using same salt and cost) and comparing the result
// to the hashed password.
// Compares using subtle.ConstantTimeCompare to avoid timing attacks, which are side-channel
// attacks that allow an attacker to determine if the hashed password is correct by
// measuring the time taken to compare the hashed password.
if err := bcrypt.CompareHashAndPassword(p.hash, []byte(plaintextPassword)); err != nil {
switch {
case errors.Is(err, bcrypt.ErrMismatchedHashAndPassword):
return false, nil
default:
return false, err
}
}
return true, nil
}
func ValidateEmail(v *validator.Validator, email string) {
v.Check(email != "", "email", "must be provided")
v.Check(validator.Matches(email, validator.EmailRX), "email", "must be a valid email address")
}
func ValidatePasswordPlaintext(v *validator.Validator, password string) {
v.Check(password != "", "password", "must be provided")
v.Check(len(password) <= 72, "password", "must not be more than 72 bytes long")
v.Check(len(password) >= 8, "password", "must be at least 8 bytes long")
}
func ValidateUser(v *validator.Validator, user *User) {
v.Check(user.Name != "", "name", "must be provided")
v.Check(len(user.Name) <= 500, "name", "must not be more than 500 bytes long")
ValidateEmail(v, user.Email)
if user.Password.plaintext != nil {
ValidatePasswordPlaintext(v, *user.Password.plaintext)
}
// Will only be nil due to logic error, e.g. forgetting to set a password for the user.
// This is a sanity check, but isn't a problem with the data itself.
if user.Password.hash == nil {
panic("missing password hash for user")
}
}
type UserModel struct {
DB *sql.DB
}
func (m UserModel) Insert(user *User) error {
query := `
INSERT INTO users (name, email, password_hash, activated)
VALUES ($1, $2, $3, $4)
RETURNING id, created_at, version`
args := []interface{}{user.Name, user.Email, user.Password.hash, user.Activated}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.ID, &user.CreatedAt, &user.Version); err != nil {
switch {
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
return ErrDuplicateEmail
default:
return err
}
}
return nil
}
func (m UserModel) GetByEmail(email string) (*User, error) {
query := `
SELECT id, created_at, name, email, password_hash, activated, version
FROM users
WHERE email = $1`
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
var user User
if err := m.DB.QueryRowContext(ctx, query, email).Scan(
&user.ID,
&user.CreatedAt,
&user.Name,
&user.Email,
&user.Password.hash,
&user.Activated,
&user.Version,
); err != nil {
switch {
case err == sql.ErrNoRows:
return nil, ErrRecordNotFound
default:
return nil, err
}
}
return &user, nil
}
func (m UserModel) Update(user *User) error {
query := `
UPDATE users
SET name = $1, email = $2, password_hash = $3, activated = $4, version = version + 1
WHERE id = $5 AND version = $6
RETURNING version`
args := []interface{}{user.Name, user.Email, user.Password.hash, user.Activated, user.ID, user.Version}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := m.DB.QueryRowContext(ctx, query, args...).Scan(&user.Version); err != nil {
switch {
case err.Error() == `pq: duplicate key value violates unique constraint "users_email_key"`:
return ErrDuplicateEmail
case errors.Is(err, sql.ErrNoRows):
return ErrRecordNotFound
default:
return err
}
}
return nil
}
And we include that in our Models
struct:
// internal/data/models.go
package data
import (
"database/sql"
"errors"
)
var (
ErrRecordNotFound = errors.New("record not found")
ErrEditConflict = errors.New("edit conflict")
)
type Models struct {
Movies MovieModel
Users UserModel // Added
}
func NewModels(db *sql.DB) Models {
return Models{
Movies: MovieModel{DB: db},
Users: UserModel{DB: db}, // Added
}
}
Then we create the route handler:
package main
import (
"errors"
"net/http"
"greenlight.bagerbach.com/internal/data"
"greenlight.bagerbach.com/internal/validator"
)
func (app *application) registerUserHandler(w http.ResponseWriter, r *http.Request) {
var input struct {
Name string `json:"name"`
Email string `json:"email"`
Password string `json:"password"`
}
if err := app.readJSON(w, r, &input); err != nil {
app.badRequestResponse(w, r, err)
return
}
user := &data.User{
Name: input.Name,
Email: input.Email,
Activated: false,
}
if err := user.Password.Set(input.Password); err != nil {
app.badRequestResponse(w, r, err)
return
}
v := validator.New()
if data.ValidateUser(v, user); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
if err := app.models.Users.Insert(user); err != nil {
switch {
case errors.Is(err, data.ErrDuplicateEmail):
v.AddError("email", "a user with this email address already exists")
app.failedValidationResponse(w, r, v.Errors)
default:
app.serverErrorResponse(w, r, err)
}
return
}
if err := app.writeJSON(w, http.StatusCreated, envelope{"user": user}, nil); err != nil {
app.serverErrorResponse(w, r, err)
}
}
And we use it in routes
:
// cmd/api/routes.go
package main
import (
"net/http"
"github.com/julienschmidt/httprouter"
)
func (app *application) routes() http.Handler {
router := httprouter.New()
router.NotFound = http.HandlerFunc(app.notFoundResponse)
router.MethodNotAllowed = http.HandlerFunc(app.methodNotAllowedResponse)
router.HandlerFunc(http.MethodGet, "/v1/healthcheck", app.healthcheckHandler)
router.HandlerFunc(http.MethodGet, "/v1/movies", app.listMoviesHandler)
router.HandlerFunc(http.MethodPost, "/v1/movies", app.createMovieHandler)
router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.showMovieHandler)
router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.updateMovieHandler)
router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.deleteMovieHandler)
// Added this
router.HandlerFunc(http.MethodPost, "/v1/users", app.registerUserHandler)
return app.recoverPanic(app.rateLimit(router))
}
13. Sending Emails
We’re using Mailtrap because it has a nice testing suite for emails. We’re also using go-mail/mail
, as the standard Go SMTP package is not receiving updates.
$ go get github.com/go-mail/mail/v2@v2
Add SMTP configuration fields to the config
struct in cmd/api/main.go
:
type config struct {
// ...
smtp struct {
host string
port int
username string
password string
sender string
}
}
Update the main.go
file to include the new SMTP configuration:
package main
import (
...
"greenlight.bagerbach.com/internal/mailer"
...
"sync"
"strconv"
)
func main() {
var cfg config
...
flag.StringVar(&cfg.smtp.host, "smtp-host", os.Getenv("SMTP_HOST"), "SMTP host")
port, err := strconv.Atoi(os.Getenv("SMTP_PORT"))
if err != nil {
slog.Error("invalid SMTP_PORT", "error", err)
os.Exit(1)
}
flag.IntVar(&cfg.smtp.port, "smtp-port", port, "SMTP port")
flag.StringVar(&cfg.smtp.username, "smtp-username", os.Getenv("SMTP_USERNAME"), "SMTP username")
flag.StringVar(&cfg.smtp.password, "smtp-password", os.Getenv("SMTP_PASSWORD"), "SMTP password")
flag.StringVar(&cfg.smtp.sender, "smtp-sender", os.Getenv("SMTP_SENDER"), "SMTP sender")
...
app := &application{
config: cfg,
...
mailer: mailer.New(cfg.smtp.host, cfg.smtp.port, cfg.smtp.username, cfg.smtp.password, cfg.smtp.sender),
wg: sync.WaitGroup{}, // we'll use this to ensure graceful shutdown for background jobs
}
...
}
Create a new package for handling emails in internal/mailer/mailer.go
:
package mailer
import (
"bytes"
"embed"
"text/template"
"time"
"github.com/go-mail/mail/v2"
)
//go:embed "templates"
var templateFS embed.FS
type Mailer struct {
dialer *mail.Dialer
sender string
}
func New(host string, port int, username, password, sender string) Mailer {
dialer := mail.NewDialer(host, port, username, password)
dialer.Timeout = 5 * time.Second
return Mailer{
dialer: dialer,
sender: sender,
}
}
func (m Mailer) Send(recipient, templateFile string, data any) error {
tmpl, err := template.ParseFS(templateFS, "templates/"+templateFile)
if err != nil {
return err
}
subject := new(bytes.Buffer)
if err := tmpl.ExecuteTemplate(subject, "subject", data); err != nil {
return err
}
plainBody := new(bytes.Buffer)
if err := tmpl.ExecuteTemplate(plainBody, "plainBody", data); err != nil {
return err
}
htmlBody := new(bytes.Buffer)
if err := tmpl.ExecuteTemplate(htmlBody, "htmlBody", data); err != nil {
return err
}
msg := mail.NewMessage()
msg.SetHeader("To", recipient)
msg.SetHeader("From", m.sender)
msg.SetHeader("Subject", subject.String())
msg.SetBody("text/plain", plainBody.String())
msg.AddAlternative("text/html", htmlBody.String())
const maxRetries = 3
for i := 0; i < maxRetries; i++ {
if err := m.dialer.DialAndSend(msg); err != nil {
if i == maxRetries-1 {
return err
}
time.Sleep(2 * time.Second)
continue
}
break
}
return nil
}
Add the email template user_welcome.tmpl
in internal/mailer/templates
:
{{define "subject"}}Welcome to Greenlight!{{end}}
{{define "plainBody"}}
Hi {{.Name}},
Thanks for signing up for a Greenlight account. We're excited to have you on board!
Thanks,
The Greenlight Team
{{end}}
{{define "htmlBody"}}
<!doctype html>
<html>
<head>
<meta name="viewport" content="width=device-width" />
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
</head>
<body>
<p>Hi {{.Name}},</p>
<p>Thanks for signing up for a Greenlight account. We're excited to have you on board!</p>
<p>Thanks,</p>
<p>The Greenlight Team</p>
</body>
</html>
{{end}}
Modify the registerUserHandler
function in cmd/api/users.go
to send a welcome email:
func (app *application) registerUserHandler(w http.ResponseWriter, r *http.Request) {
...
app.background(func() {
if err := app.mailer.Send(user.Email, "user_welcome.tmpl", user); err != nil {
app.logger.Error("failed to send email", "error", err)
}
})
if err := app.writeJSON(w, http.StatusAccepted, envelope{"user": user}, nil); err != nil {
app.serverErrorResponse(w, r, err)
}
}
Add a method to handle background tasks in cmd/api/helpers.go
:
func (app *application) background(fn func()) {
app.wg.Add(1)
go func() {
defer app.wg.Done()
defer func() {
if err := recover(); err != nil {
app.logger.Error(fmt.Sprintf("%v", err))
}
}()
fn()
}()
}
Ensure that the server waits for all background tasks to complete before shutting down, in cmd/api/server.go
:
func (app *application) serve() error {
...
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err := srv.Shutdown(ctx)
if err != nil {
shutdownError <- err
}
app.logger.Info("completing background tasks", "addr", srv.Addr)
app.wg.Wait()
shutdownError <- nil
}()
...
err := <-shutdownError
if err != nil {
return err
}
app.logger.Info("server closed", "addr", srv.Addr)
return nil
}
14. User Activation
Create migration:
$ migrate create -seq -ext .sql -dir ./migrations create_tokens_table
Create up & down migrations:
-- UP
CREATE TABLE IF NOT EXISTS tokens (
hash bytea PRIMARY KEY,
user_id bigint NOT NULL REFERENCES users ON DELETE CASCADE,
expires timestamp(0) with time zone NOT NULL,
scope text NOT NULL
);
-- DOWN
DROP TABLE IF EXISTS tokens;
We use a SHA-256 hash for the activation token. A fast algorithm like SHA-256 is fine as we use a high-entropy random string for our activation token.
We reference users
, creating a foreign key constraint against the primary key of the users
table. This means the user_id
will have a corresponding id
entry in the users
table. The ON DELETE CASCADE
ensures that if a user is deleted, the corresponding tokens are deleted as well.
The tokens
table will store more than just activation tokens, which is why we also create the scope
column. We’ll use it to restrict the scope of what a token can be used for.
Creating tokens
Tokens should be ‘unguessable’ – not easy to guess, not possible to brute-force.
So we want the token to be generated by a cryptographically secure random number generator (CSPRNG) and have enough entropy (randomness) that it’s impossible to guess.
The activation tokens will be created with crypto/rand
and 128 bits (16 bytes) of entropy.
Never use math/rand
, which provides a deterministic pseudo-random number generator (PRNG), for any purpose requiring cryptographic security.
// internal/data/tokens.go
package data
import (
"context"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base32"
"time"
"greenlight.bagerbach.com/internal/validator"
)
const (
ScopeActivation = "activation"
)
type Token struct {
Plaintext string
Hash []byte
UserID int64
Expiry time.Time
Scope string
}
func generateToken(userID int64, ttl time.Duration, scope string) (*Token, error) {
token := &Token{
UserID: userID,
Expiry: time.Now().Add(ttl),
Scope: scope,
}
// 16 doesn't mean the plaintext tokens are 16 characters long, but that they have an underlying entropy of 16 bytes of randomness
// The length of the plaintext token depends on the 16 random bytes are encoded to create a string.
// Since we'll encode them to a base-32 string, it'll be 26 characters long
randomBytes := make([]byte, 16)
// Fill randomBytes with random data using the OS' CSPRNG
if _, err := rand.Read(randomBytes); err != nil {
return nil, err
}
// We encode the random bytes to a base32-encoded string
// By default, base-32 strings may be padded with '=' characters, so we use WithPadding to remove them, we don't need them
token.Plaintext = base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(randomBytes)
// Now we generate the hash of the plaintext token.
hash := sha256.Sum256([]byte(token.Plaintext))
token.Hash = hash[:] // It returned an array, so we convert it to a slice to make it easier to work with
return token, nil
}
func ValidateTokenPlaintext(v *validator.Validator, tokenPlaintext string) {
v.Check(tokenPlaintext != "", "token", "must be provided")
v.Check(len(tokenPlaintext) == 26, "token", "must be 26 characters long")
}
type TokenModel struct {
DB *sql.DB
}
func (m TokenModel) New(userID int64, ttl time.Duration, scope string) (*Token, error) {
token, err := generateToken(userID, ttl, scope)
if err != nil {
return nil, err
}
err = m.Insert(token)
return token, err
}
func (m TokenModel) Insert(token *Token) error {
query := `
INSERT INTO tokens (hash, user_id, expiry, scope)
VALUES ($1, $2, $3, $4)
`
args := []interface{}{token.Hash, token.UserID, token.Expiry, token.Scope}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := m.DB.ExecContext(ctx, query, args...)
return err
}
func (m TokenModel) DeleteAllForUser(userID int64, scope string) error {
query := `
DELETE FROM tokens
WHERE user_id = $1 AND scope = $2
`
args := []interface{}{userID, scope}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := m.DB.ExecContext(ctx, query, args...)
return err
}
And add it to models
:
// internal/data/models.go
package data
import (
"database/sql"
"errors"
)
var (
ErrRecordNotFound = errors.New("record not found")
ErrEditConflict = errors.New("edit conflict")
)
type Models struct {
Movies MovieModel
Users UserModel
Tokens TokenModel // here
}
func NewModels(db *sql.DB) Models {
return Models{
Movies: MovieModel{DB: db},
Users: UserModel{DB: db},
Tokens: TokenModel{DB: db}, // here
}
}
Send the activation token to users
We’ll put the token in the email, asking the user to send a PUT
request.
The book argues against letting users click a link to activate. That would mean we activate with a GET
request, which is more convenient, but has some drawbacks. First, it violates the HTTP principle that GET
should only be used for ‘safe’ requests which retrieve resources, not modify something.
And it’s possible that the link would get pre-fetched, inadvertently activating the account.
Generally, ensure any actions which change state are only ever executed via POST
, PUT
, PATCH
, or DELETE
. Not GET
.
If the API is a backend for a website, you could ask users to click a link, which takes them to a page on the website. Then they can click a button which sends the PUT
request to confirm their activation. Alternatively, you could ask them to paste in the activation code.
If you go with the former, the URL you send them to could be, e.g.: https://example.com/users/activate?token=Y3QMGX3PJ3WLRL2YRTQGQ6KRHU
But make sure you avoid the token being leaked in a referrer header if you go with that option. You could use either the Referrer-Policy: Origin
header or the meta
tag for it.
You also don’t want to rely on the Host
header from r.header
to construct the URL, as you’d be vulnerable to a host header injection attack. So hard-code the URL or pass it in as a CLI flag.
We start by updating the welcome email template:
{{define "subject"}}Welcome to Greenlight!{{end}}
{{define "plainBody"}}
Hi {{.name}},
Thanks for signing up for a Greenlight account. We're excited to have you on board!
Please send a request to the `PUT /v1/users/activated` endpoint with the following JSON
body to activate your account:
{"token": "{{.activationToken}}"}
Please note that this is a one-time use token, and it will expire in 3 days.
Thanks,
The Greenlight Team
{{end}}
{{define "htmlBody"}}
<!doctype html>
<html>
<head>
<meta name="viewport" content="width=device-width" />
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
</head>
<body>
<p>Hi {{.name}},</p>
<p>Thanks for signing up for a Greenlight account. We're excited to have you on board!</p>
<p>Please send a request to the <code>PUT /v1/users/activated</code> endpoint with the
following JSON body to activate your account:</p>
<pre><code>
{"token": "{{.activationToken}}"}
</code></pre>
<p>Please note that this is a one-time use token, and it will expire in 3 days.</p>
<p>Thanks,</p>
<p>The Greenlight Team</p>
</body>
</html>
{{end}}
And then we update the registerUserHandler
we made earlier:
// cmd/api/users.go
// ...
func (app *application) registerUserHandler(w http.ResponseWriter, r *http.Request) {
// ...
token, err := app.models.Tokens.New(user.ID, 3*24*time.Hour, "activation")
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
app.background(func() {
data := map[string]interface{}{
"activationToken": token.Plaintext,
"name": user.Name,
}
if err := app.mailer.Send(user.Email, "user_welcome.tmpl", data); err != nil {
app.logger.Error("failed to send email", "error", err)
}
})
if err := app.writeJSON(w, http.StatusAccepted, envelope{"user": user}, nil); err != nil {
app.serverErrorResponse(w, r, err)
}
}
A nice addition to this would be to provide a standalone endpoint for generating tokens.
For example, if a user doesn’t activate within the 3-day limit, or if they never receive their welcome email, then you’d be able to resend it.
Activating users
Since we have a one-to-many relationship between our users and tokens, it can be useful to execute queries against the relationship from both sides.
We could update the database models like so:
UserModel.GetForToken(token) // -> Get user associated with token
TokenModel.GetAllForUser(user) // -> Get all tokens associated with a user
This also aligns with the responsibility of the models. The UserModel
method returns a user, and the TokenModel
returns tokens.
We create the handler for activating users:
// cmd/api/users.go
func (app *application) activateUserHandler(w http.ResponseWriter, r *http.Request) {
var input struct {
TokenPlaintext string `json:"token"`
}
if err := app.readJSON(w, r, &input); err != nil {
app.badRequestResponse(w, r, err)
return
}
v := validator.New()
if data.ValidateTokenPlaintext(v, input.TokenPlaintext); !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
user, err := app.models.Users.GetForToken(data.ScopeActivation, input.TokenPlaintext)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
v.AddError("token", "invalid or expired activation token")
app.failedValidationResponse(w, r, v.Errors)
default:
app.serverErrorResponse(w, r, err)
}
return
}
user.Activated = true
if err := app.models.Users.Update(user); err != nil {
switch {
case errors.Is(err, data.ErrEditConflict):
app.editConflictResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
if err := app.models.Tokens.DeleteAllForUser(user.ID, data.ScopeActivation); err != nil {
app.serverErrorResponse(w, r, err)
return
}
if err := app.writeJSON(w, http.StatusOK, envelope{"user": user}, nil); err != nil {
app.serverErrorResponse(w, r, err)
}
}
Which needs the UserModel.GetForToken
method:
// internal/data/users.go
// ...
func (m UserModel) GetForToken(tokenScope, tokenPlaintext string) (*User, error) {
tokenHash := sha256.Sum256([]byte(tokenPlaintext))
query := `
SELECT users.id, users.created_at, users.name, users.email, users.password_hash, users.activated, users.version
FROM users
INNER JOIN tokens ON users.id = tokens.user_id
WHERE tokens.hash = $1
AND tokens.scope = $2
AND tokens.expiry > $3`
// use [:] to turn the byte array into a slice
args := []interface{}{tokenHash[:], tokenScope, time.Now()}
var user User
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := m.DB.QueryRowContext(ctx, query, args...).Scan(
&user.ID,
&user.CreatedAt,
&user.Name,
&user.Email,
&user.Password.hash,
&user.Activated,
&user.Version,
); err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrRecordNotFound
default:
return nil, err
}
}
return &user, nil
}
We also add the route:
// cmd/api/routes.go
func (app *application) routes() http.Handler {
// ...
router.HandlerFunc(http.MethodPut, "/v1/users/activated", app.activateUserHandler)
// ...
We use PUT
and not POST
because the endpoint is idempotent. This means if the same request is sent multiple times, the first will succeed—if the token is valid—and any subsequent requests will result in an error for the user, as the token is consumed and deleted. However, nothing in the application state changes after the first request.
15. Authentication
There are many approaches to authenticating requests to APIs.
Here are some common ones:
- Basic authentication
- Stateful token authentication
- Stateless token authentication
- API key authentication
- OAuth 2.0 / OpenID Connect
HTTP Basic Authentication
The client includes an Authorization
header with every request containing their credentials.
They’ll be in the format username:password
and are base-64 encoded.
Like:
Authorization: Basic YWxpY2VAZXhhbXBsZS5jb206cGE1NXdvcmQ=
You can extract these with Go’s Request.BasicAuth()
method.
Comparing the password with a hashed password is a slow operation (deliberately), and when you use basic auth, you’ll have to do it on each request.
Token Authentication
Also known as bearer token authentication.
- Client sends a request to the API with credentials (usually username/email and password).
- API verifies credentials, generates a bearer token representing the user, and sends it back. The token expires after some time, after which users will have to resubmit their credentials to get a new token.
- For subsequent requests, the client includes the token in an
Authorization
header asAuthorization: Bearer <token>
. - When the API receives that, it will check that the token hasn’t expired and examine it to determine who the user is.
This avoids the issue of having to do the expensive checking of password vs. hashed password on every request. However, managing tokens can be complicated for clients (they need to handle caching tokens, monitoring and managing expiry, and periodically generating new ones).
There are two subtypes: stateful and stateless.
Stateful token authentication
The token value is a high-entropy, cryptographically secure random string.
It is usually stored in a server-side database with the user ID and expiry time.
The client sends the token in subsequent requests, and the server looks up the token, checks its expiry, and retrieves the user ID.
The API maintains control over the tokens. It’s easy to remove them by deleting or marking them as expired. This approach is also simple and robust, as we get security from an unguessable high-entropy token.
There aren’t many downsides besides those inherent with token authentication.
It requires a database lookup, though that’s necessary for user status checks anyway.
Stateless token authentication
Encodes the user ID and expiry time in the token itself. The token is cryptographically signed to prevent tampering and sometimes encrypted to prevent reading the contents.
Encoding the information in a JSON Web Token (JWT) is common. However, PASETO, Branca, and nacl/secretbox are viable alternatives.
The work required to encode/decode the token can be done in memory, and all necessary information to identify the user is in the token itself – no database lookups are needed to figure out who made the request.
But the tokens aren’t easy to revoke once they’re issued.
You could revoke all tokens by changing the secret used for signing the tokens, which forces all users to re-authenticate. Or maintain a blocklist of revoked tokens, but that defeats the stateless aspect.
Avoid storing additional information in these tokens, like activation status or permissions. Save that for authorization checks.
JWTs are highly configurable, meaning lots can go wrong.
Because of these downsides, stateless tokens (and especially JWTs) aren’t the best choice for managing authentication in most API applications. But they can be useful for delegated authentication, which is where the application creating the token is different from the one consuming it, and those apps don’t share state (so stateful isn’t an option). This is common in a microservice-style architecture.
API-Key Authentication
The user gets a non-expiring secret ‘key’ associated with their account.
This key should be a high-entropy, cryptographically secure random string. You store a fast hash of the key (SHA256 or SHA512) alongside the corresponding user ID in your database.
The user passes their key with each request to the API in the authorization header:
Authorization: Key <key>
When received, your API can regenerate the fast hash of the key and use it to look up the corresponding user ID.
This approach is similar to the stateful token approach, except keys are permanent, not temporary tokens.
This is also nice for the client, as they don’t have to manage tokens or expiry.
But now they have both their password and their API key to manage.
You need to build ways for users to regenerate their API keys if they are lost or compromised. Users may want multiple keys for the same account.
They should only be communicated to users over a secure channel – you should treat them with the same level of care as a password.
OAuth 2.0 / OpenID Connect
With this approach, information about your users is stored by a third-party identity provider, like Google, rather than by yourself.
OAuth 2.0 isn’t really an authentication protocol, and you shouldn’t use it for authenticating users.
If you want authentication checks against a third-party provider, you should use OpenID Connect.
Generating Authentication Tokens
Create an error for invalid credentials:
// cmd/api/errors.go
// ...
func (app *application) invalidCredentialsResponse(w http.ResponseWriter, r *http.Request) {
message := "invalid authentication credentials"
app.errorResponse(w, r, http.StatusUnauthorized, message)
}
Update Token
struct to include JSON struct tags & add new constant for authentication scope:
// internal/data/tokens.go
const (
ScopeActivation = "activation"
ScopeAuthentication = "authentication"
)
type Token struct {
Plaintext string `json:"token"`
Hash []byte `json:"-"`
UserID int64 `json:"-"`
Expiry time.Time `json:"expiry"`
Scope string `json:"-"`
}
Create a new route handler for the token creation endpoint:
// cmd/api/tokens.go
package main
import (
"errors"
"net/http"
"time"
"greenlight.bagerbach.com/internal/data"
"greenlight.bagerbach.com/internal/validator"
)
func (app *application) createAuthenticationTokenHandler(w http.ResponseWriter, r *http.Request) {
var input struct {
Email string `json:"email"`
Password string `json:"password"`
}
err := app.readJSON(w, r, &input)
if err != nil {
app.badRequestResponse(w, r, err)
return
}
v := validator.New()
data.ValidateEmail(v, input.Email)
data.ValidatePasswordPlaintext(v, input.Password)
if !v.Valid() {
app.failedValidationResponse(w, r, v.Errors)
return
}
user, err := app.models.Users.GetByEmail(input.Email)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.invalidCredentialsResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
match, err := user.Password.Matches(input.Password)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
if !match {
app.invalidCredentialsResponse(w, r)
return
}
token, err := app.models.Tokens.New(user.ID, 24*time.Hour, data.ScopeAuthentication)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
err = app.writeJSON(w, http.StatusCreated, envelope{"authentication_token": token}, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}
And then include it in the routes:
func (app *application) routes() http.Handler {
// ...
router.HandlerFunc(http.MethodPost, "/v1/tokens/authentication", app.createAuthenticationTokenHandler)
return app.recoverPanic(app.rateLimit(router))
}
Authenticating Requests
We’ll expect clients to send their authentication tokens, once they have them, in all requests in an Authorization
header (as a bearer token).
We’ll use an authenticate()
middleware to handle validation and adding user info to the request context. If the token is valid, we’ll populate the request context with user info. If not, we’ll send a 401 Unauthorized
. And if no Authorization
header is given, we add details for an anonymous user.
Add context handling for users:
// cmd/api/context.go
package main
import (
"context"
"net/http"
"greenlight.bagerbach.com/internal/data"
)
type contextKey string
const userContextKey = contextKey("user")
func (app *application) contextSetUser(r *http.Request, user *data.User) *http.Request {
ctx := context.WithValue(r.Context(), userContextKey, user)
return r.WithContext(ctx)
}
func (app *application) contextGetUser(r *http.Request) *data.User {
user, ok := r.Context().Value(userContextKey).(*data.User)
if !ok {
panic("missing user value in request context")
}
return user
}
Add error response for invalid authentication tokens:
// cmd/api/errors.go
func (app *application) invalidAuthenticationTokenResponse(w http.ResponseWriter, r *http.Request) {
// Remind client that we expect them to authenticate using a bearer token.
w.Header().Set("WWW-Authenticate", "Bearer")
message := "invalid or missing authentication token"
app.errorResponse(w, r, http.StatusUnauthorized, message)
}
Implement middleware to handle authentication:
// cmd/api/middleware.go
func (app *application) authenticate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// This header indicates that the response may vary based on the value of the Authorization header in the request.
w.Header().Add("Vary", "Authorization")
authorizationHeader := r.Header.Get("Authorization")
if authorizationHeader == "" {
app.contextSetUser(r, data.AnonymousUser)
next.ServeHTTP(w, r)
return
}
// We expect the value in the Authorization header to be in the format of "Bearer <token>".
headerParts := strings.Split(authorizationHeader, " ")
if len(headerParts) != 2 || headerParts[0] != "Bearer" {
app.invalidAuthenticationTokenResponse(w, r)
return
}
token := headerParts[1]
// Validate token
v := validator.New()
if data.ValidateTokenPlaintext(v, token); !v.Valid() {
app.invalidAuthenticationTokenResponse(w, r)
return
}
// Get user associated with authentication token
user, err := app.models.Users.GetForToken(data.ScopeAuthentication, token)
if err != nil {
switch {
case errors.Is(err, data.ErrRecordNotFound):
app.invalidAuthenticationTokenResponse(w, r)
default:
app.serverErrorResponse(w, r, err)
}
return
}
r = app.contextSetUser(r, user)
next.ServeHTTP(w, r)
})
}
And we use it in routes.go
:
// cmd/api/routes.go
func (app *application) routes() http.Handler {
// ...
return app.recoverPanic(app.rateLimit(app.authenticate(router)))
}
We also add an anonymous user and a method to check if a user is anonymous:
// internal/data/users.go
var AnonymousUser = &User{}
type User struct {
// ...
}
func (u *User) IsAnonymous() bool {
return u == AnonymousUser
}
16. Permission-based Authorization
Ensure only activated users can access some endpoints and implement a permission-based authorization pattern for fine-grained control over who can access which endpoints.
Requiring User Activation
We’ll add some errors to handle the new cases:
// cmd/api/errors.go
func (app *application) authenticationRequiredResponse(w http.ResponseWriter, r *http.Request) {
message := "you must be authenticated to access this resource"
app.errorResponse(w, r, http.StatusUnauthorized, message)
}
func (app *application) inactiveAccountResponse(w http.ResponseWriter, r *http.Request) {
message := "your account must be activated to access this resource"
app.errorResponse(w, r, http.StatusForbidden, message)
}
Implement middleware for requiring authenticated & activated user accounts:
// cmd/api/middleware.go
func (app *application) requireAuthenticatedUser(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := app.contextGetUser(r)
if user.IsAnonymous() {
app.authenticationRequiredResponse(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (app *application) requireActivatedUser(next http.HandlerFunc) http.HandlerFunc {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := app.contextGetUser(r)
if !user.Activated {
app.inactiveAccountResponse(w, r)
return
}
next.ServeHTTP(w, r)
})
return app.requireAuthenticatedUser(fn)
}
Update routes to use middleware:
// cmd/api/routes.go
func (app *application) routes() http.Handler {
// ...
router.HandlerFunc(http.MethodGet, "/v1/movies", app.requireActivatedUser(app.listMoviesHandler))
router.HandlerFunc(http.MethodPost, "/v1/movies", app.requireActivatedUser(app.createMovieHandler))
router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.requireActivatedUser(app.showMovieHandler))
router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.requireActivatedUser(app.updateMovieHandler))
router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.requireActivatedUser(app.deleteMovieHandler))
// ...
return app.recoverPanic(app.rateLimit(app.authenticate(router)))
}
Permissions Database Table
We’ll use this to restrict routes based on permissions like movies:read
and movies:write
.
The relationship between users and permissions will be many-to-many. A user can have many permissions, and the same permission can belong to many users.
This kind of relationship can be managed in a relational database by creating a joining table between two entities.
So we may store all users in one table (id, email, etc.) and all permissions in another table (id, code). Then our joined table will be user_permissions
, storing information about which users have which permissions, by their respective ids (user_id
, permission_id
).
We start by creating the migration:
migrate create -seq -ext .sql -dir ./migrations add_permissions
Up migration
CREATE TABLE IF NOT EXISTS permissions (
id bigserial PRIMARY KEY,
code text NOT NULL
);
CREATE TABLE IF NOT EXISTS users_permissions (
user_id bigint NOT NULL REFERENCES users ON DELETE CASCADE,
permission_id bigint NOT NULL REFERENCES permissions ON DELETE CASCADE,
-- This is a composite primary key
-- It means that the combination of user_id and permission_id must be unique
PRIMARY KEY (user_id, permission_id)
);
INSERT INTO permissions (code) VALUES ('movies:read'), ('movies:write');
Down migration
DROP TABLE IF EXISTS users_permissions;
DROP TABLE IF EXISTS permissions;
And now we’ll create the permissions model.
// internal/data/permissions.go
package data
import (
"context"
"database/sql"
"time"
)
type Permissions []string
func (p Permissions) Include(code string) bool {
for _, permission := range p {
if permission == code {
return true
}
}
return false
}
type PermissionModel struct {
DB *sql.DB
}
func (m PermissionModel) GetAllForUser(userID int64) (Permissions, error) {
query := `
SELECT permissions.code FROM permissions
INNER JOIN users_permissions ON users_permissions.permission_id = permissions.id
INNER JOIN users ON users.id = users_permissions.user_id
WHERE users.id = $1
`
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
rows, err := m.DB.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var permissions Permissions
for rows.Next() {
var permission string
err := rows.Scan(&permission)
if err != nil {
return nil, err
}
permissions = append(permissions, permission)
}
if err := rows.Err(); err != nil {
return nil, err
}
return permissions, nil
}
And include it in the parent Model struct:
// internal/data/models.go
type Models struct {
Movies MovieModel
Users UserModel
Tokens TokenModel
Permissions PermissionModel
}
func NewModels(db *sql.DB) Models {
return Models{
Movies: MovieModel{DB: db},
Users: UserModel{DB: db},
Tokens: TokenModel{DB: db},
Permissions: PermissionModel{DB: db},
}
}
Checking Permissions
Add an error for insufficient permissions:
// cmd/api/errors.go
func (app *application) nonPermittedResponse(w http.ResponseWriter, r *http.Request) {
message := "your account does not have the necessary permissions to access this resource"
app.errorResponse(w, r, http.StatusForbidden, message)
}
Add middleware to require a permission:
// cmd/api/middleware.go
func (app *application) requirePermission(code string, next http.HandlerFunc) http.HandlerFunc {
fn := func(w http.ResponseWriter, r *http.Request) {
user := app.contextGetUser(r)
permissions, err := app.models.Permissions.GetAllForUser(user.ID)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}
if !permissions.Include(code) {
app.nonPermittedResponse(w, r)
return
}
next.ServeHTTP(w, r)
}
return app.requireActivatedUser(fn)
}
Use the middleware:
// cmd/api/routes.go
func (app *application) routes() http.Handler {
// ...
router.HandlerFunc(http.MethodGet, "/v1/movies", app.requirePermission("movies:read", app.listMoviesHandler))
router.HandlerFunc(http.MethodPost, "/v1/movies", app.requirePermission("movies:write", app.createMovieHandler))
router.HandlerFunc(http.MethodGet, "/v1/movies/:id", app.requirePermission("movies:read", app.showMovieHandler))
router.HandlerFunc(http.MethodPatch, "/v1/movies/:id", app.requirePermission("movies:write", app.updateMovieHandler))
router.HandlerFunc(http.MethodDelete, "/v1/movies/:id", app.requirePermission("movies:write", app.deleteMovieHandler))
// ...
return app.recoverPanic(app.rateLimit(app.authenticate(router)))
}
Granting Permissions
We’ll need a method for adding permissions to a user.
// internal/data/permissions.go
func (m PermissionModel) AddForUser(userID int64, codes ...string) error {
query := `
INSERT INTO users_permissions
SELECT $1, permissions.id FROM permissions WHERE permissions.code = ANY($2)
`
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := m.DB.ExecContext(ctx, query, userID, pq.Array(codes))
return err
}
You’ll need to import "github.com/lib/pq"
for pq.Array
.
And we’ll add movies:read
as a default permission for new users:
// cmd/api/users.go
func (app *application) registerUserHandler(w http.ResponseWriter, r *http.Request) {
// ...
if err := app.models.Users.Insert(user); err != nil {
switch {
case errors.Is(err, data.ErrDuplicateEmail):
v.AddError("email", "a user with this email address already exists")
app.failedValidationResponse(w, r, v.Errors)
default:
app.serverErrorResponse(w, r, err)
}
return
}
// Add this:
if err := app.models.Permissions.AddForUser(user.ID, "movies:read"); err != nil {
app.serverErrorResponse(w, r, err)
return
}
// ...
}
17. Cross Origin Requests
First, what is meant by origin?
If two URLs have the same scheme, host, and port (if specified), they are said to share the same origin. For example, https://foo.com/a
and http://foo.com/a
do not have the same origin, as they have different schemes. Likewise, http://foo.com/a
and http://www.foo.com/a
have different hosts.
This is important because browsers implement a security mechanism called the same-origin policy.
In practice, this means a webpage on one origin can embed certain resources from another in its HTML (like images, CSS, and JavaScript files). A webpage on one origin can send data to a different origin – like an HTML form sending to another origin. But a webpage on one origin is not allowed to receive data from a different origin.
This prevents potentially malicious websites on another origin from reading potentially confidential information on your website.
Sending data to other origins is not prevented by the policy. However, it is still dangerous, which is why CSRF attacks are possible.
Say you have a webpage at https://foo.com
. If the JavaScript on that site tries to make an HTTP request to https://bar.com/data.json
(a different origin), then the request is sent and processed by the bar.com
server, but the user’s browser will block the response so the JavaScript code from https://foo.com
cannot see it.
The same-origin policy is a useful safeguard. However, sometimes you may want to relax it.
Say you have an API at api.example.com
, and a trusted JavaScript front-end running on www.example.com
. Then you might want to allow cross-origin requests from the trusted www.example.com
domain to your API.
Or you have a completely open public API, and you want to allow cross-origin requests from anywhere.
Modern web browsers let you allow or disallow specific cross-origin requests to your API by setting Access-Control
headers on your API responses.
Notably, the same-origin policy is a web browser thing only. Outside browsers, anyone can make requests to the API from anywhere.
Cross-origin requests can be classified as simple if the following conditions are met:
- The HTTP method is either
HEAD
,GET
, orPOST
(CORS-safe methods). - Request headers are all either forbidden headers or one of four CORS-safe headers:
Accept
,Accept-Language
,Content-Language
,Content-Type
. - The value for the
Content-Type
header, if set, is eitherapplication/x-www-form-urlencoded
,multipart/form-data
, ortext/plain
.
If a cross-origin request doesn’t meet these conditions, the browser will trigger an initial ‘preflight’ request before the real request to check if the real request will be permitted or not.
Preflight requests are sent with the HTTP method OPTIONS
. They also always have an Origin
header and an Access-Control-Request-Method
header.
Supporting CORS
If you just want to enable access from any origin, add this middleware:
// cmd/api/middleware.go
func (app *application) enableCORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Don't use `*` if you want to use auth / credentials
w.Header().Set("Access-Control-Allow-Origin", "*")
next.ServeHTTP(w, r)
})
}
And then use it early in your middleware chain:
func (app *application) routes() http.Handler {
// ...
return app.recoverPanic(
app.enableCORS(
app.rateLimit(
app.authenticate(router)
)
)
)
}
However, we want to support a set of trusted domains. These will be passed via a string argument, separated by spaces.
// cmd/api/main.go
type config struct {
// ...
cors struct {
trustedOrigins []string
}
}
func main() {
// ...
flag.Func("cors-trusted-origins", "Trusted CORS origins (space separated)", func(s string) error {
cfg.cors.trustedOrigins = strings.Fields(s)
return nil
})
// ...
}
And now the middleware:
// cmd/api/middleware.go
// ...
func (app *application) enableCORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Vary", "Origin")
w.Header().Add("Vary", "Access-Control-Request-Method")
origin := r.Header.Get("Origin")
if origin != "" && slices.Contains(app.config.cors.trustedOrigins, origin) {
w.Header().Set("Access-Control-Allow-Origin", origin)
// If the request is a preflight OPTIONS request, we need to set the
// Access-Control-Allow-Methods and Access-Control-Allow-Headers headers.
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
w.Header().Set("Access-Control-Allow-Methods", "OPTIONS, PUT, PATCH, DELETE")
// Since we're allowing Authorization, Allow-Origin should be checked against a
// list of trusted origins. Never use `*` in this case.
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
// Write headers along with 200 OK status and return from the middleware with no further action
w.WriteHeader(http.StatusOK)
return
}
}
next.ServeHTTP(w, r)
})
}
Just make sure to never include null
in the list, as it can be forged by an attacker by sending a request from a sandboxed iframe.
If your API endpoints require credentials (cookies or HTTP basic authentication), set an Access-Control-Allow-Credentials: true
header in your responses. If you don’t, the web browser will prevent cross-origin responses with credentials from being read by JavaScript.
If you use the credentials header, never use the wildcard Access-Control-Allow-Origin: *
, as it would allow websites to make credentialed cross-origin requests to your API.
18. Metrics
To get insights into how your application is performing and what resources it’s using, we’ll use Go’s standard library expvar
package.
// cmd/api/routes.go
func (app *application) routes() http.Handler {
// ...
// Using /debug/vars, which is conventional for expvar, to display the metrics
// and debug information.
router.Handler(http.MethodGet, "/debug/vars", expvar.Handler())
// ...
}
You’ll get a JSON object with two top-level items: cmdline
and memstats
.
cmdline
contains an array of command-line arguments used to run the application.memstats
contains a ‘moment-in-time’ snapshot of memory usage.
You may also want custom metrics.
// cmd/api/main.go
func main() {
// ...
expvar.NewString("version").Set(version)
expvar.Publish("goroutines", expvar.Func(func() any {
return runtime.NumGoroutine()
}))
expvar.Publish("database", expvar.Func(func() any {
return db.Stats()
}))
expvar.Publish("timestamp", expvar.Func(func() any {
return time.Now().Unix()
}))
// ...
}
You absolutely need to protect the metrics endpoint. The information it provides is very useful to attackers and may even expose sensitive information.
One way to do so is by using authentication, e.g., with a metrics:view
permission or HTTP Basic Authentication.
We’ll run the application behind Caddy as a reverse proxy, where we’ll restrict access to GET /debug/vars
so it can only be accessed via connections from the local machine.
More Metrics
// cmd/api/middleware.go
type metricsResponseWriter struct {
wrapped http.ResponseWriter
statusCode int
headerWritten bool
}
func newMetricsResponseWriter(w http.ResponseWriter) *metricsResponseWriter {
return &metricsResponseWriter{wrapped: w, statusCode: http.StatusOK}
}
func (mrw *metricsResponseWriter) Header() http.Header {
return mrw.wrapped.Header()
}
func (mrw *metricsResponseWriter) WriteHeader(statusCode int) {
mrw.wrapped.WriteHeader(statusCode)
if !mrw.headerWritten {
mrw.statusCode = statusCode
mrw.headerWritten = true
}
}
func (mrw *metricsResponseWriter) Write(b []byte) (int, error) {
mrw.headerWritten = true
return mrw.wrapped.Write(b)
}
func (mrw *metricsResponseWriter) Unwrap() http.ResponseWriter {
return mrw.wrapped
}
func (app *application) metrics(next http.Handler) http.Handler {
var (
totalRequestsReceived = expvar.NewInt("total_requests_received")
totalResponsesSent = expvar.NewInt("total_responses_sent")
totalProcessingTimeMicroseconds = expvar.NewInt("total_processing_time_μs")
totalResponsesSentByStatus = expvar.NewMap("total_responses_sent_by_status")
)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
totalRequestsReceived.Add(1)
mrw := newMetricsResponseWriter(w)
next.ServeHTTP(mrw, r)
totalResponsesSent.Add(1)
totalResponsesSentByStatus.Add(strconv.Itoa(mrw.statusCode), 1)
duration := time.Since(start).Microseconds()
totalProcessingTimeMicroseconds.Add(duration)
})
}
Now, use that middleware in the routes. Put it at the beginning of the chain.
You can use the metrics endpoint in conjunction with something like Prometheus.
19. Building, Versioning and Quality Control
Makefiles and environment variables
Here’s an example Makefile. You will need a .envrc
to load environment variables.
audit
also requires
go install honnef.co/go/tools/cmd/staticcheck@latest
include .envrc
## help: print this help message
.PHONY: help
help:
@echo 'Usage:'
@sed -n 's/^##//p' ${MAKEFILE_LIST} | column -t -s ':' | sed -e 's/^/ /'
.PHONY: confirm
confirm:
@echo -n 'Are you sure? [y/N] ' && read ans && [ $${ans:-N} = y ]
## run/api: run the cmd/api application
.PHONY: run/api
run/api:
go run ./cmd/api -db-dsn="${DATABASE_URL}"
## db/psql: connect to the database using psql
.PHONY: db/psql
db/psql:
@echo 'Connecting to database...'
psql "${DATABASE_URL}"
## db/migrations/up: apply all up database migrations
.PHONY: db/migrations/up
db/migrations/up: confirm
@echo 'Running up migrations...'
migrate -path ./migrations -database "${DATABASE_URL}" up
## db/migrations/new: create a new database migration
.PHONY: db/migrations/new
db/migrations/new:
@echo 'Creating migration files for ${name}...'
migrate create -seq -ext=.sql -dir=./migrations ${name}
## audit: tidy dependencies and format, vet and test all code
.PHONY: audit
audit:
@echo 'Tidying and verifying module dependencies...'
go mod tidy
go mod verify
@echo 'Formatting code...'
go fmt ./...
@echo 'Vetting code...'
go vet ./...
staticcheck ./...
@echo 'Testing code...'
go test -race -vet=off ./...
## build/api: build the cmd/api application
.PHONY: build/api
build/api:
@echo 'Building cmd/api...'
go build -ldflags='-s -w' -o=./bin/api ./cmd/api
GOOS=linux GOARCH=amd64 go build -ldflags='-s -w' -o=./bin/linux_amd64/api ./cmd/api
This setup also means we don’t have to use os.Getenv
or godotenv
.
By using vendor
, you’ll be including all the packages in your repository.
This is great because you won’t be depending on them being available online later.
We’re building for two targets: the current machine and AMD64 Linux (where our app will run when deployed).
Realip
We’re getting ready for deploying with Caddy by using realip
:
go get github.com/tomasen/realip@latest
And then editing the rate limiter:
--- a/lets-go-further-book/cmd/api/middleware.go
+++ b/lets-go-further-book/cmd/api/middleware.go
@@ -4,7 +4,6 @@ import (
"errors"
"expvar"
"fmt"
- "net"
"net/http"
"slices"
"strconv"
@@ -12,6 +11,7 @@ import (
"sync"
"time"
+ "github.com/tomasen/realip"
"golang.org/x/time/rate"
"greenlight.bagerbach.com/internal/data"
"greenlight.bagerbach.com/internal/validator"
@@ -72,11 +72,7 @@ func (app *application) rateLimit(next http.Handler) http.Handler {
return
}
- ip, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
- app.serverErrorResponse(w, r, err)
- return
- }
+ ip := realip.FromRequest(r)
mu.Lock()
if _, found := clients[ip]; !found {
Liked these notes? Join the newsletter.
Get notified whenever I post new notes.