Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/main/java/com/uid2/core/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import com.uid2.core.model.Constants;
import com.uid2.core.model.SecretStore;
import com.uid2.core.service.AttestationService;
import com.uid2.core.service.JWTTokenProvider;
import com.uid2.core.service.OperatorJWTTokenProvider;
import software.amazon.awssdk.services.kms.KmsClient;
import com.uid2.core.vertx.CoreVerticle;
import com.uid2.core.vertx.Endpoints;
import com.uid2.shared.Const;
Expand Down Expand Up @@ -136,7 +138,11 @@ public static void main(String[] args) {

attestationService.with("gcp-oidc", new GcpOidcCoreAttestationService(corePublicUrl));

OperatorJWTTokenProvider operatorJWTTokenProvider = new OperatorJWTTokenProvider(config);
KmsClient kmsClient = JWTTokenProvider.buildKmsClient(config);
OperatorJWTTokenProvider operatorJWTTokenProvider = new OperatorJWTTokenProvider(
config.getString(Const.Config.CorePublicUrlProp),
config.getString(Const.Config.OptOutUrlProp),
new JWTTokenProvider(kmsClient));

IAttestationTokenService attestationTokenService = new AttestationTokenService(
SecretStore.Global.get(Constants.AttestationEncryptionKeyName),
Expand Down
56 changes: 19 additions & 37 deletions src/main/java/com/uid2/core/service/JWTTokenProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import java.util.Base64;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
Expand All @@ -33,12 +32,10 @@
public class JWTTokenProvider {
private static final Logger LOGGER = LoggerFactory.getLogger(JWTTokenProvider.class);
private static final Base64.Encoder encoder = Base64.getUrlEncoder().withoutPadding();
private final Supplier<KmsClientBuilder> kmsClientBuilderSupplier;
private final JsonObject config;
private final KmsClient kmsClient;

public JWTTokenProvider(JsonObject config, Supplier<KmsClientBuilder> kmsClientBuilderSupplier) {
this.config = config;
this.kmsClientBuilderSupplier = kmsClientBuilderSupplier;
public JWTTokenProvider(KmsClient kmsClient) {
this.kmsClient = kmsClient;
}

public String getJWT(Instant expiresAt, Instant issuedAt, Map<String, String> customClaims) throws JwtSigningException {
Expand All @@ -62,13 +59,7 @@ public String getJWT(Instant expiresAt, Instant issuedAt, Map<String, String> he
.append(encoder.encodeToString(claimsJson.encode().getBytes(StandardCharsets.UTF_8)))
.toString();

KmsClient client = null;
try {
client = getKmsClient(this.kmsClientBuilderSupplier.get(), this.config);
} catch (URISyntaxException e) {
throw new JwtSigningException(Optional.of("Unable to get KMS Client"), e);
}
String signature = signJwtContent(client, jwtContent);
String signature = signJwtContent(this.kmsClient, jwtContent);
if (signature != null && !signature.isBlank()) {
return new StringBuilder()
.append(jwtContent)
Expand Down Expand Up @@ -128,44 +119,35 @@ private void addMapToJsonObject(JsonObject jsonObject, Map<String, String> map)
}
}

private static KmsClient getKmsClient(KmsClientBuilder kmsClientBuilder, JsonObject config) throws URISyntaxException {
KmsClient client;

public static KmsClient buildKmsClient(JsonObject config) throws URISyntaxException {
String region = config.getString(KmsRegionProp, config.getString(Const.Config.AwsRegionProp));
String accessKeyId = config.getString(KmsAccessKeyIdProp);
String secretAccessKey = config.getString(KmsSecretAccessKeyProp);
String endpoint = config.getString(KmsEndpointProp);

if (accessKeyId != null && !accessKeyId.isBlank() && secretAccessKey != null && !secretAccessKey.isBlank()) {
AwsBasicCredentials basicCredentials = AwsBasicCredentials.create(accessKeyId, secretAccessKey);

StaticCredentialsProvider.create(basicCredentials);
KmsClientBuilder kmsClientBuilder = KmsClient.builder();
if (endpoint != null && !endpoint.isBlank()) {
try {
if (endpoint != null && !endpoint.isBlank()) {
kmsClientBuilder.endpointOverride(new URI(endpoint));
}

client = kmsClientBuilder
.region(Region.of(region))
.credentialsProvider(StaticCredentialsProvider.create(basicCredentials))
.build();
kmsClientBuilder.endpointOverride(new URI(endpoint));
} catch (URISyntaxException e) {
LOGGER.error("Error creating KMS Client Builder using static credentials.", e);
LOGGER.error("Error creating KMS Client Builder.", e);
throw e;
}
} else {
DefaultCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create();
}

client = kmsClientBuilder
.region(Region.of(region))
.credentialsProvider(credentialsProvider)
.build();
kmsClientBuilder.region(Region.of(region));

if (accessKeyId != null && !accessKeyId.isBlank() && secretAccessKey != null && !secretAccessKey.isBlank()) {
kmsClientBuilder.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(accessKeyId, secretAccessKey)));
} else {
kmsClientBuilder.credentialsProvider(DefaultCredentialsProvider.create());
}

return client;
return kmsClientBuilder.build();
}

public class JwtSigningException extends Exception {
public static class JwtSigningException extends Exception {
public JwtSigningException(Optional<String> message) {
this(message, null);
}
Expand Down
29 changes: 14 additions & 15 deletions src/main/java/com/uid2/core/service/OperatorJWTTokenProvider.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package com.uid2.core.service;

import com.uid2.shared.Const;
import com.uid2.shared.Utils;
import com.uid2.shared.auth.Role;
import io.vertx.core.json.JsonObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.kms.KmsClient;

import java.time.Clock;
import java.time.Instant;
Expand All @@ -20,16 +17,18 @@

public class OperatorJWTTokenProvider {
private static final Logger LOGGER = LoggerFactory.getLogger(OperatorJWTTokenProvider.class);
private final JsonObject config;
private final String issuerUrl;
private final String optOutUrl;
private final JWTTokenProvider jwtTokenProvider;
private final Clock clock;

public OperatorJWTTokenProvider(JsonObject config) {
this(config, new JWTTokenProvider(config, KmsClient::builder), Clock.systemUTC());
public OperatorJWTTokenProvider(String issuerUrl, String optOutUrl, JWTTokenProvider jwtTokenProvider) {
this(issuerUrl, optOutUrl, jwtTokenProvider, Clock.systemUTC());
}

public OperatorJWTTokenProvider(JsonObject config, JWTTokenProvider jwtTokenProvider, Clock clock) {
this.config = config;
public OperatorJWTTokenProvider(String issuerUrl, String optOutUrl, JWTTokenProvider jwtTokenProvider, Clock clock) {
this.issuerUrl = issuerUrl;
this.optOutUrl = optOutUrl;
this.jwtTokenProvider = jwtTokenProvider;
this.clock = clock;
}
Expand All @@ -39,27 +38,27 @@ public OperatorJWTTokenProvider(JsonObject config, JWTTokenProvider jwtTokenProv
OptOut when the operator makes calls to OptOut.
The claims we will add are:
"iss" : the config value for issuer, something like https://core-prod.uidapi.com
"sub" : the name of the operator as registered in the Admin site
"sub" : the base64-encoded SHA-512 hash of the operator key
"aud" : the url of the optout service that this token can be used with https://optout-prod.uidapi.com
"exp" : the expiry date time of the token, set to be the same as the expiry of the attestation token
"iat" : the current date time
*/
public String getOptOutJWTToken(String operatorKey, String name, Set<Role> roles, Integer siteId, String enclaveId, String enclaveType, String operatorVersion, Instant expiresAt) throws JWTTokenProvider.JwtSigningException {
return this.getJWTToken(this.config.getString(Const.Config.CorePublicUrlProp), this.config.getString(Const.Config.OptOutUrlProp), operatorKey, name, roles, siteId, enclaveId, enclaveType, operatorVersion, expiresAt);
return this.getJWTToken(this.issuerUrl, this.optOutUrl, operatorKey, name, roles, siteId, enclaveId, enclaveType, operatorVersion, expiresAt);
}

/*
Returns a JWT that is given to the operator. This is then presented by the operator to
OptOut when the operator makes calls to Core.
Core when the operator makes calls to Core.
The claims we will add are:
"iss" : the config value for issuer, something like https://core-prod.uidapi.com
"sub" : the name of the operator as registered in the Admin site
"aud" : the url of the optout service that this token can be used with https://core-prod.uidapi.com
"sub" : the base64-encoded SHA-512 hash of the operator key
"aud" : the url of the core service that this token can be used with https://core-prod.uidapi.com
"exp" : the expiry date time of the token, set to be the same as the expiry of the attestation token
"iat" : the current date time
*/
public String getCoreJWTToken(String operatorKey, String name, Set<Role> roles, Integer siteId, String enclaveId, String enclaveType, String operatorVersion, Instant expiresAt) throws JWTTokenProvider.JwtSigningException {
return this.getJWTToken(this.config.getString(Const.Config.CorePublicUrlProp), this.config.getString(Const.Config.CorePublicUrlProp), operatorKey, name, roles, siteId, enclaveId, enclaveType, operatorVersion, expiresAt);
return this.getJWTToken(this.issuerUrl, this.issuerUrl, operatorKey, name, roles, siteId, enclaveId, enclaveType, operatorVersion, expiresAt);
}

private String getJWTToken(String issuer, String audience, String operatorKey, String name, Set<Role> roles, Integer siteId, String enclaveId, String enclaveType, String operatorVersion, Instant expiresAt) throws JWTTokenProvider.JwtSigningException {
Expand All @@ -83,7 +82,7 @@ private String getJWTToken(String issuer, String audience, String operatorKey, S
claims.put("operatorVersion", operatorVersion);
claims.put("jti", UUID.randomUUID().toString());

LOGGER.debug(String.format("Creating token with: Issuer: %s, Audience: %s, Roles: %s, SiteId: %s, EnclaveId: %s, EnclaveType: %s, OperatorVersion: %s", audience, issuer, roleString, siteId, enclaveId, enclaveType, operatorVersion));
LOGGER.debug(String.format("Creating token with: Issuer: %s, Audience: %s, Roles: %s, SiteId: %s, EnclaveId: %s, EnclaveType: %s, OperatorVersion: %s", issuer, audience, roleString, siteId, enclaveId, enclaveType, operatorVersion));
return this.jwtTokenProvider.getJWT(expiresAt, this.clock.instant(), claims);
}
}
44 changes: 17 additions & 27 deletions src/test/java/com/uid2/core/service/JWTTokenProviderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.kms.KmsClient;
import software.amazon.awssdk.services.kms.KmsClientBuilder;
import software.amazon.awssdk.services.kms.model.KmsException;
import software.amazon.awssdk.services.kms.model.SignRequest;
import software.amazon.awssdk.services.kms.model.SignResponse;
Expand All @@ -27,7 +24,6 @@

import static com.uid2.shared.Utils.readToEndAsString;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -36,11 +32,10 @@ public class JWTTokenProviderTest {

private KmsClient mockClient;
private ArgumentCaptor<SignRequest> capturedSignRequest;
private JsonObject config;

@BeforeEach
void setUp() throws IOException {
this.config = ((JsonObject) Json.decodeValue(openFile("/com.uid2.core/service/jwt-token-provider-test-config.json")));
JsonObject config = (JsonObject) Json.decodeValue(openFile("/com.uid2.core/service/jwt-token-provider-test-config.json"));
ConfigStore.Global.load(config);
defaultHeaders.put("typ", "JWT");
defaultHeaders.put("alg", "RS256");
Expand All @@ -59,8 +54,8 @@ void getJwtReturnsValidToken() throws JWTTokenProvider.JwtSigningException {
content.put("iss", "issuer");
content.put("jti", jti);

var builder = getBuilder(true, "TestSignature");
JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);
var kmsClient = getKmsClient(true, "TestSignature");
JWTTokenProvider provider = new JWTTokenProvider(kmsClient);

Instant i = Clock.systemUTC().instant();

Expand All @@ -86,9 +81,9 @@ void getJwtReturnsValidToken() throws JWTTokenProvider.JwtSigningException {

@Test
void getJwtEmptySignatureThrowsException() {
var builder = getBuilder(false, "");
var kmsClient = getKmsClient(false, "");

JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);
JWTTokenProvider provider = new JWTTokenProvider(kmsClient);

JWTTokenProvider.JwtSigningException e = assertThrows(
JWTTokenProvider.JwtSigningException.class,
Expand All @@ -99,9 +94,9 @@ void getJwtEmptySignatureThrowsException() {

@Test
void getJwtEmptySignatureEmptyResponseText() {
var builder = getBuilder(false, "", Optional.empty());
var kmsClient = getKmsClient(false, "", Optional.empty());

JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);
JWTTokenProvider provider = new JWTTokenProvider(kmsClient);

JWTTokenProvider.JwtSigningException e = assertThrows(
JWTTokenProvider.JwtSigningException.class,
Expand All @@ -112,9 +107,9 @@ void getJwtEmptySignatureEmptyResponseText() {

@Test
void getJwtEmptySignatureNullResponseText() {
var builder = getBuilder(false, "", null);
var kmsClient = getKmsClient(false, "", null);

JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);
JWTTokenProvider provider = new JWTTokenProvider(kmsClient);

JWTTokenProvider.JwtSigningException e = assertThrows(
JWTTokenProvider.JwtSigningException.class,
Expand All @@ -125,9 +120,9 @@ void getJwtEmptySignatureNullResponseText() {

@Test
void getJwtSignatureThrowsKmsException() {
var builder = getBuilder(false, "", Optional.empty());
var kmsClient = getKmsClient(false, "", Optional.empty());

JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);
JWTTokenProvider provider = new JWTTokenProvider(kmsClient);
var ex = KmsException.builder().message("Test Error").build();
when(mockClient.sign(capturedSignRequest.capture())).thenThrow(ex);

Expand All @@ -146,9 +141,9 @@ void getJwtMissingKeyInConfig() throws IOException {

ConfigStore.Global.load(data);

var builder = getBuilder(false, "", Optional.empty());
var kmsClient = getKmsClient(false, "", Optional.empty());

JWTTokenProvider provider = new JWTTokenProvider(config, () -> builder);
JWTTokenProvider provider = new JWTTokenProvider(kmsClient);

JWTTokenProvider.JwtSigningException e = assertThrows(
JWTTokenProvider.JwtSigningException.class,
Expand All @@ -161,11 +156,11 @@ String openFile(String filePath) throws IOException {
return readToEndAsString(JWTTokenProviderTest.class.getResourceAsStream(filePath));
}

private KmsClientBuilder getBuilder(boolean isSuccessful, String signature) {
return getBuilder(isSuccessful, signature, Optional.of("Test status text"));
private KmsClient getKmsClient(boolean isSuccessful, String signature) {
return getKmsClient(isSuccessful, signature, Optional.of("Test status text"));
}

private KmsClientBuilder getBuilder(boolean isSuccessful, String signature, Optional<String> statusText) {
private KmsClient getKmsClient(boolean isSuccessful, String signature, Optional<String> statusText) {
SdkHttpResponse sdkHttpResponse = mock(SdkHttpResponse.class);
when(sdkHttpResponse.isSuccessful()).thenReturn(isSuccessful);
when(sdkHttpResponse.statusText()).thenReturn(statusText);
Expand All @@ -178,12 +173,7 @@ private KmsClientBuilder getBuilder(boolean isSuccessful, String signature, Opti
capturedSignRequest = ArgumentCaptor.forClass(SignRequest.class);
when(mockClient.sign(capturedSignRequest.capture())).thenReturn(response);

KmsClientBuilder builder = mock(KmsClientBuilder.class);
when(builder.region(any(Region.class))).thenReturn(builder);
when(builder.credentialsProvider(any(AwsCredentialsProvider.class))).thenReturn(builder);
when(builder.build()).thenReturn(mockClient);

return builder;
return mockClient;
}

private void assertJWT(String expectedHeader, String expectedContent, String expectedSignature, String jwt) {
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/com/uid2/core/vertx/CoreVerticleTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ void attestOptOutJWTCalledReturns500OnError(Vertx vertx, VertxTestContext testCo
EncryptedAttestationToken encryptedAttestationToken = new EncryptedAttestationToken("test-attestation-token", Instant.ofEpochMilli(111));
when(attestationTokenService.createToken(any())).thenReturn(encryptedAttestationToken);

when(operatorJWTTokenProvider.getCoreJWTToken(anyString(), anyString(), any(), anyInt(), anyString(), any(), anyString(), any())).thenThrow(new JWTTokenProvider(null, null).new JwtSigningException(Optional.of("Test error")));
when(operatorJWTTokenProvider.getCoreJWTToken(anyString(), anyString(), any(), anyInt(), anyString(), any(), anyString(), any())).thenThrow(new JWTTokenProvider.JwtSigningException(Optional.of("Test error")));
post(vertx, "attest", makeAttestationRequestJson("xxx", null), ar -> {
assertTrue(ar.succeeded());
HttpResponse response = ar.result();
Expand Down
Loading