diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 627f2a96..93781507 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -39,7 +39,7 @@ jobs: # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed env: CGO_ENABLED: 0 - GOFLAGS: "-tags=sync,search,scrub,metrics,userprefs,containers_image_openpgp" + GOFLAGS: "-tags=sync,search,scrub,metrics,userprefs,apikey,containers_image_openpgp" steps: - name: Checkout repository diff --git a/.github/workflows/ecosystem-tools.yaml b/.github/workflows/ecosystem-tools.yaml index 88995a2c..963396b8 100644 --- a/.github/workflows/ecosystem-tools.yaml +++ b/.github/workflows/ecosystem-tools.yaml @@ -46,6 +46,12 @@ jobs: sudo systemctl enable crio.service sudo systemctl start crio.service sudo chmod 0777 /var/run/crio/crio.sock + # install dex + git clone https://github.com/dexidp/dex.git + cd dex/ + make bin/dex + ./bin/dex serve $GITHUB_WORKSPACE/test/dex/config-dev.yaml & + cd $GITHUB_WORKSPACE - name: Run referrers tests run: | make test-bats-referrers diff --git a/.github/workflows/golangci-lint.yaml b/.github/workflows/golangci-lint.yaml index 2c9dc620..22431394 100644 --- a/.github/workflows/golangci-lint.yaml +++ b/.github/workflows/golangci-lint.yaml @@ -32,7 +32,7 @@ jobs: # Optional: golangci-lint command line arguments. # args: --issues-exit-code=0 - args: --config ./golangcilint.yaml --enable-all --build-tags debug,needprivileges,sync,scrub,search,userprefs,metrics,containers_image_openpgp,lint,mgmt ./cmd/... ./pkg/... + args: --config ./golangcilint.yaml --enable-all --build-tags debug,needprivileges,sync,scrub,search,userprefs,metrics,containers_image_openpgp,lint,mgmt,apikey ./cmd/... ./pkg/... # Optional: show only new issues if it's a pull request. The default value is `false`. # only-new-issues: true diff --git a/Makefile b/Makefile index 6d1e1290..231d4217 100644 --- a/Makefile +++ b/Makefile @@ -32,8 +32,8 @@ TESTDATA := $(TOP_LEVEL)/test/data OS ?= linux ARCH ?= amd64 BENCH_OUTPUT ?= stdout -EXTENSIONS ?= sync,search,scrub,metrics,lint,ui,mgmt,userprefs -UI_DEPENDENCIES := search,mgmt,userprefs +EXTENSIONS ?= sync,search,scrub,metrics,lint,ui,mgmt,userprefs,apikey +UI_DEPENDENCIES := search,mgmt,userprefs,apikey comma:= , space := $(null) # hyphen:= - diff --git a/errors/errors.go b/errors/errors.go index 71a595f1..110438ba 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -80,7 +80,14 @@ var ( ErrMediaTypeNotSupported = errors.New("repodb: media type is not supported") ErrTimeout = errors.New("operation timeout") ErrNotImplemented = errors.New("not implemented") - ErrUnableToCreateUserBucket = errors.New("repodb: unable to create a user bucket for user") + ErrDedupeRebuild = errors.New("dedupe: couldn't rebuild dedupe index") + ErrMissingAuthHeader = errors.New("auth: required authorization header is missing") + ErrUserAPIKeyNotFound = errors.New("userDB: user info for given API key hash not found") + ErrUserSessionNotFound = errors.New("userDB: user session for given ID not found") + ErrBucketDoesNotExist = errors.New("DB: bucket does not exist") + ErrOpenIDProviderDoesNotExist = errors.New("openID: provider does not exist in given config") + ErrHashKeyNotCreated = errors.New("cookiestore: generated random hash key is nil, not continuing") + ErrFailedTypeAssertion = errors.New("type assertion failed") ErrInvalidOldUserStarredRepos = errors.New("repodb: invalid old entry for user starred repos") ErrUnmarshalledRepoListIsNil = errors.New("repodb: list of repos is still nil") ErrCouldNotMarshalStarredRepos = errors.New("repodb: could not repack entry for user starred repos") @@ -89,7 +96,6 @@ var ( ErrUserDataNotFound = errors.New("repodb: user data not found for given user identifier") ErrUserDataNotAllowed = errors.New("repodb: user data operations are not allowed") ErrCouldNotPersistData = errors.New("repodb: could not persist to db") - ErrDedupeRebuild = errors.New("dedupe: couldn't rebuild dedupe index") ErrSignConfigDirNotSet = errors.New("signatures: signature config dir not set") ErrBadManifestDigest = errors.New("signatures: bad manifest digest") ErrInvalidSignatureType = errors.New("signatures: invalid signature type") @@ -100,4 +106,5 @@ var ( ErrInvalidTruststoreType = errors.New("signatures: invalid truststore type") ErrInvalidTruststoreName = errors.New("signatures: invalid truststore name") ErrInvalidCertificateContent = errors.New("signatures: invalid certificate content") + ErrInvalidStateCookie = errors.New("auth: state cookie not present or differs from original state") ) diff --git a/examples/README.md b/examples/README.md index 9df51e2c..a97f6f62 100644 --- a/examples/README.md +++ b/examples/README.md @@ -18,13 +18,20 @@ Examples of working configurations for various use cases are available [here](.. # Configuration Parameters -* [Network](#network) -* [Storage](#storage) -* [Authentication](#authentication) -* [Identity-based Authorization](#identity-based-authorization) -* [Logging](#logging) -* [Metrics](#metrics) -* [Sync](#sync) +- [Configuration Parameters](#configuration-parameters) + - [Network](#network) + - [Storage](#storage) + - [Authentication](#authentication) + - [TLS Mutual Authentication](#tls-mutual-authentication) + - [Passphrase Authentication](#passphrase-authentication) + - [Authentication Failures](#authentication-failures) + - [API keys](#api-keys) + - [Identity-based Authorization](#identity-based-authorization) + - [Logging](#logging) + - [Metrics](#metrics) + - [Storage Drivers](#storage-drivers) + - [Specifying S3 credentials](#specifying-s3-credentials) + - [Sync](#sync) ## Network @@ -162,6 +169,98 @@ NOTE: When both htpasswd and LDAP configuration are specified, LDAP authenticati } ``` +### OpenID/OAuth2 social login + +zot supports several openID/OAuth2 providers: + - google + - github + - gitlab + - dex + +zot can be configured to use the above providers with: +``` +{ + "http": { + "auth": { + "openid": { + "providers": { + "github": { + "clientid": , + "clientsecret": , + "scopes": ["read:org", "user", "repo"] + }, + "google": { + "issuer": "https://accounts.google.com", + "clientid": , + "clientsecret": , + "scopes": ["openid", "email"] + }, + "gitlab": { + "issuer": "https://gitlab.com", + "clientid": , + "clientsecret": , + "scopes": ["openid", "read_api", "read_user", "profile", "email"] + } + } + } + } + } +``` + +the login with either provider use http://127.0.0.1:8080/auth/login?provider=\&callback_ui=http://127.0.0.1:8080/home +for example to login with github use http://127.0.0.1:8080/auth/login?provider=github&callback_ui=http://127.0.0.1:8080/home + +callback_ui query parameter is used by zot to redirect to UI after a successful openid/oauth2 authentication + +the callback url which should be used when making oauth2 provider setup is http://127.0.0.1:8080/auth/callback/\ +for example github callback url would be http://127.0.0.1:8080/auth/callback/github + +If network policy doesn't allow inbound connections, this callback wont work! + +dex is an identity service that uses OpenID Connect to drive authentication for other apps https://github.com/dexidp/dex +To setup dex service see https://dexidp.io/docs/getting-started/ + +to configure zot as a client in dex (assuming zot is hosted at 127.0.0.1:8080), we need to configure dex with: + +``` +staticClients: + - id: zot-client + redirectURIs: + - 'http://127.0.0.1:8080/auth/callback/dex' + name: 'zot' + secret: ZXhhbXBsZS1hcHAtc2VjcmV0 +``` + +zot can be configured to use dex with: + +``` + "http": { + "auth": { + "openid": { + "providers": { + "dex": { + "clientid": "zot-client", + "clientsecret": "ZXhhbXBsZS1hcHAtc2VjcmV0", + "keypath": "", + "issuer": "http://127.0.0.1:5556/dex", + "scopes": ["openid", "profile", "email", "groups"] + } + } + } + } + } +``` + +to login using openid dex provider use http://127.0.0.1:8080/auth/login?provider=dex + +### Session based login + +Whenever a user logs in zot using any of the auth options available(basic auth/openid) zot will set a 'session' cookie on its response. +Using that cookie on subsequent calls will authenticate them, asumming the cookie didn't expire. + +In case of using filesystem storage sessions are saved in zot's root directory. +In case of using cloud storage sessions are saved in memory. + #### Authentication Failures Should authentication fail, to prevent automated attacks, a delayed response can be configured with: @@ -172,6 +271,21 @@ Should authentication fail, to prevent automated attacks, a delayed response can "failDelay": 5 ``` +#### API keys + +zot allows authentication for REST API calls using your API key as an alternative to your password. +for more info see [API keys doc](../pkg/extensions/README_apikey.md) + +To activate API keys use: + +``` +"extensions": { + "apikey": { + "enable": true + } +} +``` + ## Identity-based Authorization Allowing actions on one or more repository paths can be tied to user diff --git a/examples/config-metricsbug.json b/examples/config-metricsbug.json new file mode 100644 index 00000000..374a391b --- /dev/null +++ b/examples/config-metricsbug.json @@ -0,0 +1,121 @@ +{ + "distSpecVersion": "1.1.0-dev", + "extensions": { + "metrics": { + "enable": true, + "prometheus": { + "path": "/metrics" + } + }, + "mgmt": { + "enable": true + }, + "scrub": { + "enable": true, + "interval": "24h" + }, + "search": { + "cve": { + "updateInterval": "2h" + }, + "enable": true + }, + "sync": { + "enable": true, + "registries": [ + { + "content": [ + { + "destination": "/docker.io", + "prefix": "**" + } + ], + "onDemand": true, + "tlsVerify": true, + "urls": [ + "https://docker.io/library" + ] + }, + { + "content": [ + { + "destination": "/registry.gitlab.com", + "prefix": "**" + } + ], + "onDemand": true, + "tlsVerify": true, + "urls": [ + "https://registry.gitlab.com" + ] + }, + { + "content": [ + { + "destination": "ghcr.io", + "prefix": "**" + } + ], + "onDemand": true, + "tlsVerify": true, + "urls": [ + "https://ghcr.io" + ] + }, + { + "content": [ + { + "destination": "/quay.io", + "prefix": "**" + } + ], + "onDemand": true, + "tlsVerify": true, + "urls": [ + "https://quay.io" + ] + }, + { + "content": [ + { + "destination": "/gcr.io", + "prefix": "**" + } + ], + "onDemand": true, + "tlsVerify": true, + "urls": [ + "https://gcr.io" + ] + }, + { + "content": [ + { + "destination": "/registry.k8s.io", + "prefix": "**" + } + ], + "onDemand": true, + "tlsVerify": true, + "urls": [ + "https://registry.k8s.io" + ] + } + ] + }, + "ui": { + "enable": true + } + }, + "http": { + "address": "0.0.0.0", + "port": "5000" + }, + "log": { + "level": "debug" + }, + "storage": { + "gc": true, + "rootDirectory": "/tmp/zot" + } + } diff --git a/examples/config-openid.json b/examples/config-openid.json new file mode 100644 index 00000000..b85e1063 --- /dev/null +++ b/examples/config-openid.json @@ -0,0 +1,75 @@ +{ + "distSpecVersion": "1.1.0-dev", + "storage": { + "rootDirectory": "/tmp/zot", + "dedupe": true + }, + "http": { + "address": "127.0.0.1", + "port": "8080", + "realm": "zot", + "auth": { + "htpasswd": { + "path": "test/data/htpasswd" + }, + "openid": { + "providers": { + "github": { + "clientid": "client_id", + "clientsecret": "client_secret", + "keypath": "", + "scopes": ["read:org", "user", "repo"] + }, + "google": { + "issuer": "https://accounts.google.com", + "clientid": "client_id", + "clientsecret": "client_secret", + "scopes": ["openid", "email"] + }, + "gitlab": { + "issuer": "https://gitlab.com", + "clientid": "client_id", + "clientsecret": "client_secret", + "scopes": ["openid", "read_api", "read_user", "profile", "email"] + }, + "dex": { + "issuer": "http://127.0.0.1:5556/dex", + "clientid": "client_id", + "clientsecret": "client_secret", + "scopes": ["openid", "user", "email", "groups"] + } + } + }, + "failDelay": 5 + }, + "accessControl": { + "repositories": { + "**": { + "policies": [ + { + "users": [ + "test" + ], + "actions": [ + "read", + "create" + ] + } + ], + "defaultPolicy": ["read"] + } + } + } + }, + "log": { + "level": "debug" + }, + "extensions": { + "apikey": { + "enable": true + }, + "mgmt": { + "enable": true + } + } +} diff --git a/go.mod b/go.mod index 21f0c852..2682056a 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,6 @@ require ( github.com/gofrs/uuid v4.4.0+incompatible github.com/google/go-containerregistry v0.15.2 github.com/google/uuid v1.3.0 - github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 github.com/hashicorp/golang-lru/v2 v2.0.3 github.com/json-iterator/go v1.1.12 @@ -53,14 +52,29 @@ require ( github.com/aws/aws-sdk-go-v2/service/dynamodb v1.20.0 github.com/containers/image/v5 v5.25.0 github.com/gobwas/glob v0.2.3 + github.com/google/go-github/v52 v52.0.0 + github.com/gorilla/handlers v1.5.1 + github.com/gorilla/securecookie v1.1.1 + github.com/gorilla/sessions v1.2.1 + github.com/migueleliasweb/go-github-mock v0.0.18 github.com/notaryproject/notation-go v1.0.0-rc.6 github.com/opencontainers/distribution-spec/specs-go v0.0.0-20230117141039-067a0f5b0e25 + github.com/project-zot/mockoidc v0.0.0-20230307111146-f607b4b5fb97 github.com/sigstore/cosign/v2 v2.0.2 github.com/swaggo/http-swagger v1.3.4 + github.com/zitadel/oidc v1.12.0 + golang.org/x/oauth2 v0.9.0 modernc.org/sqlite v1.23.1 oras.land/oras-go/v2 v2.2.1 ) +require ( + github.com/google/go-github/v50 v50.2.0 // indirect + golang.org/x/sync v0.3.0 // indirect + golang.org/x/sys v0.9.0 // indirect + golang.org/x/text v0.10.0 // indirect +) + require ( filippo.io/edwards25519 v1.0.0 // indirect github.com/AdamKorcz/go-118-fuzz-build v0.0.0-20221215162035-5330a85ea652 // indirect @@ -108,7 +122,6 @@ require ( github.com/go-jose/go-jose/v3 v3.0.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/google/gnostic v0.5.7-v3refs // indirect - github.com/google/go-github/v50 v50.2.0 // indirect github.com/google/licenseclassifier/v2 v2.0.0 // indirect github.com/google/s2a-go v0.1.3 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect @@ -331,6 +344,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.8.0 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/gorilla/schema v1.2.0 // indirect github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect @@ -452,11 +466,7 @@ require ( golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect golang.org/x/mod v0.10.0 // indirect golang.org/x/net v0.11.0 // indirect - golang.org/x/oauth2 v0.9.0 // indirect - golang.org/x/sync v0.3.0 // indirect - golang.org/x/sys v0.9.0 // indirect golang.org/x/term v0.9.0 // indirect - golang.org/x/text v0.10.0 // indirect golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.8.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect diff --git a/go.sum b/go.sum index 81584d30..8b09ad58 100644 --- a/go.sum +++ b/go.sum @@ -72,6 +72,7 @@ cloud.google.com/go/compute v1.7.0/go.mod h1:435lt8av5oL9P3fv1OEzSbSUe+ybHXGMPQH cloud.google.com/go/compute v1.10.0/go.mod h1:ER5CLbMxl90o2jtNbGSbtfOpQKR0t15FOtRsugnLrlU= cloud.google.com/go/compute v1.19.1 h1:am86mquDUgjGNWxiGn+5PGLbmgiWXlE/yNWpIpNvuXY= cloud.google.com/go/compute v1.19.1/go.mod h1:6ylj3a05WF8leseCdIf77NK0g1ey+nj5IKd5/kvShxE= +cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= cloud.google.com/go/containeranalysis v0.5.1/go.mod h1:1D92jd8gRR/c0fGMlymRgxWD3Qw9C1ff6/T7mLgVL8I= @@ -287,6 +288,7 @@ github.com/Microsoft/hcsshim v0.10.0-rc.7/go.mod h1:ILuwjA+kNW+MrN/w5un7n3mTqkws github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OneOfOne/xxhash v1.2.8 h1:31czK/TI9sNkxIKfaUfGlU47BAxQ0ztGgd9vPyqimf8= github.com/OneOfOne/xxhash v1.2.8/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdIIOT9Um7Q= +github.com/ProtonMail/go-crypto v0.0.0-20230217124315-7d5c6f04bbb8/go.mod h1:I0gYDMZ6Z5GRU7l58bNFSkPTFN6Yl12dsUlAZ8xy98g= github.com/ProtonMail/go-crypto v0.0.0-20230518184743-7afd39499903 h1:ZK3C5DtzV2nVAQTx5S5jQvMeDqWtD1By5mOoyY/xJek= github.com/ProtonMail/go-crypto v0.0.0-20230518184743-7afd39499903/go.mod h1:8TI4H3IbrackdNgv+92dI+rhpCaLqM0IfpgCgenFvRE= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= @@ -501,6 +503,7 @@ github.com/bradleyjkemp/cupaloy/v2 v2.8.0 h1:any4BmKE+jGIaMpnU8YgH/I2LPiLBufr6oM github.com/briandowns/spinner v1.23.0 h1:alDF2guRWqa/FOZZYWjlMIx2L6H0wyewPxo/CH4Pt2A= github.com/briandowns/spinner v1.23.0/go.mod h1:rPG4gmXeN3wQV/TsAY4w8lPdIM6RX3yqeBQJSrbXjuE= github.com/bshuster-repo/logrus-logstash-hook v1.0.0 h1:e+C0SB5R1pu//O4MQ3f9cFuPGoOVeF2fE4Og9otCc70= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bugsnag/bugsnag-go v0.0.0-20141110184014-b1d153021fcd h1:rFt+Y/IK1aEZkEHchZRSq9OQbsSzIT/OrI8YFFmRIng= github.com/bugsnag/osext v0.0.0-20130617224835-0dd3f918b21b h1:otBG+dV+YK+Soembjv71DPz3uX/V/6MMlSyD9JBQ6kQ= github.com/bugsnag/panicwrap v0.0.0-20151223152923-e2c28503fcd0 h1:nvj0OLI3YqYXer/kZD8Ri1aaunCxIEsOst1BVJswV0o= @@ -724,9 +727,11 @@ github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxF github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-ldap/ldap/v3 v3.4.5 h1:ekEKmaDrpvR2yf5Nc/DClsGG9lAmdDixe44mLzlW5r8= github.com/go-ldap/ldap/v3 v3.4.5/go.mod h1:bMGIq3AGbytbaMwf8wdv5Phdxz0FWHTIYMSzyrYgnQs= +github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= @@ -910,8 +915,12 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-containerregistry v0.14.1-0.20230409045903-ed5c185df419 h1:gMlTWagRJgCJ3EnISyF5+p9phYpFyWEI70Z56T+o2MY= github.com/google/go-containerregistry v0.14.1-0.20230409045903-ed5c185df419/go.mod h1:ETSJmRH9iO4Q0WQILIMkDUiKk+CaxItZW+gEDjyw8Ug= +github.com/google/go-github/v31 v31.0.0/go.mod h1:NQPZol8/1sMoWYGN2yaALIBytu17gAWfhbweiEed3pM= github.com/google/go-github/v50 v50.2.0 h1:j2FyongEHlO9nxXLc+LP3wuBSVU9mVxfpdYUexMpIfk= github.com/google/go-github/v50 v50.2.0/go.mod h1:VBY8FB6yPIjrtKhozXv4FQupxKLS6H4m6xFZlT43q8Q= +github.com/google/go-github/v52 v52.0.0 h1:uyGWOY+jMQ8GVGSX8dkSwCzlehU3WfdxQ7GweO/JP7M= +github.com/google/go-github/v52 v52.0.0/go.mod h1:WJV6VEEUPuMo5pXqqa2ZCZEdbQqua4zAk2MZTIo+m+4= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -984,6 +993,12 @@ github.com/gorilla/handlers v1.5.1 h1:9lRY6j8DEeeBT10CvO9hGW0gmky0BprnvDI5vfhUHH github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc= +github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= +github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= @@ -1067,6 +1082,8 @@ github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i github.com/jedisct1/go-minisign v0.0.0-20211028175153-1c139d1cc84b h1:ZGiXF8sz7PDk6RgkP+A/SFfUD0ZR/AgG6SpRNEDKZy8= github.com/jedisct1/go-minisign v0.0.0-20211028175153-1c139d1cc84b/go.mod h1:hQmNrgofl+IY/8L+n20H6E6PWBBTokdsv+q49j0QhsU= github.com/jellydator/ttlcache/v3 v3.0.1 h1:cHgCSMS7TdQcoprXnWUptJZzyFsqs18Lt8VVhRuZYVU= +github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc= +github.com/jeremija/gosubmit v0.2.7/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= @@ -1230,6 +1247,8 @@ github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/miekg/pkcs11 v1.1.1 h1:Ugu9pdy6vAYku5DEpVWVFPYnzV+bxB+iRdbuFSu7TvU= github.com/miekg/pkcs11 v1.1.1/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/migueleliasweb/go-github-mock v0.0.18 h1:0lWt9MYmZQGnQE2rFtjlft/YtD6hzxuN6JJRFpujzEI= +github.com/migueleliasweb/go-github-mock v0.0.18/go.mod h1:CcgXcbMoRnf3rRVHqGssuBquZDIcaplxL2W6G+xs7kM= github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM= github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= @@ -1391,6 +1410,8 @@ github.com/poy/onpar v0.0.0-20190519213022-ee068f8ea4d1 h1:oL4IBbcqwhhNWh31bjOX8 github.com/poy/onpar v0.0.0-20190519213022-ee068f8ea4d1/go.mod h1:nSbFQvMj97ZyhFRSJYtut+msi4sOY6zJDGCdSc+/rZU= github.com/proglottis/gpgme v0.1.3 h1:Crxx0oz4LKB3QXc5Ea0J19K/3ICfy3ftr5exgUK1AU0= github.com/proglottis/gpgme v0.1.3/go.mod h1:fPbW/EZ0LvwQtH8Hy7eixhp1eF3G39dtx7GUN+0Gmy0= +github.com/project-zot/mockoidc v0.0.0-20230307111146-f607b4b5fb97 h1:V6z9y0Yx2sQs4WSKx79mgkKJWwjbu/lHQg1yza5bmQE= +github.com/project-zot/mockoidc v0.0.0-20230307111146-f607b4b5fb97/go.mod h1:46X30UrCsiwicZcg5L098Pyilaj94AO39mvS5PEyPn8= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= @@ -1688,6 +1709,10 @@ github.com/zclconf/go-cty-yaml v1.0.2 h1:dNyg4QLTrv2IfJpm7Wtxi55ed5gLGOlPrZ6kMd5 github.com/zclconf/go-cty-yaml v1.0.2/go.mod h1:IP3Ylp0wQpYm50IHK8OZWKMu6sPJIUgKa8XhiVHura0= github.com/zeebo/errs v1.3.0 h1:hmiaKqgYZzcVgRL1Vkc1Mn2914BbzB0IBxs+ebeutGs= github.com/zeebo/errs v1.3.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= +github.com/zitadel/logging v0.3.4 h1:9hZsTjMMTE3X2LUi0xcF9Q9EdLo+FAezeu52ireBbHM= +github.com/zitadel/logging v0.3.4/go.mod h1:aPpLQhE+v6ocNK0TWrBrd363hZ95KcI17Q1ixAQwZF0= +github.com/zitadel/oidc v1.12.0 h1:JNUjbhuuQRxFsfwd9s1CP3GOI4IwZ70G8sa3FQHPy6Y= +github.com/zitadel/oidc v1.12.0/go.mod h1:uOTKHn4pqb0w7WbdGCMTxkmzZaBcyOmBKBNcptlF6f8= github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs= github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= @@ -1872,6 +1897,7 @@ golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -1900,6 +1926,7 @@ golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094/go.mod h1:h4gKUeWbJ4rQPri golang.org/x/oauth2 v0.0.0-20220909003341-f21342109be1/go.mod h1:h4gKUeWbJ4rQPri7E0u6Gs4e9Ri2zaLxzw5DI5XGrYg= golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783/go.mod h1:h4gKUeWbJ4rQPri7E0u6Gs4e9Ri2zaLxzw5DI5XGrYg= golang.org/x/oauth2 v0.1.0/go.mod h1:G9FE4dLTsbXUu90h/Pf85g4w1D+SSAgR+q46nJZ8M4A= +golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= golang.org/x/oauth2 v0.9.0 h1:BPpt2kU7oMRq3kCHAA1tbSEshXRw1LpG2ztgDwrzuAs= golang.org/x/oauth2 v0.9.0/go.mod h1:qYgFZaFiu6Wg24azG8bdV52QJXJGbZzIIsRCdVKzbLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -2007,6 +2034,7 @@ golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -2029,6 +2057,7 @@ golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -2037,6 +2066,7 @@ golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28= golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -2052,6 +2082,7 @@ golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -2365,6 +2396,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/cheggaaa/pb.v1 v1.0.27/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= diff --git a/pkg/api/authn.go b/pkg/api/authn.go index af61cd65..f1b9a703 100644 --- a/pkg/api/authn.go +++ b/pkg/api/authn.go @@ -3,139 +3,315 @@ package api import ( "bufio" "context" + "crypto/sha256" "crypto/x509" "encoding/base64" + "encoding/gob" + "errors" "fmt" + "net" "net/http" "os" + "path" "strconv" "strings" "time" "github.com/chartmuseum/auth" + "github.com/google/go-github/v52/github" + "github.com/google/uuid" "github.com/gorilla/mux" + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" + godigest "github.com/opencontainers/go-digest" + "github.com/zitadel/oidc/pkg/client/rp" + httphelper "github.com/zitadel/oidc/pkg/http" + "github.com/zitadel/oidc/pkg/oidc" "golang.org/x/crypto/bcrypt" + "golang.org/x/oauth2" + githubOAuth "golang.org/x/oauth2/github" - "zotregistry.io/zot/errors" + zerr "zotregistry.io/zot/errors" "zotregistry.io/zot/pkg/api/config" "zotregistry.io/zot/pkg/api/constants" apiErr "zotregistry.io/zot/pkg/api/errors" "zotregistry.io/zot/pkg/common" + "zotregistry.io/zot/pkg/log" localCtx "zotregistry.io/zot/pkg/requestcontext" + storageConstants "zotregistry.io/zot/pkg/storage/constants" ) const ( bearerAuthDefaultAccessEntryType = "repository" + issuedAtOffset = 5 * time.Second + relyingPartyCookieMaxAge = 120 ) -func AuthHandler(c *Controller) mux.MiddlewareFunc { - if isBearerAuthEnabled(c.Config) { - return bearerAuthHandler(c) - } - - return basicAuthHandler(c) +type AuthnMiddleware struct { + credMap map[string]string + ldapClient *LDAPClient } -func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc { - authorizer, err := auth.NewAuthorizer(&auth.AuthorizerOptions{ - Realm: ctlr.Config.HTTP.Auth.Bearer.Realm, - Service: ctlr.Config.HTTP.Auth.Bearer.Service, - PublicKeyPath: ctlr.Config.HTTP.Auth.Bearer.Cert, - AccessEntryType: bearerAuthDefaultAccessEntryType, - EmptyDefaultNamespace: true, - }) +func AuthHandler(ctlr *Controller) mux.MiddlewareFunc { + authnMiddleware := &AuthnMiddleware{} + + if isBearerAuthEnabled(ctlr.Config) { + return bearerAuthHandler(ctlr) + } + + return authnMiddleware.TryAuthnHandlers(ctlr) +} + +func (amw *AuthnMiddleware) sessionAuthn(ctlr *Controller, next http.Handler, response http.ResponseWriter, + request *http.Request, delay int, +) { + clientHeader := request.Header.Get(constants.SessionClientHeaderName) + if clientHeader != constants.SessionClientHeaderValue { + authFail(response, request, ctlr.Config.HTTP.Realm, delay) + + return + } + + identity, ok := common.GetAuthUserFromRequestSession(ctlr.CookieStore, request, ctlr.Log) + if !ok { + // let the client know that this session is invalid/expired + cookie := &http.Cookie{ + Name: "session", + Value: "", + Path: "/", + Expires: time.Unix(0, 0), + + HttpOnly: true, + } + + http.SetCookie(response, cookie) + + authFail(response, request, ctlr.Config.HTTP.Realm, delay) + + return + } + + ctx := getReqContextWithAuthorization(identity, []string{}, request) + + groups, err := ctlr.RepoDB.GetUserGroups(ctx) if err != nil { - ctlr.Log.Panic().Err(err).Msg("error creating bearer authorizer") + if errors.Is(err, zerr.ErrUserDataNotFound) { + ctlr.Log.Err(err).Str("identity", identity).Msg("can not find user profile in DB") + + authFail(response, request, ctlr.Config.HTTP.Realm, delay) + + return + } + + ctlr.Log.Err(err).Str("identity", identity).Msg("can not get user profile in DB") + + response.WriteHeader(http.StatusInternalServerError) + + return } - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - if request.Method == http.MethodOptions { - next.ServeHTTP(response, request) - response.WriteHeader(http.StatusNoContent) + ctx = getReqContextWithAuthorization(identity, groups, request) - return - } - vars := mux.Vars(request) - name := vars["name"] - - // we want to bypass auth for mgmt route - isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix - - header := request.Header.Get("Authorization") - - if (header == "" || header == "Basic Og==") && isMgmtRequested { - next.ServeHTTP(response, request) - - return - } - - action := auth.PullAction - if m := request.Method; m != http.MethodGet && m != http.MethodHead { - action = auth.PushAction - } - permissions, err := authorizer.Authorize(header, action, name) - if err != nil { - ctlr.Log.Error().Err(err).Msg("issue parsing Authorization header") - response.Header().Set("Content-Type", "application/json") - common.WriteJSON(response, http.StatusInternalServerError, apiErr.NewErrorList(apiErr.NewError(apiErr.UNSUPPORTED))) - - return - } - - if !permissions.Allowed { - authFail(response, permissions.WWWAuthenticateHeader, 0) - - return - } - - next.ServeHTTP(response, request) - }) - } + next.ServeHTTP(response, request.WithContext(ctx)) } -func noPasswdAuth(realm string, config *config.Config) mux.MiddlewareFunc { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - if request.Method == http.MethodOptions { - next.ServeHTTP(response, request) - response.WriteHeader(http.StatusNoContent) +func (amw *AuthnMiddleware) basicAuthn(ctlr *Controller, response http.ResponseWriter, + request *http.Request, +) (bool, http.ResponseWriter, *http.Request, error) { + cookieStore := ctlr.CookieStore - return - } + // we want to bypass auth for mgmt route + isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix - // Process request + if request.Header.Get("Authorization") == "" { + if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested { ctx := getReqContextWithAuthorization("", []string{}, request) - next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck - }) + // Process request + + return true, response, request.WithContext(ctx), nil + } } + + identity, passphrase, err := getUsernamePasswordBasicAuth(request) + if err != nil { + ctlr.Log.Error().Err(err).Msg("failed to parse authorization header") + + return false, nil, nil, nil + } + + // some client tools might send Authorization: Basic Og== (decoded into ":") + // empty username and password + if identity == "" && passphrase == "" { + if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested { + ctx := getReqContextWithAuthorization("", []string{}, request) + + return true, response, request.WithContext(ctx), nil + } + } + + passphraseHash, ok := amw.credMap[identity] + if ok { + // first, HTTPPassword authN (which is local) + if err := bcrypt.CompareHashAndPassword([]byte(passphraseHash), []byte(passphrase)); err == nil { + // Process request + var groups []string + + if ctlr.Config.HTTP.AccessControl != nil { + ac := NewAccessController(ctlr.Config) + groups = ac.getUserGroups(identity) + } + + ctx := getReqContextWithAuthorization(identity, groups, request) + + // saved logged session + if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil { + return false, response, request, err + } + + if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil { + ctlr.Log.Error().Err(err).Str("identity", identity).Msg("couldn't update user profile") + + return false, response, request, err + } + + ctlr.Log.Info().Str("identity", identity).Msgf("user profile successfully set") + + return true, response, request.WithContext(ctx), nil + } + } + + // next, LDAP if configured (network-based which can lose connectivity) + if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil { + ok, _, ldapgroups, err := amw.ldapClient.Authenticate(identity, passphrase) + if ok && err == nil { + // Process request + var groups []string + + if ctlr.Config.HTTP.AccessControl != nil { + ac := NewAccessController(ctlr.Config) + groups = ac.getUserGroups(identity) + } + + groups = append(groups, ldapgroups...) + + ctx := getReqContextWithAuthorization(identity, groups, request) + + if err := saveUserLoggedSession(cookieStore, response, request, identity, ctlr.Log); err != nil { + return false, response, request, err + } + + if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil { + ctlr.Log.Error().Err(err).Str("identity", identity).Msg("couldn't update user profile") + + return false, response, request, err + } + + return true, response, request.WithContext(ctx), nil + } + } + + // last try API keys + if isAPIKeyEnabled(ctlr.Config) { + apiKey := passphrase + + if !strings.HasPrefix(apiKey, constants.APIKeysPrefix) { + ctlr.Log.Error().Msg("api token has invalid format") + + return false, nil, nil, nil + } + + trimmedAPIKey := strings.TrimPrefix(apiKey, constants.APIKeysPrefix) + + hashedKey := hashUUID(trimmedAPIKey) + + storedIdentity, err := ctlr.RepoDB.GetUserAPIKeyInfo(hashedKey) + if err != nil { + if errors.Is(err, zerr.ErrUserAPIKeyNotFound) { + ctlr.Log.Info().Err(err).Msgf("can not find any user info for hashed key %s in DB", hashedKey) + + return false, nil, nil, nil + } + + ctlr.Log.Error().Err(err).Msgf("can not get user info for hashed key %s in DB", hashedKey) + + return false, nil, nil, err + } + + if storedIdentity == identity { + ctx := getReqContextWithAuthorization(identity, []string{}, request) + + err := ctlr.RepoDB.UpdateUserAPIKeyLastUsed(ctx, hashedKey) + if err != nil { + ctlr.Log.Err(err).Str("identity", identity).Msg("can not update user profile in DB") + + return false, nil, nil, err + } + + groups, err := ctlr.RepoDB.GetUserGroups(ctx) + if err != nil { + ctlr.Log.Err(err).Str("identity", identity).Msg("can not get user's groups in DB") + + return false, nil, nil, err + } + + ctx = getReqContextWithAuthorization(identity, groups, request) + + return true, response, request.WithContext(ctx), nil + } + } + + return false, nil, nil, nil } -//nolint:gocyclo // we use closure making this a complex subroutine -func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc { - realm := ctlr.Config.HTTP.Realm - if realm == "" { - realm = "Authorization Required" - } - - realm = "Basic realm=" + strconv.Quote(realm) - +func (amw *AuthnMiddleware) TryAuthnHandlers(ctlr *Controller) mux.MiddlewareFunc { //nolint: gocyclo // no password based authN, if neither LDAP nor HTTP BASIC is enabled if ctlr.Config.HTTP.Auth == nil || - (ctlr.Config.HTTP.Auth.HTPasswd.Path == "" && ctlr.Config.HTTP.Auth.LDAP == nil) { - return noPasswdAuth(realm, ctlr.Config) + (ctlr.Config.HTTP.Auth.HTPasswd.Path == "" && ctlr.Config.HTTP.Auth.LDAP == nil && + ctlr.Config.HTTP.Auth.OpenID == nil) { + return noPasswdAuth(ctlr.Config) } - credMap := make(map[string]string) + amw.credMap = make(map[string]string) delay := ctlr.Config.HTTP.Auth.FailDelay - var ldapClient *LDAPClient + // setup sessions cookie store used to preserve logged in user in web sessions + if isAuthnEnabled(ctlr.Config) || isOpenIDAuthEnabled(ctlr.Config) { + // To store custom types in our cookies, + // we must first register them using gob.Register + gob.Register(map[string]interface{}{}) + cookieStoreHashKey := securecookie.GenerateRandomKey(64) + if cookieStoreHashKey == nil { + panic(zerr.ErrHashKeyNotCreated) + } + + // if storage is filesystem then use zot's rootDir to store sessions + if ctlr.Config.Storage.StorageDriver == nil { + sessionsDir := path.Join(ctlr.Config.Storage.RootDirectory, "_sessions") + if err := os.MkdirAll(sessionsDir, storageConstants.DefaultDirPerms); err != nil { + panic(err) + } + + cookieStore := sessions.NewFilesystemStore(sessionsDir, cookieStoreHashKey) + + cookieStore.MaxAge(cookiesMaxAge) + + ctlr.CookieStore = cookieStore + } else { + cookieStore := sessions.NewCookieStore(cookieStoreHashKey) + + cookieStore.MaxAge(cookiesMaxAge) + + ctlr.CookieStore = cookieStore + } + } + + // ldap and htpasswd based authN if ctlr.Config.HTTP.Auth != nil { if ctlr.Config.HTTP.Auth.LDAP != nil { ldapConfig := ctlr.Config.HTTP.Auth.LDAP - ldapClient = &LDAPClient{ + amw.ldapClient = &LDAPClient{ Host: ldapConfig.Address, Port: ldapConfig.Port, UseSSL: !ldapConfig.Insecure, @@ -160,18 +336,18 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc { caCertPool := x509.NewCertPool() if !caCertPool.AppendCertsFromPEM(caCert) { - panic(errors.ErrBadCACert) + panic(zerr.ErrBadCACert) } - ldapClient.ClientCAs = caCertPool + amw.ldapClient.ClientCAs = caCertPool } else { // default to system cert pool caCertPool, err := x509.SystemCertPool() if err != nil { - panic(errors.ErrBadCACert) + panic(zerr.ErrBadCACert) } - ldapClient.ClientCAs = caCertPool + amw.ldapClient.ClientCAs = caCertPool } } @@ -188,12 +364,27 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc { line := scanner.Text() if strings.Contains(line, ":") { tokens := strings.Split(scanner.Text(), ":") - credMap[tokens[0]] = tokens[1] + amw.credMap[tokens[0]] = tokens[1] } } } } + // openid based authN + if ctlr.Config.HTTP.Auth.OpenID != nil { + ctlr.RelyingParties = make(map[string]rp.RelyingParty) + + for provider := range ctlr.Config.HTTP.Auth.OpenID.Providers { + if IsOpenIDSupported(provider) { + rp := NewRelyingPartyOIDC(ctlr.Config, provider) + ctlr.RelyingParties[provider] = rp + } else if IsOauth2Supported(provider) { + rp := NewRelyingPartyGithub(ctlr.Config, provider) + ctlr.RelyingParties[provider] = rp + } + } + } + return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { if request.Method == http.MethodOptions { @@ -203,84 +394,231 @@ func basicAuthHandler(ctlr *Controller) mux.MiddlewareFunc { return } - // we want to bypass auth for mgmt route - isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix - - if request.Header.Get("Authorization") == "" { - if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested { - // Process request - ctx := getReqContextWithAuthorization("", []string{}, request) - next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck - - return - } - } - - username, passphrase, err := getUsernamePasswordBasicAuth(request) + //nolint: contextcheck + authenticated, cloneResp, cloneReq, err := amw.basicAuthn(ctlr, response, request) if err != nil { - ctlr.Log.Error().Err(err).Msg("failed to parse authorization header") - authFail(response, realm, delay) + response.WriteHeader(http.StatusInternalServerError) return } - // some client tools might send Authorization: Basic Og== (decoded into ":") - // empty username and password - if username == "" && passphrase == "" { - if ctlr.Config.HTTP.AccessControl.AnonymousPolicyExists() || isMgmtRequested { - // Process request - ctx := getReqContextWithAuthorization("", []string{}, request) - next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck + if authenticated && cloneResp != nil && cloneReq != nil { + next.ServeHTTP(cloneResp, cloneReq) - return - } + return } - // first, HTTPPassword authN (which is local) - passphraseHash, ok := credMap[username] - if ok { - if err := bcrypt.CompareHashAndPassword([]byte(passphraseHash), []byte(passphrase)); err == nil { - // Process request - var userGroups []string - - if ctlr.Config.HTTP.AccessControl != nil { - ac := NewAccessController(ctlr.Config) - userGroups = ac.getUserGroups(username) - } - - ctx := getReqContextWithAuthorization(username, userGroups, request) - next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck - - return - } - } - - // next, LDAP if configured (network-based which can lose connectivity) - if ctlr.Config.HTTP.Auth != nil && ctlr.Config.HTTP.Auth.LDAP != nil { - ok, _, ldapgroups, err := ldapClient.Authenticate(username, passphrase) - if ok && err == nil { - // Process request - var userGroups []string - - if ctlr.Config.HTTP.AccessControl != nil { - ac := NewAccessController(ctlr.Config) - userGroups = ac.getUserGroups(username) - } - - userGroups = append(userGroups, ldapgroups...) - - ctx := getReqContextWithAuthorization(username, userGroups, request) - next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck - - return - } - } - - authFail(response, realm, delay) + //nolint: contextcheck + amw.sessionAuthn(ctlr, next, response, request, delay) }) } } +func bearerAuthHandler(ctlr *Controller) mux.MiddlewareFunc { + authorizer, err := auth.NewAuthorizer(&auth.AuthorizerOptions{ + Realm: ctlr.Config.HTTP.Auth.Bearer.Realm, + Service: ctlr.Config.HTTP.Auth.Bearer.Service, + PublicKeyPath: ctlr.Config.HTTP.Auth.Bearer.Cert, + AccessEntryType: bearerAuthDefaultAccessEntryType, + EmptyDefaultNamespace: true, + }) + if err != nil { + ctlr.Log.Panic().Err(err).Msg("error creating bearer authorizer") + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + if request.Method == http.MethodOptions { + next.ServeHTTP(response, request) + response.WriteHeader(http.StatusNoContent) + + return + } + acCtrlr := NewAccessController(ctlr.Config) + vars := mux.Vars(request) + name := vars["name"] + + // we want to bypass auth for mgmt route + isMgmtRequested := request.RequestURI == constants.FullMgmtPrefix + + header := request.Header.Get("Authorization") + + if (header == "" || header == "Basic Og==") && isMgmtRequested { + next.ServeHTTP(response, request) + + return + } + + action := auth.PullAction + if m := request.Method; m != http.MethodGet && m != http.MethodHead { + action = auth.PushAction + } + + permissions, err := authorizer.Authorize(header, action, name) + if err != nil { + ctlr.Log.Error().Err(err).Msg("issue parsing Authorization header") + response.Header().Set("Content-Type", "application/json") + common.WriteJSON(response, http.StatusInternalServerError, apiErr.NewErrorList(apiErr.NewError(apiErr.UNSUPPORTED))) + + return + } + + if !permissions.Allowed { + response.Header().Set("Content-Type", "application/json") + response.Header().Set("WWW-Authenticate", permissions.WWWAuthenticateHeader) + + common.WriteJSON(response, http.StatusUnauthorized, + apiErr.NewErrorList(apiErr.NewError(apiErr.UNAUTHORIZED))) + + return + } + + amCtx := acCtrlr.getAuthnMiddlewareContext(BEARER, request) + next.ServeHTTP(response, request.WithContext(amCtx)) //nolint:contextcheck + }) + } +} + +func noPasswdAuth(config *config.Config) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + if request.Method == http.MethodOptions { + next.ServeHTTP(response, request) + response.WriteHeader(http.StatusNoContent) + + return + } + + ctx := getReqContextWithAuthorization("", []string{}, request) + // Process request + next.ServeHTTP(response, request.WithContext(ctx)) //nolint:contextcheck + }) + } +} + +func (rh *RouteHandler) AuthURLHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + callbackUI := query.Get(constants.CallbackUIQueryParam) + + provider := query.Get("provider") + + client, ok := rh.c.RelyingParties[provider] + if !ok { + http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + response.WriteHeader(http.StatusBadRequest) + })(w, r) + } + + /* save cookie containing state to later verify it and + callback ui where we will redirect after openid/oauth2 logic is completed*/ + session, _ := rh.c.CookieStore.Get(r, "statecookie") + + session.Options.Secure = true + session.Options.HttpOnly = true + session.Options.SameSite = http.SameSiteDefaultMode + session.Options.Path = constants.CallbackBasePath + + state := uuid.New().String() + + session.Values["state"] = state + session.Values["callback"] = callbackUI + + // let the session set its own id + err := session.Save(r, w) + if err != nil { + rh.c.Log.Error().Err(err).Msg("unable to save http session") + + w.WriteHeader(http.StatusInternalServerError) + + return + } + + stateFunc := func() string { + return state + } + + rp.AuthURLHandler(stateFunc, client)(w, r) + } +} + +func NewRelyingPartyOIDC(config *config.Config, provider string) rp.RelyingParty { + issuer, clientID, clientSecret, redirectURI, scopes, options := getRelyingPartyArgs(config, provider) + + relyingParty, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...) + if err != nil { + panic(err) + } + + return relyingParty +} + +func NewRelyingPartyGithub(config *config.Config, provider string) rp.RelyingParty { + _, clientID, clientSecret, redirectURI, scopes, options := getRelyingPartyArgs(config, provider) + + rpConfig := &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURI, + Scopes: scopes, + Endpoint: githubOAuth.Endpoint, + } + + relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, options...) + if err != nil { + panic(err) + } + + return relyingParty +} + +func getRelyingPartyArgs(config *config.Config, provider string) ( + string, string, string, string, []string, []rp.Option, +) { + if _, ok := config.HTTP.Auth.OpenID.Providers[provider]; !ok { + panic(zerr.ErrOpenIDProviderDoesNotExist) + } + + scheme := "http" + if config.HTTP.TLS != nil { + scheme = "https" + } + + clientID := config.HTTP.Auth.OpenID.Providers[provider].ClientID + clientSecret := config.HTTP.Auth.OpenID.Providers[provider].ClientSecret + + scopes := config.HTTP.Auth.OpenID.Providers[provider].Scopes + // openid scope must be the first one in list + if !common.Contains(scopes, oidc.ScopeOpenID) && IsOpenIDSupported(provider) { + scopes = append([]string{oidc.ScopeOpenID}, scopes...) + } + + port := config.HTTP.Port + issuer := config.HTTP.Auth.OpenID.Providers[provider].Issuer + keyPath := config.HTTP.Auth.OpenID.Providers[provider].KeyPath + baseURL := net.JoinHostPort(config.HTTP.Address, port) + redirectURI := fmt.Sprintf("%s://%s%s", scheme, baseURL, constants.CallbackBasePath+fmt.Sprintf("/%s", provider)) + + options := []rp.Option{ + rp.WithVerifierOpts(rp.WithIssuedAtOffset(issuedAtOffset)), + } + + key := securecookie.GenerateRandomKey(32) //nolint: gomnd + + cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithMaxAge(relyingPartyCookieMaxAge)) + options = append(options, rp.WithCookieHandler(cookieHandler)) + + if clientSecret == "" { + options = append(options, rp.WithPKCE(cookieHandler)) + } + + if keyPath != "" { + options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) + } + + return issuer, clientID, clientSecret, redirectURI, scopes, options +} + func getReqContextWithAuthorization(username string, groups []string, request *http.Request) context.Context { acCtx := localCtx.AccessControlContext{ Username: username, @@ -314,9 +652,71 @@ func isBearerAuthEnabled(config *config.Config) bool { return false } -func authFail(w http.ResponseWriter, realm string, delay int) { +func isOpenIDAuthEnabled(config *config.Config) bool { + if config.HTTP.Auth != nil && + config.HTTP.Auth.OpenID != nil { + for provider := range config.HTTP.Auth.OpenID.Providers { + if isOpenIDAuthProviderEnabled(config, provider) { + return true + } + } + } + + return false +} + +func isAPIKeyEnabled(config *config.Config) bool { + if config.Extensions != nil && config.Extensions.APIKey != nil && + *config.Extensions.APIKey.Enable { + return true + } + + return false +} + +func isOpenIDAuthProviderEnabled(config *config.Config, provider string) bool { + if providerConfig, ok := config.HTTP.Auth.OpenID.Providers[provider]; ok { + if IsOpenIDSupported(provider) { + if providerConfig.ClientID != "" || providerConfig.Issuer != "" || + len(providerConfig.Scopes) > 0 { + return true + } + } else if IsOauth2Supported(provider) { + if providerConfig.ClientID != "" || len(providerConfig.Scopes) > 0 { + return true + } + } + } + + return false +} + +func IsOpenIDSupported(provider string) bool { + supported := []string{"google", "gitlab", "dex"} + + return common.Contains(supported, provider) +} + +func IsOauth2Supported(provider string) bool { + supported := []string{"github"} + + return common.Contains(supported, provider) +} + +func authFail(w http.ResponseWriter, r *http.Request, realm string, delay int) { time.Sleep(time.Duration(delay) * time.Second) - w.Header().Set("WWW-Authenticate", realm) + + // don't send auth headers if request is coming from UI + if r.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue { + if realm == "" { + realm = "Authorization Required" + } + + realm = "Basic realm=" + strconv.Quote(realm) + + w.Header().Set("WWW-Authenticate", realm) + } + w.Header().Set("Content-Type", "application/json") common.WriteJSON(w, http.StatusUnauthorized, apiErr.NewErrorList(apiErr.NewError(apiErr.UNAUTHORIZED))) } @@ -325,12 +725,12 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error) basicAuth := request.Header.Get("Authorization") if basicAuth == "" { - return "", "", errors.ErrParsingAuthHeader + return "", "", zerr.ErrParsingAuthHeader } - splitStr := strings.SplitN(basicAuth, " ", 2) //nolint:gomnd + splitStr := strings.SplitN(basicAuth, " ", 2) //nolint: gomnd if len(splitStr) != 2 || strings.ToLower(splitStr[0]) != "basic" { - return "", "", errors.ErrParsingAuthHeader + return "", "", zerr.ErrParsingAuthHeader } decodedStr, err := base64.StdEncoding.DecodeString(splitStr[1]) @@ -338,9 +738,9 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error) return "", "", err } - pair := strings.SplitN(string(decodedStr), ":", 2) //nolint:gomnd - if len(pair) != 2 { //nolint:gomnd - return "", "", errors.ErrParsingAuthHeader + pair := strings.SplitN(string(decodedStr), ":", 2) //nolint: gomnd + if len(pair) != 2 { //nolint: gomnd + return "", "", zerr.ErrParsingAuthHeader } username := pair[0] @@ -348,3 +748,118 @@ func getUsernamePasswordBasicAuth(request *http.Request) (string, string, error) return username, passphrase, nil } + +func GetGithubUserInfo(ctx context.Context, client *github.Client, log log.Logger) (string, []string, error) { + var primaryEmail string + + userEmails, _, err := client.Users.ListEmails(ctx, nil) + if err != nil { + log.Error().Msg("couldn't set user record for empty email value") + + return "", []string{}, err + } + + if len(userEmails) != 0 { + for _, email := range userEmails { // should have at least one primary email, if any + if email.GetPrimary() { // check if it's primary email + primaryEmail = email.GetEmail() + + break + } + } + } + + orgs, _, err := client.Organizations.List(ctx, "", nil) + if err != nil { + log.Error().Msg("couldn't set user record for empty email value") + + return "", []string{}, err + } + + groups := []string{} + for _, org := range orgs { + groups = append(groups, *org.Login) + } + + return primaryEmail, groups, nil +} + +func saveUserLoggedSession(cookieStore sessions.Store, response http.ResponseWriter, + request *http.Request, identity string, log log.Logger, +) error { + session, _ := cookieStore.Get(request, "session") + + session.Options.Secure = true + session.Options.HttpOnly = true + session.Options.SameSite = http.SameSiteDefaultMode + session.Values["authStatus"] = true + session.Values["user"] = identity + + // let the session set its own id + err := session.Save(request, response) + if err != nil { + log.Error().Err(err).Str("identity", identity).Msg("unable to save http session") + + return err + } + + userInfoCookie := sessions.NewCookie("user", identity, &sessions.Options{ + Secure: true, + HttpOnly: false, + MaxAge: cookiesMaxAge, + SameSite: http.SameSiteDefaultMode, + Path: "/", + }) + + http.SetCookie(response, userInfoCookie) + + return nil +} + +// OAuth2Callback is the callback logic where openid/oauth2 will redirect back to our app. +func OAuth2Callback(ctlr *Controller, w http.ResponseWriter, r *http.Request, state, email string, + groups []string, +) (string, error) { + stateCookie, _ := ctlr.CookieStore.Get(r, "statecookie") + + stateOrigin, ok := stateCookie.Values["state"].(string) + if !ok { + ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Msg("openID: unable to get 'state' cookie from request") + + return "", zerr.ErrInvalidStateCookie + } + + if stateOrigin != state { + ctlr.Log.Error().Err(zerr.ErrInvalidStateCookie).Msg("openID: 'state' cookie differs from the actual one") + + return "", zerr.ErrInvalidStateCookie + } + + ctx := getReqContextWithAuthorization(email, groups, r) + + // if this line has been reached, then a new session should be created + // if the `session` key is already on the cookie, it's not a valid one + if err := saveUserLoggedSession(ctlr.CookieStore, w, r, email, ctlr.Log); err != nil { + return "", err + } + + if err := ctlr.RepoDB.SetUserGroups(ctx, groups); err != nil { + ctlr.Log.Error().Err(err).Str("identity", email).Msg("couldn't update the user profile") + + return "", err + } + + ctlr.Log.Info().Msgf("user profile set successfully for email %s", email) + + // redirect to UI + callbackUI, _ := stateCookie.Values["callback"].(string) + + return callbackUI, nil +} + +func hashUUID(uuid string) string { + digester := sha256.New() + digester.Write([]byte(uuid)) + + return godigest.NewDigestFromEncoded(godigest.SHA256, fmt.Sprintf("%x", digester.Sum(nil))).Encoded() +} diff --git a/pkg/api/authz.go b/pkg/api/authz.go index 3755393d..6b3bcb3d 100644 --- a/pkg/api/authz.go +++ b/pkg/api/authz.go @@ -21,6 +21,9 @@ const ( Delete = "delete" // behaviour actions. DetectManifestCollision = "detectManifestCollision" + BASIC = "Basic" + BEARER = "Bearer" + OPENID = "OpenID" ) // AccessController authorizes users to act on resources. @@ -29,10 +32,17 @@ type AccessController struct { Log log.Logger } -func NewAccessController(config *config.Config) *AccessController { +func NewAccessController(conf *config.Config) *AccessController { + if conf.HTTP.AccessControl == nil { + return &AccessController{ + Config: &config.AccessControlConfig{}, + Log: log.NewLogger(conf.Log.Level, conf.Log.Output), + } + } + return &AccessController{ - Config: config.HTTP.AccessControl, - Log: log.NewLogger(config.Log.Level, config.Log.Output), + Config: conf.HTTP.AccessControl, + Log: log.NewLogger(conf.Log.Level, conf.Log.Output), } } @@ -171,6 +181,18 @@ func (ac *AccessController) getContext(acCtx *localCtx.AccessControlContext, req return ctx } +// getAuthnMiddlewareContext builds ac context(allowed to read repos and if user is admin) and returns it. +func (ac *AccessController) getAuthnMiddlewareContext(authnType string, request *http.Request) context.Context { + amwCtx := localCtx.AuthnMiddlewareContext{ + AuthnType: authnType, + } + + amwCtxKey := localCtx.GetAuthnMiddlewareCtxKey() + ctx := context.WithValue(request.Context(), amwCtxKey, amwCtx) + + return ctx +} + // isPermitted returns true if username can do action on a repository policy. func (ac *AccessController) isPermitted(userGroups []string, username, action string, policyGroup config.PolicyGroup, @@ -231,6 +253,14 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { return } + // request comes from bearer authn, bypass it + authnMwCtx, err := localCtx.GetAuthnMiddlewareContext(request.Context()) + if err != nil || (authnMwCtx != nil && authnMwCtx.AuthnType == BEARER) { + next.ServeHTTP(response, request) + + return + } + // bypass authz for /v2/ route if request.RequestURI == "/v2/" { next.ServeHTTP(response, request) @@ -242,8 +272,6 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { var identity string - var err error - // anonymous context acCtx := &localCtx.AccessControlContext{} @@ -252,7 +280,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { // get access control context made in authn.go if authn is enabled acCtx, err = localCtx.GetAccessControlContext(request.Context()) if err != nil { // should never happen - authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) return } @@ -272,7 +300,7 @@ func BaseAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { // if we still don't have an identity if identity == "" { acCtrlr.Log.Info().Msg("couldn't get identity from TLS certificate") - authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) return } @@ -298,6 +326,14 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { return } + // request comes from bearer authn, bypass it + authnMwCtx, err := localCtx.GetAuthnMiddlewareContext(request.Context()) + if err != nil || (authnMwCtx != nil && authnMwCtx.AuthnType == BEARER) { + next.ServeHTTP(response, request) + + return + } + vars := mux.Vars(request) resource := vars["name"] reference, ok := vars["reference"] @@ -306,12 +342,10 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { var identity string - var err error - // get acCtx built in authn and previous authz middlewares acCtx, err := localCtx.GetAccessControlContext(request.Context()) if err != nil { // should never happen - authFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + authFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) return } @@ -344,7 +378,7 @@ func DistSpecAuthzHandler(ctlr *Controller) mux.MiddlewareFunc { can := acCtrlr.can(request.Context(), identity, action, resource) //nolint:contextcheck if !can { - common.AuthzFail(response, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) + common.AuthzFail(response, request, ctlr.Config.HTTP.Realm, ctlr.Config.HTTP.Auth.FailDelay) } else { next.ServeHTTP(response, request) //nolint:contextcheck } diff --git a/pkg/api/config/config.go b/pkg/api/config/config.go index d1a891cc..1e038d7e 100644 --- a/pkg/api/config/config.go +++ b/pkg/api/config/config.go @@ -45,6 +45,7 @@ type AuthConfig struct { HTPasswd AuthHTPasswd LDAP *LDAPConfig Bearer *BearerConfig + OpenID *OpenIDConfig } type BearerConfig struct { @@ -53,6 +54,18 @@ type BearerConfig struct { Cert string } +type OpenIDConfig struct { + Providers map[string]OpenIDProviderConfig +} + +type OpenIDProviderConfig struct { + ClientID string + ClientSecret string + KeyPath string + Issuer string + Scopes []string +} + type MethodRatelimitConfig struct { Method string Rate int @@ -63,6 +76,7 @@ type RatelimitConfig struct { Methods []MethodRatelimitConfig `mapstructure:",omitempty"` } +//nolint:maligned type HTTPConfig struct { Address string Port string diff --git a/pkg/api/constants/consts.go b/pkg/api/constants/consts.go index 82375ef7..d5cf28da 100644 --- a/pkg/api/constants/consts.go +++ b/pkg/api/constants/consts.go @@ -12,4 +12,11 @@ const ( DefaultMediaType = "application/json" BinaryMediaType = "application/octet-stream" DefaultMetricsExtensionRoute = "/metrics" + CallbackBasePath = "/auth/callback" + LoginPath = "/auth/login" + LogoutPath = "/auth/logout" + SessionClientHeaderName = "X-ZOT-API-CLIENT" + SessionClientHeaderValue = "zot-ui" + APIKeysPrefix = "zak_" + CallbackUIQueryParam = "callback_ui" ) diff --git a/pkg/api/constants/extensions.go b/pkg/api/constants/extensions.go index acd1c25a..ca490bc4 100644 --- a/pkg/api/constants/extensions.go +++ b/pkg/api/constants/extensions.go @@ -18,4 +18,7 @@ const ( ExtUserPreferences = "/userprefs" ExtUserPreferencesPrefix = ExtPrefix + ExtUserPreferences FullUserPreferencesPrefix = RoutePrefix + ExtUserPreferencesPrefix + ExtAPIKey = "/apikey" + ExtAPIKeyPrefix = ExtPrefix + ExtAPIKey //nolint: gosec + FullAPIKeyPrefix = RoutePrefix + ExtAPIKeyPrefix ) diff --git a/pkg/api/controller.go b/pkg/api/controller.go index dd1131c0..30025b58 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -16,6 +16,8 @@ import ( "github.com/gorilla/handlers" "github.com/gorilla/mux" + "github.com/gorilla/sessions" + "github.com/zitadel/oidc/pkg/client/rp" "zotregistry.io/zot/errors" "zotregistry.io/zot/pkg/api/config" @@ -31,6 +33,7 @@ import ( const ( idleTimeout = 120 * time.Second readHeaderTimeout = 5 * time.Second + cookiesMaxAge = 86400 // seconds ) type Controller struct { @@ -44,6 +47,8 @@ type Controller struct { Metrics monitoring.MetricServer CveInfo ext.CveInfo SyncOnDemand SyncOnDemand + RelyingParties map[string]rp.RelyingParty + CookieStore sessions.Store // runtime params chosenPort int // kernel-chosen port } @@ -254,7 +259,9 @@ func (c *Controller) InitImageStore() error { } func (c *Controller) InitRepoDB(reloadCtx context.Context) error { - if c.Config.Extensions != nil && c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable { + // init repoDB if search is enabled or authn enabled (need to store user profiles) or apikey ext is enabled + if (c.Config.Extensions != nil && c.Config.Extensions.Search != nil && *c.Config.Extensions.Search.Enable) || + isAuthnEnabled(c.Config) || isOpenIDAuthEnabled(c.Config) || isAPIKeyEnabled(c.Config) { driver, err := repodbfactory.New(c.Config.Storage.StorageConfig, c.Log) //nolint:contextcheck if err != nil { return err diff --git a/pkg/api/controller_test.go b/pkg/api/controller_test.go index cae0ee72..bc67fc81 100644 --- a/pkg/api/controller_test.go +++ b/pkg/api/controller_test.go @@ -1,5 +1,5 @@ -//go:build sync && scrub && metrics && search -// +build sync,scrub,metrics,search +//go:build sync && scrub && metrics && search && mgmt +// +build sync,scrub,metrics,search,mgmt package api_test @@ -19,18 +19,24 @@ import ( "net/url" "os" "path" + "path/filepath" "strconv" "strings" "testing" "time" + "github.com/google/go-github/v52/github" "github.com/gorilla/mux" + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" + "github.com/migueleliasweb/go-github-mock/src/mock" vldap "github.com/nmcclain/ldap" notreg "github.com/notaryproject/notation-go/registry" distext "github.com/opencontainers/distribution-spec/specs-go/v1/extensions" godigest "github.com/opencontainers/go-digest" ispec "github.com/opencontainers/image-spec/specs-go/v1" artifactspec "github.com/oras-project/artifacts-spec/specs-go/v1" + "github.com/project-zot/mockoidc" "github.com/sigstore/cosign/v2/cmd/cosign/cli/generate" "github.com/sigstore/cosign/v2/cmd/cosign/cli/options" "github.com/sigstore/cosign/v2/cmd/cosign/cli/sign" @@ -55,6 +61,7 @@ import ( storageConstants "zotregistry.io/zot/pkg/storage/constants" "zotregistry.io/zot/pkg/test" "zotregistry.io/zot/pkg/test/inject" + "zotregistry.io/zot/pkg/test/mocks" ) const ( @@ -249,6 +256,7 @@ func TestCreateRepoDBDriver(t *testing.T) { "manifestdatatablename": "ManifestDataTable", "indexdatatablename": "IndexDataTable", "userdatatablename": "ZotUserDataTable", + "apikeytablename": "APIKeyTable", "versiontablename": "1", } @@ -356,6 +364,9 @@ func TestAutoPortSelection(t *testing.T) { func TestObjectStorageController(t *testing.T) { skipIt(t) + + bucket := "zot-storage-test" + Convey("Negative make a new object storage controller", t, func() { port := test.GetFreePort() conf := config.New() @@ -377,7 +388,6 @@ func TestObjectStorageController(t *testing.T) { conf := config.New() conf.HTTP.Port = port - bucket := "zot-storage-test" endpoint := os.Getenv("S3MOCK_ENDPOINT") storageDriverParams := map[string]interface{}{ @@ -389,6 +399,7 @@ func TestObjectStorageController(t *testing.T) { "secure": false, "skipverify": false, } + conf.Storage.StorageDriver = storageDriverParams ctlr := makeController(conf, "/", "") So(ctlr, ShouldNotBeNil) @@ -397,16 +408,92 @@ func TestObjectStorageController(t *testing.T) { cm.StartAndWait(port) defer cm.StopServer() }) + + Convey("Make a new object storage controller with openid", t, func() { + port := test.GetFreePort() + conf := config.New() + conf.HTTP.Port = port + + endpoint := os.Getenv("S3MOCK_ENDPOINT") + + storageDriverParams := map[string]interface{}{ + "rootdirectory": "/zot", + "name": storageConstants.S3StorageDriverName, + "region": "us-east-2", + "bucket": bucket, + "regionendpoint": endpoint, + "secure": false, + "skipverify": false, + } + conf.Storage.RemoteCache = true + conf.Storage.StorageDriver = storageDriverParams + + conf.Storage.CacheDriver = map[string]interface{}{ + "name": "dynamodb", + "endpoint": "http://localhost:4566", + "region": "us-east-2", + "cachetablename": "test", + "repometatablename": "RepoMetadataTable", + "manifestdatatablename": "ManifestDataTable", + "indexdatatablename": "IndexDataTable", + "userdatatablename": "ZotUserDataTable", + "apikeytablename": "APIKeyTable1", + "versiontablename": "Version", + } + + mockOIDCServer, err := test.MockOIDCRun() + if err != nil { + panic(err) + } + + defer func() { + err := mockOIDCServer.Shutdown() + if err != nil { + panic(err) + } + }() + + mockOIDCConfig := mockOIDCServer.Config() + + conf.HTTP.Auth = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "dex": { + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + KeyPath: "", + Issuer: mockOIDCConfig.Issuer, + Scopes: []string{"openid", "email"}, + }, + }, + }, + } + + // create s3 bucket + _, err = resty.R().Put("http://" + os.Getenv("S3MOCK_ENDPOINT") + "/" + bucket) + if err != nil { + panic(err) + } + + ctlr := makeController(conf, "/", "") + So(ctlr, ShouldNotBeNil) + + cm := test.NewControllerManager(ctlr) + cm.StartAndWait(port) + defer cm.StopServer() + }) } func TestObjectStorageControllerSubPaths(t *testing.T) { skipIt(t) + + bucket := "zot-storage-test" + Convey("Make a new object storage controller", t, func() { port := test.GetFreePort() conf := config.New() conf.HTTP.Port = port - bucket := "zot-storage-test" endpoint := os.Getenv("S3MOCK_ENDPOINT") storageDriverParams := map[string]interface{}{ @@ -471,12 +558,12 @@ func TestHtpasswdSingleCred(t *testing.T) { So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusOK) - header := []string{"Authorization,content-type"} + header := []string{"Authorization,content-type," + constants.SessionClientHeaderName} resp, _ = resty.R().SetBasicAuth(user, password).Options(baseURL + "/v2/") So(resp, ShouldNotBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusNoContent) - So(len(resp.Header()), ShouldEqual, 4) + So(len(resp.Header()), ShouldEqual, 5) So(resp.Header()["Access-Control-Allow-Headers"], ShouldResemble, header) So(resp.Header().Get("Access-Control-Allow-Methods"), ShouldResemble, "HEAD,GET,POST,OPTIONS") @@ -2184,6 +2271,29 @@ func TestBearerAuth(t *testing.T) { }) } +func TestBearerAuthWrongAuthorizer(t *testing.T) { + Convey("Make a new authorizer", t, func() { + port := test.GetFreePort() + + conf := config.New() + conf.HTTP.Port = port + conf.HTTP.Auth = &config.AuthConfig{ + Bearer: &config.BearerConfig{ + Cert: "bla", + Realm: "blabla", + Service: "blablabla", + }, + } + ctlr := makeController(conf, t.TempDir(), "") + cm := test.NewControllerManager(ctlr) + + So(func() { + ctx := context.Background() + cm.RunServer(ctx) + }, ShouldPanic) + }) +} + func TestBearerAuthWithAllowReadAccess(t *testing.T) { Convey("Make a new controller", t, func() { authTestServer := test.MakeAuthTestServer(ServerKey, UnauthorizedNamespace) @@ -2354,11 +2464,1004 @@ func TestBearerAuthWithAllowReadAccess(t *testing.T) { }) } +func TestNewRelyingPartyOIDC(t *testing.T) { + Convey("Test NewRelyingPartyOIDC", t, func() { + conf := config.New() + + mockOIDCServer, err := test.MockOIDCRun() + if err != nil { + panic(err) + } + + defer func() { + err := mockOIDCServer.Shutdown() + if err != nil { + panic(err) + } + }() + + mockOIDCConfig := mockOIDCServer.Config() + + conf.HTTP.Auth = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "dex": { + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + KeyPath: "", + Issuer: mockOIDCConfig.Issuer, + Scopes: []string{"openid", "email"}, + }, + }, + }, + } + + Convey("provider not found in config", func() { + So(func() { _ = api.NewRelyingPartyOIDC(conf, "notDex") }, ShouldPanic) + }) + + Convey("key path not found on disk", func() { + dexProviderCfg := conf.HTTP.Auth.OpenID.Providers["dex"] + dexProviderCfg.KeyPath = "path/to/file" + conf.HTTP.Auth.OpenID.Providers["dex"] = dexProviderCfg + + So(func() { _ = api.NewRelyingPartyOIDC(conf, "dex") }, ShouldPanic) + }) + + Convey("https callback", func() { + conf.HTTP.TLS = &config.TLSConfig{ + Cert: ServerCert, + Key: ServerKey, + } + + rp := api.NewRelyingPartyOIDC(conf, "dex") + So(rp, ShouldNotBeNil) + }) + + Convey("no client secret in config", func() { + dexProvider := conf.HTTP.Auth.OpenID.Providers["dex"] + dexProvider.ClientSecret = "" + conf.HTTP.Auth.OpenID.Providers["dex"] = dexProvider + + rp := api.NewRelyingPartyOIDC(conf, "dex") + So(rp, ShouldNotBeNil) + }) + + Convey("provider issuer unreachable", func() { + dexProvider := conf.HTTP.Auth.OpenID.Providers["dex"] + dexProvider.Issuer = "" + conf.HTTP.Auth.OpenID.Providers["dex"] = dexProvider + + So(func() { _ = api.NewRelyingPartyOIDC(conf, "dex") }, ShouldPanic) + }) + }) +} + +func TestOpenIDMiddleware(t *testing.T) { + port := test.GetFreePort() + baseURL := test.GetBaseURL(port) + defaultVal := true + + conf := config.New() + conf.HTTP.Port = port + + // need a username different than ldap one, to test both logic + htpasswdUsername := "htpasswduser" + content := fmt.Sprintf("%s:$2y$05$hlbSXDp6hzDLu6VwACS39ORvVRpr3OMR4RlJ31jtlaOEGnPjKZI1m\n", htpasswdUsername) + htpasswdPath := test.MakeHtpasswdFileFromString(content) + + defer os.Remove(htpasswdPath) + + ldapServer := newTestLDAPServer() + port = test.GetFreePort() + + ldapPort, err := strconv.Atoi(port) + if err != nil { + panic(err) + } + + ldapServer.Start(ldapPort) + defer ldapServer.Stop() + + mockOIDCServer, err := test.MockOIDCRun() + if err != nil { + panic(err) + } + + defer func() { + err := mockOIDCServer.Shutdown() + if err != nil { + panic(err) + } + }() + + mockOIDCConfig := mockOIDCServer.Config() + conf.HTTP.Auth = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{ + Path: htpasswdPath, + }, + LDAP: &config.LDAPConfig{ + Insecure: true, + Address: LDAPAddress, + Port: ldapPort, + BindDN: LDAPBindDN, + BindPassword: LDAPBindPassword, + BaseDN: LDAPBaseDN, + UserAttribute: "uid", + }, + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "dex": { + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + KeyPath: "", + Issuer: mockOIDCConfig.Issuer, + Scopes: []string{"openid", "email"}, + }, + // just for the constructor coverage + "github": { + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + KeyPath: "", + Issuer: mockOIDCConfig.Issuer, + Scopes: []string{"openid", "email"}, + }, + }, + }, + } + + mgmtConfg := &extconf.MgmtConfig{ + BaseConfig: extconf.BaseConfig{Enable: &defaultVal}, + } + + conf.Extensions = &extconf.ExtensionConfig{ + Mgmt: mgmtConfg, + } + + ctlr := api.NewController(conf) + dir := t.TempDir() + + ctlr.Config.Storage.RootDirectory = dir + + cm := test.NewControllerManager(ctlr) + + cm.StartServer() + defer cm.StopServer() + test.WaitTillServerReady(baseURL) + + Convey("browser client requests", t, func() { + Convey("login with no provider supplied", func() { + client := resty.New() + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + // first login user + resp, err := client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("provider", "unknown"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest) + }) + + Convey("login with openid and get catalog with session", func() { + client := resty.New() + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + + Convey("with callback_ui value provided", func() { + // first login user + resp, err := client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("provider", "dex"). + SetQueryParam("callback_ui", baseURL+"/v2/"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + }) + + // first login user + resp, err := client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("provider", "dex"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusCreated) + + client.SetCookies(resp.Cookies()) + + // call endpoint with session (added to client after previous request) + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // logout with options method for coverage + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Options(baseURL + constants.LogoutPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + + // logout user + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Post(baseURL + constants.LogoutPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // calling endpoint should fail with unathorized access + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + + //nolint: dupl + Convey("login with basic auth(htpasswd) and get catalog with session", func() { + client := resty.New() + + // without creds, should get access error + resp, err := client.R().Get(baseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + var e apiErr.Error + err = json.Unmarshal(resp.Body(), &e) + So(err, ShouldBeNil) + + // first login user + // with creds, should get expected status code + resp, err = client.R().SetBasicAuth(htpasswdUsername, passphrase).Get(baseURL) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) + + resp, err = client.R().SetBasicAuth(htpasswdUsername, passphrase).Get(baseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + resp, err = client.R(). + SetBasicAuth(htpasswdUsername, passphrase). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + client.SetCookies(resp.Cookies()) + + // call endpoint with session, without credentials, (added to client after previous request) + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // logout user + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Post(baseURL + constants.LogoutPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // calling endpoint should fail with unathorized access + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + + //nolint: dupl + Convey("login with ldap and get catalog", func() { + client := resty.New() + + // without creds, should get access error + resp, err := client.R().Get(baseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + var e apiErr.Error + err = json.Unmarshal(resp.Body(), &e) + So(err, ShouldBeNil) + + // first login user + // with creds, should get expected status code + resp, err = client.R().SetBasicAuth(username, passphrase).Get(baseURL) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusNotFound) + + resp, err = client.R().SetBasicAuth(username, passphrase).Get(baseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + resp, err = client.R(). + SetBasicAuth(username, passphrase). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + client.SetCookies(resp.Cookies()) + + // call endpoint with session, without credentials, (added to client after previous request) + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // logout user + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Post(baseURL + constants.LogoutPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // calling endpoint should fail with unathorized access + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + + Convey("unauthenticated catalog request", func() { + client := resty.New() + + // mgmt should work both unauthenticated and authenticated + resp, err := client.R(). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // call endpoint without session + resp, err = client.R(). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + }) +} + +func TestIsOpenIDEnabled(t *testing.T) { + Convey("make oidc server", t, func() { + port := test.GetFreePort() + baseURL := test.GetBaseURL(port) + + conf := config.New() + conf.HTTP.Port = port + + mockOIDCServer, err := test.MockOIDCRun() + if err != nil { + panic(err) + } + + defer func() { + err := mockOIDCServer.Shutdown() + if err != nil { + panic(err) + } + }() + + rootDir := t.TempDir() + + Convey("Only OAuth2 provided", func() { + mockOIDCConfig := mockOIDCServer.Config() + conf.HTTP.Auth = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "github": { + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + KeyPath: "", + Issuer: mockOIDCConfig.Issuer, + Scopes: []string{"email", "groups"}, + }, + }, + }, + } + + ctlr := api.NewController(conf) + + ctlr.Config.Storage.RootDirectory = rootDir + + cm := test.NewControllerManager(ctlr) + + cm.StartServer() + defer cm.StopServer() + test.WaitTillServerReady(baseURL) + + resp, err := resty.R(). + Get(baseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + + Convey("Unsupported provider", func() { + mockOIDCConfig := mockOIDCServer.Config() + conf.HTTP.Auth = &config.AuthConfig{ + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "invalidProvider": { + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + KeyPath: "", + Issuer: mockOIDCConfig.Issuer, + Scopes: []string{"email", "groups"}, + }, + }, + }, + } + + ctlr := api.NewController(conf) + + ctlr.Config.Storage.RootDirectory = rootDir + + cm := test.NewControllerManager(ctlr) + + cm.StartServer() + defer cm.StopServer() + test.WaitTillServerReady(baseURL) + + resp, err := resty.R(). + Get(baseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + }) +} + +func TestAuthnSessionErrors(t *testing.T) { + Convey("make controller", t, func() { + port := test.GetFreePort() + baseURL := test.GetBaseURL(port) + defaultVal := true + + conf := config.New() + conf.HTTP.Port = port + invalidSessionID := "sessionID" + + // need a username different than ldap one, to test both logic + htpasswdUsername := "htpasswduser" + content := fmt.Sprintf("%s:$2y$05$hlbSXDp6hzDLu6VwACS39ORvVRpr3OMR4RlJ31jtlaOEGnPjKZI1m\n", htpasswdUsername) + + htpasswdPath := test.MakeHtpasswdFileFromString(content) + defer os.Remove(htpasswdPath) + + ldapServer := newTestLDAPServer() + port = test.GetFreePort() + + ldapPort, err := strconv.Atoi(port) + if err != nil { + panic(err) + } + + ldapServer.Start(ldapPort) + defer ldapServer.Stop() + + mockOIDCServer, err := test.MockOIDCRun() + if err != nil { + panic(err) + } + + defer func() { + err := mockOIDCServer.Shutdown() + if err != nil { + panic(err) + } + }() + + rootDir := t.TempDir() + + mockOIDCConfig := mockOIDCServer.Config() + conf.HTTP.Auth = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{ + Path: htpasswdPath, + }, + LDAP: &config.LDAPConfig{ + Insecure: true, + Address: LDAPAddress, + Port: ldapPort, + BindDN: LDAPBindDN, + BindPassword: LDAPBindPassword, + BaseDN: LDAPBaseDN, + UserAttribute: "uid", + }, + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "dex": { + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + KeyPath: "", + Issuer: mockOIDCConfig.Issuer, + Scopes: []string{"email", "groups"}, + }, + }, + }, + } + + mgmtConfg := &extconf.MgmtConfig{ + BaseConfig: extconf.BaseConfig{Enable: &defaultVal}, + } + + conf.Extensions = &extconf.ExtensionConfig{ + Mgmt: mgmtConfg, + } + + ctlr := api.NewController(conf) + + ctlr.Config.Storage.RootDirectory = rootDir + + cm := test.NewControllerManager(ctlr) + + cm.StartServer() + defer cm.StopServer() + test.WaitTillServerReady(baseURL) + + Convey("trigger basic authn middle(htpasswd) error", func() { + client := resty.New() + + ctlr.RepoDB = mocks.RepoDBMock{ + SetUserGroupsFn: func(ctx context.Context, groups []string) error { + return ErrUnexpectedError + }, + } + + resp, err := client.R(). + SetBasicAuth(htpasswdUsername, passphrase). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + }) + + Convey("trigger basic authn middle(ldap) error", func() { + client := resty.New() + + ctlr.RepoDB = mocks.RepoDBMock{ + SetUserGroupsFn: func(ctx context.Context, groups []string) error { + return ErrUnexpectedError + }, + } + + resp, err := client.R(). + SetBasicAuth(username, passphrase). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + }) + + Convey("trigger updateUserData error", func() { + client := resty.New() + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + + ctlr.RepoDB = mocks.RepoDBMock{ + SetUserGroupsFn: func(ctx context.Context, groups []string) error { + return ErrUnexpectedError + }, + } + + resp, err := client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("provider", "dex"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + }) + + Convey("trigger session middle repoDB errors", func() { + client := resty.New() + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + + user := mockoidc.DefaultUser() + user.Groups = []string{"group1", "group2"} + + mockOIDCServer.QueueUser(user) + + ctlr.RepoDB = mocks.RepoDBMock{} + + // first login user + resp, err := client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("provider", "dex"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusCreated) + + Convey("trigger session middle error internal server error", func() { + cookies := resp.Cookies() + + client.SetCookies(cookies) + + ctlr.RepoDB = mocks.RepoDBMock{ + GetUserGroupsFn: func(ctx context.Context) ([]string, error) { + return []string{}, ErrUnexpectedError + }, + } + + // call endpoint with session (added to client after previous request) + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + }) + + Convey("trigger session middle error GetUserGroups not found", func() { + cookies := resp.Cookies() + + client.SetCookies(cookies) + + ctlr.RepoDB = mocks.RepoDBMock{ + GetUserGroupsFn: func(ctx context.Context) ([]string, error) { + return []string{}, errors.ErrUserDataNotFound + }, + } + + // call endpoint with session (added to client after previous request) + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + }) + + Convey("trigger no email error in routes(callback)", func() { + user := mockoidc.DefaultUser() + user.Email = "" + + mockOIDCServer.QueueUser(user) + + client := resty.New() + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + + client.SetCookie(&http.Cookie{Name: "session"}) + + // call endpoint with session (added to client after previous request) + resp, err := client.R(). + SetQueryParam("provider", "dex"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + + Convey("trigger session save error in routes(callback)", func() { + err := os.Chmod(rootDir, 0o000) + So(err, ShouldBeNil) + + defer func() { + err := os.Chmod(rootDir, storageConstants.DefaultDirPerms) + So(err, ShouldBeNil) + }() + + client := resty.New() + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + + // first login user + resp, err := client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("provider", "dex"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + }) + + Convey("trigger session save error in basicAuthn", func() { + err := os.Chmod(rootDir, 0o000) + So(err, ShouldBeNil) + + defer func() { + err := os.Chmod(rootDir, storageConstants.DefaultDirPerms) + So(err, ShouldBeNil) + }() + + client := resty.New() + + // first htpasswd saveSessionLoggedUser() error + resp, err := client.R(). + SetBasicAuth(htpasswdUsername, passphrase). + Get(baseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + + // second ldap saveSessionLoggedUser() error + resp, err = client.R(). + SetBasicAuth(username, passphrase). + Get(baseURL + "/v2/") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + }) + + Convey("trigger session middle errors", func() { + client := resty.New() + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + + user := mockoidc.DefaultUser() + user.Groups = []string{"group1", "group2"} + + mockOIDCServer.QueueUser(user) + + // first login user + resp, err := client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("provider", "dex"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusCreated) + + Convey("trigger bad session encoding error in authn", func() { + cookies := resp.Cookies() + for _, cookie := range cookies { + if cookie.Name == "session" { + cookie.Value = "badSessionValue" + } + } + + client.SetCookies(cookies) + + // call endpoint with session (added to client after previous request) + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + + Convey("web request without cookies", func() { + client.SetCookie(&http.Cookie{}) + + // call endpoint with session (added to client after previous request) + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + + Convey("web request with userless cookie", func() { + // first get session + session, err := ctlr.CookieStore.Get(resp.RawResponse.Request, "session") + So(err, ShouldBeNil) + + session.ID = invalidSessionID + session.IsNew = false + session.Values["authStatus"] = true + + cookieStore, ok := ctlr.CookieStore.(*sessions.FilesystemStore) + So(ok, ShouldBeTrue) + + // first encode sessionID + encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, + cookieStore.Codecs...) + So(err, ShouldBeNil) + + // save cookie + cookie := sessions.NewCookie(session.Name(), encoded, session.Options) + client.SetCookie(cookie) + + // encode session values and save on disk + encoded, err = securecookie.EncodeMulti(session.Name(), session.Values, + cookieStore.Codecs...) + So(err, ShouldBeNil) + + filename := filepath.Join(rootDir, "session_"+session.ID) + + err = os.WriteFile(filename, []byte(encoded), 0o600) + So(err, ShouldBeNil) + + // call endpoint with session (added to client after previous request) + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + + Convey("web request with authStatus false cookie", func() { + // first get session + session, err := ctlr.CookieStore.Get(resp.RawResponse.Request, "session") + So(err, ShouldBeNil) + + session.ID = invalidSessionID + session.IsNew = false + session.Values["authStatus"] = false + session.Values["username"] = username + + cookieStore, ok := ctlr.CookieStore.(*sessions.FilesystemStore) + So(ok, ShouldBeTrue) + + // first encode sessionID + encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, + cookieStore.Codecs...) + So(err, ShouldBeNil) + + // save cookie + cookie := sessions.NewCookie(session.Name(), encoded, session.Options) + client.SetCookie(cookie) + + // encode session values and save on disk + encoded, err = securecookie.EncodeMulti(session.Name(), session.Values, + cookieStore.Codecs...) + So(err, ShouldBeNil) + + filename := filepath.Join(rootDir, "session_"+session.ID) + + err = os.WriteFile(filename, []byte(encoded), 0o600) + So(err, ShouldBeNil) + + // call endpoint with session (added to client after previous request) + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + }) + }) +} + +func TestAuthnRepoDBErrors(t *testing.T) { + Convey("make controller", t, func() { + port := test.GetFreePort() + baseURL := test.GetBaseURL(port) + conf := config.New() + conf.HTTP.Port = port + + htpasswdPath := test.MakeHtpasswdFile() + defer os.Remove(htpasswdPath) + + mockOIDCServer, err := test.MockOIDCRun() + if err != nil { + panic(err) + } + + defer func() { + err := mockOIDCServer.Shutdown() + if err != nil { + panic(err) + } + }() + + rootDir := t.TempDir() + + mockOIDCConfig := mockOIDCServer.Config() + conf.HTTP.Auth = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{ + Path: htpasswdPath, + }, + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "dex": { + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + KeyPath: "", + Issuer: mockOIDCConfig.Issuer, + Scopes: []string{"openid", "email"}, + }, + }, + }, + } + + ctlr := api.NewController(conf) + + ctlr.Config.Storage.RootDirectory = rootDir + + cm := test.NewControllerManager(ctlr) + + cm.StartServer() + defer cm.StopServer() + test.WaitTillServerReady(baseURL) + + Convey("trigger basic authn middle(htpasswd) error", func() { + client := resty.New() + + ctlr.RepoDB = mocks.RepoDBMock{ + SetUserGroupsFn: func(ctx context.Context, groups []string) error { + return ErrUnexpectedError + }, + } + + resp, err := client.R(). + SetBasicAuth(username, passphrase). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + }) + + Convey("trigger session middle repoDB errors", func() { + client := resty.New() + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + + user := mockoidc.DefaultUser() + user.Groups = []string{"group1", "group2"} + + mockOIDCServer.QueueUser(user) + + // first login user + resp, err := client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("provider", "dex"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusCreated) + + Convey("trigger session middle error", func() { + cookies := resp.Cookies() + + client.SetCookies(cookies) + + ctlr.RepoDB = mocks.RepoDBMock{ + GetUserGroupsFn: func(ctx context.Context) ([]string, error) { + return []string{}, ErrUnexpectedError + }, + } + + // call endpoint with session (added to client after previous request) + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + }) + }) + }) +} + func TestAuthorizationWithBasicAuth(t *testing.T) { Convey("Make a new controller", t, func() { port := test.GetFreePort() baseURL := test.GetBaseURL(port) - conf := config.New() conf.HTTP.Port = port htpasswdPath := test.MakeHtpasswdFile() @@ -3325,17 +4428,18 @@ func TestInvalidCases(t *testing.T) { }, } - ctlr := makeController(conf, "oci-repo-test", "") - - err := os.Mkdir("oci-repo-test", 0o000) - if err != nil { - panic(err) - } + dir := t.TempDir() + ctlr := makeController(conf, dir, "") cm := test.NewControllerManager(ctlr) cm.StartAndWait(port) defer func(ctrl *api.Controller) { - err := ctrl.Server.Shutdown(context.Background()) + err := os.Chmod(dir, 0o755) + if err != nil { + panic(err) + } + + err = ctrl.Server.Shutdown(context.Background()) if err != nil { panic(err) } @@ -3346,6 +4450,11 @@ func TestInvalidCases(t *testing.T) { } }(ctlr) + err := os.Chmod(dir, 0o000) + if err != nil { + panic(err) + } + digest := test.GetTestBlobDigest("zot-cve-test", "config").String() name := "zot-c-test" @@ -3439,10 +4548,9 @@ func TestCrossRepoMount(t *testing.T) { } dir := t.TempDir() - err := os.MkdirAll(path.Join(dir, "zot-cve-test"), storageConstants.DefaultDirPerms) - So(err, ShouldBeNil) + ctlr := api.NewController(conf) - ctlr := makeController(conf, path.Join(dir, "zot-cve-test"), "../../test/data/zot-cve-test") + test.CopyTestFiles("../../test/data/zot-cve-test", path.Join(dir, "zot-cve-test")) ctlr.Config.Storage.RootDirectory = dir ctlr.Config.Storage.RemoteCache = false @@ -3454,7 +4562,7 @@ func TestCrossRepoMount(t *testing.T) { params := make(map[string]string) var manifestDigest godigest.Digest - manifestDigest, _, _ = test.GetOciLayoutDigests("../../test/data/zot-cve-test") + manifestDigest, _, _ = test.GetOciLayoutDigests(path.Join(dir, "zot-cve-test")) dgst := manifestDigest name := "zot-cve-test" @@ -3550,10 +4658,18 @@ func TestCrossRepoMount(t *testing.T) { // in cache, now try mount blob request status and it should be 201 because now blob is present in cache // and it should do hard link. - // restart server with dedupe enabled + // make a new server with dedupe on and same rootDir (can't restart because of repodb - boltdb being open) + newDir := t.TempDir() + err = test.CopyFiles(dir, newDir) + So(err, ShouldBeNil) + cm.StopServer() + ctlr.Config.Storage.Dedupe = true + ctlr.Config.Storage.RootDirectory = newDir + cm = test.NewControllerManager(ctlr) //nolint: varnamelen cm.StartAndWait(port) + defer cm.StopServer() // wait for dedupe task to run time.Sleep(10 * time.Second) @@ -3976,13 +5092,6 @@ func TestHardLink(t *testing.T) { port := test.GetFreePort() conf := config.New() conf.HTTP.Port = port - htpasswdPath := test.MakeHtpasswdFileFromString(getCredString(username, passphrase)) - - conf.HTTP.Auth = &config.AuthConfig{ - HTPasswd: config.AuthHTPasswd{ - Path: htpasswdPath, - }, - } dir := t.TempDir() @@ -7390,6 +8499,84 @@ func TestHTTPOptionsResponse(t *testing.T) { }) } +func TestGetGithubUserInfo(t *testing.T) { + Convey("github api calls works", t, func() { + mockedHTTPClient := mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUserEmails, + []github.UserEmail{ + { + Email: github.String("test@test"), + Primary: github.Bool(true), + }, + }, + ), + mock.WithRequestMatch( + mock.GetUserOrgs, + []github.Organization{ + { + Login: github.String("testOrg"), + }, + }, + ), + ) + + client := github.NewClient(mockedHTTPClient) + + _, _, err := api.GetGithubUserInfo(context.Background(), client, log.Logger{}) + So(err, ShouldBeNil) + }) + + Convey("github ListEmails error", t, func() { + mockedHTTPClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetUserEmails, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mock.WriteError( + w, + http.StatusInternalServerError, + "github error", + ) + }), + ), + ) + + client := github.NewClient(mockedHTTPClient) + + _, _, err := api.GetGithubUserInfo(context.Background(), client, log.Logger{}) + So(err, ShouldNotBeNil) + }) + + Convey("github ListEmails error", t, func() { + mockedHTTPClient := mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetUserEmails, + []github.UserEmail{ + { + Email: github.String("test@test"), + Primary: github.Bool(true), + }, + }, + ), + mock.WithRequestMatchHandler( + mock.GetUserOrgs, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mock.WriteError( + w, + http.StatusInternalServerError, + "github error", + ) + }), + ), + ) + + client := github.NewClient(mockedHTTPClient) + + _, _, err := api.GetGithubUserInfo(context.Background(), client, log.Logger{}) + So(err, ShouldNotBeNil) + }) +} + func getAllBlobs(imagePath string) []string { blobList := make([]string, 0) diff --git a/pkg/api/routes.go b/pkg/api/routes.go index c159defd..792e660f 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -20,11 +20,14 @@ import ( "strconv" "strings" + "github.com/google/go-github/v52/github" "github.com/gorilla/mux" "github.com/opencontainers/distribution-spec/specs-go/v1/extensions" godigest "github.com/opencontainers/go-digest" ispec "github.com/opencontainers/image-spec/specs-go/v1" artifactspec "github.com/oras-project/artifacts-spec/specs-go/v1" + "github.com/zitadel/oidc/pkg/client/rp" + "github.com/zitadel/oidc/pkg/oidc" zerr "zotregistry.io/zot/errors" "zotregistry.io/zot/pkg/api/constants" @@ -55,13 +58,38 @@ func NewRouteHandler(c *Controller) *RouteHandler { } func (rh *RouteHandler) SetupRoutes() { + // first get Auth middleware in order to first setup openid/ldap/htpasswd, before oidc provider routes are setup + authHandler := AuthHandler(rh.c) + + applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin) + + if isOpenIDAuthEnabled(rh.c.Config) { + // login path for openID + rh.c.Router.HandleFunc(constants.LoginPath, rh.AuthURLHandler()) + + // logout path for openID + rh.c.Router.HandleFunc(constants.LogoutPath, applyCORSHeaders(rh.Logout)). + Methods(zcommon.AllowedMethods("POST")...) + + // callback path for openID + for provider, relyingParty := range rh.c.RelyingParties { + if IsOauth2Supported(provider) { + rh.c.Router.HandleFunc(constants.CallbackBasePath+fmt.Sprintf("/%s", provider), + rp.CodeExchangeHandler(rh.GithubCodeExchangeCallback(), relyingParty)) + } else if IsOpenIDSupported(provider) { + rh.c.Router.HandleFunc(constants.CallbackBasePath+fmt.Sprintf("/%s", provider), + rp.CodeExchangeHandler(rp.UserinfoCallback(rh.OpenIDCodeExchangeCallback()), relyingParty)) + } + } + } + prefixedRouter := rh.c.Router.PathPrefix(constants.RoutePrefix).Subrouter() - prefixedRouter.Use(AuthHandler(rh.c)) + prefixedRouter.Use(authHandler) prefixedDistSpecRouter := prefixedRouter.NewRoute().Subrouter() // authz is being enabled if AccessControl is specified // if Authn is not present AccessControl will have only default policies - if rh.c.Config.HTTP.AccessControl != nil && !isBearerAuthEnabled(rh.c.Config) { + if rh.c.Config.HTTP.AccessControl != nil { if isAuthnEnabled(rh.c.Config) { rh.c.Log.Info().Msg("access control is being enabled") } else { @@ -72,8 +100,6 @@ func (rh *RouteHandler) SetupRoutes() { prefixedDistSpecRouter.Use(DistSpecAuthzHandler(rh.c)) } - applyCORSHeaders := getCORSHeadersHandler(rh.c.Config.HTTP.AllowOrigin) - // https://github.com/opencontainers/distribution-spec/blob/main/spec.md#endpoints { prefixedDistSpecRouter.HandleFunc(fmt.Sprintf("/{name:%s}/tags/list", zreg.NameRegexp.String()), @@ -118,7 +144,7 @@ func (rh *RouteHandler) SetupRoutes() { constants.ArtifactSpecRoutePrefix, zreg.NameRegexp.String()), rh.GetOrasReferrers).Methods("GET") // swagger - debug.SetupSwaggerRoutes(rh.c.Config, rh.c.Router, AuthHandler(rh.c), rh.c.Log) + debug.SetupSwaggerRoutes(rh.c.Config, rh.c.Router, authHandler, rh.c.Log) // Setup Extensions Routes if rh.c.Config != nil { @@ -135,8 +161,8 @@ func (rh *RouteHandler) SetupRoutes() { rh.c.Log) ext.SetupUserPreferencesRoutes(rh.c.Config, prefixedExtensionsRouter, rh.c.StoreController, rh.c.RepoDB, rh.c.CveInfo, rh.c.Log) - - ext.SetupMetricsRoutes(rh.c.Config, rh.c.Router, rh.c.StoreController, AuthHandler(rh.c), rh.c.Log) + ext.SetupAPIKeyRoutes(rh.c.Config, prefixedExtensionsRouter, rh.c.RepoDB, rh.c.CookieStore, rh.c.Log) + ext.SetupMetricsRoutes(rh.c.Config, rh.c.Router, rh.c.StoreController, authHandler, rh.c.Log) gqlPlayground.SetupGQLPlaygroundRoutes(rh.c.Config, prefixedRouter, rh.c.StoreController, rh.c.Log) @@ -185,7 +211,8 @@ func addCORSHeaders(allowOrigin string, response http.ResponseWriter) { // @Success 200 {string} string "ok". func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, request *http.Request) { response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS") - response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) + response.Header().Set("Access-Control-Allow-Credentials", "true") if request.Method == http.MethodOptions { return @@ -195,10 +222,13 @@ func (rh *RouteHandler) CheckVersionSupport(response http.ResponseWriter, reques // NOTE: compatibility workaround - return this header in "allowed-read" mode to allow for clients to // work correctly if rh.c.Config.HTTP.Auth != nil { - if rh.c.Config.HTTP.Auth.Bearer != nil { - response.Header().Set("WWW-Authenticate", fmt.Sprintf("bearer realm=%s", rh.c.Config.HTTP.Auth.Bearer.Realm)) - } else { - response.Header().Set("WWW-Authenticate", fmt.Sprintf("basic realm=%s", rh.c.Config.HTTP.Realm)) + // don't send auth headers if request is coming from UI + if request.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue { + if rh.c.Config.HTTP.Auth.Bearer != nil { + response.Header().Set("WWW-Authenticate", fmt.Sprintf("bearer realm=%s", rh.c.Config.HTTP.Auth.Bearer.Realm)) + } else { + response.Header().Set("WWW-Authenticate", fmt.Sprintf("basic realm=%s", rh.c.Config.HTTP.Realm)) + } } } @@ -224,7 +254,8 @@ type ImageTags struct { // @Failure 400 {string} string "bad request". func (rh *RouteHandler) ListTags(response http.ResponseWriter, request *http.Request) { response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS") - response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) + response.Header().Set("Access-Control-Allow-Credentials", "true") if request.Method == http.MethodOptions { return @@ -355,7 +386,8 @@ func (rh *RouteHandler) ListTags(response http.ResponseWriter, request *http.Req // @Failure 500 {string} string "internal server error". func (rh *RouteHandler) CheckManifest(response http.ResponseWriter, request *http.Request) { response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS") - response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) + response.Header().Set("Access-Control-Allow-Credentials", "true") if request.Method == http.MethodOptions { return @@ -427,7 +459,8 @@ type ExtensionList struct { // @Router /v2/{name}/manifests/{reference} [get]. func (rh *RouteHandler) GetManifest(response http.ResponseWriter, request *http.Request) { response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS") - response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) + response.Header().Set("Access-Control-Allow-Credentials", "true") if request.Method == http.MethodOptions { return @@ -527,7 +560,8 @@ func getReferrers(routeHandler *RouteHandler, // @Router /v2/{name}/referrers/{digest} [get]. func (rh *RouteHandler) GetReferrers(response http.ResponseWriter, request *http.Request) { response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS") - response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) + response.Header().Set("Access-Control-Allow-Credentials", "true") if request.Method == http.MethodOptions { return @@ -1576,7 +1610,8 @@ type RepositoryList struct { // @Router /v2/_catalog [get]. func (rh *RouteHandler) ListRepositories(response http.ResponseWriter, request *http.Request) { response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS") - response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) + response.Header().Set("Access-Control-Allow-Credentials", "true") if request.Method == http.MethodOptions { return @@ -1642,7 +1677,8 @@ func (rh *RouteHandler) ListRepositories(response http.ResponseWriter, request * // @Router /v2/_oci/ext/discover [get]. func (rh *RouteHandler) ListExtensions(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + w.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) + w.Header().Set("Access-Control-Allow-Credentials", "true") if r.Method == http.MethodOptions { return @@ -1653,6 +1689,116 @@ func (rh *RouteHandler) ListExtensions(w http.ResponseWriter, r *http.Request) { zcommon.WriteJSON(w, http.StatusOK, extensionList) } +// The following routes are specific to zot and NOT part of the OCI dist-spec + +// Logout godoc +// @Summary Logout by removing current session +// @Description Logout by removing current session +// @Router /openid/auth/logout [post] +// @Accept json +// @Produce json +// @Success 200 {string} string "ok". +// @Failure 500 {string} string "internal server error". +func (rh *RouteHandler) Logout(response http.ResponseWriter, request *http.Request) { + response.Header().Set("Access-Control-Allow-Methods", "HEAD,GET,POST,OPTIONS") + response.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) + response.Header().Set("Access-Control-Allow-Credentials", "true") + + if request.Method == http.MethodOptions { + return + } + + session, _ := rh.c.CookieStore.Get(request, "session") + session.Options.MaxAge = -1 + + err := session.Save(request, response) + if err != nil { + response.WriteHeader(http.StatusInternalServerError) + + return + } + + response.WriteHeader(http.StatusOK) +} + +// github Oauth2 CodeExchange callback. +func (rh *RouteHandler) GithubCodeExchangeCallback() rp.CodeExchangeCallback { + return func(w http.ResponseWriter, r *http.Request, + tokens *oidc.Tokens, state string, relyingParty rp.RelyingParty, + ) { + ctx := r.Context() + + client := github.NewClient(relyingParty.OAuthConfig().Client(ctx, tokens.Token)) + + email, groups, err := GetGithubUserInfo(ctx, client, rh.c.Log) + if email == "" || err != nil { + w.WriteHeader(http.StatusUnauthorized) + + return + } + + callbackUI, err := OAuth2Callback(rh.c, w, r, state, email, groups) //nolint: contextcheck + if err != nil { + if errors.Is(err, zerr.ErrInvalidStateCookie) { + w.WriteHeader(http.StatusUnauthorized) + } + + w.WriteHeader(http.StatusInternalServerError) + } + + if callbackUI != "" { + http.Redirect(w, r, callbackUI, http.StatusFound) + + return + } + + w.WriteHeader(http.StatusCreated) + } +} + +// Openid CodeExchange callback. +func (rh *RouteHandler) OpenIDCodeExchangeCallback() rp.CodeExchangeUserinfoCallback { + return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, + relyingParty rp.RelyingParty, info oidc.UserInfo, + ) { + email := info.GetEmail() + if email == "" { + rh.c.Log.Error().Msg("couldn't set user record for empty email value") + w.WriteHeader(http.StatusUnauthorized) + + return + } + + var groups []string + + val, ok := info.GetClaim("groups").([]interface{}) + if !ok { + rh.c.Log.Info().Msgf("couldn't find any 'groups' claim for user %s", email) + } + + for _, group := range val { + groups = append(groups, fmt.Sprint(group)) + } + + callbackUI, err := OAuth2Callback(rh.c, w, r, state, email, groups) + if err != nil { + if errors.Is(err, zerr.ErrInvalidStateCookie) { + w.WriteHeader(http.StatusUnauthorized) + } + + w.WriteHeader(http.StatusInternalServerError) + } + + if callbackUI != "" { + http.Redirect(w, r, callbackUI, http.StatusFound) + + return + } + + w.WriteHeader(http.StatusCreated) + } +} + func (rh *RouteHandler) GetMetrics(w http.ResponseWriter, r *http.Request) { m := rh.c.Metrics.ReceiveMetrics() zcommon.WriteJSON(w, http.StatusOK, m) diff --git a/pkg/api/routes_test.go b/pkg/api/routes_test.go index a4ffc684..cf23b8e1 100644 --- a/pkg/api/routes_test.go +++ b/pkg/api/routes_test.go @@ -1,26 +1,36 @@ -//go:build sync && scrub && metrics && search && lint -// +build sync,scrub,metrics,search,lint +//go:build sync && scrub && metrics && search && lint && apikey +// +build sync,scrub,metrics,search,lint,apikey package api_test import ( "bytes" "context" + "encoding/json" "errors" "io" "net/http" "net/http/httptest" + "os" "testing" + "github.com/google/uuid" "github.com/gorilla/mux" godigest "github.com/opencontainers/go-digest" ispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/project-zot/mockoidc" . "github.com/smartystreets/goconvey/convey" + "github.com/zitadel/oidc/pkg/client/rp" + "github.com/zitadel/oidc/pkg/oidc" + "golang.org/x/oauth2" zerr "zotregistry.io/zot/errors" "zotregistry.io/zot/pkg/api" "zotregistry.io/zot/pkg/api/config" "zotregistry.io/zot/pkg/api/constants" + "zotregistry.io/zot/pkg/extensions" + extconf "zotregistry.io/zot/pkg/extensions/config" + "zotregistry.io/zot/pkg/meta/repodb" localCtx "zotregistry.io/zot/pkg/requestcontext" storageTypes "zotregistry.io/zot/pkg/storage/types" "zotregistry.io/zot/pkg/test" @@ -29,6 +39,8 @@ import ( var ErrUnexpectedError = errors.New("error: unexpected error") +const sessionStr = "session" + func TestRoutes(t *testing.T) { Convey("Make a new controller", t, func() { port := test.GetFreePort() @@ -36,6 +48,45 @@ func TestRoutes(t *testing.T) { conf := config.New() conf.HTTP.Port = port + htpasswdPath := test.MakeHtpasswdFile() + defer os.Remove(htpasswdPath) + mockOIDCServer, err := mockoidc.Run() + if err != nil { + panic(err) + } + defer func() { + err := mockOIDCServer.Shutdown() + if err != nil { + panic(err) + } + }() + + mockOIDCConfig := mockOIDCServer.Config() + conf.HTTP.Auth = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{ + Path: htpasswdPath, + }, + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "dex": { + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + KeyPath: "", + Issuer: mockOIDCConfig.Issuer, + Scopes: []string{"openid", "email"}, + }, + }, + }, + } + + defaultVal := true + apiKeyConfig := &extconf.APIKeyConfig{ + BaseConfig: extconf.BaseConfig{Enable: &defaultVal}, + } + conf.Extensions = &extconf.ExtensionConfig{ + APIKey: apiKeyConfig, + } + ctlr := api.NewController(conf) ctlr.Config.Storage.RootDirectory = t.TempDir() @@ -50,6 +101,52 @@ func TestRoutes(t *testing.T) { // NOTE: the url or method itself doesn't matter below since we are calling the handlers directly, // so path routing is bypassed + Convey("Test GithubCodeExchangeCallback", func() { + callback := rthdlr.GithubCodeExchangeCallback() + ctx := context.TODO() + + request, _ := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) + response := httptest.NewRecorder() + + tokens := &oidc.Tokens{} + relyingParty, err := rp.NewRelyingPartyOAuth(&oauth2.Config{}) + So(err, ShouldBeNil) + + callback(response, request, tokens, "state", relyingParty) + + resp := response.Result() + defer resp.Body.Close() + So(resp, ShouldNotBeNil) + So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) + }) + + Convey("Test OAuth2Callback errors", func() { + ctx := context.TODO() + + request, _ := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) + response := httptest.NewRecorder() + + _, err := api.OAuth2Callback(ctlr, response, request, "state", "email", []string{"group"}) + So(err, ShouldEqual, zerr.ErrInvalidStateCookie) + + session, _ := ctlr.CookieStore.Get(request, "statecookie") + + session.Options.Secure = true + session.Options.HttpOnly = true + session.Options.SameSite = http.SameSiteDefaultMode + + state := uuid.New().String() + + session.Values["state"] = state + + // let the session set its own id + err = session.Save(request, response) + So(err, ShouldBeNil) + + _, err = api.OAuth2Callback(ctlr, response, request, "state", "email", []string{"group"}) + So(err, ShouldEqual, zerr.ErrInvalidStateCookie) + }) + Convey("List repositories authz error", func() { var invalid struct{} @@ -575,7 +672,7 @@ func TestRoutes(t *testing.T) { }, &mocks.MockedImageStore{ FullBlobUploadFn: func(repo string, body io.Reader, digest godigest.Digest) (string, int64, error) { - return "session", 0, zerr.ErrBadBlobDigest + return sessionStr, 0, zerr.ErrBadBlobDigest }, }) So(statusCode, ShouldEqual, http.StatusInternalServerError) @@ -591,7 +688,7 @@ func TestRoutes(t *testing.T) { }, &mocks.MockedImageStore{ FullBlobUploadFn: func(repo string, body io.Reader, digest godigest.Digest) (string, int64, error) { - return "session", 20, nil + return sessionStr, 20, nil }, }) So(statusCode, ShouldEqual, http.StatusInternalServerError) @@ -1327,6 +1424,80 @@ func TestRoutes(t *testing.T) { So(resp.StatusCode, ShouldEqual, http.StatusOK) }) + Convey("Test API keys", func() { + var invalid struct{} + + ctx := context.TODO() + key := localCtx.GetContextKey() + ctx = context.WithValue(ctx, key, invalid) + + request, _ := http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader([]byte{})) + response := httptest.NewRecorder() + + extensions.CreateAPIKey(response, request, ctlr.RepoDB, ctlr.CookieStore, ctlr.Log) + + resp := response.Result() + defer resp.Body.Close() + So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError) + + acCtx := localCtx.AccessControlContext{ + Username: username, + } + + ctx = context.TODO() + key = localCtx.GetContextKey() + ctx = context.WithValue(ctx, key, acCtx) + + request, _ = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader([]byte{})) + response = httptest.NewRecorder() + + extensions.CreateAPIKey(response, request, ctlr.RepoDB, ctlr.CookieStore, ctlr.Log) + + resp = response.Result() + defer resp.Body.Close() + + So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError) + + payload := extensions.APIKeyPayload{ + Label: "test", + Scopes: []string{"test"}, + } + reqBody, err := json.Marshal(payload) + So(err, ShouldBeNil) + + request, _ = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(reqBody)) + response = httptest.NewRecorder() + + extensions.CreateAPIKey(response, request, mocks.RepoDBMock{ + AddUserAPIKeyFn: func(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error { + return ErrUnexpectedError + }, + }, ctlr.CookieStore, ctlr.Log) + + resp = response.Result() + defer resp.Body.Close() + + So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError) + + request, _ = http.NewRequestWithContext(ctx, http.MethodDelete, baseURL, bytes.NewReader([]byte{})) + response = httptest.NewRecorder() + + q := request.URL.Query() + q.Add("id", "apikeyid") + request.URL.RawQuery = q.Encode() + + extensions.RevokeAPIKey(response, request, mocks.RepoDBMock{ + DeleteUserAPIKeyFn: func(ctx context.Context, id string) error { + return ErrUnexpectedError + }, + }, ctlr.CookieStore, ctlr.Log) + + resp = response.Result() + defer resp.Body.Close() + + So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError) + }) + Convey("Helper functions", func() { testUpdateBlobUpload := func( query []struct{ k, v string }, diff --git a/pkg/cli/extensions_test.go b/pkg/cli/extensions_test.go index b28b6d01..4ae19ecc 100644 --- a/pkg/cli/extensions_test.go +++ b/pkg/cli/extensions_test.go @@ -1,5 +1,5 @@ -//go:build sync && scrub && metrics && search -// +build sync,scrub,metrics,search +//go:build sync && scrub && metrics && search && apikey +// +build sync,scrub,metrics,search,apikey package cli_test @@ -857,6 +857,67 @@ func TestServeMgmtExtension(t *testing.T) { }) } +func TestServeAPIKeyExtension(t *testing.T) { + oldArgs := os.Args + + defer func() { os.Args = oldArgs }() + + Convey("apikey implicitly enabled", t, func(c C) { + content := `{ + "storage": { + "rootDirectory": "%s" + }, + "http": { + "address": "127.0.0.1", + "port": "%s" + }, + "log": { + "level": "debug", + "output": "%s" + }, + "extensions": { + "apikey": { + } + } + }` + + logPath, err := runCLIWithConfig(t.TempDir(), content) + So(err, ShouldBeNil) + data, err := os.ReadFile(logPath) + So(err, ShouldBeNil) + defer os.Remove(logPath) // clean up + So(string(data), ShouldContainSubstring, "\"APIKey\":{\"Enable\":true}") + }) + + Convey("apikey disabled", t, func(c C) { + content := `{ + "storage": { + "rootDirectory": "%s" + }, + "http": { + "address": "127.0.0.1", + "port": "%s" + }, + "log": { + "level": "debug", + "output": "%s" + }, + "extensions": { + "apikey": { + "enable": "false" + } + } + }` + + logPath, err := runCLIWithConfig(t.TempDir(), content) + So(err, ShouldBeNil) + data, err := os.ReadFile(logPath) + So(err, ShouldBeNil) + defer os.Remove(logPath) // clean up + So(string(data), ShouldContainSubstring, "\"APIKey\":{\"Enable\":false}") + }) +} + func readLogFileAndSearchString(logPath string, stringToMatch string, timeout time.Duration) (bool, error) { //nolint:unparam,lll ctx, cancelFunc := context.WithTimeout(context.Background(), timeout) defer cancelFunc() diff --git a/pkg/cli/root.go b/pkg/cli/root.go index 744f43c0..ee118827 100644 --- a/pkg/cli/root.go +++ b/pkg/cli/root.go @@ -361,6 +361,10 @@ func validateConfiguration(config *config.Config) error { return err } + if err := validateOpenIDConfig(config); err != nil { + return err + } + if err := validateSync(config); err != nil { return err } @@ -377,7 +381,7 @@ func validateConfiguration(config *config.Config) error { return err } - // check authorization config, it should have basic auth enabled or ldap + // check authorization config, it should have basic auth enabled or ldap, api keys or OpenID if config.HTTP.AccessControl != nil { // checking for anonymous policy only authorization config: no users, no policies but anonymous policy if err := validateAuthzPolicies(config); err != nil { @@ -435,11 +439,42 @@ func validateConfiguration(config *config.Config) error { return nil } +func validateOpenIDConfig(config *config.Config) error { + if config.HTTP.Auth != nil && config.HTTP.Auth.OpenID != nil { + for provider, providerConfig := range config.HTTP.Auth.OpenID.Providers { + //nolint: gocritic + if api.IsOpenIDSupported(provider) { + if providerConfig.ClientID == "" || providerConfig.Issuer == "" || + len(providerConfig.Scopes) == 0 { + log.Error().Err(errors.ErrBadConfig). + Msg("OpenID provider config requires clientid, issuer and scopes parameters") + + return errors.ErrBadConfig + } + } else if api.IsOauth2Supported(provider) { + if providerConfig.ClientID == "" || len(providerConfig.Scopes) == 0 { + log.Error().Err(errors.ErrBadConfig). + Msg("OAuth2 provider config requires clientid and scopes parameters") + + return errors.ErrBadConfig + } + } else { + log.Error().Err(errors.ErrBadConfig). + Msg("unsupported openid/oauth2 provider") + + return errors.ErrBadConfig + } + } + } + + return nil +} + func validateAuthzPolicies(config *config.Config) error { - if (config.HTTP.Auth == nil || (config.HTTP.Auth.HTPasswd.Path == "" && config.HTTP.Auth.LDAP == nil)) && - !authzContainsOnlyAnonymousPolicy(config) { + if (config.HTTP.Auth == nil || (config.HTTP.Auth.HTPasswd.Path == "" && config.HTTP.Auth.LDAP == nil && + config.HTTP.Auth.OpenID == nil)) && !authzContainsOnlyAnonymousPolicy(config) { log.Error().Err(errors.ErrBadConfig). - Msg("access control config requires httpasswd, ldap authentication " + + Msg("access control config requires one of httpasswd, ldap or openid authentication " + "or using only 'anonymousPolicy' policies") return errors.ErrBadConfig @@ -484,6 +519,13 @@ func applyDefaultValues(config *config.Config, viperInstance *viper.Viper) { // Note: In case mgmt is not empty the config.Extensions will not be nil and we will not reach here config.Extensions.Mgmt = &extconf.MgmtConfig{} } + + _, ok = extMap["apikey"] + if ok { + // we found a config like `"extensions": {"mgmt:": {}}` + // Note: In case mgmt is not empty the config.Extensions will not be nil and we will not reach here + config.Extensions.APIKey = &extconf.APIKeyConfig{} + } } if config.Extensions != nil { @@ -550,6 +592,12 @@ func applyDefaultValues(config *config.Config, viperInstance *viper.Viper) { } } + if config.Extensions.APIKey != nil { + if config.Extensions.APIKey.Enable == nil { + config.Extensions.APIKey.Enable = &defaultVal + } + } + if config.Extensions.Scrub != nil { if config.Extensions.Scrub.Enable == nil { config.Extensions.Scrub.Enable = &defaultVal diff --git a/pkg/cli/root_test.go b/pkg/cli/root_test.go index 599ae903..ef57e7ac 100644 --- a/pkg/cli/root_test.go +++ b/pkg/cli/root_test.go @@ -952,6 +952,71 @@ func TestVerify(t *testing.T) { So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic) }) + Convey("Test verify openid config with missing parameter", t, func(c C) { + tmpfile, err := os.CreateTemp("", "zot-test*.json") + So(err, ShouldBeNil) + defer os.Remove(tmpfile.Name()) // clean up + content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"}, + "http":{"address":"127.0.0.1","port":"8080","realm":"zot", + "auth":{"openid":{"providers":{"dex":{"issuer":"http://127.0.0.1:5556/dex"}}}}}, + "log":{"level":"debug"}}`) + _, err = tmpfile.Write(content) + So(err, ShouldBeNil) + err = tmpfile.Close() + So(err, ShouldBeNil) + os.Args = []string{"cli_test", "verify", tmpfile.Name()} + So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic) + }) + + Convey("Test verify oauth2 config with missing parameter", t, func(c C) { + tmpfile, err := os.CreateTemp("", "zot-test*.json") + So(err, ShouldBeNil) + defer os.Remove(tmpfile.Name()) // clean up + content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"}, + "http":{"address":"127.0.0.1","port":"8080","realm":"zot", + "auth":{"openid":{"providers":{"github":{"clientid":"client_id"}}}}}, + "log":{"level":"debug"}}`) + _, err = tmpfile.Write(content) + So(err, ShouldBeNil) + err = tmpfile.Close() + So(err, ShouldBeNil) + os.Args = []string{"cli_test", "verify", tmpfile.Name()} + So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic) + }) + + Convey("Test verify openid config with unsupported provider", t, func(c C) { + tmpfile, err := os.CreateTemp("", "zot-test*.json") + So(err, ShouldBeNil) + defer os.Remove(tmpfile.Name()) // clean up + content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"}, + "http":{"address":"127.0.0.1","port":"8080","realm":"zot", + "auth":{"openid":{"providers":{"unsupported":{"issuer":"http://127.0.0.1:5556/dex"}}}}}, + "log":{"level":"debug"}}`) + _, err = tmpfile.Write(content) + So(err, ShouldBeNil) + err = tmpfile.Close() + So(err, ShouldBeNil) + os.Args = []string{"cli_test", "verify", tmpfile.Name()} + So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldPanic) + }) + + Convey("Test verify openid config without apikey extension enabled", t, func(c C) { + tmpfile, err := os.CreateTemp("", "zot-test*.json") + So(err, ShouldBeNil) + defer os.Remove(tmpfile.Name()) // clean up + content := []byte(`{"distSpecVersion":"1.1.0-dev","storage":{"rootDirectory":"/tmp/zot"}, + "http":{"address":"127.0.0.1","port":"8080","realm":"zot", + "auth":{"openid":{"providers":{"dex":{"issuer":"http://127.0.0.1:5556/dex", + "clientid":"client_id","scopes":["openid"]}}}}}, + "log":{"level":"debug"}}`) + _, err = tmpfile.Write(content) + So(err, ShouldBeNil) + err = tmpfile.Close() + So(err, ShouldBeNil) + os.Args = []string{"cli_test", "verify", tmpfile.Name()} + So(func() { _ = cli.NewServerRootCmd().Execute() }, ShouldNotPanic) + }) + Convey("Test verify config with missing basedn key", t, func(c C) { tmpfile, err := os.CreateTemp("", "zot-test*.json") So(err, ShouldBeNil) diff --git a/pkg/common/http_server.go b/pkg/common/http_server.go index e90db94e..fd075039 100644 --- a/pkg/common/http_server.go +++ b/pkg/common/http_server.go @@ -2,14 +2,17 @@ package common import ( "net/http" + "strconv" "strings" "time" "github.com/gorilla/mux" + "github.com/gorilla/sessions" jsoniter "github.com/json-iterator/go" "zotregistry.io/zot/pkg/api/constants" apiErr "zotregistry.io/zot/pkg/api/errors" + "zotregistry.io/zot/pkg/log" ) func AllowedMethods(methods ...string) []string { @@ -32,7 +35,8 @@ func ACHeadersHandler(allowedMethods ...string) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { resp.Header().Set("Access-Control-Allow-Methods", headerValue) - resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type") + resp.Header().Set("Access-Control-Allow-Headers", "Authorization,content-type,"+constants.SessionClientHeaderName) + resp.Header().Set("Access-Control-Allow-Credentials", "true") if req.Method == http.MethodOptions { return @@ -43,9 +47,20 @@ func ACHeadersHandler(allowedMethods ...string) mux.MiddlewareFunc { } } -func AuthzFail(w http.ResponseWriter, realm string, delay int) { +func AuthzFail(w http.ResponseWriter, r *http.Request, realm string, delay int) { time.Sleep(time.Duration(delay) * time.Second) - w.Header().Set("WWW-Authenticate", realm) + + // don't send auth headers if request is coming from UI + if r.Header.Get(constants.SessionClientHeaderName) != constants.SessionClientHeaderValue { + if realm == "" { + realm = "Authorization Required" + } + + realm = "Basic realm=" + strconv.Quote(realm) + + w.Header().Set("WWW-Authenticate", realm) + } + w.Header().Set("Content-Type", "application/json") WriteJSON(w, http.StatusForbidden, apiErr.NewErrorList(apiErr.NewError(apiErr.DENIED))) } @@ -66,3 +81,39 @@ func WriteData(w http.ResponseWriter, status int, mediaType string, data []byte) w.WriteHeader(status) _, _ = w.Write(data) } + +/* +GetAuthUserFromRequestSession returns identity +and auth status if on the request's cookie session is a logged in user. +*/ +func GetAuthUserFromRequestSession(cookieStore sessions.Store, request *http.Request, log log.Logger, +) (string, bool) { + session, err := cookieStore.Get(request, "session") + if err != nil { + log.Error().Err(err).Msg("can not decode existing session") + // expired cookie, no need to return err + return "", false + } + + // at this point we should have a session set on cookie. + // if created in the earlier Get() call then user is not logged in with sessions. + if session.IsNew { + return "", false + } + + authenticated := session.Values["authStatus"] + if authenticated != true { + log.Error().Msg("can not get `user` session value") + + return "", false + } + + identity, ok := session.Values["user"].(string) + if !ok { + log.Error().Msg("can not get `user` session value") + + return "", false + } + + return identity, true +} diff --git a/pkg/extensions/README_apikey.md b/pkg/extensions/README_apikey.md new file mode 100644 index 00000000..c8ad0fb2 --- /dev/null +++ b/pkg/extensions/README_apikey.md @@ -0,0 +1,66 @@ +# `API keys` + +zot allows authentication for REST API calls using your API key as an alternative to your password. + +* User can create/revoke his API key. + +* Can not be retrieved, it is shown to the user only the first time is created. + +* An API key has the same rights as the user who generated it. + +## API keys REST API + + +### Create API Key +**Description**: Create an API key for the current user. + +**Usage**: POST /v2/_zot/ext/apikey + +**Produces**: application/json + +**Sample input**: +``` +POST /api/security/apiKey +Body: {"label": "git", "scopes": ["repo1", "repo2"]}' +``` + +**Example cURL** +``` +curl -u user:password -X POST http://localhost:8080/v2/_zot/ext/apikey -d '{"label": "myLabel", "scopes": ["repo1", "repo2"]}' +``` + +**Sample output**: +```json +{ + "createdAt": "2023-05-05T15:39:28.420926+03:00", + "creatorUa": "curl/7.68.0", + "generatedBy": "manual", + "lastUsed": "2023-05-05T15:39:28.4209282+03:00", + "label": "git", + "scopes": [ + "repo1", + "repo2" + ], + "uuid": "46a45ce7-5d92-498a-a9cb-9654b1da3da1", + "apiKey": "zak_e77bcb9e9f634f1581756abbf9ecd269" +} +``` + +**Using API keys cURL** +``` +curl -u user:zak_e77bcb9e9f634f1581756abbf9ecd269 http://localhost:8080/v2/_catalog +``` + + +### Revoke API Key +**Description**: Revokes one current user API key by api key UUID + +**Usage**: DELETE /api/security/apiKey?id=$uuid + +**Produces**: application/json + + +**Example cURL** +``` +curl -u user:password -X DELETE http://localhost:8080/v2/_zot/ext/apikey?id=46a45ce7-5d92-498a-a9cb-9654b1da3da1 +``` diff --git a/pkg/extensions/mgmt.md b/pkg/extensions/README_mgmt.md similarity index 100% rename from pkg/extensions/mgmt.md rename to pkg/extensions/README_mgmt.md diff --git a/pkg/extensions/_zot.md b/pkg/extensions/_zot.md index 8f2af16f..16c4d974 100644 --- a/pkg/extensions/_zot.md +++ b/pkg/extensions/_zot.md @@ -8,6 +8,7 @@ Component | Endpoint | Description [`search`](search/search.md) | `/v2/_zot/ext/search` | efficient and enhanced registry search capabilities using graphQL backend [`mgmt`](mgmt.md) | `/v2/_zot/ext/mgmt` | config management [`userprefs`](userprefs.md) | `/v2/_zot/ext/userprefs` | change user preferences +[`apikey`](README_apikey.md) | `/v2/_zot/ext/apikey` | user api keys management # References diff --git a/pkg/extensions/config/config.go b/pkg/extensions/config/config.go index c84de457..b13c015c 100644 --- a/pkg/extensions/config/config.go +++ b/pkg/extensions/config/config.go @@ -19,6 +19,11 @@ type ExtensionConfig struct { Lint *LintConfig UI *UIConfig Mgmt *MgmtConfig + APIKey *APIKeyConfig +} + +type APIKeyConfig struct { + BaseConfig `mapstructure:",squash"` } type MgmtConfig struct { diff --git a/pkg/extensions/extension_api_key.go b/pkg/extensions/extension_api_key.go new file mode 100644 index 00000000..f0e8dc9e --- /dev/null +++ b/pkg/extensions/extension_api_key.go @@ -0,0 +1,197 @@ +//go:build apikey +// +build apikey + +package extensions + +import ( + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + guuid "github.com/gofrs/uuid" + "github.com/gorilla/mux" + "github.com/gorilla/sessions" + jsoniter "github.com/json-iterator/go" + godigest "github.com/opencontainers/go-digest" + + "zotregistry.io/zot/pkg/api/config" + "zotregistry.io/zot/pkg/api/constants" + zcommon "zotregistry.io/zot/pkg/common" + "zotregistry.io/zot/pkg/log" + "zotregistry.io/zot/pkg/meta/repodb" +) + +func SetupAPIKeyRoutes(config *config.Config, router *mux.Router, repoDB repodb.RepoDB, + cookieStore sessions.Store, log log.Logger, +) { + if config.Extensions.APIKey != nil && *config.Extensions.APIKey.Enable { + log.Info().Msg("setting up api key routes") + + allowedMethods := zcommon.AllowedMethods(http.MethodPost, http.MethodDelete) + + apiKeyRouter := router.PathPrefix(constants.ExtAPIKey).Subrouter() + apiKeyRouter.Use(zcommon.ACHeadersHandler(allowedMethods...)) + apiKeyRouter.Use(zcommon.AddExtensionSecurityHeaders()) + apiKeyRouter.Methods(allowedMethods...).Handler(HandleAPIKeyRequest(repoDB, cookieStore, log)) + } +} + +type APIKeyPayload struct { //nolint:revive + Label string `json:"label"` + Scopes []string `json:"scopes"` +} + +func HandleAPIKeyRequest(repoDB repodb.RepoDB, cookieStore sessions.Store, + log log.Logger, +) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + switch req.Method { + case http.MethodPost: + CreateAPIKey(resp, req, repoDB, cookieStore, log) //nolint:contextcheck + + return + case http.MethodDelete: + RevokeAPIKey(resp, req, repoDB, cookieStore, log) //nolint:contextcheck + + return + } + }) +} + +// CreateAPIKey godoc +// @Summary Create an API key for the current user +// @Description Can create an api key for a logged in user, based on the provided label and scopes. +// @Accept json +// @Produce json +// @Success 201 {string} string "created" +// @Failure 401 {string} string "unauthorized" +// @Failure 500 {string} string "internal server error" +// @Router /v2/_zot/ext/apikey [post]. +func CreateAPIKey(resp http.ResponseWriter, req *http.Request, repoDB repodb.RepoDB, + cookieStore sessions.Store, log log.Logger, +) { + var payload APIKeyPayload + + body, err := io.ReadAll(req.Body) + if err != nil { + log.Error().Msg("unable to read request body") + resp.WriteHeader(http.StatusInternalServerError) + + return + } + + err = json.Unmarshal(body, &payload) + if err != nil { + log.Error().Err(err).Msg("unable to unmarshal body") + resp.WriteHeader(http.StatusInternalServerError) + + return + } + + apiKeyBase, err := guuid.NewV4() + if err != nil { + log.Error().Err(err).Msg("unable to generate uuid") + resp.WriteHeader(http.StatusInternalServerError) + + return + } + + apiKey := strings.ReplaceAll(apiKeyBase.String(), "-", "") + + hashedAPIKey := hashUUID(apiKey) + + // will be used for identifying a specific api key + apiKeyID, err := guuid.NewV4() + if err != nil { + log.Error().Err(err).Msg("unable to generate uuid") + resp.WriteHeader(http.StatusInternalServerError) + + return + } + + apiKeyDetails := &repodb.APIKeyDetails{ + CreatedAt: time.Now(), + LastUsed: time.Now(), + CreatorUA: req.UserAgent(), + GeneratedBy: "manual", + Label: payload.Label, + Scopes: payload.Scopes, + UUID: apiKeyID.String(), + } + + err = repoDB.AddUserAPIKey(req.Context(), hashedAPIKey, apiKeyDetails) + if err != nil { + log.Error().Err(err).Msg("error storing API key") + resp.WriteHeader(http.StatusInternalServerError) + + return + } + + apiKeyResponse := struct { + repodb.APIKeyDetails + APIKey string `json:"apiKey"` + }{ + APIKey: fmt.Sprintf("%s%s", constants.APIKeysPrefix, apiKey), + APIKeyDetails: *apiKeyDetails, + } + + json := jsoniter.ConfigCompatibleWithStandardLibrary + + data, err := json.Marshal(apiKeyResponse) + if err != nil { + log.Error().Err(err).Msg("unable to marshal api key response") + + resp.WriteHeader(http.StatusInternalServerError) + + return + } + + resp.Header().Set("Content-Type", constants.DefaultMediaType) + resp.WriteHeader(http.StatusCreated) + _, _ = resp.Write(data) +} + +// RevokeAPIKey godoc +// @Summary Revokes one current user API key +// @Description Revokes one current user API key based on given key ID +// @Accept json +// @Produce json +// @Param id path string true "api token id (UUID)" +// @Success 200 {string} string "ok" +// @Failure 500 {string} string "internal server error" +// @Failure 401 {string} string "unauthorized" +// @Failure 400 {string} string "bad request" +// @Router /v2/_zot/ext/apikey?id=UUID [delete]. +func RevokeAPIKey(resp http.ResponseWriter, req *http.Request, repoDB repodb.RepoDB, + cookieStore sessions.Store, log log.Logger, +) { + ids, ok := req.URL.Query()["id"] + if !ok || len(ids) != 1 { + resp.WriteHeader(http.StatusBadRequest) + + return + } + + keyID := ids[0] + + err := repoDB.DeleteUserAPIKey(req.Context(), keyID) + if err != nil { + log.Error().Err(err).Str("keyID", keyID).Msg("error deleting API key") + resp.WriteHeader(http.StatusInternalServerError) + + return + } + + resp.WriteHeader(http.StatusOK) +} + +func hashUUID(uuid string) string { + digester := sha256.New() + digester.Write([]byte(uuid)) + + return godigest.NewDigestFromEncoded(godigest.SHA256, fmt.Sprintf("%x", digester.Sum(nil))).Encoded() +} diff --git a/pkg/extensions/extension_api_key_disabled.go b/pkg/extensions/extension_api_key_disabled.go new file mode 100644 index 00000000..9fa3fc09 --- /dev/null +++ b/pkg/extensions/extension_api_key_disabled.go @@ -0,0 +1,20 @@ +//go:build !apikey +// +build !apikey + +package extensions + +import ( + "github.com/gorilla/mux" + "github.com/gorilla/sessions" + + "zotregistry.io/zot/pkg/api/config" + "zotregistry.io/zot/pkg/log" + "zotregistry.io/zot/pkg/meta/repodb" +) + +func SetupAPIKeyRoutes(config *config.Config, router *mux.Router, repoDB repodb.RepoDB, + cookieStore sessions.Store, log log.Logger, +) { + log.Warn().Msg("skipping setting up API key routes because given zot binary doesn't include this feature," + + "please build a binary that does so") +} diff --git a/pkg/extensions/extension_api_key_test.go b/pkg/extensions/extension_api_key_test.go new file mode 100644 index 00000000..6c43f795 --- /dev/null +++ b/pkg/extensions/extension_api_key_test.go @@ -0,0 +1,531 @@ +//go:build apikey +// +build apikey + +package extensions_test + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "os" + "testing" + + "github.com/project-zot/mockoidc" + . "github.com/smartystreets/goconvey/convey" + "gopkg.in/resty.v1" + + "zotregistry.io/zot/pkg/api" + "zotregistry.io/zot/pkg/api/config" + "zotregistry.io/zot/pkg/api/constants" + "zotregistry.io/zot/pkg/extensions" + extconf "zotregistry.io/zot/pkg/extensions/config" + "zotregistry.io/zot/pkg/meta/repodb" + localCtx "zotregistry.io/zot/pkg/requestcontext" + "zotregistry.io/zot/pkg/test" + "zotregistry.io/zot/pkg/test/mocks" +) + +type ( + apiKeyResponse struct { + repodb.APIKeyDetails + APIKey string `json:"apiKey"` + } +) + +var ErrUnexpectedError = errors.New("unexpected err") + +func TestAPIKeys(t *testing.T) { + Convey("Make a new controller", t, func() { + port := test.GetFreePort() + baseURL := test.GetBaseURL(port) + + conf := config.New() + conf.HTTP.Port = port + + htpasswdPath := test.MakeHtpasswdFile() + defer os.Remove(htpasswdPath) + + mockOIDCServer, err := test.MockOIDCRun() + if err != nil { + panic(err) + } + + defer func() { + err := mockOIDCServer.Shutdown() + if err != nil { + panic(err) + } + }() + + mockOIDCConfig := mockOIDCServer.Config() + conf.HTTP.Auth = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{ + Path: htpasswdPath, + }, + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "dex": { + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + KeyPath: "", + Issuer: mockOIDCConfig.Issuer, + Scopes: []string{"openid", "email", "groups"}, + }, + }, + }, + } + + conf.HTTP.AccessControl = &config.AccessControlConfig{} + + defaultVal := true + apiKeyConfig := &extconf.APIKeyConfig{ + BaseConfig: extconf.BaseConfig{Enable: &defaultVal}, + } + + mgmtConfg := &extconf.MgmtConfig{ + BaseConfig: extconf.BaseConfig{Enable: &defaultVal}, + } + + conf.Extensions = &extconf.ExtensionConfig{ + APIKey: apiKeyConfig, + Mgmt: mgmtConfg, + } + + ctlr := api.NewController(conf) + dir := t.TempDir() + + ctlr.Config.Storage.RootDirectory = dir + + cm := test.NewControllerManager(ctlr) + + cm.StartServer() + defer cm.StopServer() + test.WaitTillServerReady(baseURL) + + payload := extensions.APIKeyPayload{ + Label: "test", + Scopes: []string{"test"}, + } + reqBody, err := json.Marshal(payload) + So(err, ShouldBeNil) + + Convey("API key retrieved with basic auth", func() { + // call endpoint with session ( added to client after previous request) + resp, err := resty.R(). + SetBody(reqBody). + SetBasicAuth("test", "test"). + Post(baseURL + constants.FullAPIKeyPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusCreated) + + user := mockoidc.DefaultUser() + + // get API key and email from apikey route response + var apiKeyResponse apiKeyResponse + err = json.Unmarshal(resp.Body(), &apiKeyResponse) + So(err, ShouldBeNil) + + email := user.Email + So(email, ShouldNotBeEmpty) + + resp, err = resty.R(). + SetBasicAuth("test", apiKeyResponse.APIKey). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // add another one + resp, err = resty.R(). + SetBody(reqBody). + SetBasicAuth("test", "test"). + Post(baseURL + constants.FullAPIKeyPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusCreated) + + err = json.Unmarshal(resp.Body(), &apiKeyResponse) + So(err, ShouldBeNil) + + resp, err = resty.R(). + SetBasicAuth("test", apiKeyResponse.APIKey). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + }) + + Convey("API key retrieved with openID", func() { + client := resty.New() + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + + // first login user + resp, err := client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("provider", "dex"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + + cookies := resp.Cookies() + + // call endpoint without session + resp, err = client.R(). + SetBody(reqBody). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Post(baseURL + constants.FullAPIKeyPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + + client.SetCookies(cookies) + + // call endpoint with session ( added to client after previous request) + resp, err = client.R(). + SetBody(reqBody). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Post(baseURL + constants.FullAPIKeyPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusCreated) + + user := mockoidc.DefaultUser() + + // get API key and email from apikey route response + var apiKeyResponse apiKeyResponse + err = json.Unmarshal(resp.Body(), &apiKeyResponse) + So(err, ShouldBeNil) + + email := user.Email + So(email, ShouldNotBeEmpty) + + resp, err = client.R(). + SetBasicAuth(email, apiKeyResponse.APIKey). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // trigger errors + ctlr.RepoDB = mocks.RepoDBMock{ + GetUserAPIKeyInfoFn: func(hashedKey string) (string, error) { + return "", ErrUnexpectedError + }, + } + + resp, err = client.R(). + SetBasicAuth(email, apiKeyResponse.APIKey). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + + ctlr.RepoDB = mocks.RepoDBMock{ + GetUserAPIKeyInfoFn: func(hashedKey string) (string, error) { + return user.Email, nil + }, + GetUserGroupsFn: func(ctx context.Context) ([]string, error) { + return []string{}, ErrUnexpectedError + }, + } + + resp, err = client.R(). + SetBasicAuth(email, apiKeyResponse.APIKey). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + + ctlr.RepoDB = mocks.RepoDBMock{ + GetUserAPIKeyInfoFn: func(hashedKey string) (string, error) { + return user.Email, nil + }, + UpdateUserAPIKeyLastUsedFn: func(ctx context.Context, hashedKey string) error { + return ErrUnexpectedError + }, + } + + resp, err = client.R(). + SetBasicAuth(email, apiKeyResponse.APIKey). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + + client = resty.New() + + // call endpoint without session + resp, err = client.R(). + SetBody(reqBody). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Post(baseURL + constants.FullAPIKeyPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + }) + + Convey("Login with openid and create API key", func() { + client := resty.New() + + // mgmt should work both unauthenticated and authenticated + resp, err := client.R(). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + // first login user + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("provider", "dex"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusCreated) + + client.SetCookies(resp.Cookies()) + + // call endpoint with session ( added to client after previous request) + resp, err = client.R(). + SetBody(reqBody). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Post(baseURL + constants.FullAPIKeyPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusCreated) + + var apiKeyResponse apiKeyResponse + err = json.Unmarshal(resp.Body(), &apiKeyResponse) + So(err, ShouldBeNil) + + user := mockoidc.DefaultUser() + email := user.Email + So(email, ShouldNotBeEmpty) + + resp, err = client.R(). + SetBasicAuth(email, apiKeyResponse.APIKey). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // auth with API key + // we need new client without session cookie set + client = resty.New() + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + + resp, err = client.R(). + SetBasicAuth(email, apiKeyResponse.APIKey). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + resp, err = client.R(). + SetBasicAuth(email, apiKeyResponse.APIKey). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // invalid api keys + resp, err = client.R(). + SetBasicAuth("invalidEmail", apiKeyResponse.APIKey). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + + resp, err = client.R(). + SetBasicAuth(email, "noprefixAPIKey"). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + + resp, err = client.R(). + SetBasicAuth(email, "zak_notworkingAPIKey"). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: email, + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + + err = ctlr.RepoDB.DeleteUserData(ctx) + So(err, ShouldBeNil) + + resp, err = client.R(). + SetBasicAuth(email, apiKeyResponse.APIKey). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusInternalServerError) + + client = resty.New() + client.SetRedirectPolicy(test.CustomRedirectPolicy(20)) + + // without creds should work + resp, err = client.R(). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // login again + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("provider", "dex"). + Get(baseURL + constants.LoginPath) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusCreated) + + client.SetCookies(resp.Cookies()) + + resp, err = client.R(). + SetBody(reqBody). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Post(baseURL + constants.FullAPIKeyPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusCreated) + + err = json.Unmarshal(resp.Body(), &apiKeyResponse) + So(err, ShouldBeNil) + + // should work with session + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // should work with api key + resp, err = client.R(). + SetBasicAuth(email, apiKeyResponse.APIKey). + Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + resp, err = client.R(). + SetBasicAuth(email, apiKeyResponse.APIKey). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + err = json.Unmarshal(resp.Body(), &apiKeyResponse) + So(err, ShouldBeNil) + + // delete api key + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + SetQueryParam("id", apiKeyResponse.UUID). + Delete(baseURL + constants.FullAPIKeyPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + resp, err = client.R(). + SetHeader(constants.SessionClientHeaderName, constants.SessionClientHeaderValue). + Delete(baseURL + constants.FullAPIKeyPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusBadRequest) + + resp, err = client.R(). + SetBasicAuth(email, apiKeyResponse.APIKey). + Get(baseURL + "/v2/_catalog") + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusUnauthorized) + + resp, err = client.R(). + SetBasicAuth("test", "test"). + SetQueryParam("id", apiKeyResponse.UUID). + Delete(baseURL + constants.FullAPIKeyPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // unsupported method + resp, err = client.R(). + Put(baseURL + constants.FullAPIKeyPrefix) + So(err, ShouldBeNil) + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusMethodNotAllowed) + }) + }) +} + +func TestAPIKeysOpenDBError(t *testing.T) { + Convey("Test API keys - unable to create database", t, func() { + conf := config.New() + htpasswdPath := test.MakeHtpasswdFile() + defer os.Remove(htpasswdPath) + + mockOIDCServer, err := test.MockOIDCRun() + if err != nil { + panic(err) + } + + defer func() { + err := mockOIDCServer.Shutdown() + if err != nil { + panic(err) + } + }() + + mockOIDCConfig := mockOIDCServer.Config() + conf.HTTP.Auth = &config.AuthConfig{ + HTPasswd: config.AuthHTPasswd{ + Path: htpasswdPath, + }, + + OpenID: &config.OpenIDConfig{ + Providers: map[string]config.OpenIDProviderConfig{ + "dex": { + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + KeyPath: "", + Issuer: mockOIDCConfig.Issuer, + Scopes: []string{"openid", "email"}, + }, + }, + }, + } + + defaultVal := true + apiKeyConfig := &extconf.APIKeyConfig{ + BaseConfig: extconf.BaseConfig{Enable: &defaultVal}, + } + conf.Extensions = &extconf.ExtensionConfig{ + APIKey: apiKeyConfig, + } + + ctlr := api.NewController(conf) + dir := t.TempDir() + + err = os.Chmod(dir, 0o000) + So(err, ShouldBeNil) + + ctlr.Config.Storage.RootDirectory = dir + cm := test.NewControllerManager(ctlr) + + So(func() { + cm.StartServer() + }, ShouldPanic) + }) +} diff --git a/pkg/extensions/extension_mgmt.go b/pkg/extensions/extension_mgmt.go index 8c3b31b1..f06e5e7b 100644 --- a/pkg/extensions/extension_mgmt.go +++ b/pkg/extensions/extension_mgmt.go @@ -36,12 +36,19 @@ type BearerConfig struct { Service string `json:"service,omitempty"` } +type OpenIDProviderConfig struct{} + +type OpenIDConfig struct { + Providers map[string]OpenIDProviderConfig `json:"providers,omitempty" mapstructure:"providers"` +} + type Auth struct { HTPasswd *HTPasswd `json:"htpasswd,omitempty" mapstructure:"htpasswd"` Bearer *BearerConfig `json:"bearer,omitempty" mapstructure:"bearer"` LDAP *struct { Address string `json:"address,omitempty" mapstructure:"address"` } `json:"ldap,omitempty" mapstructure:"ldap"` + OpenID *OpenIDConfig `json:"openid,omitempty" mapstructure:"openid"` } type StrippedConfig struct { @@ -60,8 +67,10 @@ func (auth Auth) MarshalJSON() ([]byte, error) { type localAuth Auth if auth.Bearer == nil && auth.LDAP == nil && - auth.HTPasswd.Path == "" { + auth.HTPasswd.Path == "" && + (auth.OpenID == nil || len(auth.OpenID.Providers) == 0) { auth.HTPasswd = nil + auth.OpenID = nil return json.Marshal((localAuth)(auth)) } @@ -72,6 +81,10 @@ func (auth Auth) MarshalJSON() ([]byte, error) { auth.HTPasswd.Path = "" } + if auth.OpenID != nil && len(auth.OpenID.Providers) == 0 { + auth.OpenID = nil + } + auth.LDAP = nil return json.Marshal((localAuth)(auth)) diff --git a/pkg/extensions/extension_userprefs.go b/pkg/extensions/extension_userprefs.go index 0a9404ec..5bee652c 100644 --- a/pkg/extensions/extension_userprefs.go +++ b/pkg/extensions/extension_userprefs.go @@ -39,6 +39,7 @@ func SetupUserPreferencesRoutes(config *config.Config, router *mux.Router, store userprefsRouter := router.PathPrefix(constants.ExtUserPreferences).Subrouter() userprefsRouter.Use(zcommon.ACHeadersHandler(allowedMethods...)) userprefsRouter.Use(zcommon.AddExtensionSecurityHeaders()) + userprefsRouter.HandleFunc("", HandleUserPrefs(repoDB, log)).Methods(allowedMethods...) } } diff --git a/pkg/extensions/extensions_test.go b/pkg/extensions/extensions_test.go index 73d79b0d..13059fc6 100644 --- a/pkg/extensions/extensions_test.go +++ b/pkg/extensions/extensions_test.go @@ -1,5 +1,5 @@ -//go:build sync || metrics || mgmt -// +build sync metrics mgmt +//go:build sync || metrics || mgmt || apikey +// +build sync metrics mgmt apikey package extensions_test @@ -128,6 +128,20 @@ func TestMgmtExtension(t *testing.T) { defaultValue := true + mockOIDCServer, err := test.MockOIDCRun() + if err != nil { + panic(err) + } + + defer func() { + err := mockOIDCServer.Shutdown() + if err != nil { + panic(err) + } + }() + + mockOIDCConfig := mockOIDCServer.Config() + Convey("Verify mgmt route enabled with htpasswd", t, func() { htpasswdPath := test.MakeHtpasswdFile() conf.HTTP.Auth.HTPasswd.Path = htpasswdPath @@ -145,7 +159,7 @@ func TestMgmtExtension(t *testing.T) { ctlr := api.NewController(conf) subPaths := make(map[string]config.StorageConfig) - subPaths["/a"] = config.StorageConfig{} + subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()} ctlr.Config.Storage.RootDirectory = globalDir ctlr.Config.Storage.SubPaths = subPaths @@ -158,6 +172,13 @@ func TestMgmtExtension(t *testing.T) { So(string(data), ShouldContainSubstring, "setting up mgmt routes") + Convey("unsupported http method call", func() { + // without credentials + resp, err := resty.R().Patch(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusMethodNotAllowed) + }) + // without credentials resp, err := resty.R().Get(baseURL + constants.FullMgmtPrefix) So(err, ShouldBeNil) @@ -210,9 +231,9 @@ func TestMgmtExtension(t *testing.T) { ctlr := api.NewController(conf) subPaths := make(map[string]config.StorageConfig) - subPaths["/a"] = config.StorageConfig{} + subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()} - ctlr.Config.Storage.RootDirectory = globalDir + ctlr.Config.Storage.RootDirectory = t.TempDir() ctlr.Config.Storage.SubPaths = subPaths ctlrManager := test.NewControllerManager(ctlr) @@ -259,9 +280,9 @@ func TestMgmtExtension(t *testing.T) { ctlr := api.NewController(conf) subPaths := make(map[string]config.StorageConfig) - subPaths["/a"] = config.StorageConfig{} + subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()} - ctlr.Config.Storage.RootDirectory = globalDir + ctlr.Config.Storage.RootDirectory = t.TempDir() ctlr.Config.Storage.SubPaths = subPaths ctlrManager := test.NewControllerManager(ctlr) @@ -325,11 +346,7 @@ func TestMgmtExtension(t *testing.T) { ctlr := api.NewController(conf) - subPaths := make(map[string]config.StorageConfig) - subPaths["/a"] = config.StorageConfig{} - - ctlr.Config.Storage.RootDirectory = globalDir - ctlr.Config.Storage.SubPaths = subPaths + ctlr.Config.Storage.RootDirectory = t.TempDir() ctlrManager := test.NewControllerManager(ctlr) ctlrManager.StartAndWait(port) @@ -396,9 +413,9 @@ func TestMgmtExtension(t *testing.T) { ctlr := api.NewController(conf) subPaths := make(map[string]config.StorageConfig) - subPaths["/a"] = config.StorageConfig{} + subPaths["/a"] = config.StorageConfig{RootDirectory: t.TempDir()} - ctlr.Config.Storage.RootDirectory = globalDir + ctlr.Config.Storage.RootDirectory = t.TempDir() ctlr.Config.Storage.SubPaths = subPaths ctlrManager := test.NewControllerManager(ctlr) @@ -445,11 +462,7 @@ func TestMgmtExtension(t *testing.T) { ctlr := api.NewController(conf) - subPaths := make(map[string]config.StorageConfig) - subPaths["/a"] = config.StorageConfig{} - - ctlr.Config.Storage.RootDirectory = globalDir - ctlr.Config.Storage.SubPaths = subPaths + ctlr.Config.Storage.RootDirectory = t.TempDir() ctlrManager := test.NewControllerManager(ctlr) ctlrManager.StartAndWait(port) @@ -474,6 +487,110 @@ func TestMgmtExtension(t *testing.T) { So(mgmtResp.HTTP.Auth.Bearer.Service, ShouldEqual, "service") }) + Convey("Verify mgmt route enabled with openID", t, func() { + conf.HTTP.Auth.HTPasswd.Path = "" + conf.HTTP.Auth.LDAP = nil + conf.HTTP.Auth.Bearer = nil + + openIDProviders := make(map[string]config.OpenIDProviderConfig) + openIDProviders["dex"] = config.OpenIDProviderConfig{ + ClientID: mockOIDCConfig.ClientID, + ClientSecret: mockOIDCConfig.ClientSecret, + Issuer: mockOIDCConfig.Issuer, + } + + conf.HTTP.Auth.OpenID = &config.OpenIDConfig{ + Providers: openIDProviders, + } + + conf.Extensions = &extconf.ExtensionConfig{} + conf.Extensions.Mgmt = &extconf.MgmtConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &defaultValue, + }, + } + + conf.Log.Output = logFile.Name() + defer os.Remove(logFile.Name()) // cleanup + + ctlr := api.NewController(conf) + + ctlr.Config.Storage.RootDirectory = t.TempDir() + + ctlrManager := test.NewControllerManager(ctlr) + ctlrManager.StartAndWait(port) + defer ctlrManager.StopServer() + + data, _ := os.ReadFile(logFile.Name()) + + So(string(data), ShouldContainSubstring, "setting up mgmt routes") + + // without credentials + resp, err := resty.R().Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + mgmtResp := extensions.StrippedConfig{} + err = json.Unmarshal(resp.Body(), &mgmtResp) + t.Logf("resp: %v", mgmtResp.HTTP.Auth.OpenID) + So(err, ShouldBeNil) + So(mgmtResp.HTTP.Auth.HTPasswd, ShouldBeNil) + So(mgmtResp.HTTP.Auth.LDAP, ShouldBeNil) + So(mgmtResp.HTTP.Auth.Bearer, ShouldBeNil) + So(mgmtResp.HTTP.Auth.OpenID, ShouldNotBeNil) + So(mgmtResp.HTTP.Auth.OpenID.Providers, ShouldNotBeEmpty) + }) + + Convey("Verify mgmt route enabled with empty openID provider list", t, func() { + htpasswdPath := test.MakeHtpasswdFile() + + conf.HTTP.Auth.HTPasswd.Path = htpasswdPath + conf.HTTP.Auth.LDAP = nil + conf.HTTP.Auth.Bearer = nil + + openIDProviders := make(map[string]config.OpenIDProviderConfig) + + conf.HTTP.Auth.OpenID = &config.OpenIDConfig{ + Providers: openIDProviders, + } + + conf.Extensions = &extconf.ExtensionConfig{} + conf.Extensions.Mgmt = &extconf.MgmtConfig{ + BaseConfig: extconf.BaseConfig{ + Enable: &defaultValue, + }, + } + + conf.Log.Output = logFile.Name() + defer os.Remove(logFile.Name()) // cleanup + + ctlr := api.NewController(conf) + + ctlr.Config.Storage.RootDirectory = t.TempDir() + + ctlrManager := test.NewControllerManager(ctlr) + ctlrManager.StartAndWait(port) + defer ctlrManager.StopServer() + + data, _ := os.ReadFile(logFile.Name()) + + So(string(data), ShouldContainSubstring, "setting up mgmt routes") + + // without credentials + resp, err := resty.R().Get(baseURL + constants.FullMgmtPrefix) + So(err, ShouldBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + mgmtResp := extensions.StrippedConfig{} + err = json.Unmarshal(resp.Body(), &mgmtResp) + t.Logf("resp: %v", mgmtResp.HTTP.Auth.OpenID) + So(err, ShouldBeNil) + So(mgmtResp.HTTP.Auth.HTPasswd, ShouldNotBeNil) + So(mgmtResp.HTTP.Auth.LDAP, ShouldBeNil) + So(mgmtResp.HTTP.Auth.Bearer, ShouldBeNil) + So(mgmtResp.HTTP.Auth.OpenID, ShouldBeNil) + }) + Convey("Verify mgmt route enabled without any auth", t, func() { globalDir := t.TempDir() conf := config.New() @@ -499,11 +616,7 @@ func TestMgmtExtension(t *testing.T) { ctlr := api.NewController(conf) - subPaths := make(map[string]config.StorageConfig) - subPaths["/a"] = config.StorageConfig{} - - ctlr.Config.Storage.RootDirectory = globalDir - ctlr.Config.Storage.SubPaths = subPaths + ctlr.Config.Storage.RootDirectory = t.TempDir() ctlrManager := test.NewControllerManager(ctlr) ctlrManager.StartAndWait(port) @@ -856,3 +969,32 @@ func TestAllowedMethodsHeaderMgmt(t *testing.T) { So(resp.StatusCode(), ShouldEqual, http.StatusNoContent) }) } + +func TestAllowedMethodsHeaderAPIKey(t *testing.T) { + defaultVal := true + + Convey("Test http options response", t, func() { + conf := config.New() + port := test.GetFreePort() + conf.HTTP.Port = port + conf.Extensions = &extconf.ExtensionConfig{ + APIKey: &extconf.APIKeyConfig{ + BaseConfig: extconf.BaseConfig{Enable: &defaultVal}, + }, + } + baseURL := test.GetBaseURL(port) + + ctlr := api.NewController(conf) + ctlr.Config.Storage.RootDirectory = t.TempDir() + + ctrlManager := test.NewControllerManager(ctlr) + + ctrlManager.StartAndWait(port) + defer ctrlManager.StopServer() + + resp, _ := resty.R().Options(baseURL + constants.FullAPIKeyPrefix) + So(resp, ShouldNotBeNil) + So(resp.Header().Get("Access-Control-Allow-Methods"), ShouldResemble, "POST,DELETE,OPTIONS") + So(resp.StatusCode(), ShouldEqual, http.StatusNoContent) + }) +} diff --git a/pkg/meta/bolt/buckets.go b/pkg/meta/bolt/buckets.go index 388daf96..72ce7872 100644 --- a/pkg/meta/bolt/buckets.go +++ b/pkg/meta/bolt/buckets.go @@ -7,6 +7,5 @@ const ( RepoMetadataBucket = "RepoMetadata" UserDataBucket = "UserData" VersionBucket = "Version" - StarredReposKey = "StarredReposKey" - BookmarkedReposKey = "BookmarkedReposKey" + UserAPIKeysBucket = "UserAPIKeys" ) diff --git a/pkg/meta/dynamo/parameters.go b/pkg/meta/dynamo/parameters.go index cbe71729..a5355799 100644 --- a/pkg/meta/dynamo/parameters.go +++ b/pkg/meta/dynamo/parameters.go @@ -10,7 +10,7 @@ import ( type DBDriverParameters struct { Endpoint, Region, RepoMetaTablename, ManifestDataTablename, IndexDataTablename, - VersionTablename, UserDataTablename string + UserDataTablename, APIKeyTablename, VersionTablename string } func GetDynamoClient(params DBDriverParameters) (*dynamodb.Client, error) { diff --git a/pkg/meta/repodb/boltdb-wrapper/boltdb_wrapper.go b/pkg/meta/repodb/boltdb-wrapper/boltdb_wrapper.go index a49d00c8..1a32e06c 100644 --- a/pkg/meta/repodb/boltdb-wrapper/boltdb_wrapper.go +++ b/pkg/meta/repodb/boltdb-wrapper/boltdb_wrapper.go @@ -3,6 +3,7 @@ package bolt import ( "context" "encoding/json" + "errors" "fmt" "strings" "time" @@ -60,6 +61,11 @@ func NewBoltDBWrapper(boltDB *bbolt.DB, log log.Logger) (*DBWrapper, error) { return err } + _, err = transaction.CreateBucketIfNotExists([]byte(bolt.UserAPIKeysBucket)) + if err != nil { + return err + } + return nil }) if err != nil { @@ -1680,39 +1686,26 @@ func (bdw *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (repodb.T var res repodb.ToggleState if err := bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:varnamelen - userdb := tx.Bucket([]byte(bolt.UserDataBucket)) - userBucket, err := userdb.CreateBucketIfNotExists([]byte(userid)) - if err != nil { - // this is a serious failure - return zerr.ErrUnableToCreateUserBucket + var userData repodb.UserData + + err := bdw.getUserData(userid, tx, &userData) + if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) { + return err } - mdata := userBucket.Get([]byte(bolt.StarredReposKey)) - unpacked := []string{} - if mdata != nil { - if err = json.Unmarshal(mdata, &unpacked); err != nil { - return zerr.ErrInvalidOldUserStarredRepos - } - } - - isRepoStarred := zcommon.Contains(unpacked, repo) + isRepoStarred := zcommon.Contains(userData.StarredRepos, repo) if isRepoStarred { res = repodb.Removed - unpacked = zcommon.RemoveFrom(unpacked, repo) + userData.StarredRepos = zcommon.RemoveFrom(userData.StarredRepos, repo) } else { res = repodb.Added - unpacked = append(unpacked, repo) + userData.StarredRepos = append(userData.StarredRepos, repo) } - var repacked []byte - if repacked, err = json.Marshal(unpacked); err != nil { - return zerr.ErrCouldNotMarshalStarredRepos - } - - err = userBucket.Put([]byte(bolt.StarredReposKey), repacked) + err = bdw.setUserData(userid, tx, userData) if err != nil { - return zerr.ErrCouldNotPersistData + return err } repoBuck := tx.Bucket([]byte(bolt.RepoMetadataBucket)) @@ -1755,46 +1748,12 @@ func (bdw *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) (repodb.T } func (bdw *DBWrapper) GetStarredRepos(ctx context.Context) ([]string, error) { - starredRepos := make([]string, 0) - - acCtx, err := localCtx.GetAccessControlContext(ctx) - if err != nil { - return starredRepos, err + userData, err := bdw.GetUserData(ctx) + if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) { + return []string{}, nil } - userid := localCtx.GetUsernameFromContext(acCtx) - - err = bdw.DB.View(func(tx *bbolt.Tx) error { //nolint:dupl - if userid == "" { - return nil - } - - userdb := tx.Bucket([]byte(bolt.UserDataBucket)) - userBucket := userdb.Bucket([]byte(userid)) - - if userBucket == nil { - return nil - } - - mdata := userBucket.Get([]byte(bolt.StarredReposKey)) - if mdata == nil { - return nil - } - - if err := json.Unmarshal(mdata, &starredRepos); err != nil { - bdw.Log.Info().Str("user", userid).Err(err).Msg("unmarshal error") - - return zerr.ErrInvalidOldUserStarredRepos - } - - if starredRepos == nil { - starredRepos = make([]string, 0) - } - - return nil - }) - - return starredRepos, err + return userData.StarredRepos, err } func (bdw *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (repodb.ToggleState, error) { @@ -1815,43 +1774,25 @@ func (bdw *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (repo var res repodb.ToggleState - if err := bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:dupl - userdb := tx.Bucket([]byte(bolt.UserDataBucket)) - userBucket, err := userdb.CreateBucketIfNotExists([]byte(userid)) - if err != nil { - // this is a serious failure - return zerr.ErrUnableToCreateUserBucket + if err := bdw.DB.Update(func(transaction *bbolt.Tx) error { //nolint:dupl + var userData repodb.UserData + + err := bdw.getUserData(userid, transaction, &userData) + if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) { + return err } - mdata := userBucket.Get([]byte(bolt.BookmarkedReposKey)) - unpacked := []string{} - if mdata != nil { - if err = json.Unmarshal(mdata, &unpacked); err != nil { - return zerr.ErrInvalidOldUserBookmarkedRepos - } - } - - isRepoBookmarked := zcommon.Contains(unpacked, repo) + isRepoBookmarked := zcommon.Contains(userData.BookmarkedRepos, repo) if isRepoBookmarked { res = repodb.Removed - unpacked = zcommon.RemoveFrom(unpacked, repo) + userData.BookmarkedRepos = zcommon.RemoveFrom(userData.BookmarkedRepos, repo) } else { res = repodb.Added - unpacked = append(unpacked, repo) + userData.BookmarkedRepos = append(userData.BookmarkedRepos, repo) } - var repacked []byte - if repacked, err = json.Marshal(unpacked); err != nil { - return zerr.ErrCouldNotMarshalBookmarkedRepos - } - - err = userBucket.Put([]byte(bolt.BookmarkedReposKey), repacked) - if err != nil { - return zerr.ErrUnableToCreateUserBucket - } - - return nil + return bdw.setUserData(userid, transaction, userData) }); err != nil { return repodb.NotChanged, err } @@ -1860,46 +1801,12 @@ func (bdw *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) (repo } func (bdw *DBWrapper) GetBookmarkedRepos(ctx context.Context) ([]string, error) { - bookmarkedRepos := []string{} - - acCtx, err := localCtx.GetAccessControlContext(ctx) - if err != nil { - return bookmarkedRepos, err + userData, err := bdw.GetUserData(ctx) + if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) { + return []string{}, nil } - userid := localCtx.GetUsernameFromContext(acCtx) - - err = bdw.DB.View(func(tx *bbolt.Tx) error { //nolint:dupl - if userid == "" { - return nil - } - - userdb := tx.Bucket([]byte(bolt.UserDataBucket)) - userBucket := userdb.Bucket([]byte(userid)) - - if userBucket == nil { - return nil - } - - mdata := userBucket.Get([]byte(bolt.BookmarkedReposKey)) - if mdata == nil { - return nil - } - - if err := json.Unmarshal(mdata, &bookmarkedRepos); err != nil { - bdw.Log.Info().Str("user", userid).Err(err).Msg("unmarshal error") - - return zerr.ErrInvalidOldUserBookmarkedRepos - } - - if bookmarkedRepos == nil { - bookmarkedRepos = make([]string, 0) - } - - return nil - }) - - return bookmarkedRepos, err + return userData.BookmarkedRepos, err } func (bdw *DBWrapper) PatchDB() error { @@ -1940,30 +1847,25 @@ func getUserStars(ctx context.Context, transaction *bbolt.Tx) []string { } var ( - userid = localCtx.GetUsernameFromContext(acCtx) - starredRepos = []string{} - userdb = transaction.Bucket([]byte(bolt.UserDataBucket)) - userBucket = userdb.Bucket([]byte(userid)) + userData repodb.UserData + userid = localCtx.GetUsernameFromContext(acCtx) + userdb = transaction.Bucket([]byte(bolt.UserDataBucket)) ) - if userid == "" { + if userid == "" || userdb == nil { return []string{} } - if userBucket == nil { - return []string{} - } - - mdata := userBucket.Get([]byte(bolt.StarredReposKey)) + mdata := userdb.Get([]byte(userid)) if mdata == nil { return []string{} } - if err := json.Unmarshal(mdata, &starredRepos); err != nil { + if err := json.Unmarshal(mdata, &userData); err != nil { return []string{} } - return starredRepos + return userData.StarredRepos } func getUserBookmarks(ctx context.Context, transaction *bbolt.Tx) []string { @@ -1973,28 +1875,309 @@ func getUserBookmarks(ctx context.Context, transaction *bbolt.Tx) []string { } var ( - userid = localCtx.GetUsernameFromContext(acCtx) - bookmarkedRepos = []string{} - userdb = transaction.Bucket([]byte(bolt.UserDataBucket)) - userBucket = userdb.Bucket([]byte(userid)) + userData repodb.UserData + userid = localCtx.GetUsernameFromContext(acCtx) + userdb = transaction.Bucket([]byte(bolt.UserDataBucket)) ) - if userid == "" { + if userid == "" || userdb == nil { return []string{} } - if userBucket == nil { - return []string{} - } - - mdata := userBucket.Get([]byte(bolt.BookmarkedReposKey)) + mdata := userdb.Get([]byte(userid)) if mdata == nil { return []string{} } - if err := json.Unmarshal(mdata, &bookmarkedRepos); err != nil { + if err := json.Unmarshal(mdata, &userData); err != nil { return []string{} } - return bookmarkedRepos + return userData.BookmarkedRepos +} + +func (bdw *DBWrapper) SetUserGroups(ctx context.Context, groups []string) error { + acCtx, err := localCtx.GetAccessControlContext(ctx) + if err != nil { + return err + } + + userid := localCtx.GetUsernameFromContext(acCtx) + + if userid == "" { + // empty user is anonymous + return zerr.ErrUserDataNotAllowed + } + + err = bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:varnamelen + var userData repodb.UserData + + err := bdw.getUserData(userid, tx, &userData) + if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) { + return err + } + + userData.Groups = append(userData.Groups, groups...) + + err = bdw.setUserData(userid, tx, userData) + + return err + }) + + return err +} + +func (bdw *DBWrapper) GetUserGroups(ctx context.Context) ([]string, error) { + userData, err := bdw.GetUserData(ctx) + + return userData.Groups, err +} + +func (bdw *DBWrapper) UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error { + acCtx, err := localCtx.GetAccessControlContext(ctx) + if err != nil { + return err + } + + userid := localCtx.GetUsernameFromContext(acCtx) + + if userid == "" { + // empty user is anonymous + return zerr.ErrUserDataNotAllowed + } + + err = bdw.DB.Update(func(tx *bbolt.Tx) error { //nolint:varnamelen + var userData repodb.UserData + + err := bdw.getUserData(userid, tx, &userData) + if err != nil { + return err + } + + apiKeyDetails := userData.APIKeys[hashedKey] + apiKeyDetails.LastUsed = time.Now() + + userData.APIKeys[hashedKey] = apiKeyDetails + + err = bdw.setUserData(userid, tx, userData) + + return err + }) + + return err +} + +func (bdw *DBWrapper) AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error { + acCtx, err := localCtx.GetAccessControlContext(ctx) + if err != nil { + return err + } + + userid := localCtx.GetUsernameFromContext(acCtx) + if userid == "" { + // empty user is anonymous + return zerr.ErrUserDataNotAllowed + } + + err = bdw.DB.Update(func(transaction *bbolt.Tx) error { + var userData repodb.UserData + + apiKeysbuck := transaction.Bucket([]byte(bolt.UserAPIKeysBucket)) + if apiKeysbuck == nil { + return zerr.ErrBucketDoesNotExist + } + + err := apiKeysbuck.Put([]byte(hashedKey), []byte(userid)) + if err != nil { + return fmt.Errorf("repoDB: error while setting userData for identity %s %w", userid, err) + } + + err = bdw.getUserData(userid, transaction, &userData) + if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) { + return err + } + + if userData.APIKeys == nil { + userData.APIKeys = make(map[string]repodb.APIKeyDetails) + } + + userData.APIKeys[hashedKey] = *apiKeyDetails + + err = bdw.setUserData(userid, transaction, userData) + + return err + }) + + return err +} + +func (bdw *DBWrapper) DeleteUserAPIKey(ctx context.Context, keyID string) error { + acCtx, err := localCtx.GetAccessControlContext(ctx) + if err != nil { + return err + } + + userid := localCtx.GetUsernameFromContext(acCtx) + if userid == "" { + // empty user is anonymous + return zerr.ErrUserDataNotAllowed + } + + err = bdw.DB.Update(func(transaction *bbolt.Tx) error { + var userData repodb.UserData + + apiKeysbuck := transaction.Bucket([]byte(bolt.UserAPIKeysBucket)) + if apiKeysbuck == nil { + return zerr.ErrBucketDoesNotExist + } + + err := bdw.getUserData(userid, transaction, &userData) + if err != nil { + return err + } + + for hash, apiKeyDetails := range userData.APIKeys { + if apiKeyDetails.UUID == keyID { + delete(userData.APIKeys, hash) + + err := apiKeysbuck.Delete([]byte(hash)) + if err != nil { + return fmt.Errorf("userDB: error while deleting userAPIKey entry for hash %s %w", hash, err) + } + } + } + + return bdw.setUserData(userid, transaction, userData) + }) + + return err +} + +func (bdw *DBWrapper) GetUserAPIKeyInfo(hashedKey string) (string, error) { + var userid string + err := bdw.DB.View(func(tx *bbolt.Tx) error { + buck := tx.Bucket([]byte(bolt.UserAPIKeysBucket)) + if buck == nil { + return zerr.ErrBucketDoesNotExist + } + + uiBlob := buck.Get([]byte(hashedKey)) + if len(uiBlob) == 0 { + return zerr.ErrUserAPIKeyNotFound + } + + userid = string(uiBlob) + + return nil + }) + + return userid, err +} + +func (bdw *DBWrapper) GetUserData(ctx context.Context) (repodb.UserData, error) { + var userData repodb.UserData + + acCtx, err := localCtx.GetAccessControlContext(ctx) + if err != nil { + return userData, err + } + + userid := localCtx.GetUsernameFromContext(acCtx) + if userid == "" { + // empty user is anonymous + return userData, zerr.ErrUserDataNotAllowed + } + + err = bdw.DB.View(func(tx *bbolt.Tx) error { + return bdw.getUserData(userid, tx, &userData) + }) + + return userData, err +} + +func (bdw *DBWrapper) getUserData(userid string, transaction *bbolt.Tx, res *repodb.UserData) error { + buck := transaction.Bucket([]byte(bolt.UserDataBucket)) + if buck == nil { + return zerr.ErrBucketDoesNotExist + } + + upBlob := buck.Get([]byte(userid)) + + if len(upBlob) == 0 { + return zerr.ErrUserDataNotFound + } + + err := json.Unmarshal(upBlob, res) + if err != nil { + return err + } + + return nil +} + +func (bdw *DBWrapper) SetUserData(ctx context.Context, userData repodb.UserData) error { + acCtx, err := localCtx.GetAccessControlContext(ctx) + if err != nil { + return err + } + + userid := localCtx.GetUsernameFromContext(acCtx) + if userid == "" { + // empty user is anonymous + return zerr.ErrUserDataNotAllowed + } + + err = bdw.DB.Update(func(tx *bbolt.Tx) error { + return bdw.setUserData(userid, tx, userData) + }) + + return err +} + +func (bdw *DBWrapper) setUserData(userid string, transaction *bbolt.Tx, userData repodb.UserData) error { + buck := transaction.Bucket([]byte(bolt.UserDataBucket)) + if buck == nil { + return zerr.ErrBucketDoesNotExist + } + + upBlob, err := json.Marshal(userData) + if err != nil { + return err + } + + err = buck.Put([]byte(userid), upBlob) + if err != nil { + return fmt.Errorf("repoDB: error while setting userData for identity %s %w", userid, err) + } + + return nil +} + +func (bdw *DBWrapper) DeleteUserData(ctx context.Context) error { + acCtx, err := localCtx.GetAccessControlContext(ctx) + if err != nil { + return err + } + + userid := localCtx.GetUsernameFromContext(acCtx) + if userid == "" { + // empty user is anonymous + return zerr.ErrUserDataNotAllowed + } + + err = bdw.DB.Update(func(tx *bbolt.Tx) error { + buck := tx.Bucket([]byte(bolt.UserDataBucket)) + if buck == nil { + return zerr.ErrBucketDoesNotExist + } + + err := buck.Delete([]byte(userid)) + if err != nil { + return fmt.Errorf("repoDB: error while deleting userData for identity %s %w", userid, err) + } + + return nil + }) + + return err } diff --git a/pkg/meta/repodb/boltdb-wrapper/boltdb_wrapper_test.go b/pkg/meta/repodb/boltdb-wrapper/boltdb_wrapper_test.go index e9e26701..d71d3176 100644 --- a/pkg/meta/repodb/boltdb-wrapper/boltdb_wrapper_test.go +++ b/pkg/meta/repodb/boltdb-wrapper/boltdb_wrapper_test.go @@ -2,7 +2,10 @@ package bolt_test import ( "context" + "crypto/rand" + "encoding/base64" "encoding/json" + "math" "testing" "github.com/opencontainers/go-digest" @@ -10,6 +13,7 @@ import ( . "github.com/smartystreets/goconvey/convey" "go.etcd.io/bbolt" + zerr "zotregistry.io/zot/errors" "zotregistry.io/zot/pkg/log" "zotregistry.io/zot/pkg/meta/bolt" "zotregistry.io/zot/pkg/meta/repodb" @@ -21,7 +25,6 @@ import ( func TestWrapperErrors(t *testing.T) { Convey("Errors", t, func() { - ctx := context.Background() tmpDir := t.TempDir() boltDBParams := bolt.DBParameters{RootDir: tmpDir} boltDriver, err := bolt.GetBoltDriver(boltDBParams) @@ -41,6 +44,231 @@ func TestWrapperErrors(t *testing.T) { repoMetaBlob, err := json.Marshal(repoMeta) So(err, ShouldBeNil) + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: "test", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + + Convey("AddUserAPIKey", func() { + Convey("no userid found", func() { + acCtx := localCtx.AccessControlContext{ + Username: "", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + err = boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{}) + So(err, ShouldNotBeNil) + }) + + err = boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{}) + So(err, ShouldNotBeNil) + + err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { + return tx.DeleteBucket([]byte(bolt.UserDataBucket)) + }) + So(err, ShouldBeNil) + + err = boltdbWrapper.AddUserAPIKey(ctx, "test", &repodb.APIKeyDetails{}) + So(err, ShouldNotBeNil) + + err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { + return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket)) + }) + + So(err, ShouldBeNil) + + err = boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{}) + So(err, ShouldEqual, zerr.ErrBucketDoesNotExist) + }) + + Convey("UpdateUserAPIKey", func() { + err = boltdbWrapper.UpdateUserAPIKeyLastUsed(ctx, "") + So(err, ShouldNotBeNil) + + acCtx := localCtx.AccessControlContext{ + Username: "", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + err = boltdbWrapper.UpdateUserAPIKeyLastUsed(ctx, "") //nolint: contextcheck + So(err, ShouldNotBeNil) + }) + + Convey("DeleteUserAPIKey", func() { + err = boltdbWrapper.SetUserData(ctx, repodb.UserData{}) + So(err, ShouldBeNil) + + err = boltdbWrapper.AddUserAPIKey(ctx, "hashedKey", &repodb.APIKeyDetails{}) + So(err, ShouldBeNil) + + Convey("no such bucket", func() { + err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { + return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket)) + }) + So(err, ShouldBeNil) + + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: "test", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + + err = boltdbWrapper.DeleteUserAPIKey(ctx, "") + So(err, ShouldEqual, zerr.ErrBucketDoesNotExist) + }) + + Convey("userdata not found", func() { + authzCtxKey := localCtx.GetContextKey() + acCtx := localCtx.AccessControlContext{ + Username: "test", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + err := boltdbWrapper.DeleteUserData(ctx) + So(err, ShouldBeNil) + + err = boltdbWrapper.DeleteUserAPIKey(ctx, "") + So(err, ShouldNotBeNil) + }) + + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: "", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) //nolint: contextcheck + + err = boltdbWrapper.DeleteUserAPIKey(ctx, "test") //nolint: contextcheck + So(err, ShouldNotBeNil) + + err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { + return tx.DeleteBucket([]byte(bolt.UserDataBucket)) + }) + So(err, ShouldBeNil) + + err = boltdbWrapper.DeleteUserAPIKey(ctx, "") //nolint: contextcheck + So(err, ShouldNotBeNil) + }) + + Convey("GetUserAPIKeyInfo", func() { + err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { + return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket)) + }) + So(err, ShouldBeNil) + + _, err = boltdbWrapper.GetUserAPIKeyInfo("") + So(err, ShouldNotBeNil) + }) + + Convey("GetUserData", func() { + err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { + buck := tx.Bucket([]byte(bolt.UserDataBucket)) + So(buck, ShouldNotBeNil) + + return buck.Put([]byte("test"), []byte("dsa8")) + }) + + So(err, ShouldBeNil) + + _, err = boltdbWrapper.GetUserData(ctx) + So(err, ShouldNotBeNil) + + err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { + return tx.DeleteBucket([]byte(bolt.UserAPIKeysBucket)) + }) + So(err, ShouldBeNil) + + _, err = boltdbWrapper.GetUserData(ctx) + So(err, ShouldNotBeNil) + }) + + Convey("SetUserData", func() { + acCtx = localCtx.AccessControlContext{ + Username: "", + } + + ctx = context.WithValue(context.Background(), authzCtxKey, acCtx) + + err = boltdbWrapper.SetUserData(ctx, repodb.UserData{}) + So(err, ShouldNotBeNil) + + buff := make([]byte, int(math.Ceil(float64(1000000)/float64(1.33333333333)))) + _, err := rand.Read(buff) + So(err, ShouldBeNil) + + longString := base64.RawURLEncoding.EncodeToString(buff) + + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: longString, + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + + err = boltdbWrapper.SetUserData(ctx, repodb.UserData{}) //nolint: contextcheck + So(err, ShouldNotBeNil) + + err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { + return tx.DeleteBucket([]byte(bolt.UserDataBucket)) + }) + So(err, ShouldBeNil) + + acCtx = localCtx.AccessControlContext{ + Username: "test", + } + + ctx = context.WithValue(context.Background(), authzCtxKey, acCtx) + + err = boltdbWrapper.SetUserData(ctx, repodb.UserData{}) //nolint: contextcheck + So(err, ShouldNotBeNil) + }) + + Convey("DeleteUserData", func() { + acCtx = localCtx.AccessControlContext{ + Username: "", + } + + ctx = context.WithValue(context.Background(), authzCtxKey, acCtx) + + err = boltdbWrapper.DeleteUserData(ctx) + So(err, ShouldNotBeNil) + + err = boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { + return tx.DeleteBucket([]byte(bolt.UserDataBucket)) + }) + So(err, ShouldBeNil) + + acCtx = localCtx.AccessControlContext{ + Username: "test", + } + + ctx = context.WithValue(context.Background(), authzCtxKey, acCtx) + + err = boltdbWrapper.DeleteUserData(ctx) + So(err, ShouldNotBeNil) + }) + + Convey("GetUserGroups and SetUserGroups", func() { + acCtx = localCtx.AccessControlContext{ + Username: "", + } + + ctx = context.WithValue(context.Background(), authzCtxKey, acCtx) + + _, err := boltdbWrapper.GetUserGroups(ctx) + So(err, ShouldNotBeNil) + + err = boltdbWrapper.SetUserGroups(ctx, []string{}) + So(err, ShouldNotBeNil) + }) + Convey("GetManifestData", func() { err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { dataBuck := tx.Bucket([]byte(bolt.ManifestDataBucket)) @@ -732,60 +960,6 @@ func TestWrapperErrors(t *testing.T) { So(err, ShouldNotBeNil) }) - Convey("ToggleStarRepo, getting StarredRepoKey from bucket fails", func() { - acCtx := localCtx.AccessControlContext{ - ReadGlobPatterns: map[string]bool{ - "repo": true, - }, - Username: "username", - } - authzCtxKey := localCtx.GetContextKey() - ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) - - err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { - userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket)) - So(err, ShouldBeNil) - userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username)) - So(err, ShouldBeNil) - - err = userBucket.Put([]byte(bolt.StarredReposKey), []byte("bad array")) - So(err, ShouldBeNil) - - return nil - }) - So(err, ShouldBeNil) - - _, err = boltdbWrapper.ToggleStarRepo(ctx, "repo") - So(err, ShouldNotBeNil) - }) - - Convey("ToggleBookmarkRepo, unmarshal error", func() { - acCtx := localCtx.AccessControlContext{ - ReadGlobPatterns: map[string]bool{ - "repo": true, - }, - Username: "username", - } - authzCtxKey := localCtx.GetContextKey() - ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) - - err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { - userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket)) - So(err, ShouldBeNil) - userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username)) - So(err, ShouldBeNil) - - err = userBucket.Put([]byte(bolt.BookmarkedReposKey), []byte("bad array")) - So(err, ShouldBeNil) - - return nil - }) - So(err, ShouldBeNil) - - _, err = boltdbWrapper.ToggleBookmarkRepo(ctx, "repo") - So(err, ShouldNotBeNil) - }) - Convey("ToggleStarRepo, no repoMeta found", func() { acCtx := localCtx.AccessControlContext{ ReadGlobPatterns: map[string]bool{ @@ -832,6 +1006,73 @@ func TestWrapperErrors(t *testing.T) { So(err, ShouldNotBeNil) }) + Convey("GetUserData bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + _, err := boltdbWrapper.GetUserData(ctx) + So(err, ShouldNotBeNil) + }) + + Convey("SetUserData bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + err := boltdbWrapper.SetUserData(ctx, repodb.UserData{}) + So(err, ShouldNotBeNil) + }) + + Convey("GetUserGroups bad context errors", func() { + _, err := boltdbWrapper.GetUserGroups(ctx) + So(err, ShouldNotBeNil) + + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + _, err = boltdbWrapper.GetUserGroups(ctx) //nolint: contextcheck + So(err, ShouldNotBeNil) + }) + + Convey("SetUserGroups bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + err := boltdbWrapper.SetUserGroups(ctx, []string{}) + So(err, ShouldNotBeNil) + }) + + Convey("AddUserAPIKey bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + err := boltdbWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{}) + So(err, ShouldNotBeNil) + }) + + Convey("DeleteUserAPIKey bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + err := boltdbWrapper.DeleteUserAPIKey(ctx, "") + So(err, ShouldNotBeNil) + }) + + Convey("UpdateUserAPIKeyLastUsed bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + err := boltdbWrapper.UpdateUserAPIKeyLastUsed(ctx, "") + So(err, ShouldNotBeNil) + }) + + Convey("DeleteUserData bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + err := boltdbWrapper.DeleteUserData(ctx) + So(err, ShouldNotBeNil) + }) + Convey("GetStarredRepos bad context errors", func() { authzCtxKey := localCtx.GetContextKey() ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") @@ -840,60 +1081,6 @@ func TestWrapperErrors(t *testing.T) { So(err, ShouldNotBeNil) }) - Convey("GetStarredRepos user data unmarshal error", func() { - acCtx := localCtx.AccessControlContext{ - ReadGlobPatterns: map[string]bool{ - "repo": true, - }, - Username: "username", - } - authzCtxKey := localCtx.GetContextKey() - ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) - - err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { - userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket)) - So(err, ShouldBeNil) - userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username)) - So(err, ShouldBeNil) - - err = userBucket.Put([]byte(bolt.StarredReposKey), []byte("bad array")) - So(err, ShouldBeNil) - - return nil - }) - So(err, ShouldBeNil) - - _, err = boltdbWrapper.GetStarredRepos(ctx) - So(err, ShouldNotBeNil) - }) - - Convey("GetBookmarkedRepos user data unmarshal error", func() { - acCtx := localCtx.AccessControlContext{ - ReadGlobPatterns: map[string]bool{ - "repo": true, - }, - Username: "username", - } - authzCtxKey := localCtx.GetContextKey() - ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) - - err := boltdbWrapper.DB.Update(func(tx *bbolt.Tx) error { - userdb, err := tx.CreateBucketIfNotExists([]byte(bolt.UserDataBucket)) - So(err, ShouldBeNil) - userBucket, err := userdb.CreateBucketIfNotExists([]byte(acCtx.Username)) - So(err, ShouldBeNil) - - err = userBucket.Put([]byte(bolt.BookmarkedReposKey), []byte("bad array")) - So(err, ShouldBeNil) - - return nil - }) - So(err, ShouldBeNil) - - _, err = boltdbWrapper.GetBookmarkedRepos(ctx) - So(err, ShouldNotBeNil) - }) - Convey("GetBookmarkedRepos bad context errors", func() { authzCtxKey := localCtx.GetContextKey() ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") diff --git a/pkg/meta/repodb/dynamodb-wrapper/dynamo_internal_test.go b/pkg/meta/repodb/dynamodb-wrapper/dynamo_internal_test.go index 079340af..a859b3aa 100644 --- a/pkg/meta/repodb/dynamodb-wrapper/dynamo_internal_test.go +++ b/pkg/meta/repodb/dynamodb-wrapper/dynamo_internal_test.go @@ -31,6 +31,7 @@ func TestWrapperErrors(t *testing.T) { manifestDataTablename := "ManifestDataTable" + uuid.String() indexDataTablename := "IndexDataTable" + uuid.String() userDataTablename := "UserDataTable" + uuid.String() + apiKeyTablename := "ApiKeyTable" + uuid.String() versionTablename := "Version" + uuid.String() @@ -58,6 +59,7 @@ func TestWrapperErrors(t *testing.T) { IndexDataTablename: indexDataTablename, VersionTablename: versionTablename, UserDataTablename: userDataTablename, + APIKeyTablename: apiKeyTablename, Patches: version.GetDynamoDBPatches(), Log: log.Logger{Logger: zerolog.New(os.Stdout)}, } @@ -74,6 +76,9 @@ func TestWrapperErrors(t *testing.T) { err = dynamoWrapper.createVersionTable() So(err, ShouldNotBeNil) + + err = dynamoWrapper.createAPIKeyTable() + So(err, ShouldNotBeNil) }) Convey("Delete table errors", t, func() { diff --git a/pkg/meta/repodb/dynamodb-wrapper/dynamo_test.go b/pkg/meta/repodb/dynamodb-wrapper/dynamo_test.go index 1f8deb44..041d9f2b 100644 --- a/pkg/meta/repodb/dynamodb-wrapper/dynamo_test.go +++ b/pkg/meta/repodb/dynamodb-wrapper/dynamo_test.go @@ -43,6 +43,7 @@ func TestIterator(t *testing.T) { versionTablename := "Version" + uuid.String() indexDataTablename := "IndexDataTable" + uuid.String() userDataTablename := "UserDataTable" + uuid.String() + apiKeyTablename := "ApiKeyTable" + uuid.String() log := log.NewLogger("debug", "") @@ -54,6 +55,7 @@ func TestIterator(t *testing.T) { ManifestDataTablename: manifestDataTablename, IndexDataTablename: indexDataTablename, VersionTablename: versionTablename, + APIKeyTablename: apiKeyTablename, UserDataTablename: userDataTablename, } client, err := dynamo.GetDynamoClient(params) @@ -144,8 +146,8 @@ func TestWrapperErrors(t *testing.T) { versionTablename := "Version" + uuid.String() indexDataTablename := "IndexDataTable" + uuid.String() userDataTablename := "UserDataTable" + uuid.String() - - ctx := context.Background() + apiKeyTablename := "ApiKeyTable" + uuid.String() + wrongTableName := "WRONG Tables" log := log.NewLogger("debug", "") @@ -157,6 +159,7 @@ func TestWrapperErrors(t *testing.T) { ManifestDataTablename: manifestDataTablename, IndexDataTablename: indexDataTablename, UserDataTablename: userDataTablename, + APIKeyTablename: apiKeyTablename, VersionTablename: versionTablename, } client, err := dynamo.GetDynamoClient(params) //nolint:contextcheck @@ -168,6 +171,61 @@ func TestWrapperErrors(t *testing.T) { So(dynamoWrapper.ResetManifestDataTable(), ShouldBeNil) //nolint:contextcheck So(dynamoWrapper.ResetRepoMetaTable(), ShouldBeNil) //nolint:contextcheck + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: "test", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + + Convey("SetUserData", func() { + hashKey := "id" + apiKeys := make(map[string]repodb.APIKeyDetails) + apiKeyDetails := repodb.APIKeyDetails{ + Label: "apiKey", + Scopes: []string{"repo"}, + UUID: hashKey, + } + + apiKeys[hashKey] = apiKeyDetails + + userProfileSrc := repodb.UserData{ + Groups: []string{"group1", "group2"}, + APIKeys: apiKeys, + } + + err := dynamoWrapper.SetUserData(ctx, userProfileSrc) + So(err, ShouldBeNil) + + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: "", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + + err = dynamoWrapper.SetUserData(ctx, repodb.UserData{}) //nolint: contextcheck + So(err, ShouldNotBeNil) + }) + + Convey("DeleteUserData", func() { + err := dynamoWrapper.DeleteUserData(ctx) + So(err, ShouldBeNil) + + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: "", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + + err = dynamoWrapper.DeleteUserData(ctx) //nolint: contextcheck + So(err, ShouldNotBeNil) + }) + Convey("ToggleBookmarkRepo no access", func() { acCtx := localCtx.AccessControlContext{ ReadGlobPatterns: map[string]bool{ @@ -290,17 +348,17 @@ func TestWrapperErrors(t *testing.T) { So(err, ShouldNotBeNil) }) - Convey("GetUserMeta bad context", func() { + Convey("GetUserData bad context", func() { authzCtxKey := localCtx.GetContextKey() ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") - userData, err := dynamoWrapper.GetUserMeta(ctx) + userData, err := dynamoWrapper.GetUserData(ctx) So(err, ShouldNotBeNil) So(userData.BookmarkedRepos, ShouldBeEmpty) So(userData.StarredRepos, ShouldBeEmpty) }) - Convey("GetUserMeta client error", func() { + Convey("GetUserData client error", func() { acCtx := localCtx.AccessControlContext{ ReadGlobPatterns: map[string]bool{ "repo": true, @@ -312,7 +370,7 @@ func TestWrapperErrors(t *testing.T) { dynamoWrapper.UserDataTablename = badTablename - _, err := dynamoWrapper.GetUserMeta(ctx) + _, err := dynamoWrapper.GetUserData(ctx) So(err, ShouldNotBeNil) }) @@ -329,27 +387,155 @@ func TestWrapperErrors(t *testing.T) { err := setBadUserData(dynamoWrapper.Client, userDataTablename, acCtx.Username) So(err, ShouldBeNil) - _, err = dynamoWrapper.GetUserMeta(ctx) + _, err = dynamoWrapper.GetUserData(ctx) So(err, ShouldNotBeNil) }) - Convey("SetUserMeta bad context", func() { + Convey("SetUserData bad context", func() { authzCtxKey := localCtx.GetContextKey() ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") - err := dynamoWrapper.SetUserMeta(ctx, repodb.UserData{}) + err := dynamoWrapper.SetUserData(ctx, repodb.UserData{}) + So(err, ShouldNotBeNil) + }) + + Convey("GetUserData bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + _, err := dynamoWrapper.GetUserData(ctx) + So(err, ShouldNotBeNil) + }) + + Convey("SetUserData bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + err := dynamoWrapper.SetUserData(ctx, repodb.UserData{}) + So(err, ShouldNotBeNil) + }) + + Convey("AddUserAPIKey bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + err := dynamoWrapper.AddUserAPIKey(ctx, "", &repodb.APIKeyDetails{}) + So(err, ShouldNotBeNil) + }) + + Convey("DeleteUserAPIKey bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + err := dynamoWrapper.DeleteUserAPIKey(ctx, "") + So(err, ShouldNotBeNil) + }) + + Convey("UpdateUserAPIKeyLastUsed bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + err := dynamoWrapper.UpdateUserAPIKeyLastUsed(ctx, "") + So(err, ShouldNotBeNil) + }) + + Convey("DeleteUserData bad context errors", func() { + authzCtxKey := localCtx.GetContextKey() + ctx := context.WithValue(context.Background(), authzCtxKey, "bad context") + + err := dynamoWrapper.DeleteUserData(ctx) + So(err, ShouldNotBeNil) + }) + + Convey("DeleteUserAPIKey returns nil", func() { + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: "email", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + + apiKeyDetails := make(map[string]repodb.APIKeyDetails) + apiKeyDetails["id"] = repodb.APIKeyDetails{ + UUID: "id", + } + err := dynamoWrapper.SetUserData(ctx, repodb.UserData{ + APIKeys: apiKeyDetails, + }) + So(err, ShouldBeNil) + + dynamoWrapper.APIKeyTablename = wrongTableName + err = dynamoWrapper.DeleteUserAPIKey(ctx, "id") + So(err, ShouldNotBeNil) + }) + + Convey("AddUserAPIKey", func() { + Convey("no userid found", func() { + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: "", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + + err = dynamoWrapper.AddUserAPIKey(ctx, "key", &repodb.APIKeyDetails{}) + So(err, ShouldNotBeNil) + }) + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: "email", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + + err := dynamoWrapper.AddUserAPIKey(ctx, "key", &repodb.APIKeyDetails{}) + So(err, ShouldBeNil) + + dynamoWrapper.APIKeyTablename = wrongTableName + err = dynamoWrapper.AddUserAPIKey(ctx, "key", &repodb.APIKeyDetails{}) + So(err, ShouldNotBeNil) + }) + + Convey("GetUserAPIKeyInfo", func() { + dynamoWrapper.APIKeyTablename = wrongTableName + _, err := dynamoWrapper.GetUserAPIKeyInfo("key") + So(err, ShouldNotBeNil) + }) + + Convey("GetUserData", func() { + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: "", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + _, err := dynamoWrapper.GetUserData(ctx) + So(err, ShouldNotBeNil) + + acCtx = localCtx.AccessControlContext{ + Username: "email", + } + + ctx = context.WithValue(context.Background(), authzCtxKey, acCtx) + + dynamoWrapper.UserDataTablename = wrongTableName + _, err = dynamoWrapper.GetUserData(ctx) So(err, ShouldNotBeNil) }) Convey("SetManifestData", func() { - dynamoWrapper.ManifestDataTablename = "WRONG tables" + dynamoWrapper.ManifestDataTablename = wrongTableName err := dynamoWrapper.SetManifestData("dig", repodb.ManifestData{}) So(err, ShouldNotBeNil) }) Convey("GetManifestData", func() { - dynamoWrapper.ManifestDataTablename = "WRONG table" + dynamoWrapper.ManifestDataTablename = wrongTableName _, err := dynamoWrapper.GetManifestData("dig") So(err, ShouldNotBeNil) @@ -364,7 +550,7 @@ func TestWrapperErrors(t *testing.T) { }) Convey("GetIndexData", func() { - dynamoWrapper.IndexDataTablename = "WRONG table" + dynamoWrapper.IndexDataTablename = wrongTableName _, err := dynamoWrapper.GetIndexData("dig") So(err, ShouldNotBeNil) @@ -1091,6 +1277,7 @@ func TestWrapperErrors(t *testing.T) { ManifestDataTablename: manifestDataTablename, IndexDataTablename: indexDataTablename, UserDataTablename: userDataTablename, + APIKeyTablename: apiKeyTablename, VersionTablename: versionTablename, } client, err := dynamo.GetDynamoClient(params) @@ -1106,6 +1293,7 @@ func TestWrapperErrors(t *testing.T) { ManifestDataTablename: "", IndexDataTablename: indexDataTablename, UserDataTablename: userDataTablename, + APIKeyTablename: apiKeyTablename, VersionTablename: versionTablename, } client, err = dynamo.GetDynamoClient(params) @@ -1121,6 +1309,7 @@ func TestWrapperErrors(t *testing.T) { ManifestDataTablename: manifestDataTablename, IndexDataTablename: "", UserDataTablename: userDataTablename, + APIKeyTablename: apiKeyTablename, VersionTablename: versionTablename, } client, err = dynamo.GetDynamoClient(params) @@ -1136,6 +1325,7 @@ func TestWrapperErrors(t *testing.T) { ManifestDataTablename: manifestDataTablename, IndexDataTablename: indexDataTablename, UserDataTablename: userDataTablename, + APIKeyTablename: apiKeyTablename, VersionTablename: "", } client, err = dynamo.GetDynamoClient(params) @@ -1150,8 +1340,41 @@ func TestWrapperErrors(t *testing.T) { RepoMetaTablename: repoMetaTablename, ManifestDataTablename: manifestDataTablename, IndexDataTablename: indexDataTablename, - UserDataTablename: "", VersionTablename: versionTablename, + UserDataTablename: userDataTablename, + APIKeyTablename: apiKeyTablename, + } + client, err = dynamo.GetDynamoClient(params) + So(err, ShouldBeNil) + + _, err = dynamoWrapper.NewDynamoDBWrapper(client, params, log) + So(err, ShouldBeNil) + + params = dynamo.DBDriverParameters{ //nolint:contextcheck + Endpoint: endpoint, + Region: region, + RepoMetaTablename: repoMetaTablename, + ManifestDataTablename: manifestDataTablename, + IndexDataTablename: indexDataTablename, + VersionTablename: versionTablename, + UserDataTablename: "", + APIKeyTablename: apiKeyTablename, + } + client, err = dynamo.GetDynamoClient(params) + So(err, ShouldBeNil) + + _, err = dynamoWrapper.NewDynamoDBWrapper(client, params, log) + So(err, ShouldNotBeNil) + + params = dynamo.DBDriverParameters{ //nolint:contextcheck + Endpoint: endpoint, + Region: region, + RepoMetaTablename: repoMetaTablename, + ManifestDataTablename: manifestDataTablename, + IndexDataTablename: indexDataTablename, + VersionTablename: versionTablename, + UserDataTablename: userDataTablename, + APIKeyTablename: "", } client, err = dynamo.GetDynamoClient(params) So(err, ShouldBeNil) @@ -1250,7 +1473,7 @@ func setBadUserData(client *dynamodb.Client, userDataTablename, userID string) e ":UserData": userAttributeValue, }, Key: map[string]types.AttributeValue{ - "UserID": &types.AttributeValueMemberS{ + "Identity": &types.AttributeValueMemberS{ Value: userID, }, }, diff --git a/pkg/meta/repodb/dynamodb-wrapper/dynamo_wrapper.go b/pkg/meta/repodb/dynamodb-wrapper/dynamo_wrapper.go index 302bd3ce..3f89cb76 100644 --- a/pkg/meta/repodb/dynamodb-wrapper/dynamo_wrapper.go +++ b/pkg/meta/repodb/dynamodb-wrapper/dynamo_wrapper.go @@ -30,6 +30,7 @@ var errRepodb = errors.New("repodb: error while constructing manifest meta") type DBWrapper struct { Client *dynamodb.Client + APIKeyTablename string RepoMetaTablename string IndexDataTablename string ManifestDataTablename string @@ -47,6 +48,7 @@ func NewDynamoDBWrapper(client *dynamodb.Client, params dynamo.DBDriverParameter IndexDataTablename: params.IndexDataTablename, VersionTablename: params.VersionTablename, UserDataTablename: params.UserDataTablename, + APIKeyTablename: params.APIKeyTablename, Patches: version.GetDynamoDBPatches(), Log: log, } @@ -76,6 +78,11 @@ func NewDynamoDBWrapper(client *dynamodb.Client, params dynamo.DBDriverParameter return nil, err } + err = dynamoWrapper.createAPIKeyTable() + if err != nil { + return nil, err + } + // Using the Config value, create the DynamoDB client return &dynamoWrapper, nil } @@ -580,13 +587,13 @@ func (dwr *DBWrapper) GetUserRepoMeta(ctx context.Context, repo string) (repodb. return repodb.RepoMetadata{}, err } - userMeta, err := dwr.GetUserMeta(ctx) + userData, err := dwr.GetUserData(ctx) if err != nil { return repodb.RepoMetadata{}, err } - repoMeta.IsBookmarked = zcommon.Contains(userMeta.BookmarkedRepos, repo) - repoMeta.IsStarred = zcommon.Contains(userMeta.StarredRepos, repo) + repoMeta.IsBookmarked = zcommon.Contains(userData.BookmarkedRepos, repo) + repoMeta.IsStarred = zcommon.Contains(userData.StarredRepos, repo) return repoMeta, nil } @@ -1802,7 +1809,7 @@ func (dwr *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) ( return res, zerr.ErrUserDataNotAllowed } - userMeta, err := dwr.GetUserMeta(ctx) + userData, err := dwr.GetUserData(ctx) if err != nil { if errors.Is(err, zerr.ErrUserDataNotFound) { return repodb.NotChanged, nil @@ -1811,16 +1818,16 @@ func (dwr *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) ( return res, err } - if !zcommon.Contains(userMeta.BookmarkedRepos, repo) { - userMeta.BookmarkedRepos = append(userMeta.BookmarkedRepos, repo) + if !zcommon.Contains(userData.BookmarkedRepos, repo) { + userData.BookmarkedRepos = append(userData.BookmarkedRepos, repo) res = repodb.Added } else { - userMeta.BookmarkedRepos = zcommon.RemoveFrom(userMeta.BookmarkedRepos, repo) + userData.BookmarkedRepos = zcommon.RemoveFrom(userData.BookmarkedRepos, repo) res = repodb.Removed } if res != repodb.NotChanged { - err = dwr.SetUserMeta(ctx, userMeta) + err = dwr.SetUserData(ctx, userData) } if err != nil { @@ -1833,9 +1840,9 @@ func (dwr *DBWrapper) ToggleBookmarkRepo(ctx context.Context, repo string) ( } func (dwr *DBWrapper) GetBookmarkedRepos(ctx context.Context) ([]string, error) { - userMeta, err := dwr.GetUserMeta(ctx) + userMeta, err := dwr.GetUserData(ctx) - if errors.Is(err, zerr.ErrUserDataNotFound) { + if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) { return []string{}, nil } @@ -1863,7 +1870,7 @@ func (dwr *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) ( return res, zerr.ErrUserDataNotAllowed } - userData, err := dwr.GetUserMeta(ctx) + userData, err := dwr.GetUserData(ctx) if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) { return res, err } @@ -1902,21 +1909,21 @@ func (dwr *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) ( _, err = dwr.Client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{ TransactItems: []types.TransactWriteItem{ { - // Update User Meta + // Update User Profile Update: &types.Update{ ExpressionAttributeNames: map[string]string{ - "#UM": "UserData", + "#UP": "UserData", }, ExpressionAttributeValues: map[string]types.AttributeValue{ ":UserData": userAttributeValue, }, Key: map[string]types.AttributeValue{ - "UserID": &types.AttributeValueMemberS{ + "Identity": &types.AttributeValueMemberS{ Value: userid, }, }, TableName: aws.String(dwr.UserDataTablename), - UpdateExpression: aws.String("SET #UM = :UserData"), + UpdateExpression: aws.String("SET #UP = :UserData"), }, }, { @@ -1948,64 +1955,27 @@ func (dwr *DBWrapper) ToggleStarRepo(ctx context.Context, repo string) ( } func (dwr *DBWrapper) GetStarredRepos(ctx context.Context) ([]string, error) { - userMeta, err := dwr.GetUserMeta(ctx) + userMeta, err := dwr.GetUserData(ctx) - if errors.Is(err, zerr.ErrUserDataNotFound) { + if errors.Is(err, zerr.ErrUserDataNotFound) || errors.Is(err, zerr.ErrUserDataNotAllowed) { return []string{}, nil } return userMeta.StarredRepos, err } -func (dwr *DBWrapper) GetUserMeta(ctx context.Context) (repodb.UserData, error) { - acCtx, err := localCtx.GetAccessControlContext(ctx) - if err != nil { - return repodb.UserData{}, err - } - - userid := localCtx.GetUsernameFromContext(acCtx) - - if userid == "" { - // empty user is anonymous, it has no data - return repodb.UserData{}, nil - } - - resp, err := dwr.Client.GetItem(ctx, &dynamodb.GetItemInput{ - TableName: aws.String(dwr.UserDataTablename), - Key: map[string]types.AttributeValue{ - "UserID": &types.AttributeValueMemberS{Value: userid}, - }, - }) - if err != nil { - return repodb.UserData{}, err - } - - if resp.Item == nil { - return repodb.UserData{}, zerr.ErrUserDataNotFound - } - - var userMeta repodb.UserData - - err = attributevalue.Unmarshal(resp.Item["UserData"], &userMeta) - if err != nil { - return repodb.UserData{}, err - } - - return userMeta, nil -} - func (dwr *DBWrapper) createUserDataTable() error { _, err := dwr.Client.CreateTable(context.Background(), &dynamodb.CreateTableInput{ TableName: aws.String(dwr.UserDataTablename), AttributeDefinitions: []types.AttributeDefinition{ { - AttributeName: aws.String("UserID"), + AttributeName: aws.String("Identity"), AttributeType: types.ScalarAttributeTypeS, }, }, KeySchema: []types.KeySchemaElement{ { - AttributeName: aws.String("UserID"), + AttributeName: aws.String("Identity"), KeyType: types.KeyTypeHash, }, }, @@ -2019,38 +1989,279 @@ func (dwr *DBWrapper) createUserDataTable() error { return dwr.waitTableToBeCreated(dwr.UserDataTablename) } -func (dwr *DBWrapper) SetUserMeta(ctx context.Context, userMeta repodb.UserData) error { +func (dwr DBWrapper) createAPIKeyTable() error { + _, err := dwr.Client.CreateTable(context.Background(), &dynamodb.CreateTableInput{ + TableName: aws.String(dwr.APIKeyTablename), + AttributeDefinitions: []types.AttributeDefinition{ + { + AttributeName: aws.String("HashedKey"), + AttributeType: types.ScalarAttributeTypeS, + }, + }, + KeySchema: []types.KeySchemaElement{ + { + AttributeName: aws.String("HashedKey"), + KeyType: types.KeyTypeHash, + }, + }, + BillingMode: types.BillingModePayPerRequest, + }) + + if err != nil && !strings.Contains(err.Error(), "Table already exists") { + return err + } + + return dwr.waitTableToBeCreated(dwr.APIKeyTablename) +} + +func (dwr DBWrapper) SetUserGroups(ctx context.Context, groups []string) error { + userData, err := dwr.GetUserData(ctx) + if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) { + return err + } + + userData.Groups = append(userData.Groups, groups...) + + return dwr.SetUserData(ctx, userData) +} + +func (dwr DBWrapper) GetUserGroups(ctx context.Context) ([]string, error) { + userData, err := dwr.GetUserData(ctx) + + return userData.Groups, err +} + +func (dwr DBWrapper) UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error { + userData, err := dwr.GetUserData(ctx) + if err != nil { + return err + } + + apiKeyDetails := userData.APIKeys[hashedKey] + apiKeyDetails.LastUsed = time.Now() + + userData.APIKeys[hashedKey] = apiKeyDetails + + err = dwr.SetUserData(ctx, userData) + + return err +} + +func (dwr DBWrapper) AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error { acCtx, err := localCtx.GetAccessControlContext(ctx) if err != nil { return err } userid := localCtx.GetUsernameFromContext(acCtx) - if userid == "" { - // empty user is anonymous, it has no data + // empty user is anonymous return zerr.ErrUserDataNotAllowed } - userAttributeValue, err := attributevalue.Marshal(userMeta) + userData, err := dwr.GetUserData(ctx) + if err != nil && !errors.Is(err, zerr.ErrUserDataNotFound) { + return fmt.Errorf("repoDB: error while getting userData for identity %s %w", userid, err) + } + + if userData.APIKeys == nil { + userData.APIKeys = make(map[string]repodb.APIKeyDetails) + } + + userData.APIKeys[hashedKey] = *apiKeyDetails + + userAttributeValue, err := attributevalue.Marshal(userData) + if err != nil { + return err + } + + _, err = dwr.Client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{ + TransactItems: []types.TransactWriteItem{ + { + // Update UserData + Update: &types.Update{ + ExpressionAttributeNames: map[string]string{ + "#UP": "UserData", + }, + ExpressionAttributeValues: map[string]types.AttributeValue{ + ":UserData": userAttributeValue, + }, + Key: map[string]types.AttributeValue{ + "Identity": &types.AttributeValueMemberS{ + Value: userid, + }, + }, + TableName: aws.String(dwr.UserDataTablename), + UpdateExpression: aws.String("SET #UP = :UserData"), + }, + }, + { + // Update APIKeyInfo + Update: &types.Update{ + ExpressionAttributeNames: map[string]string{ + "#EM": "Identity", + }, + ExpressionAttributeValues: map[string]types.AttributeValue{ + ":Identity": &types.AttributeValueMemberS{Value: userid}, + }, + Key: map[string]types.AttributeValue{ + "HashedKey": &types.AttributeValueMemberS{ + Value: hashedKey, + }, + }, + TableName: aws.String(dwr.APIKeyTablename), + UpdateExpression: aws.String("SET #EM = :Identity"), + }, + }, + }, + }) + + return err +} + +func (dwr DBWrapper) DeleteUserAPIKey(ctx context.Context, keyID string) error { + userData, err := dwr.GetUserData(ctx) + if err != nil { + return fmt.Errorf("repoDB: error while getting userData %w", err) + } + + for hash, apiKeyDetails := range userData.APIKeys { + if apiKeyDetails.UUID == keyID { + delete(userData.APIKeys, hash) + + _, err = dwr.Client.DeleteItem(ctx, &dynamodb.DeleteItemInput{ + TableName: aws.String(dwr.APIKeyTablename), + Key: map[string]types.AttributeValue{ + "HashedKey": &types.AttributeValueMemberS{Value: hash}, + }, + }) + if err != nil { + return fmt.Errorf("repoDB: error while deleting userAPIKey entry for hash %s %w", hash, err) + } + + err := dwr.SetUserData(ctx, userData) + + return err + } + } + + return nil +} + +func (dwr DBWrapper) GetUserAPIKeyInfo(hashedKey string) (string, error) { + var userid string + + resp, err := dwr.Client.GetItem(context.Background(), &dynamodb.GetItemInput{ + TableName: aws.String(dwr.APIKeyTablename), + Key: map[string]types.AttributeValue{ + "HashedKey": &types.AttributeValueMemberS{Value: hashedKey}, + }, + }) + if err != nil { + return "", err + } + + if resp.Item == nil { + return "", zerr.ErrUserAPIKeyNotFound + } + + err = attributevalue.Unmarshal(resp.Item["Identity"], &userid) + if err != nil { + return "", err + } + + return userid, nil +} + +func (dwr DBWrapper) GetUserData(ctx context.Context) (repodb.UserData, error) { + var userData repodb.UserData + + acCtx, err := localCtx.GetAccessControlContext(ctx) + if err != nil { + return userData, err + } + + userid := localCtx.GetUsernameFromContext(acCtx) + if userid == "" { + // empty user is anonymous + return userData, zerr.ErrUserDataNotAllowed + } + + resp, err := dwr.Client.GetItem(ctx, &dynamodb.GetItemInput{ + TableName: aws.String(dwr.UserDataTablename), + Key: map[string]types.AttributeValue{ + "Identity": &types.AttributeValueMemberS{Value: userid}, + }, + }) + if err != nil { + return repodb.UserData{}, err + } + + if resp.Item == nil { + return repodb.UserData{}, zerr.ErrUserDataNotFound + } + + err = attributevalue.Unmarshal(resp.Item["UserData"], &userData) + if err != nil { + return repodb.UserData{}, err + } + + return userData, nil +} + +func (dwr DBWrapper) SetUserData(ctx context.Context, userData repodb.UserData) error { + acCtx, err := localCtx.GetAccessControlContext(ctx) + if err != nil { + return err + } + + userid := localCtx.GetUsernameFromContext(acCtx) + if userid == "" { + // empty user is anonymous + return zerr.ErrUserDataNotAllowed + } + + userAttributeValue, err := attributevalue.Marshal(userData) if err != nil { return err } _, err = dwr.Client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ ExpressionAttributeNames: map[string]string{ - "#UM": "UserData", + "#UP": "UserData", }, ExpressionAttributeValues: map[string]types.AttributeValue{ ":UserData": userAttributeValue, }, Key: map[string]types.AttributeValue{ - "UserID": &types.AttributeValueMemberS{ + "Identity": &types.AttributeValueMemberS{ Value: userid, }, }, TableName: aws.String(dwr.UserDataTablename), - UpdateExpression: aws.String("SET #UM = :UserData"), + UpdateExpression: aws.String("SET #UP = :UserData"), + }) + + return err +} + +func (dwr DBWrapper) DeleteUserData(ctx context.Context) error { + acCtx, err := localCtx.GetAccessControlContext(ctx) + if err != nil { + return err + } + + userid := localCtx.GetUsernameFromContext(acCtx) + if userid == "" { + // empty user is anonymous + return zerr.ErrUserDataNotAllowed + } + + _, err = dwr.Client.DeleteItem(ctx, &dynamodb.DeleteItemInput{ + TableName: aws.String(dwr.UserDataTablename), + Key: map[string]types.AttributeValue{ + "Identity": &types.AttributeValueMemberS{Value: userid}, + }, }) return err diff --git a/pkg/meta/repodb/repodb.go b/pkg/meta/repodb/repodb.go index f5776b3f..17cc419c 100644 --- a/pkg/meta/repodb/repodb.go +++ b/pkg/meta/repodb/repodb.go @@ -24,6 +24,7 @@ type ( ) type RepoDB interface { //nolint:interfacebloat + UserDB // IncrementRepoStars adds 1 to the star count of an image IncrementRepoStars(repo string) error @@ -111,6 +112,10 @@ type RepoDB interface { //nolint:interfacebloat FilterTags(ctx context.Context, filter FilterFunc, requestedPage PageInput) ([]RepoMetadata, map[string]ManifestMetadata, map[string]IndexData, common.PageInfo, error) + PatchDB() error +} + +type UserDB interface { //nolint:interfacebloat // GetStarredRepos returns starred repos and takes current user in consideration GetStarredRepos(ctx context.Context) ([]string, error) @@ -123,7 +128,24 @@ type RepoDB interface { //nolint:interfacebloat // ToggleBookmarkRepo adds/removes bookmarks on repos ToggleBookmarkRepo(ctx context.Context, reponame string) (ToggleState, error) - PatchDB() error + // UserDB profile/api key CRUD + GetUserData(ctx context.Context) (UserData, error) + + SetUserData(ctx context.Context, userData UserData) error + + SetUserGroups(ctx context.Context, groups []string) error + + GetUserGroups(ctx context.Context) ([]string, error) + + DeleteUserData(ctx context.Context) error + + GetUserAPIKeyInfo(hashedKey string) (identity string, err error) + + AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *APIKeyDetails) error + + UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error + + DeleteUserAPIKey(ctx context.Context, id string) error } type ManifestMetadata struct { @@ -195,12 +217,6 @@ type SignatureMetadata struct { LayersInfo []LayerInfo } -type UserData struct { - // data for each user. - StarredRepos []string - BookmarkedRepos []string -} - type SortCriteria string const ( @@ -235,3 +251,20 @@ type FilterData struct { IsStarred bool IsBookmarked bool } + +type UserData struct { + StarredRepos []string + BookmarkedRepos []string + Groups []string + APIKeys map[string]APIKeyDetails +} + +type APIKeyDetails struct { + CreatedAt time.Time `json:"createdAt"` + CreatorUA string `json:"creatorUa"` + GeneratedBy string `json:"generatedBy"` + LastUsed time.Time `json:"lastUsed"` + Label string `json:"label"` + Scopes []string `json:"scopes"` + UUID string `json:"uuid"` +} diff --git a/pkg/meta/repodb/repodb_test.go b/pkg/meta/repodb/repodb_test.go index 48c98898..4ac1498a 100644 --- a/pkg/meta/repodb/repodb_test.go +++ b/pkg/meta/repodb/repodb_test.go @@ -92,6 +92,7 @@ func TestDynamoDBWrapper(t *testing.T) { versionTablename := "Version" + uuid.String() indexDataTablename := "IndexDataTable" + uuid.String() userDataTablename := "UserDataTable" + uuid.String() + apiKeyTablename := "ApiKeyTable" + uuid.String() Convey("DynamoDB Wrapper", t, func() { dynamoDBDriverParams := dynamo.DBDriverParameters{ @@ -101,6 +102,7 @@ func TestDynamoDBWrapper(t *testing.T) { IndexDataTablename: indexDataTablename, VersionTablename: versionTablename, UserDataTablename: userDataTablename, + APIKeyTablename: apiKeyTablename, Region: "us-east-2", } @@ -137,6 +139,118 @@ func RunRepoDBTests(t *testing.T, repoDB repodb.RepoDB, preparationFuncs ...func So(err, ShouldBeNil) } + Convey("Test CRUD operations on UserData and API keys", func() { + hashKey1 := "id" + hashKey2 := "key" + apiKeys := make(map[string]repodb.APIKeyDetails) + apiKeyDetails := repodb.APIKeyDetails{ + Label: "apiKey", + Scopes: []string{"repo"}, + UUID: hashKey1, + } + + apiKeys[hashKey1] = apiKeyDetails + + userProfileSrc := repodb.UserData{ + Groups: []string{"group1", "group2"}, + APIKeys: apiKeys, + } + + authzCtxKey := localCtx.GetContextKey() + + acCtx := localCtx.AccessControlContext{ + Username: "test", + } + + ctx := context.WithValue(context.Background(), authzCtxKey, acCtx) + + err := repoDB.AddUserAPIKey(ctx, hashKey1, &apiKeyDetails) + So(err, ShouldBeNil) + + err = repoDB.SetUserData(ctx, userProfileSrc) + So(err, ShouldBeNil) + + userProfile, err := repoDB.GetUserData(ctx) + So(err, ShouldBeNil) + So(userProfile.Groups, ShouldResemble, userProfileSrc.Groups) + So(userProfile.APIKeys, ShouldContainKey, hashKey1) + So(userProfile.APIKeys[hashKey1].Label, ShouldEqual, apiKeyDetails.Label) + So(userProfile.APIKeys[hashKey1].Scopes, ShouldResemble, apiKeyDetails.Scopes) + + lastUsed := userProfile.APIKeys[hashKey1].LastUsed + + err = repoDB.UpdateUserAPIKeyLastUsed(ctx, hashKey1) + So(err, ShouldBeNil) + + userProfile, err = repoDB.GetUserData(ctx) + So(err, ShouldBeNil) + So(userProfile.APIKeys[hashKey1].LastUsed, ShouldHappenAfter, lastUsed) + + userGroups, err := repoDB.GetUserGroups(ctx) + So(err, ShouldBeNil) + So(userGroups, ShouldResemble, userProfileSrc.Groups) + + apiKeyDetails.UUID = hashKey2 + err = repoDB.AddUserAPIKey(ctx, hashKey2, &apiKeyDetails) + So(err, ShouldBeNil) + + userProfile, err = repoDB.GetUserData(ctx) + So(err, ShouldBeNil) + So(userProfile.Groups, ShouldResemble, userProfileSrc.Groups) + So(userProfile.APIKeys, ShouldContainKey, hashKey2) + So(userProfile.APIKeys[hashKey2].Label, ShouldEqual, apiKeyDetails.Label) + So(userProfile.APIKeys[hashKey2].Scopes, ShouldResemble, apiKeyDetails.Scopes) + + email, err := repoDB.GetUserAPIKeyInfo(hashKey2) + So(err, ShouldBeNil) + So(email, ShouldEqual, "test") + + err = repoDB.DeleteUserAPIKey(ctx, hashKey1) + So(err, ShouldBeNil) + + userProfile, err = repoDB.GetUserData(ctx) + So(err, ShouldBeNil) + So(len(userProfile.APIKeys), ShouldEqual, 1) + So(userProfile.APIKeys, ShouldNotContainKey, hashKey1) + + err = repoDB.DeleteUserAPIKey(ctx, hashKey2) + So(err, ShouldBeNil) + + userProfile, err = repoDB.GetUserData(ctx) + So(err, ShouldBeNil) + So(len(userProfile.APIKeys), ShouldEqual, 0) + So(userProfile.APIKeys, ShouldNotContainKey, hashKey2) + + // delete non existent api key + err = repoDB.DeleteUserAPIKey(ctx, hashKey2) + So(err, ShouldBeNil) + + err = repoDB.DeleteUserData(ctx) + So(err, ShouldBeNil) + + email, err = repoDB.GetUserAPIKeyInfo(hashKey2) + So(err, ShouldNotBeNil) + So(email, ShouldBeEmpty) + + email, err = repoDB.GetUserAPIKeyInfo(hashKey1) + So(err, ShouldNotBeNil) + So(email, ShouldBeEmpty) + + _, err = repoDB.GetUserData(ctx) + So(err, ShouldNotBeNil) + + userGroups, err = repoDB.GetUserGroups(ctx) + So(err, ShouldNotBeNil) + So(userGroups, ShouldBeEmpty) + + err = repoDB.SetUserGroups(ctx, userProfileSrc.Groups) + So(err, ShouldBeNil) + + userGroups, err = repoDB.GetUserGroups(ctx) + So(err, ShouldBeNil) + So(userGroups, ShouldResemble, userProfileSrc.Groups) + }) + Convey("Test SetManifestData and GetManifestData", func() { configBlob, manifestBlob, err := generateTestImage() So(err, ShouldBeNil) diff --git a/pkg/meta/repodb/repodbfactory/repodb_factory.go b/pkg/meta/repodb/repodbfactory/repodb_factory.go index b4fd878f..cd270702 100644 --- a/pkg/meta/repodb/repodbfactory/repodb_factory.go +++ b/pkg/meta/repodb/repodbfactory/repodb_factory.go @@ -95,6 +95,9 @@ func getDynamoParams(cacheDriverConfig map[string]interface{}, log log.Logger) d indexDataTablename, ok := toStringIfOk(cacheDriverConfig, "indexdatatablename", log) allParametersOk = allParametersOk && ok + apiKeyTablename, ok := toStringIfOk(cacheDriverConfig, "apikeytablename", log) + allParametersOk = allParametersOk && ok + versionTablename, ok := toStringIfOk(cacheDriverConfig, "versiontablename", log) allParametersOk = allParametersOk && ok @@ -112,6 +115,7 @@ func getDynamoParams(cacheDriverConfig map[string]interface{}, log log.Logger) d ManifestDataTablename: manifestDataTablename, IndexDataTablename: indexDataTablename, UserDataTablename: userDataTablename, + APIKeyTablename: apiKeyTablename, VersionTablename: versionTablename, } } diff --git a/pkg/meta/repodb/repodbfactory/repodb_factory_test.go b/pkg/meta/repodb/repodbfactory/repodb_factory_test.go index b3930207..ee8222bd 100644 --- a/pkg/meta/repodb/repodbfactory/repodb_factory_test.go +++ b/pkg/meta/repodb/repodbfactory/repodb_factory_test.go @@ -25,6 +25,7 @@ func TestCreateDynamo(t *testing.T) { ManifestDataTablename: "ManifestDataTable", IndexDataTablename: "IndexDataTable", UserDataTablename: "UserDataTable", + APIKeyTablename: "ApiKeyTable", VersionTablename: "Version", Region: "us-east-2", } diff --git a/pkg/meta/repodb/storage_parsing_test.go b/pkg/meta/repodb/storage_parsing_test.go index 4b5e289d..f6516efc 100644 --- a/pkg/meta/repodb/storage_parsing_test.go +++ b/pkg/meta/repodb/storage_parsing_test.go @@ -390,6 +390,7 @@ func TestParseStorageDynamoWrapper(t *testing.T) { ManifestDataTablename: "ManifestDataTable", IndexDataTablename: "IndexDataTable", UserDataTablename: "UserDataTable", + APIKeyTablename: "ApiKeyTable", VersionTablename: "Version", } diff --git a/pkg/meta/version/version_test.go b/pkg/meta/version/version_test.go index c0f0e1a0..12ee2ba0 100644 --- a/pkg/meta/version/version_test.go +++ b/pkg/meta/version/version_test.go @@ -127,6 +127,7 @@ func TestVersioningDynamoDB(t *testing.T) { ManifestDataTablename: "ManifestDataTable", IndexDataTablename: "IndexDataTable", UserDataTablename: "UserDataTable", + APIKeyTablename: "ApiKeyTable", VersionTablename: "Version", } diff --git a/pkg/requestcontext/context.go b/pkg/requestcontext/context.go index 74062e44..685f6cec 100644 --- a/pkg/requestcontext/context.go +++ b/pkg/requestcontext/context.go @@ -92,3 +92,29 @@ func (acCtx *AccessControlContext) matchesRepo(globPatterns map[string]bool, rep return allowed } + +// request-local context key. +var amwCtxKey = Key(1) //nolint: gochecknoglobals + +// pointer needed for use in context.WithValue. +func GetAuthnMiddlewareCtxKey() *Key { + return &amwCtxKey +} + +type AuthnMiddlewareContext struct { + AuthnType string +} + +func GetAuthnMiddlewareContext(ctx context.Context) (*AuthnMiddlewareContext, error) { + authnMiddlewareCtxKey := GetAuthnMiddlewareCtxKey() + if authnMiddlewareCtx := ctx.Value(authnMiddlewareCtxKey); authnMiddlewareCtx != nil { + amCtx, ok := authnMiddlewareCtx.(AuthnMiddlewareContext) + if !ok { + return nil, errors.ErrBadType + } + + return &amCtx, nil + } + + return nil, nil //nolint: nilnil +} diff --git a/pkg/test/common.go b/pkg/test/common.go index 409eb0d0..ca94b2a5 100644 --- a/pkg/test/common.go +++ b/pkg/test/common.go @@ -15,6 +15,7 @@ import ( "log" "math" "math/big" + "net" "net/http" "net/url" "os" @@ -38,6 +39,7 @@ import ( ispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/opencontainers/umoci" "github.com/phayes/freeport" + "github.com/project-zot/mockoidc" "github.com/sigstore/cosign/v2/cmd/cosign/cli/generate" "github.com/sigstore/cosign/v2/cmd/cosign/cli/options" "github.com/sigstore/cosign/v2/cmd/cosign/cli/sign" @@ -1967,3 +1969,55 @@ func GetIndexBlobWithManifests(manifestDigests []godigest.Digest) ([]byte, error return json.Marshal(indexContent) } + +func MockOIDCRun() (*mockoidc.MockOIDC, error) { + // Create a fresh RSA Private Key for token signing + rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) //nolint: gomnd + + // Create an unstarted MockOIDC server + mockServer, _ := mockoidc.NewServer(rsaKey) + + // Create the net.Listener, kernel will chose a valid port + listener, _ := net.Listen("tcp", "127.0.0.1:0") + + bearerMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(response http.ResponseWriter, req *http.Request) { + // stateVal := req.Form.Get("state") + header := req.Header.Get("Authorization") + parts := strings.SplitN(header, " ", 2) //nolint: gomnd + if header != "" { + if strings.ToLower(parts[0]) == "bearer" { + req.Header.Set("Authorization", strings.Join([]string{"Bearer", parts[1]}, " ")) + } + } + + next.ServeHTTP(response, req) + }) + } + + err := mockServer.AddMiddleware(bearerMiddleware) + if err != nil { + return mockServer, err + } + // tlsConfig can be nil if you want HTTP + return mockServer, mockServer.Start(listener, nil) +} + +func CustomRedirectPolicy(noOfRedirect int) resty.RedirectPolicy { + return resty.RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { + if len(via) >= noOfRedirect { + return fmt.Errorf("stopped after %d redirects", noOfRedirect) //nolint: goerr113 + } + + for key, val := range via[len(via)-1].Header { + req.Header[key] = val + } + + respCookies := req.Response.Cookies() + for _, cookie := range respCookies { + req.AddCookie(cookie) + } + + return nil + }) +} diff --git a/pkg/test/mocks/repo_db_mock.go b/pkg/test/mocks/repo_db_mock.go index 3f53b6cc..99cd9512 100644 --- a/pkg/test/mocks/repo_db_mock.go +++ b/pkg/test/mocks/repo_db_mock.go @@ -95,6 +95,24 @@ type RepoDBMock struct { ToggleBookmarkRepoFn func(ctx context.Context, repo string) (repodb.ToggleState, error) + GetUserDataFn func(ctx context.Context) (repodb.UserData, error) + + SetUserDataFn func(ctx context.Context, userProfile repodb.UserData) error + + SetUserGroupsFn func(ctx context.Context, groups []string) error + + GetUserGroupsFn func(ctx context.Context) ([]string, error) + + DeleteUserDataFn func(ctx context.Context) error + + GetUserAPIKeyInfoFn func(hashedKey string) (string, error) + + AddUserAPIKeyFn func(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error + + UpdateUserAPIKeyLastUsedFn func(ctx context.Context, hashedKey string) error + + DeleteUserAPIKeyFn func(ctx context.Context, id string) error + PatchDBFn func() error } @@ -414,3 +432,75 @@ func (sdm RepoDBMock) ToggleBookmarkRepo(ctx context.Context, repo string) (repo return repodb.NotChanged, nil } + +func (sdm RepoDBMock) GetUserData(ctx context.Context) (repodb.UserData, error) { + if sdm.GetUserDataFn != nil { + return sdm.GetUserDataFn(ctx) + } + + return repodb.UserData{}, nil +} + +func (sdm RepoDBMock) SetUserData(ctx context.Context, userProfile repodb.UserData) error { + if sdm.SetUserDataFn != nil { + return sdm.SetUserDataFn(ctx, userProfile) + } + + return nil +} + +func (sdm RepoDBMock) SetUserGroups(ctx context.Context, groups []string) error { + if sdm.SetUserGroupsFn != nil { + return sdm.SetUserGroupsFn(ctx, groups) + } + + return nil +} + +func (sdm RepoDBMock) GetUserGroups(ctx context.Context) ([]string, error) { + if sdm.GetUserGroupsFn != nil { + return sdm.GetUserGroupsFn(ctx) + } + + return []string{}, nil +} + +func (sdm RepoDBMock) DeleteUserData(ctx context.Context) error { + if sdm.DeleteUserDataFn != nil { + return sdm.DeleteUserDataFn(ctx) + } + + return nil +} + +func (sdm RepoDBMock) GetUserAPIKeyInfo(hashedKey string) (string, error) { + if sdm.GetUserAPIKeyInfoFn != nil { + return sdm.GetUserAPIKeyInfoFn(hashedKey) + } + + return "", nil +} + +func (sdm RepoDBMock) AddUserAPIKey(ctx context.Context, hashedKey string, apiKeyDetails *repodb.APIKeyDetails) error { + if sdm.AddUserAPIKeyFn != nil { + return sdm.AddUserAPIKeyFn(ctx, hashedKey, apiKeyDetails) + } + + return nil +} + +func (sdm RepoDBMock) UpdateUserAPIKeyLastUsed(ctx context.Context, hashedKey string) error { + if sdm.UpdateUserAPIKeyLastUsedFn != nil { + return sdm.UpdateUserAPIKeyLastUsedFn(ctx, hashedKey) + } + + return nil +} + +func (sdm RepoDBMock) DeleteUserAPIKey(ctx context.Context, id string) error { + if sdm.DeleteUserAPIKeyFn != nil { + return sdm.DeleteUserAPIKeyFn(ctx, id) + } + + return nil +} diff --git a/test/blackbox/cloud-only.bats b/test/blackbox/cloud-only.bats index d4b03c9b..6429100a 100644 --- a/test/blackbox/cloud-only.bats +++ b/test/blackbox/cloud-only.bats @@ -39,12 +39,34 @@ function setup() { "manifestDataTablename": "ManifestDataTable", "indexDataTablename": "IndexDataTable", "userDataTablename": "UserDataTable", + "apiKeyTablename":"ApiKeyTable", "versionTablename": "Version" } }, "http": { "address": "127.0.0.1", - "port": "8080" + "port": "8080", + "realm": "zot", + "auth": { + "openid": { + "providers": { + "dex": { + "issuer": "http://127.0.0.1:5556/dex", + "clientid": "zot-client", + "clientsecret": "ZXhhbXBsZS1hcHAtc2VjcmV0", + "scopes": ["openid", "email", "groups"] + } + } + }, + "failDelay": 5 + }, + "accessControl": { + "repositories": { + "**": { + "anonymousPolicy": ["read", "create"] + } + } + } }, "log": { "level": "debug" @@ -80,6 +102,17 @@ function teardown() { awslocal dynamodb --region "us-east-2" delete-table --table-name "BlobTable" } +dex_session () { + STATE=$(curl -L -f -s http://localhost:8080/openid/auth/login?provider=dex | grep -m 1 -oP '(?<=state=)[^ ]*"' | cut -d \" -f1) + echo $STATE >&3 + curl -L -f -s "http://127.0.0.1:5556/dex/auth/mock?client_id=zot-client&redirect_uri=http%3A%2F%2F127.0.0.1%3A8080%2Fopenid%2Fauth%2Fcallback%2Fdex&response_type=code&scope=profile+email+groups+openid&state=$STATE" +} + +@test "check dex is working" { + run dex_session + [ "$status" -eq 0 ] +} + @test "check for local disk writes" { run skopeo --insecure-policy copy --dest-tls-verify=false \ docker://centos:centos8 docker://localhost:8080/centos:8 diff --git a/test/dex/config-dev.yaml b/test/dex/config-dev.yaml new file mode 100644 index 00000000..d6eb3634 --- /dev/null +++ b/test/dex/config-dev.yaml @@ -0,0 +1,28 @@ +issuer: http://127.0.0.1:5556/dex + +storage: + type: sqlite3 + config: + file: dex.db + +web: + http: 127.0.0.1:5556 + +telemetry: + http: 127.0.0.1:5558 + +grpc: + addr: 127.0.0.1:5557 + +staticClients: + - id: zot-client + redirectURIs: + - 'http://127.0.0.1:8080/openid/auth/callback/dex' + name: 'zot' + secret: ZXhhbXBsZS1hcHAtc2VjcmV0 + +connectors: + - type: mockCallback + id: mock + name: Example +enablePasswordDB: true