azure-signalr icon indicating copy to clipboard operation
azure-signalr copied to clipboard

get rid of System.IdentityModel.Tokens.Jwt package when generating the JWT token

Open xingsy97 opened this issue 2 years ago • 2 comments

get rid of System.IdentityModel.Tokens.Jwt package when generating the JWT token

Summary of the changes (Less than 80 chars)

  • Implement our own JWT token generator to get rid of System.IdentityModel.Tokens.Jwt package following JwtBuilder.cs and JwtPayload.cs
  • remove useless unit test TestGenerateJwtBearerCaching

Fix Issue #1606

xingsy97 avatar May 30 '22 04:05 xingsy97

  • submitted a rewritten version which is simplified from package System.Security.Claims while the previous version is modified from this
  • Why abandon previous version
    • It uses a complicated implementation to convert between decoded JWT and encoded JWT.
    • It doesn't support multiple ClaimValueTypes

xingsy97 avatar Jun 07 '22 13:06 xingsy97

Here is the code path to generate JWT token in previous SDK. Pay attention to comments start with ATTENTION I use version 6.22.0 of Microsoft.IdentityModel.JsonWebTokens

1

Entry Point. method GenerateJwtBearer in SDK AuthUtility.cs

public static string GenerateJwtBearer(
    string issuer = null,
    string audience = null,
    IEnumerable<Claim> claims = null,
    DateTime? expires = null,
    AccessKey signingKey = null,
    DateTime? issuedAt = null,
    DateTime? notBefore = null,
    AccessTokenAlgorithm algorithm = AccessTokenAlgorithm.HS256)
{
    var subject = claims == null ? null : new ClaimsIdentity(claims);
    SigningCredentials credentials = null;
    if (signingKey != null)
    {
        // Refer: https://github.com/AzureAD/azure-activedirectory-identitymodel-extensions-for-dotnet/releases/tag/5.5.0
        // From version 5.5.0, SignatureProvider caching is turned On by default, assign KeyId to enable correct cache for same SigningKey
        var securityKey = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(signingKey.Value))
        {
            KeyId = signingKey.Id
        };

        if (signingKey is AadAccessKey)
        {
            // disable cache when using AadAccessKey
            securityKey.CryptoProviderFactory.CacheSignatureProviders = false;
        }
        credentials = new SigningCredentials(securityKey, GetSecurityAlgorithm(algorithm));
    }

    var token = JwtTokenHandler.CreateJwtSecurityToken(          // ATTENTION: 1.1
        issuer: issuer,
        audience: audience,
        subject: subject,
        notBefore: notBefore,
        expires: expires,
        issuedAt: issuedAt,
        signingCredentials: credentials);
    return JwtTokenHandler.WriteToken(token);                    // ATTENTION: 1.2
}

1.1

method CreateJwtSecurityToken in JwtSecruityTokenHandler.cs

public virtual JwtSecurityToken CreateJwtSecurityToken(
	string issuer,
	string audience,
	ClaimsIdentity subject,
	DateTime? notBefore,
	DateTime? expires,
	DateTime? issuedAt,
	SigningCredentials signingCredentials,
	EncryptingCredentials encryptingCredentials)
{
	return CreateJwtSecurityTokenPrivate(                     // ATTENTION: 1.1.1
		issuer,
		audience,
		subject,
		notBefore,
		expires,
		issuedAt,
		signingCredentials,
		encryptingCredentials, null, null, null, null);
}

1.1.1

method CreateJwtSecurityTokenPrivate in JwtSecurityTokenHandler.cs

private JwtSecurityToken CreateJwtSecurityTokenPrivate(
	string issuer,
	string audience,
	ClaimsIdentity subject,
	DateTime? notBefore,
	DateTime? expires,
	DateTime? issuedAt,
	SigningCredentials signingCredentials,
	EncryptingCredentials encryptingCredentials,
	IDictionary<string, object> claimCollection,
	string tokenType,
	IDictionary<string, object> additionalHeaderClaims,
	IDictionary<string, object> additionalInnerHeaderClaims)
{
	// ATTENTION: variable SetDefaultTimesOnTokenCreation 
	if (SetDefaultTimesOnTokenCreation && (!expires.HasValue || !issuedAt.HasValue || !notBefore.HasValue))
	{
		DateTime now = DateTime.UtcNow;
		// ATTENTION: variable TokenLifetimeInMinutes 
		if (!expires.HasValue)
			expires = now + TimeSpan.FromMinutes(TokenLifetimeInMinutes);

		if (!issuedAt.HasValue)
			issuedAt = now;

		if (!notBefore.HasValue)
			notBefore = now;
	}

	LogHelper.LogVerbose(LogMessages.IDX12721, (audience ?? "null"), (issuer ?? "null"));
	
	// ATTENTION: 1.1.1.1
	JwtPayload payload = new JwtPayload(issuer, audience, (subject == null ? null : OutboundClaimTypeTransform(subject.Claims)), (claimCollection == null ? null : OutboundClaimTypeTransform(claimCollection)), notBefore, expires, issuedAt);
	
	// ATTENTION: 1.1.1.2
	JwtHeader header = new JwtHeader(signingCredentials, OutboundAlgorithmMap, tokenType, additionalInnerHeaderClaims);

	if (subject?.Actor != null)
		payload.AddClaim(new Claim(JwtRegisteredClaimNames.Actort, CreateActorValue(subject.Actor)));

	string rawHeader = header.Base64UrlEncode();
	string rawPayload = payload.Base64UrlEncode();
	string message = string.Concat(header.Base64UrlEncode(), ".", payload.Base64UrlEncode());
	string rawSignature = signingCredentials == null ? string.Empty : JwtTokenUtilities.CreateEncodedSignature(message, signingCredentials);

	LogHelper.LogInformation(LogMessages.IDX12722, rawHeader, rawPayload, rawSignature);

	if (encryptingCredentials != null)
	{
		return EncryptToken(
				new JwtSecurityToken(header, payload, rawHeader, rawPayload, rawSignature),
				encryptingCredentials,
				tokenType,
				additionalHeaderClaims);
	}

	// ATTENTION: 1.1.1.3
	return new JwtSecurityToken(header, payload, rawHeader, rawPayload, rawSignature);  
}

And we should know what the exact value of SetDefaultTimesOnTokenCreation and TokenLifetimeInMinutes

Class inherit Chain: JwtSecurityTokenHandler <- SecurityTokenHandler <- TokenHandler and ISecurityTokenValidator

For SetDefaultTimesOnTokenCreation, we can find following code in TokenHandler.cs

[DefaultValue(true)]
public bool SetDefaultTimesOnTokenCreation { get; set; } = true;

For TokenLifetimeInMinutes, we can find following code in TokenHandler.cs

private int _defaultTokenLifetimeInMinutes = DefaultTokenLifetimeInMinutes;
public static readonly int DefaultTokenLifetimeInMinutes = 60;
...
...
public int TokenLifetimeInMinutes
{
	get => _defaultTokenLifetimeInMinutes;
	set => _defaultTokenLifetimeInMinutes = (value < 1) ? throw LogExceptionMessage(new ArgumentOutOfRangeException(nameof(value), FormatInvariant(LogMessages.IDX10104, LogHelper.MarkAsNonPII(value)))) : value;
}

1.1.1.1

constructor JwtPayload in JwtPayload.cs

public JwtPayload(string issuer, string audience, IEnumerable<Claim> claims, IDictionary<string, object> claimsCollection, DateTime? notBefore, DateTime? expires, DateTime? issuedAt)
	: base(StringComparer.Ordinal)
{
	if (claims != null)
		AddClaims(claims);

	// ATTENTION: According to the code path above, we can ensure `claimsCollection` is `null`, so next two line will be igonred
	if (claimsCollection != null && claimsCollection.Any())
		AddDictionaryClaims(claimsCollection);

	AddFirstPriorityClaims(issuer, audience, notBefore, expires, issuedAt);
}

And method AddFirstPriorityClaims in JwtPayload.cs

internal void AddFirstPriorityClaims(string issuer, string audience, DateTime? notBefore, DateTime? expires, DateTime? issuedAt)
{
	if (expires.HasValue)
	{
		if (notBefore.HasValue)
		{
			if (notBefore.Value >= expires.Value)
			{
				throw LogHelper.LogExceptionMessage(new ArgumentException(LogHelper.FormatInvariant(LogMessages.IDX12401, LogHelper.MarkAsNonPII(expires.Value), LogHelper.MarkAsNonPII(notBefore.Value))));
			}

			this[JwtRegisteredClaimNames.Nbf] = EpochTime.GetIntDate(notBefore.Value.ToUniversalTime());
		}

		this[JwtRegisteredClaimNames.Exp] = EpochTime.GetIntDate(expires.Value.ToUniversalTime());
	}

	if (issuedAt.HasValue)
		this[JwtRegisteredClaimNames.Iat] = EpochTime.GetIntDate(issuedAt.Value.ToUniversalTime());

	if (!string.IsNullOrEmpty(issuer))
		this[JwtRegisteredClaimNames.Iss] = issuer;

	// if could be the case that some of the claims above had an 'aud' claim;
	if (!string.IsNullOrEmpty(audience))
		AddClaim(new Claim(JwtRegisteredClaimNames.Aud, audience, ClaimValueTypes.String));
}

1.1.1.2

constructor JwtHeader in JwtHeader.cs

public JwtHeader(SigningCredentials signingCredentials, IDictionary<string, string> outboundAlgorithmMap, string tokenType, IDictionary<string, object> additionalInnerHeaderClaims)
	: base(StringComparer.Ordinal)
{
	if (signingCredentials == null)
		this[JwtHeaderParameterNames.Alg] = SecurityAlgorithms.None;

	else
	{
		if (outboundAlgorithmMap != null && outboundAlgorithmMap.TryGetValue(signingCredentials.Algorithm, out string outboundAlg))
			Alg = outboundAlg;
		else
			Alg = signingCredentials.Algorithm;

		if (!string.IsNullOrEmpty(signingCredentials.Key.KeyId))
			Kid = signingCredentials.Key.KeyId;

		if (signingCredentials is X509SigningCredentials x509SigningCredentials)
			this[JwtHeaderParameterNames.X5t] = Base64UrlEncoder.Encode(x509SigningCredentials.Certificate.GetCertHash());
	}

	if (string.IsNullOrEmpty(tokenType))
		Typ = JwtConstants.HeaderType;
	else
		Typ = tokenType;

	AddAdditionalClaims(additionalInnerHeaderClaims, false);
	SigningCredentials = signingCredentials;
}

and we can find method AddAdditionalClaims in JwtHeader.cs

internal void AddAdditionalClaims(IDictionary<string, object> additionalHeaderClaims, bool setDefaultCtyClaim)
{
	if (additionalHeaderClaims?.Count > 0 && additionalHeaderClaims.Keys.Intersect(DefaultHeaderParameters, StringComparer.OrdinalIgnoreCase).Any())
		throw LogHelper.LogExceptionMessage(new SecurityTokenException(LogHelper.FormatInvariant(LogMessages.IDX12742, nameof(additionalHeaderClaims), string.Join(", ", DefaultHeaderParameters))));

	if (additionalHeaderClaims != null)
	{
		if (!additionalHeaderClaims.TryGetValue(JwtHeaderParameterNames.Cty, out _) && setDefaultCtyClaim)
			Cty = JwtConstants.HeaderType;

		foreach (string claim in additionalHeaderClaims.Keys)
			this[claim] = additionalHeaderClaims[claim];
	}
	else if (setDefaultCtyClaim)
		Cty = JwtConstants.HeaderType;
}

1.1.1.3

constructor JwtSecurity in JwtSecurityToken.cs

public JwtSecurityToken(JwtHeader header, JwtPayload payload, string rawHeader, string rawPayload, string rawSignature)
{
	if (header == null)
		throw LogHelper.LogArgumentNullException(nameof(header));

	if (payload == null)
		throw LogHelper.LogArgumentNullException(nameof(payload));

	if (string.IsNullOrWhiteSpace(rawHeader))
		throw LogHelper.LogArgumentNullException(nameof(rawHeader));

	if (string.IsNullOrWhiteSpace(rawPayload))
		throw LogHelper.LogArgumentNullException(nameof(rawPayload));

	if (rawSignature == null)
		throw LogHelper.LogArgumentNullException(nameof(rawSignature));

	Header = header;
	Payload = payload;
	RawData = string.Concat(rawHeader, ".", rawPayload, ".", rawSignature);

	RawHeader = rawHeader;
	RawPayload = rawPayload;
	RawSignature = rawSignature;
}

1.2

WriteToken in JwtSecurityTokenHandler.cs

public override string WriteToken(SecurityToken token)
{
	if (token == null)
		throw LogArgumentNullException(nameof(token));

	var samlToken = token as SamlSecurityToken;
	if (samlToken == null)
		throw LogExceptionMessage(new ArgumentException(FormatInvariant(LogMessages.IDX11400, LogHelper.MarkAsNonPII(_className), LogHelper.MarkAsNonPII(typeof(SamlSecurityToken)), LogHelper.MarkAsNonPII(token.GetType()))));

	using (var memoryStream = new MemoryStream())
	{
		using (var writer = XmlDictionaryWriter.CreateTextWriter(memoryStream, Encoding.UTF8, false))
		{
			WriteToken(writer, samlToken);
			writer.Flush();
			return Encoding.UTF8.GetString(memoryStream.GetBuffer(), 0, (int)memoryStream.Length);
		}
	}
}

xingsy97 avatar Aug 10 '22 05:08 xingsy97