Files
zot/pkg/cel/claim_processor.go
T

262 lines
6.8 KiB
Go

package cel
import (
"context"
"fmt"
"slices"
"strings"
"github.com/google/cel-go/common/types"
zerr "zotregistry.dev/zot/v2/errors"
"zotregistry.dev/zot/v2/pkg/api/config"
)
// defaultUsernameExpr is the default CEL expression for extracting the username from OIDC claims.
const defaultUsernameExpr = "claims.iss + '/' + claims.sub"
// ClaimResult holds the result of processing OIDC claims.
type ClaimResult struct {
Username string
Groups []string
// Claims is the raw OIDC claim set. Carried through so authorization-time
// CEL expressions can reference token claims directly via `req.claims`.
Claims map[string]any
}
// ClaimProcessor processes OIDC claims using CEL expressions.
// It validates and maps claims to Zot identities.
type ClaimProcessor struct {
variables []variable
validations []validation
audiences []string
username *Expression
groups *Expression
}
// variable contains a compiled CEL expression for extracting
// a variable from OIDC claims.
type variable struct {
name string
expr *Expression
}
// validation contains a compiled CEL expression for validating
// OIDC claims.
type validation struct {
expr *Expression
msg string
}
// NewClaimProcessor creates a new ClaimProcessor.
func NewClaimProcessor(audiences []string, conf *config.CELClaimValidationAndMapping) (*ClaimProcessor, error) {
// Sanitize and validate audiences.
audiences = slices.Clone(audiences)
if len(audiences) == 0 {
return nil, zerr.ErrOIDCNoAudiences
}
for i := range audiences {
audiences[i] = strings.TrimSpace(audiences[i])
if audiences[i] == "" {
return nil, fmt.Errorf("audience[%d]: %w", i, zerr.ErrOIDCEmptyAudience)
}
}
// Apply defaults.
if conf == nil {
conf = &config.CELClaimValidationAndMapping{
Username: defaultUsernameExpr,
}
}
if conf.Username == "" {
conf.Username = defaultUsernameExpr
}
// Parse variable expressions.
variables := make([]variable, 0, len(conf.Variables))
for i, varConf := range conf.Variables {
if varConf.Name == "" {
return nil, fmt.Errorf("variable[%d]: %w", i, zerr.ErrOIDCEmptyVariableName)
}
expr, err := NewExpression(varConf.Expression,
WithCompile(),
WithStructVariables("claims", "vars"))
if err != nil {
return nil, fmt.Errorf("failed to parse CEL expression for variable[%d] (name: %s): %w",
i, varConf.Name, err)
}
variables = append(variables, variable{
name: varConf.Name,
expr: expr,
})
}
// Parse validation expressions.
validations := make([]validation, 0, len(conf.Validations))
for i, valConf := range conf.Validations {
if valConf.Message == "" {
return nil, fmt.Errorf("validation[%d]: %w", i, zerr.ErrOIDCEmptyValidationMsg)
}
expr, err := NewExpression(valConf.Expression,
WithCompile(),
WithStructVariables("claims", "vars"),
WithOutputType(types.BoolType))
if err != nil {
return nil, fmt.Errorf("failed to parse CEL expression for validation[%d]: %w", i, err)
}
validations = append(validations, validation{
expr: expr,
msg: valConf.Message,
})
}
// Parse username expression.
username, err := NewExpression(conf.Username,
WithCompile(),
WithStructVariables("claims", "vars"))
if err != nil {
return nil, fmt.Errorf("failed to parse CEL expression for username: %w", err)
}
// Parse groups expression if provided.
var groups *Expression
if conf.Groups != "" {
groups, err = NewExpression(conf.Groups,
WithCompile(),
WithStructVariables("claims", "vars"))
if err != nil {
return nil, fmt.Errorf("failed to parse CEL expression for groups: %w", err)
}
}
return &ClaimProcessor{
variables: variables,
validations: validations,
audiences: audiences,
username: username,
groups: groups,
}, nil
}
// Process processes the OIDC claims applying all validations, including CEL expressions
// and audiences, and returns the mapped username and groups.
func (c *ClaimProcessor) Process(ctx context.Context, claims map[string]any) (*ClaimResult, error) {
// First, validate the audience.
if err := c.validateAudience(claims); err != nil {
return nil, err
}
// Next, we extract variables. The process is iterative:
// variable expressions can refer to both the claims and
// previously extracted variables.
vars := make(map[string]any)
data := map[string]any{
"vars": vars,
"claims": claims,
}
for i := range c.variables {
celVar := c.variables[i]
val, err := celVar.expr.Evaluate(ctx, data)
if err != nil {
return nil, fmt.Errorf("failed to evaluate variable '%s': %w", celVar.name, err)
}
vars[celVar.name] = val
}
// Next, we run validations. If any validation fails, we
// return an error. Validations can refer to both claims
// and the extracted variables.
for i := range c.validations {
celVal := c.validations[i]
val, err := celVal.expr.EvaluateBoolean(ctx, data)
if err != nil {
return nil, fmt.Errorf("failed to evaluate validation expression: %w", err)
}
if !val {
return nil, fmt.Errorf("%w: %s", zerr.ErrOIDCValidationFailed, celVal.msg)
}
}
// Next, we extract the username. It can refer to both
// claims and the extracted variables.
username, err := c.username.EvaluateString(ctx, data)
if err != nil {
return nil, fmt.Errorf("failed to evaluate username expression: %w", err)
}
// Finally, we extract groups if a groups expression is provided.
// It can refer to both claims and the extracted variables.
var groups []string
if c.groups != nil {
groups, err = c.groups.EvaluateStringSlice(ctx, data)
if err != nil {
return nil, fmt.Errorf("failed to evaluate groups expression: %w", err)
}
}
return &ClaimResult{
Username: username,
Groups: groups,
Claims: claims,
}, nil
}
// validateAudience checks if the provided audiences contain at least one of the expected audiences.
func (c *ClaimProcessor) validateAudience(claims map[string]any) error {
audiencesValue, ok := claims["aud"]
if !ok {
return fmt.Errorf("%w: missing 'aud' claim", zerr.ErrOIDCNoAudiences)
}
audiences := make(map[string]struct{})
if audiencesAnySlice, ok := audiencesValue.([]any); ok {
for _, audValue := range audiencesAnySlice {
aud, ok := audValue.(string)
if !ok {
return fmt.Errorf("%w: 'aud' claim contains non-string value", zerr.ErrOIDCInvalidAudiences)
}
audiences[aud] = struct{}{}
}
}
if audiencesStringSlice, ok := audiencesValue.([]string); ok {
for _, aud := range audiencesStringSlice {
audiences[aud] = struct{}{}
}
}
if audiencesString, ok := audiencesValue.(string); ok {
audiences[audiencesString] = struct{}{}
}
hasAudience := false
for _, aud := range c.audiences {
if _, ok := audiences[aud]; ok {
hasAudience = true
break
}
}
if !hasAudience {
return fmt.Errorf("%w: token=%v, expected=%v", zerr.ErrOIDCAudienceMismatch, audiences, c.audiences)
}
return nil
}