JWT Security with Azure in Spring Boot

April 2, 2020

I wanted to have my users authenticate with Azure, and create a JWT which then gets passed into each subsequent request to identify them.

First up we have a filter which we add to the request chain. This filter: Picks out a correlation id from the x-correlation-id request header and stores this using MDC fpr later logging Checks the JWT token provided in the authorisation header and if it is present, validates it * OPTIONS requests are ignored

package com.drumcoder.diary.security;

import java.io.IOException;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.slf4j.MDC;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.filter.GenericFilterBean;

public class JwtTokenFilter extends GenericFilterBean {
    private static final String CORRELATION_ID_HEADER_NAME = "x-correlation-id";

    private final JwtTokenProvider jwtTokenProvider;

    public JwtTokenFilter(JwtTokenProvider jwtTokenProvider) {
        this.jwtTokenProvider = jwtTokenProvider;
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        // Store correlation id for use in logging
        String correlationId = ((HttpServletRequest) (request)).getHeader(JwtTokenFilter.CORRELATION_ID_HEADER_NAME);
        if (correlationId == null) {
            correlationId = "";
        }
        MDC.put("correlation-id", correlationId);

        // Check for JWT
        String token = null;

        HttpServletRequest httpRequest = (HttpServletRequest) request;
        if (!"OPTIONS".equals(httpRequest.getMethod())) {
            token = this.jwtTokenProvider.resolveToken((HttpServletRequest) request);
        }

        if (token == null) {
            // There is no token - pass request down chain
            chain.doFilter(request, response);
        } else {
            // There is a token
            try {
                this.jwtTokenProvider.validateToken(token);

                // Token is valid, set auth context
                final Authentication auth = this.jwtTokenProvider.getAuthentication(token);
                SecurityContextHolder.getContext().setAuthentication(auth);
                // proceed to next filter in chain
                chain.doFilter(request, response);

            } catch (final InvalidJwtAuthenticationException ex) {
                // Token failed validation
                ((HttpServletResponse) response).setStatus(HttpStatus.FORBIDDEN.value());
                final String message = String.format(ExConstants.JSON_EXCEPTION_STRUCTURE, 
                                                  ExConstants.SECURITY_003_TOKEN_INVALID, 
                                                  ex.getMessage(), ex.getClass().getName());
                response.getWriter().write(message);
            }
        }
    }
}

This is added to the filter chain using

package com.drumcoder.diary.security;

import org.springframework.security.config.annotation.SecurityConfigurerAdapter;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;

public class JwtConfigurer extends SecurityConfigurerAdapter<DefaultSecurityFilterChain, HttpSecurity> {
    private JwtTokenProvider jwtTokenProvider;

    public JwtConfigurer(JwtTokenProvider jwtTokenProvider) {
        this.jwtTokenProvider = jwtTokenProvider;
    }

    @Override
    public void configure(HttpSecurity http) throws Exception {
        JwtTokenFilter customFilter = new JwtTokenFilter(this.jwtTokenProvider);
        http.addFilterBefore(customFilter, UsernamePasswordAuthenticationFilter.class);
    }
}

The filter makes use of the JwtTokenProvider class to validate the token

package com.drumcoder.diary.security;

import java.text.ParseException;
import java.time.LocalDateTime;
import java.time.OffsetDateTime;

import javax.servlet.http.HttpServletRequest;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.MessageSource;
import org.springframework.security.core.Authentication;
import org.springframework.stereotype.Component;

import com.nimbusds.jwt.SignedJWT;

@Component
public class JwtTokenProvider {

    private final MessageSource messageSource;
    private final CurrentDate currentDate;

    @Autowired
    public JwtTokenProvider(MessageSource messageSource, CurrentDate currentDate) {
        this.messageSource = messageSource;
        this.currentDate = currentDate;
    }

    public Authentication getAuthentication(String token) {
        try {
            final SignedJWT jwt = SignedJWT.parse(token);
            return new JwtAuthenticationToken(this.messageSource, jwt);
        } catch (final ParseException ex) {
            throw InvalidJwtAuthenticationException.build(this.messageSource, ex, ExConstants.SECURITY_001_JWT_INVALID);
        }
    }

    public String resolveToken(HttpServletRequest req) {
        final String bearerToken = req.getHeader("Authorization");
        if (bearerToken != null && bearerToken.startsWith("Bearer ")) {
            return bearerToken.substring("Bearer ".length(), bearerToken.length());
        }
        return null;
    }

    public void validateToken(String token) {
        try {
            final SignedJWT jwt = SignedJWT.parse(token);

            // See if JWT has expired
            final LocalDateTime expiryDate = jwt.getJWTClaimsSet().getDateClaim("exp")
                                                 .toInstant().atOffset(OffsetDateTime.now().getOffset()).toLocalDateTime();
            final LocalDateTime now = this.currentDate.getCurrentDateTime();
            if (expiryDate.isBefore(now)) {
                throw InvalidJwtAuthenticationException.build(this.messageSource, ExConstants.SECURITY_001_JWT_INVALID);
            }

            // Verify JWT
        final boolean verified = jwt.verify(new RSASSAVerifier(jwk));
        if (!verified) {
            throw InvalidJwtAuthenticationException.build(this.messageSource, ExConstants.JWT_INVALID_TOKEN);
        }
        } catch (final ParseException ex) {
            throw InvalidJwtAuthenticationException.build(this.messageSource, ex, ExConstants.SECURITY_001_JWT_INVALID);
        }
    }
}

The SecurityConfig class makes sure we are checking the right URLs for security:

package com.drumcoder.diary.security;

import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.http.SessionCreationPolicy;

@Configuration
public class SecurityConfig extends WebSecurityConfigurerAdapter {

    public SecurityConfig(JwtTokenProvider jwtTokenProvider) {
        this.jwtTokenProvider = jwtTokenProvider;
    }

    private final JwtTokenProvider jwtTokenProvider;

    @Override
    protected void configure(HttpSecurity http) throws Exception {
    // @formatter:off
    http.httpBasic().disable()
        .cors().and()
        .csrf().disable()
        .sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS)
        .and().authorizeRequests()
          .antMatchers("/").permitAll()
          .antMatchers("/swagger-ui.html**").permitAll()
          .antMatchers("/api/**").authenticated()
        .and()
          .apply(new JwtConfigurer(this.jwtTokenProvider));
    // @formatter:on
    }
}

JwtAuthenticationToken is used as a wrapper for the parsed JWT to allow access to the contents using simple methods:

package com.drumcoder.diary.security;

import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.MessageSource;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;

import com.nimbusds.jwt.SignedJWT;
import net.minidev.json.JSONArray;

public class JwtAuthenticationToken implements Authentication {

    private static final Logger LOGGER = LoggerFactory.getLogger(JwtAuthenticationToken.class.getName());

    private final SignedJWT jwt;
    private final MessageSource messageSource;

    public JwtAuthenticationToken(MessageSource messageSource, SignedJWT jwt) {
        this.jwt = jwt;
        this.messageSource = messageSource;
    }

    @Override
    public String getName() {
        try {
            String username = (String) this.jwt.getJWTClaimsSet().getClaim("preferred_username");
            if (username == null) {
                username = (String) this.jwt.getJWTClaimsSet().getClaim("unique_name");
            }
            if (username == null) {
                username = (String) this.jwt.getJWTClaimsSet().getClaim("sub");
            }
            if (username == null) {
                throw InvalidJwtAuthenticationException.build(this.messageSource, ExConstants.SECURITY_001_JWT_INVALID);
            }
            return username;
        } catch (final ParseException e) {
            throw InvalidJwtAuthenticationException.build(this.messageSource, ExConstants.SECURITY_001_JWT_INVALID);
        }
    }

    @Override
    public Collection<? extends GrantedAuthority> getAuthorities() {
        final List<GrantedAuthority> authorities = new ArrayList<>();
        try {
            JSONArray roles = (JSONArray) this.jwt.getJWTClaimsSet().getClaim("roles");
            for (int i = 0; i < roles.size(); i++) {
                String roleName = (String) roles.get(i);
                authorities.add(new SimpleGrantedAuthority(roleName));
                LOGGER.debug("Added {} to roles for this user", roleName);
            }
        } catch (final ParseException e) {
            throw InvalidJwtAuthenticationException.build(this.messageSource, ExConstants.SECURITY_001_JWT_INVALID);
        }
        return authorities;
    }

    @Override
    public Object getCredentials() {
        return "";
    }

    @Override
    public Object getDetails() {
        return "";
    }

    @Override
    public Object getPrincipal() {
        return this.getName();
    }

    @Override
    public boolean isAuthenticated() {
        return true;
    }

    @Override
    public void setAuthenticated(boolean isAuthenticated) {
        // Attempting to set a token as Authenticated manually must fail.
        // This class is only used for internal test tokens - there is a
        // PreAuthenticatedAuthenticationToken is used for tokens created after logging
        // into Azure Active Directory.
        if (isAuthenticated) {
            throw new UnsupportedOperationException();
        }    
    }

}