Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/database/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,11 @@ type Config struct {

MaxOpenNum int `json:"maxOpenNum"`
MaxIdleNum int `json:"maxIdleNum"`

// UseIAMAuth authenticates to AWS RDS/Aurora with short-lived IAM tokens
// instead of a DSN password. The DSN should omit the password and enable TLS.
UseIAMAuth bool `json:"useIAMAuth"`
// AWSRegion signs the IAM tokens. Optional; falls back to the default AWS
// config chain (e.g. AWS_REGION) when empty.
AWSRegion string `json:"awsRegion"`
}
23 changes: 22 additions & 1 deletion common/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ func InitDB(config *Config) (*gorm.DB, error) {
gethLogger: log.Root(),
}

db, err := gorm.Open(postgres.Open(config.DSN), &gorm.Config{
dialector, err := newDialector(config)
if err != nil {
return nil, err
}

db, err := gorm.Open(dialector, &gorm.Config{
CreateBatchSize: 1000,
Logger: &tmpGormLogger,
NowFunc: func() time.Time {
Expand Down Expand Up @@ -80,6 +85,22 @@ func InitDB(config *Config) (*gorm.DB, error) {
return db, nil
}

// newDialector builds the gorm postgres dialector. With IAM auth it wraps a
// token-refreshing connector in a *sql.DB; otherwise it opens the DSN directly.
func newDialector(config *Config) (gorm.Dialector, error) {
if !config.UseIAMAuth {
log.Info("connecting to database with password auth")
return postgres.Open(config.DSN), nil
}

log.Info("connecting to database with AWS RDS IAM auth", "region", config.AWSRegion)
connector, err := NewRDSIAMConnector(context.Background(), config.DSN, config.AWSRegion)
if err != nil {
return nil, err
}
return postgres.New(postgres.Config{Conn: sql.OpenDB(connector)}), nil
}

// CloseDB close the db handler. notice the db handler only can close when then program exit.
func CloseDB(db *gorm.DB) error {
sqlDB, err := db.DB()
Expand Down
104 changes: 104 additions & 0 deletions common/database/iam.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package database

import (
"context"
"database/sql/driver"
"fmt"
"net"
"strconv"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/stdlib"
)

// rdsIAMConnector is a driver.Connector that mints a fresh AWS RDS IAM auth
// token for every new connection (tokens expire after 15 minutes; already-open
// connections keep working since the token is only checked at connect time).
type rdsIAMConnector struct {
base *pgx.ConnConfig // parsed DSN; Password is set to the token per-connect
endpoint string // host:port the token is signed for
region string // AWS region the token is signed for
creds aws.CredentialsProvider // refreshes credentials internally
metrics *rdsIAMMetrics
}

// NewRDSIAMConnector returns a connector that injects a fresh RDS IAM token as
// the password on each new connection. region falls back to the default AWS
// config chain (e.g. AWS_REGION) when empty.
//
// IAM auth requires TLS, so a DSN with sslmode=disable is rejected rather than
// silently downgraded; sslmode=verify-full (+ sslrootcert) is recommended.
func NewRDSIAMConnector(ctx context.Context, dsn, region string) (driver.Connector, error) {
base, err := pgx.ParseConfig(dsn)
if err != nil {
return nil, fmt.Errorf("rds iam: parse dsn: %w", err)
}
if base.User == "" {
return nil, fmt.Errorf("rds iam: dsn must specify a database user")
}

var opts []func(*awsconfig.LoadOptions) error
if region != "" {
opts = append(opts, awsconfig.WithRegion(region))
}
awsCfg, err := awsconfig.LoadDefaultConfig(ctx, opts...)
if err != nil {
return nil, fmt.Errorf("rds iam: load aws config: %w", err)
}
if region == "" {
region = awsCfg.Region
}
if region == "" {
return nil, fmt.Errorf("rds iam: aws region is not set (configure awsRegion or AWS_REGION)")
}

// Fail closed on any path that could send the token (which is the password)
// in cleartext or to the wrong host. The token is scoped to a single
// host:port, so reject:
// - TLS-less DSNs: sslmode=disable/allow leave the primary TLSConfig nil;
// - plaintext fallbacks: sslmode=prefer (the pgx default) keeps a non-TLS
// entry in Fallbacks that pgx retries on if SSL negotiation fails;
// - multi-host DSNs: a fallback host the single token is not signed for.
if base.TLSConfig == nil {
return nil, fmt.Errorf("rds iam: TLS is required, set sslmode=require or higher (verify-full recommended) in the dsn")
}
for _, fb := range base.Fallbacks {
if fb.TLSConfig == nil {
return nil, fmt.Errorf("rds iam: dsn allows a plaintext fallback (sslmode=prefer); set sslmode=require or higher")
}
if fb.Host != base.Host || fb.Port != base.Port {
return nil, fmt.Errorf("rds iam: multi-host dsn is not supported; the IAM token is scoped to a single host:port")
}
}

return &rdsIAMConnector{
base: base,
endpoint: net.JoinHostPort(base.Host, strconv.Itoa(int(base.Port))),
region: region,
creds: awsCfg.Credentials,
metrics: initRDSIAMMetrics(),
}, nil
}

// Connect generates a fresh IAM auth token and opens a new connection with it.
func (c *rdsIAMConnector) Connect(ctx context.Context) (driver.Conn, error) {
start := time.Now()
token, err := auth.BuildAuthToken(ctx, c.endpoint, c.region, c.base.User, c.creds)
c.metrics.tokenDuration.Observe(time.Since(start).Seconds())
if err != nil {
c.metrics.tokenFailureTotal.Inc()
return nil, fmt.Errorf("rds iam: build auth token: %w", err)
}
c.metrics.tokenTotal.Inc()

cfg := c.base.Copy()
cfg.Password = token
return stdlib.GetConnector(*cfg).Connect(ctx)
}

// Driver returns the underlying pgx stdlib driver.
func (c *rdsIAMConnector) Driver() driver.Driver { return stdlib.GetDefaultDriver() }
41 changes: 41 additions & 0 deletions common/database/iam_metrics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package database

import (
"sync"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)

type rdsIAMMetrics struct {
tokenTotal prometheus.Counter
tokenFailureTotal prometheus.Counter
tokenDuration prometheus.Histogram
}

var (
initRDSIAMMetricsOnce sync.Once
rdsIAMMetric *rdsIAMMetrics
)

func initRDSIAMMetrics() *rdsIAMMetrics {
initRDSIAMMetricsOnce.Do(func() {
reg := prometheus.DefaultRegisterer
rdsIAMMetric = &rdsIAMMetrics{
tokenTotal: promauto.With(reg).NewCounter(prometheus.CounterOpts{
Name: "database_rds_iam_token_total",
Help: "Total number of AWS RDS IAM auth tokens generated (one per new connection).",
}),
tokenFailureTotal: promauto.With(reg).NewCounter(prometheus.CounterOpts{
Name: "database_rds_iam_token_failure_total",
Help: "Total number of AWS RDS IAM auth token generation failures.",
}),
tokenDuration: promauto.With(reg).NewHistogram(prometheus.HistogramOpts{
Name: "database_rds_iam_token_duration_seconds",
Help: "Latency of AWS RDS IAM auth token generation; spikes indicate a credential refresh.",
Buckets: []float64{.0001, .001, .005, .01, .05, .1, .5, 1, 5},
}),
}
})
return rdsIAMMetric
}
91 changes: 91 additions & 0 deletions common/database/iam_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package database

import (
"context"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewRDSIAMConnector(t *testing.T) {
ctx := context.Background()

t.Run("parses endpoint, user and region", func(t *testing.T) {
c, err := NewRDSIAMConnector(ctx,
"postgres://svc_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/scroll?sslmode=verify-full",
"us-east-1")
require.NoError(t, err)

conn, ok := c.(*rdsIAMConnector)
require.True(t, ok)
assert.Equal(t, "mydb.abc123.us-east-1.rds.amazonaws.com:5432", conn.endpoint)
assert.Equal(t, "svc_user", conn.base.User)
assert.Equal(t, "us-east-1", conn.region)
// TLS requested in the DSN must be preserved.
require.NotNil(t, conn.base.TLSConfig)
})

t.Run("defaults port to 5432", func(t *testing.T) {
c, err := NewRDSIAMConnector(ctx,
"postgres://svc_user@mydb.example.rds.amazonaws.com/scroll?sslmode=require",
"eu-west-1")
require.NoError(t, err)
assert.Equal(t, "mydb.example.rds.amazonaws.com:5432", c.(*rdsIAMConnector).endpoint)
})

t.Run("rejects sslmode=disable", func(t *testing.T) {
// IAM auth requires TLS; a plaintext DSN must fail closed rather than
// silently sending the token in cleartext.
_, err := NewRDSIAMConnector(ctx,
"postgres://svc_user@mydb.example.rds.amazonaws.com:5432/scroll?sslmode=disable",
"us-east-1")
assert.ErrorContains(t, err, "TLS")
})

t.Run("rejects sslmode=prefer (plaintext fallback)", func(t *testing.T) {
// prefer keeps a non-TLS entry in Fallbacks that pgx would retry on,
// which would leak the token in cleartext.
_, err := NewRDSIAMConnector(ctx,
"postgres://svc_user@mydb.example.rds.amazonaws.com:5432/scroll?sslmode=prefer",
"us-east-1")
assert.ErrorContains(t, err, "plaintext")
})

t.Run("rejects sslmode=allow (plaintext primary)", func(t *testing.T) {
_, err := NewRDSIAMConnector(ctx,
"postgres://svc_user@mydb.example.rds.amazonaws.com:5432/scroll?sslmode=allow",
"us-east-1")
assert.Error(t, err)
})

t.Run("rejects multi-host dsn", func(t *testing.T) {
// The token is signed for a single host:port, so failover would fail auth.
_, err := NewRDSIAMConnector(ctx,
"host=host1.rds.amazonaws.com,host2.rds.amazonaws.com port=5432 user=svc_user dbname=scroll sslmode=require",
"us-east-1")
assert.ErrorContains(t, err, "multi-host")
})

t.Run("region falls back to explicit empty error when unresolved", func(t *testing.T) {
// With no region argument, no AWS_REGION in the environment, and no
// shared config file, region resolution must fail rather than silently
// signing with an empty region.
t.Setenv("AWS_REGION", "")
t.Setenv("AWS_DEFAULT_REGION", "")
// Isolate from any ~/.aws/config on the dev/CI machine, which would
// otherwise resolve a default region and make this test nondeterministic.
t.Setenv("AWS_CONFIG_FILE", filepath.Join(t.TempDir(), "no-config"))
t.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join(t.TempDir(), "no-creds"))
_, err := NewRDSIAMConnector(ctx,
"postgres://svc_user@mydb.example.rds.amazonaws.com:5432/scroll?sslmode=require",
"")
assert.ErrorContains(t, err, "region")
})

t.Run("rejects invalid dsn", func(t *testing.T) {
_, err := NewRDSIAMConnector(ctx, "::not a dsn::", "us-east-1")
assert.Error(t, err)
})
}
7 changes: 4 additions & 3 deletions common/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ toolchain go1.22.2

require (
github.com/Masterminds/semver/v3 v3.2.1
github.com/aws/aws-sdk-go-v2 v1.21.2
github.com/aws/aws-sdk-go-v2/config v1.18.45
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.2.21
github.com/bits-and-blooms/bitset v1.20.0
github.com/docker/docker v26.1.0+incompatible
github.com/gin-contrib/pprof v1.4.0
github.com/gin-gonic/gin v1.9.1
github.com/jackc/pgx/v5 v5.5.4
github.com/mattn/go-colorable v0.1.13
github.com/mattn/go-isatty v0.0.20
github.com/mitchellh/mapstructure v1.5.0
Expand All @@ -34,8 +38,6 @@ require (
github.com/Microsoft/go-winio v0.6.1 // indirect
github.com/Microsoft/hcsshim v0.11.4 // indirect
github.com/VictoriaMetrics/fastcache v1.12.2 // indirect
github.com/aws/aws-sdk-go-v2 v1.21.2 // indirect
github.com/aws/aws-sdk-go-v2/config v1.18.45 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.13.43 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.13 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.43 // indirect
Expand Down Expand Up @@ -128,7 +130,6 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.4 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions common/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.13.43 h1:LU8vo40zBlo3R7bAvBVy/ku4nxG
github.com/aws/aws-sdk-go-v2/credentials v1.13.43/go.mod h1:zWJBz1Yf1ZtX5NGax9ZdNjhhI4rgjfgsyk6vTY1yfVg=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.13 h1:PIktER+hwIG286DqXyvVENjgLTAwGgoeriLDD5C+YlQ=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.13/go.mod h1:f/Ib/qYjhV2/qdsf79H3QP/eRE4AkVyEf6sk7XfZ1tg=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.2.21 h1:m/oetLggG4HFTcU0CkY1uR18uKRNTm+V1XocGd3Wcxk=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.2.21/go.mod h1:XoCNC17AXoRDfkX2bsFsGsn036fch7ATgchnAy+PsOQ=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.43 h1:nFBQlGtkbPzp/NjZLuFxRqmT91rLJkgvsEQs68h962Y=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.43/go.mod h1:auo+PiyLl0n1l8A0e8RIeR8tOzYPfZZH/JNlrJ8igTQ=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.37 h1:JRVhO25+r3ar2mKGP7E0LDl8K9/G36gjlqca5iQbaqc=
Expand Down
2 changes: 1 addition & 1 deletion common/version/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"runtime/debug"
)

var tag = "v4.7.13"
var tag = "v4.7.14"

var commit = func() string {
if info, ok := debug.ReadBuildInfo(); ok {
Expand Down
37 changes: 37 additions & 0 deletions database/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,43 @@ db_cli version
db_cli rollback
```

## AWS RDS IAM authentication

Instead of a static password in the DSN, services can authenticate to AWS
RDS/Aurora PostgreSQL using short-lived IAM auth tokens. This removes the need
to rotate database passwords: access is granted via IAM and tokens are
regenerated automatically (they expire every 15 minutes) for each new
connection.

Enable it per service in the DB config block:

```json
{
"dsn": "postgres://svc_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/scroll?sslmode=require",
"driver_name": "postgres",
"maxOpenNum": 200,
"maxIdleNum": 20,
"useIAMAuth": true,
"awsRegion": "us-east-1"
}
```

Notes:

- Omit the password from the `dsn`; it is supplied by the generated IAM token.
- `awsRegion` is optional — when empty it is resolved from the default AWS
config chain (e.g. the `AWS_REGION` environment variable).
- IAM auth requires TLS, so the DSN must set `sslmode=require` or higher.
`sslmode=disable`/`allow`/`prefer` (including an unset `sslmode`, which
defaults to `prefer`) are rejected at startup rather than silently sending the
token over a connection that may fall back to plaintext. `require` encrypts
but does not verify the server certificate; for that, use `sslmode=verify-full`
with `sslrootcert` pointing at the
[RDS CA bundle](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.SSL.html).
- The database role must be granted `rds_iam`, and the service's IAM role needs
`rds-db:connect` on the `dbuser` resource. Leaving `useIAMAuth` unset
preserves the previous password-based behavior.

## Test

```bash
Expand Down
7 changes: 7 additions & 0 deletions database/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ type DBConfig struct {

MaxOpenNum int `json:"maxOpenNum"`
MaxIdleNum int `json:"maxIdleNum"`

// UseIAMAuth authenticates to AWS RDS/Aurora with short-lived IAM tokens
// instead of a DSN password. The DSN should omit the password and enable TLS.
UseIAMAuth bool `json:"useIAMAuth"`
// AWSRegion signs the IAM tokens. Optional; falls back to the default AWS
// config chain (e.g. AWS_REGION) when empty.
AWSRegion string `json:"awsRegion"`
}

// NewConfig returns a new instance of Config.
Expand Down
Loading
Loading