refactor: JWT関係のコードをリファクタリング

This commit is contained in:
usbharu 2023-05-02 15:53:04 +09:00
parent 5c82bfd532
commit dd30728548
Signed by: usbharu
GPG Key ID: 6556747BF94EEBC8
11 changed files with 233 additions and 84 deletions

View File

@ -86,7 +86,9 @@ fun Application.parent() {
single<IdGenerateService> { TwitterSnowflakeIdGenerateService }
single<IMetaRepository> { MetaRepositoryImpl(get()) }
single<IServerInitialiseService> { ServerInitialiseServiceImpl(get()) }
single<IJwtRefreshTokenRepository> { JwtRefreshTokenRepositoryImpl(get()) }
single<IJwtRefreshTokenRepository> { JwtRefreshTokenRepositoryImpl(get(),get()) }
single<IMetaService> { MetaServiceImpl(get()) }
single<IJwtService> { JwtServiceImpl(get(),get(),get()) }
}
configureKoin(module)
runBlocking {
@ -99,10 +101,9 @@ fun Application.parent() {
register(inject<IUserService>().value)
configureSecurity(
inject<IUserAuthService>().value,
inject<IMetaRepository>().value,
inject<IJwtRefreshTokenRepository>().value,
inject<IMetaService>().value,
inject<IUserRepository>().value,
inject<IdGenerateService>().value
inject<IJwtService>().value
)
configureRouting(
inject<HttpSignatureVerifyService>().value,

View File

@ -0,0 +1,8 @@
package dev.usbharu.hideout.exception
class InvalidRefreshTokenException : IllegalArgumentException{
constructor() : super()
constructor(s: String?) : super(s)
constructor(message: String?, cause: Throwable?) : super(message, cause)
constructor(cause: Throwable?) : super(cause)
}

View File

@ -0,0 +1,14 @@
package dev.usbharu.hideout.exception
class NotInitException : Exception {
constructor() : super()
constructor(message: String?) : super(message)
constructor(message: String?, cause: Throwable?) : super(message, cause)
constructor(cause: Throwable?) : super(cause)
constructor(message: String?, cause: Throwable?, enableSuppression: Boolean, writableStackTrace: Boolean) : super(
message,
cause,
enableSuppression,
writableStackTrace
)
}

View File

@ -3,23 +3,16 @@
package dev.usbharu.hideout.plugins
import com.auth0.jwk.JwkProviderBuilder
import com.auth0.jwt.JWT
import com.auth0.jwt.algorithms.Algorithm
import dev.usbharu.hideout.config.Config
import dev.usbharu.hideout.domain.model.hideout.dto.JwtToken
import dev.usbharu.hideout.domain.model.hideout.entity.JwtRefreshToken
import dev.usbharu.hideout.domain.model.hideout.form.RefreshToken
import dev.usbharu.hideout.domain.model.hideout.form.UserLogin
import dev.usbharu.hideout.exception.UserNotFoundException
import dev.usbharu.hideout.property
import dev.usbharu.hideout.repository.IJwtRefreshTokenRepository
import dev.usbharu.hideout.repository.IMetaRepository
import dev.usbharu.hideout.repository.IUserRepository
import dev.usbharu.hideout.service.IJwtService
import dev.usbharu.hideout.service.IMetaService
import dev.usbharu.hideout.service.IUserAuthService
import dev.usbharu.hideout.service.IdGenerateService
import dev.usbharu.hideout.util.Base64Util
import dev.usbharu.hideout.util.JsonWebKeyUtil
import dev.usbharu.hideout.util.RsaUtil
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
@ -27,38 +20,23 @@ import io.ktor.server.auth.jwt.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import kotlinx.coroutines.runBlocking
import java.time.Instant
import java.util.*
import java.util.concurrent.TimeUnit
const val TOKEN_AUTH = "jwt-auth"
fun Application.configureSecurity(
userAuthService: IUserAuthService,
metaRepository: IMetaRepository,
refreshTokenRepository: IJwtRefreshTokenRepository,
metaService: IMetaService,
userRepository: IUserRepository,
idGenerateService: IdGenerateService
jwtService: IJwtService
) {
val privateKey = runBlocking {
RsaUtil.decodeRsaPrivateKey(Base64Util.decode(requireNotNull(metaRepository.get()).jwt.privateKey))
}
val publicKey = runBlocking {
val publicKey = requireNotNull(metaRepository.get()).jwt.publicKey
RsaUtil.decodeRsaPublicKey(Base64Util.decode(publicKey))
}
val issuer = property("hideout.url")
// val audience = property("jwt.audience")
val myRealm = property("jwt.realm")
val jwkProvider = JwkProviderBuilder(issuer)
.cached(10, 24, TimeUnit.HOURS)
.rateLimited(10, 1, TimeUnit.MINUTES)
.build()
install(Authentication) {
jwt(TOKEN_AUTH) {
realm = myRealm
verifier(jwkProvider, issuer) {
acceptLeeway(3)
@ -70,78 +48,34 @@ fun Application.configureSecurity(
null
}
}
challenge { defaultScheme, realm ->
call.respondRedirect("/login")
}
}
}
routing {
post("/login") {
val user = call.receive<UserLogin>()
val check = userAuthService.verifyAccount(user.username, user.password)
val loginUser = call.receive<UserLogin>()
val check = userAuthService.verifyAccount(loginUser.username, loginUser.password)
if (check.not()) {
return@post call.respond(HttpStatusCode.Unauthorized)
}
val findByNameAndDomain = userRepository.findByNameAndDomain(user.username, Config.configData.domain)
?: throw UserNotFoundException("${user.username} was not found.")
val user = userRepository.findByNameAndDomain(loginUser.username, Config.configData.domain)
?: throw UserNotFoundException("${loginUser.username} was not found.")
val token = JWT.create()
.withAudience("${Config.configData.url}/users/${user.username}")
.withIssuer(issuer)
.withKeyId(metaRepository.get()?.jwt?.kid.toString())
.withClaim("username", user.username)
.withExpiresAt(Date(System.currentTimeMillis() + 60000))
.sign(Algorithm.RSA256(publicKey, privateKey))
val refreshToken = UUID.randomUUID().toString()
refreshTokenRepository.save(
JwtRefreshToken(
idGenerateService.generateId(), findByNameAndDomain.id, refreshToken, Instant.now(),
Instant.ofEpochMilli(Instant.now().toEpochMilli() + 1209600033)
)
)
return@post call.respond(JwtToken(token, refreshToken))
return@post call.respond(jwtService.createToken(user))
}
post("/refresh-token") {
val refreshToken = call.receive<RefreshToken>()
val findByToken = refreshTokenRepository.findByToken(refreshToken.refreshToken)
?: return@post call.respondText("token not found",status = HttpStatusCode.Forbidden)
if (findByToken.createdAt.isAfter(Instant.now())) {
return@post call.respondText("created_at", status = HttpStatusCode.Forbidden)
}
if (findByToken.expiresAt.isBefore(Instant.now())) {
return@post call.respondText( "expires_at", status = HttpStatusCode.Forbidden)
}
val user = userRepository.findById(findByToken.userId)
?: throw UserNotFoundException("${findByToken.userId} was not found.")
val token = JWT.create()
.withAudience("${Config.configData.url}/users/${user.name}")
.withIssuer(issuer)
.withKeyId(metaRepository.get()?.jwt?.kid.toString())
.withClaim("username", user.name)
.withExpiresAt(Date(System.currentTimeMillis() + 60000))
.sign(Algorithm.RSA256(publicKey, privateKey))
val newRefreshToken = UUID.randomUUID().toString()
refreshTokenRepository.save(
JwtRefreshToken(
idGenerateService.generateId(), user.id, newRefreshToken, Instant.now(),
Instant.ofEpochMilli(Instant.now().toEpochMilli() + 1209600033)
)
)
return@post call.respond(JwtToken(token, newRefreshToken))
return@post call.respond(jwtService.refreshToken(refreshToken))
}
get("/.well-known/jwks.json") {
//language=JSON
val meta = requireNotNull(metaRepository.get())
val jwt = metaService.getJwtMeta()
call.respondText(
contentType = ContentType.Application.Json,
text = JsonWebKeyUtil.publicKeyToJwk(meta.jwt.publicKey, meta.jwt.kid.toString())
text = JsonWebKeyUtil.publicKeyToJwk(jwt.publicKey, jwt.kid.toString())
)
}
}

View File

@ -3,9 +3,18 @@ package dev.usbharu.hideout.repository
import dev.usbharu.hideout.domain.model.hideout.entity.JwtRefreshToken
interface IJwtRefreshTokenRepository {
suspend fun generateId():Long
suspend fun save(token: JwtRefreshToken)
suspend fun findById(id:Long):JwtRefreshToken?
suspend fun findByToken(token:String):JwtRefreshToken?
suspend fun findByUserId(userId:Long):JwtRefreshToken?
suspend fun delete(token:JwtRefreshToken)
suspend fun deleteById(id:Long)
suspend fun deleteByToken(token:String)
suspend fun deleteByUserId(userId:Long)
suspend fun deleteAll()
}

View File

@ -1,16 +1,22 @@
package dev.usbharu.hideout.repository
import dev.usbharu.hideout.domain.model.hideout.entity.JwtRefreshToken
import dev.usbharu.hideout.service.IdGenerateService
import kotlinx.coroutines.Dispatchers
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.SqlExpressionBuilder.eq
import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransaction
import org.jetbrains.exposed.sql.transactions.transaction
import java.time.Instant
class JwtRefreshTokenRepositoryImpl(private val database: Database) : IJwtRefreshTokenRepository {
class JwtRefreshTokenRepositoryImpl(
private val database: Database,
private val idGenerateService: IdGenerateService
) :
IJwtRefreshTokenRepository {
init {
transaction(database){
transaction(database) {
SchemaUtils.create(JwtRefreshTokens)
SchemaUtils.createMissingTablesAndColumns(JwtRefreshTokens)
}
@ -19,6 +25,8 @@ class JwtRefreshTokenRepositoryImpl(private val database: Database) : IJwtRefres
suspend fun <T> query(block: suspend () -> T): T =
newSuspendedTransaction(Dispatchers.IO) { block() }
override suspend fun generateId(): Long = idGenerateService.generateId()
override suspend fun save(token: JwtRefreshToken) {
query {
if (JwtRefreshTokens.select { JwtRefreshTokens.id.eq(token.id) }.empty()) {
@ -51,6 +59,42 @@ class JwtRefreshTokenRepositoryImpl(private val database: Database) : IJwtRefres
JwtRefreshTokens.select { JwtRefreshTokens.refreshToken.eq(token) }.singleOrNull()?.toJwtRefreshToken()
}
}
override suspend fun findByUserId(userId: Long): JwtRefreshToken? {
return query {
JwtRefreshTokens.select { JwtRefreshTokens.userId.eq(userId) }.singleOrNull()?.toJwtRefreshToken()
}
}
override suspend fun delete(token: JwtRefreshToken) {
return query {
JwtRefreshTokens.deleteWhere { JwtRefreshTokens.id eq token.id }
}
}
override suspend fun deleteById(id: Long) {
return query {
JwtRefreshTokens.deleteWhere { JwtRefreshTokens.id eq id }
}
}
override suspend fun deleteByToken(token: String) {
return query {
JwtRefreshTokens.deleteWhere { JwtRefreshTokens.refreshToken eq token }
}
}
override suspend fun deleteByUserId(userId: Long) {
return query {
JwtRefreshTokens.deleteWhere { JwtRefreshTokens.userId eq userId }
}
}
override suspend fun deleteAll() {
return query {
JwtRefreshTokens.deleteAll()
}
}
}
fun ResultRow.toJwtRefreshToken(): JwtRefreshToken {

View File

@ -0,0 +1,14 @@
package dev.usbharu.hideout.service
import dev.usbharu.hideout.domain.model.hideout.dto.JwtToken
import dev.usbharu.hideout.domain.model.hideout.entity.User
import dev.usbharu.hideout.domain.model.hideout.form.RefreshToken
interface IJwtService {
suspend fun createToken(user:User):JwtToken
suspend fun refreshToken(refreshToken: RefreshToken):JwtToken
suspend fun revokeToken(refreshToken: RefreshToken)
suspend fun revokeToken(user:User)
suspend fun revokeAll()
}

View File

@ -0,0 +1,10 @@
package dev.usbharu.hideout.service
import dev.usbharu.hideout.domain.model.hideout.entity.Jwt
import dev.usbharu.hideout.domain.model.hideout.entity.Meta
interface IMetaService {
suspend fun getMeta(): Meta
suspend fun updateMeta(meta: Meta)
suspend fun getJwtMeta(): Jwt
}

View File

@ -0,0 +1,95 @@
package dev.usbharu.hideout.service
import com.auth0.jwt.JWT
import com.auth0.jwt.algorithms.Algorithm
import dev.usbharu.hideout.config.Config
import dev.usbharu.hideout.domain.model.hideout.dto.JwtToken
import dev.usbharu.hideout.domain.model.hideout.entity.JwtRefreshToken
import dev.usbharu.hideout.domain.model.hideout.entity.User
import dev.usbharu.hideout.domain.model.hideout.form.RefreshToken
import dev.usbharu.hideout.exception.InvalidRefreshTokenException
import dev.usbharu.hideout.repository.IJwtRefreshTokenRepository
import dev.usbharu.hideout.service.impl.IUserService
import dev.usbharu.hideout.util.RsaUtil
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import java.time.Instant
import java.time.temporal.ChronoUnit
import java.util.*
class JwtServiceImpl(
private val metaService: IMetaService,
private val refreshTokenRepository: IJwtRefreshTokenRepository,
private val userService: IUserService
) : IJwtService {
private val privateKey by lazy {
CoroutineScope(Dispatchers.IO).async {
RsaUtil.decodeRsaPrivateKey(metaService.getJwtMeta().privateKey)
}
}
private val publicKey by lazy {
CoroutineScope(Dispatchers.IO).async {
RsaUtil.decodeRsaPublicKey(metaService.getJwtMeta().publicKey)
}
}
private val keyId by lazy {
CoroutineScope(Dispatchers.IO).async {
metaService.getJwtMeta().kid
}
}
override suspend fun createToken(user: User): JwtToken {
val now = Instant.now()
val token = JWT.create()
.withAudience("${Config.configData.url}/users/${user.id}")
.withIssuer(Config.configData.url)
.withKeyId(keyId.await().toString())
.withClaim("username", user.name)
.withExpiresAt(now.plus(30, ChronoUnit.MINUTES))
.sign(Algorithm.RSA256(publicKey.await(), privateKey.await()))
val jwtRefreshToken = JwtRefreshToken(
id = refreshTokenRepository.generateId(),
userId = user.id,
refreshToken = UUID.randomUUID().toString(),
createdAt = now,
expiresAt = now.plus(14, ChronoUnit.DAYS)
)
refreshTokenRepository.save(jwtRefreshToken)
return JwtToken(token, jwtRefreshToken.refreshToken)
}
override suspend fun refreshToken(refreshToken: RefreshToken): JwtToken {
val token = refreshTokenRepository.findByToken(refreshToken.refreshToken)
?: throw InvalidRefreshTokenException("Invalid Refresh Token")
val user = userService.findById(token.userId)
val now = Instant.now()
if (token.createdAt.isAfter(now)) {
throw InvalidRefreshTokenException("Invalid Refresh Token")
}
if (token.expiresAt.isBefore(now)) {
throw InvalidRefreshTokenException("Refresh Token Expired")
}
return createToken(user)
}
override suspend fun revokeToken(refreshToken: RefreshToken) {
refreshTokenRepository.deleteByToken(refreshToken.refreshToken)
}
override suspend fun revokeToken(user: User) {
refreshTokenRepository.deleteByUserId(user.id)
}
override suspend fun revokeAll() {
refreshTokenRepository.deleteAll()
}
}

View File

@ -0,0 +1,16 @@
package dev.usbharu.hideout.service
import dev.usbharu.hideout.domain.model.hideout.entity.Jwt
import dev.usbharu.hideout.domain.model.hideout.entity.Meta
import dev.usbharu.hideout.exception.NotInitException
import dev.usbharu.hideout.repository.IMetaRepository
class MetaServiceImpl(private val metaRepository: IMetaRepository) : IMetaService {
override suspend fun getMeta(): Meta = metaRepository.get() ?: throw NotInitException("Meta is null")
override suspend fun updateMeta(meta: Meta) {
metaRepository.save(meta)
}
override suspend fun getJwtMeta(): Jwt = getMeta().jwt
}

View File

@ -12,8 +12,12 @@ object RsaUtil {
return KeyFactory.getInstance("RSA").generatePublic(x509EncodedKeySpec) as RSAPublicKey
}
fun decodeRsaPublicKey(encoded: String): RSAPublicKey = decodeRsaPublicKey(Base64Util.decode(encoded))
fun decodeRsaPrivateKey(byteArray: ByteArray):RSAPrivateKey{
val pkcS8EncodedKeySpec = PKCS8EncodedKeySpec(byteArray)
return KeyFactory.getInstance("RSA").generatePrivate(pkcS8EncodedKeySpec) as RSAPrivateKey
}
fun decodeRsaPrivateKey(encoded: String):RSAPrivateKey = decodeRsaPrivateKey(Base64Util.decode(encoded))
}