diff --git a/pkg/api/routes.go b/pkg/api/routes.go index 6ae327ca..d24226b1 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -17,6 +17,7 @@ import ( "net/url" "path" "regexp" + "slices" "sort" "strconv" "strings" @@ -1992,13 +1993,25 @@ func (rh *RouteHandler) OpenIDCodeExchangeCallback() rp.CodeExchangeUserinfoCall val, ok := info.Claims["groups"].([]interface{}) if !ok { - rh.c.Log.Info().Msgf("failed to find any 'groups' claim for user %s", email) + rh.c.Log.Info().Msgf("failed to find any 'groups' claim for user %s in UserInfo", email) } for _, group := range val { groups = append(groups, fmt.Sprint(group)) } + val, ok = tokens.IDTokenClaims.Claims["groups"].([]interface{}) + if !ok { + rh.c.Log.Info().Msgf("failed to find any 'groups' claim for user %s in IDTokenClaimsToken", email) + } + + for _, group := range val { + groups = append(groups, fmt.Sprint(group)) + } + + slices.Sort(groups) + groups = slices.Compact(groups) + callbackUI, err := OAuth2Callback(rh.c, w, r, state, email, groups) if err != nil { if errors.Is(err, zerr.ErrInvalidStateCookie) { diff --git a/pkg/api/routes_test.go b/pkg/api/routes_test.go index 68f9c3ef..2a218bb5 100644 --- a/pkg/api/routes_test.go +++ b/pkg/api/routes_test.go @@ -117,6 +117,40 @@ func TestRoutes(t *testing.T) { So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) }) + Convey("Test OpenIDCodeExchangeCallback", func() { + callback := rthdlr.OpenIDCodeExchangeCallback() + ctx := context.TODO() + + request, _ := http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) + response := httptest.NewRecorder() + + tokens := &oidc.Tokens[*oidc.IDTokenClaims]{ + IDTokenClaims: &oidc.IDTokenClaims{ + Claims: map[string]any{ + "groups": []interface{}{"group1", "group3"}, + }, + }, + } + relyingParty, err := rp.NewRelyingPartyOAuth(&oauth2.Config{}) + So(err, ShouldBeNil) + + userinfo := &oidc.UserInfo{ + Subject: "sub", + Claims: map[string]any{ + "email": "test@test.com", + "groups": []interface{}{"group1", "group2"}, + }, + UserInfoEmail: oidc.UserInfoEmail{Email: "test@test.com"}, + } + + callback(response, request, tokens, "state", relyingParty, userinfo) + + resp := response.Result() + defer resp.Body.Close() + So(resp, ShouldNotBeNil) + So(resp.StatusCode, ShouldEqual, http.StatusUnauthorized) + }) + Convey("Test OAuth2Callback errors", func() { ctx := context.TODO()