JWT Security with Azure in Spring Boot
April 2, 2020
I wanted to have my users authenticate with Azure, and create a JWT which then gets passed into each subsequent request to identify them.
First up we have a filter which we add to the request chain. This filter:
Picks out a correlation id from the x-correlation-id
request header and stores this using MDC fpr later logging
Checks the JWT token provided in the authorisation header and if it is present, validates it
* OPTIONS requests are ignored
package com.drumcoder.diary.security; import java.io.IOException; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.slf4j.MDC; import org.springframework.http.HttpStatus; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.web.filter.GenericFilterBean; public class JwtTokenFilter extends GenericFilterBean { private static final String CORRELATION_ID_HEADER_NAME = "x-correlation-id"; private final JwtTokenProvider jwtTokenProvider; public JwtTokenFilter(JwtTokenProvider jwtTokenProvider) { this.jwtTokenProvider = jwtTokenProvider; } @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { // Store correlation id for use in logging String correlationId = ((HttpServletRequest) (request)).getHeader(JwtTokenFilter.CORRELATION_ID_HEADER_NAME); if (correlationId == null) { correlationId = ""; } MDC.put("correlation-id", correlationId); // Check for JWT String token = null; HttpServletRequest httpRequest = (HttpServletRequest) request; if (!"OPTIONS".equals(httpRequest.getMethod())) { token = this.jwtTokenProvider.resolveToken((HttpServletRequest) request); } if (token == null) { // There is no token - pass request down chain chain.doFilter(request, response); } else { // There is a token try { this.jwtTokenProvider.validateToken(token); // Token is valid, set auth context final Authentication auth = this.jwtTokenProvider.getAuthentication(token); SecurityContextHolder.getContext().setAuthentication(auth); // proceed to next filter in chain chain.doFilter(request, response); } catch (final InvalidJwtAuthenticationException ex) { // Token failed validation ((HttpServletResponse) response).setStatus(HttpStatus.FORBIDDEN.value()); final String message = String.format(ExConstants.JSON_EXCEPTION_STRUCTURE, ExConstants.SECURITY_003_TOKEN_INVALID, ex.getMessage(), ex.getClass().getName()); response.getWriter().write(message); } } } }
This is added to the filter chain using
package com.drumcoder.diary.security; import org.springframework.security.config.annotation.SecurityConfigurerAdapter; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.web.DefaultSecurityFilterChain; import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; public class JwtConfigurer extends SecurityConfigurerAdapter<DefaultSecurityFilterChain, HttpSecurity> { private JwtTokenProvider jwtTokenProvider; public JwtConfigurer(JwtTokenProvider jwtTokenProvider) { this.jwtTokenProvider = jwtTokenProvider; } @Override public void configure(HttpSecurity http) throws Exception { JwtTokenFilter customFilter = new JwtTokenFilter(this.jwtTokenProvider); http.addFilterBefore(customFilter, UsernamePasswordAuthenticationFilter.class); } }
The filter makes use of the JwtTokenProvider class to validate the token
package com.drumcoder.diary.security; import java.text.ParseException; import java.time.LocalDateTime; import java.time.OffsetDateTime; import javax.servlet.http.HttpServletRequest; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.MessageSource; import org.springframework.security.core.Authentication; import org.springframework.stereotype.Component; import com.nimbusds.jwt.SignedJWT; @Component public class JwtTokenProvider { private final MessageSource messageSource; private final CurrentDate currentDate; @Autowired public JwtTokenProvider(MessageSource messageSource, CurrentDate currentDate) { this.messageSource = messageSource; this.currentDate = currentDate; } public Authentication getAuthentication(String token) { try { final SignedJWT jwt = SignedJWT.parse(token); return new JwtAuthenticationToken(this.messageSource, jwt); } catch (final ParseException ex) { throw InvalidJwtAuthenticationException.build(this.messageSource, ex, ExConstants.SECURITY_001_JWT_INVALID); } } public String resolveToken(HttpServletRequest req) { final String bearerToken = req.getHeader("Authorization"); if (bearerToken != null && bearerToken.startsWith("Bearer ")) { return bearerToken.substring("Bearer ".length(), bearerToken.length()); } return null; } public void validateToken(String token) { try { final SignedJWT jwt = SignedJWT.parse(token); // See if JWT has expired final LocalDateTime expiryDate = jwt.getJWTClaimsSet().getDateClaim("exp") .toInstant().atOffset(OffsetDateTime.now().getOffset()).toLocalDateTime(); final LocalDateTime now = this.currentDate.getCurrentDateTime(); if (expiryDate.isBefore(now)) { throw InvalidJwtAuthenticationException.build(this.messageSource, ExConstants.SECURITY_001_JWT_INVALID); } // Verify JWT final boolean verified = jwt.verify(new RSASSAVerifier(jwk)); if (!verified) { throw InvalidJwtAuthenticationException.build(this.messageSource, ExConstants.JWT_INVALID_TOKEN); } } catch (final ParseException ex) { throw InvalidJwtAuthenticationException.build(this.messageSource, ex, ExConstants.SECURITY_001_JWT_INVALID); } } }
The SecurityConfig class makes sure we are checking the right URLs for security:
package com.drumcoder.diary.security; import org.springframework.context.annotation.Configuration; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.config.http.SessionCreationPolicy; @Configuration public class SecurityConfig extends WebSecurityConfigurerAdapter { public SecurityConfig(JwtTokenProvider jwtTokenProvider) { this.jwtTokenProvider = jwtTokenProvider; } private final JwtTokenProvider jwtTokenProvider; @Override protected void configure(HttpSecurity http) throws Exception { // @formatter:off http.httpBasic().disable() .cors().and() .csrf().disable() .sessionManagement().sessionCreationPolicy(SessionCreationPolicy.STATELESS) .and().authorizeRequests() .antMatchers("/").permitAll() .antMatchers("/swagger-ui.html**").permitAll() .antMatchers("/api/**").authenticated() .and() .apply(new JwtConfigurer(this.jwtTokenProvider)); // @formatter:on } }
JwtAuthenticationToken is used as a wrapper for the parsed JWT to allow access to the contents using simple methods:
package com.drumcoder.diary.security; import java.text.ParseException; import java.util.ArrayList; import java.util.Collection; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.context.MessageSource; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import com.nimbusds.jwt.SignedJWT; import net.minidev.json.JSONArray; public class JwtAuthenticationToken implements Authentication { private static final Logger LOGGER = LoggerFactory.getLogger(JwtAuthenticationToken.class.getName()); private final SignedJWT jwt; private final MessageSource messageSource; public JwtAuthenticationToken(MessageSource messageSource, SignedJWT jwt) { this.jwt = jwt; this.messageSource = messageSource; } @Override public String getName() { try { String username = (String) this.jwt.getJWTClaimsSet().getClaim("preferred_username"); if (username == null) { username = (String) this.jwt.getJWTClaimsSet().getClaim("unique_name"); } if (username == null) { username = (String) this.jwt.getJWTClaimsSet().getClaim("sub"); } if (username == null) { throw InvalidJwtAuthenticationException.build(this.messageSource, ExConstants.SECURITY_001_JWT_INVALID); } return username; } catch (final ParseException e) { throw InvalidJwtAuthenticationException.build(this.messageSource, ExConstants.SECURITY_001_JWT_INVALID); } } @Override public Collection<? extends GrantedAuthority> getAuthorities() { final List<GrantedAuthority> authorities = new ArrayList<>(); try { JSONArray roles = (JSONArray) this.jwt.getJWTClaimsSet().getClaim("roles"); for (int i = 0; i < roles.size(); i++) { String roleName = (String) roles.get(i); authorities.add(new SimpleGrantedAuthority(roleName)); LOGGER.debug("Added {} to roles for this user", roleName); } } catch (final ParseException e) { throw InvalidJwtAuthenticationException.build(this.messageSource, ExConstants.SECURITY_001_JWT_INVALID); } return authorities; } @Override public Object getCredentials() { return ""; } @Override public Object getDetails() { return ""; } @Override public Object getPrincipal() { return this.getName(); } @Override public boolean isAuthenticated() { return true; } @Override public void setAuthenticated(boolean isAuthenticated) { // Attempting to set a token as Authenticated manually must fail. // This class is only used for internal test tokens - there is a // PreAuthenticatedAuthenticationToken is used for tokens created after logging // into Azure Active Directory. if (isAuthenticated) { throw new UnsupportedOperationException(); } } }