feat: リフレッシュトークンを追加

This commit is contained in:
usbharu 2023-05-01 16:07:22 +09:00
parent 96c54d26fd
commit 6b30fc1f4d
8 changed files with 184 additions and 16 deletions

View File

@ -84,8 +84,9 @@ fun Application.parent() {
single<IPostService> { PostService(get(), get()) }
single<IPostRepository> { PostRepositoryImpl(get(), get()) }
single<IdGenerateService> { TwitterSnowflakeIdGenerateService }
single<IMetaRepository>{ MetaRepositoryImpl(get()) }
single<IMetaRepository> { MetaRepositoryImpl(get()) }
single<IServerInitialiseService> { ServerInitialiseServiceImpl(get()) }
single<IJwtRefreshTokenRepository> { JwtRefreshTokenRepositoryImpl(get()) }
}
configureKoin(module)
runBlocking {
@ -96,7 +97,12 @@ fun Application.parent() {
configureMonitoring()
configureSerialization()
register(inject<IUserService>().value)
configureSecurity(inject<IUserAuthService>().value,inject<IMetaRepository>().value)
configureSecurity(
inject<IUserAuthService>().value,
inject<IMetaRepository>().value,
inject<IJwtRefreshTokenRepository>().value,
inject<IUserRepository>().value
)
configureRouting(
inject<HttpSignatureVerifyService>().value,
inject<ActivityPubService>().value,

View File

@ -0,0 +1,3 @@
package dev.usbharu.hideout.domain.model.hideout.dto
data class JwtToken(val token:String,val refreshToken:String)

View File

@ -0,0 +1,11 @@
package dev.usbharu.hideout.domain.model.hideout.entity
import java.time.Instant
data class JwtRefreshToken(
val id: Long,
val userId: Long,
val refreshToken: String,
val createdAt: Instant,
val expiresAt: Instant
)

View File

@ -0,0 +1,3 @@
package dev.usbharu.hideout.domain.model.hideout.form
data class RefreshToken(val refreshToken:String)

View File

@ -6,10 +6,18 @@ import com.auth0.jwk.JwkProviderBuilder
import com.auth0.jwt.JWT
import com.auth0.jwt.algorithms.Algorithm
import dev.usbharu.hideout.config.Config
import dev.usbharu.hideout.domain.model.hideout.dto.JwtToken
import dev.usbharu.hideout.domain.model.hideout.entity.JwtRefreshToken
import dev.usbharu.hideout.domain.model.hideout.form.RefreshToken
import dev.usbharu.hideout.domain.model.hideout.form.UserLogin
import dev.usbharu.hideout.exception.UserNotFoundException
import dev.usbharu.hideout.property
import dev.usbharu.hideout.repository.IJwtRefreshTokenRepository
import dev.usbharu.hideout.repository.IMetaRepository
import dev.usbharu.hideout.repository.IUserRepository
import dev.usbharu.hideout.service.IUserAuthService
import dev.usbharu.hideout.service.IdGenerateService
import dev.usbharu.hideout.util.Base64Util
import dev.usbharu.hideout.util.JsonWebKeyUtil
import dev.usbharu.hideout.util.RsaUtil
import io.ktor.http.*
@ -20,25 +28,27 @@ import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import kotlinx.coroutines.runBlocking
import java.security.KeyFactory
import java.security.interfaces.RSAPrivateKey
import java.security.spec.PKCS8EncodedKeySpec
import java.time.Instant
import java.util.*
import java.util.concurrent.TimeUnit
const val TOKEN_AUTH = "jwt-auth"
fun Application.configureSecurity(userAuthService: IUserAuthService, metaRepository: IMetaRepository) {
fun Application.configureSecurity(
userAuthService: IUserAuthService,
metaRepository: IMetaRepository,
refreshTokenRepository: IJwtRefreshTokenRepository,
userRepository: IUserRepository,
idGenerateService: IdGenerateService
) {
val privateKeyString = runBlocking {
requireNotNull(metaRepository.get()).jwt.privateKey
val privateKey = runBlocking {
RsaUtil.decodeRsaPrivateKey(Base64Util.decode(requireNotNull(metaRepository.get()).jwt.privateKey))
}
val publicKey = runBlocking {
val publicKey = requireNotNull(metaRepository.get()).jwt.publicKey
println(publicKey)
RsaUtil.decodeRsaPublicKey(Base64.getDecoder().decode(publicKey))
RsaUtil.decodeRsaPublicKey(Base64Util.decode(publicKey))
}
println(privateKeyString)
val issuer = property("hideout.url")
// val audience = property("jwt.audience")
val myRealm = property("jwt.realm")
@ -73,16 +83,57 @@ fun Application.configureSecurity(userAuthService: IUserAuthService, metaReposit
if (check.not()) {
return@post call.respond(HttpStatusCode.Unauthorized)
}
val keySpecPKCS8 = PKCS8EncodedKeySpec(Base64.getDecoder().decode(privateKeyString))
val privateKey = KeyFactory.getInstance("RSA").generatePrivate(keySpecPKCS8)
val findByNameAndDomain = userRepository.findByNameAndDomain(user.username, Config.configData.domain)
?: throw UserNotFoundException("${user.username} was not found.")
val token = JWT.create()
.withAudience("${Config.configData.url}/users/${user.username}")
.withIssuer(issuer)
.withKeyId(metaRepository.get()?.jwt?.kid.toString())
.withClaim("username", user.username)
.withExpiresAt(Date(System.currentTimeMillis() + 60000))
.sign(Algorithm.RSA256(publicKey, privateKey as RSAPrivateKey))
return@post call.respond(token)
.sign(Algorithm.RSA256(publicKey, privateKey))
val refreshToken = UUID.randomUUID().toString()
refreshTokenRepository.save(
JwtRefreshToken(
idGenerateService.generateId(), findByNameAndDomain.id, refreshToken, Instant.now(),
Instant.ofEpochMilli(Instant.now().toEpochMilli() + 1209600033)
)
)
return@post call.respond(JwtToken(token, refreshToken))
}
post("/refresh-token") {
val refreshToken = call.receive<RefreshToken>()
val findByToken = refreshTokenRepository.findByToken(refreshToken.refreshToken)
?: return@post call.respond(HttpStatusCode.Forbidden)
if (findByToken.createdAt.isAfter(Instant.now())) {
return@post call.respond(HttpStatusCode.Forbidden)
}
if (findByToken.expiresAt.isAfter(Instant.now())) {
return@post call.respond(HttpStatusCode.Forbidden)
}
val user = userRepository.findById(findByToken.userId)
?: throw UserNotFoundException("${findByToken.userId} was not found.")
val token = JWT.create()
.withAudience("${Config.configData.url}/users/${user.name}")
.withIssuer(issuer)
.withKeyId(metaRepository.get()?.jwt?.kid.toString())
.withClaim("username", user.name)
.withExpiresAt(Date(System.currentTimeMillis() + 60000))
.sign(Algorithm.RSA256(publicKey, privateKey))
val newRefreshToken = UUID.randomUUID().toString()
refreshTokenRepository.save(
JwtRefreshToken(
idGenerateService.generateId(), user.id, newRefreshToken, Instant.now(),
Instant.ofEpochMilli(Instant.now().toEpochMilli() + 1209600033)
)
)
return@post call.respond(JwtToken(token, newRefreshToken))
}
get("/.well-known/jwks.json") {
@ -90,7 +141,7 @@ fun Application.configureSecurity(userAuthService: IUserAuthService, metaReposit
val meta = requireNotNull(metaRepository.get())
call.respondText(
contentType = ContentType.Application.Json,
text = JsonWebKeyUtil.publicKeyToJwk(meta.jwt.publicKey,meta.jwt.kid.toString())
text = JsonWebKeyUtil.publicKeyToJwk(meta.jwt.publicKey, meta.jwt.kid.toString())
)
}
}

View File

@ -0,0 +1,11 @@
package dev.usbharu.hideout.repository
import dev.usbharu.hideout.domain.model.hideout.entity.JwtRefreshToken
interface IJwtRefreshTokenRepository {
suspend fun save(token: JwtRefreshToken)
suspend fun findById(id:Long):JwtRefreshToken?
suspend fun findByToken(token:String):JwtRefreshToken?
}

View File

@ -0,0 +1,73 @@
package dev.usbharu.hideout.repository
import dev.usbharu.hideout.domain.model.hideout.entity.JwtRefreshToken
import kotlinx.coroutines.Dispatchers
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransaction
import org.jetbrains.exposed.sql.transactions.transaction
import java.time.Instant
class JwtRefreshTokenRepositoryImpl(private val database: Database) : IJwtRefreshTokenRepository {
init {
transaction(database){
SchemaUtils.create(JwtRefreshTokens)
SchemaUtils.createMissingTablesAndColumns(JwtRefreshTokens)
}
}
suspend fun <T> query(block: suspend () -> T): T =
newSuspendedTransaction(Dispatchers.IO) { block() }
override suspend fun save(token: JwtRefreshToken) {
query {
if (JwtRefreshTokens.select { JwtRefreshTokens.id.eq(token.id) }.empty()) {
JwtRefreshTokens.insert {
it[id] = token.id
it[userId] = token.userId
it[refreshToken] = token.refreshToken
it[createdAt] = token.createdAt.toEpochMilli()
it[expiresAt] = token.expiresAt.toEpochMilli()
}
} else {
JwtRefreshTokens.update({ JwtRefreshTokens.id eq token.id }) {
it[userId] = token.userId
it[refreshToken] = token.refreshToken
it[createdAt] = token.createdAt.toEpochMilli()
it[expiresAt] = token.expiresAt.toEpochMilli()
}
}
}
}
override suspend fun findById(id: Long): JwtRefreshToken? {
return query {
JwtRefreshTokens.select { JwtRefreshTokens.id.eq(id) }.singleOrNull()?.toJwtRefreshToken()
}
}
override suspend fun findByToken(token: String): JwtRefreshToken? {
return query {
JwtRefreshTokens.select { JwtRefreshTokens.refreshToken.eq(token) }.singleOrNull()?.toJwtRefreshToken()
}
}
}
fun ResultRow.toJwtRefreshToken(): JwtRefreshToken {
return JwtRefreshToken(
this[JwtRefreshTokens.id],
this[JwtRefreshTokens.userId],
this[JwtRefreshTokens.refreshToken],
Instant.ofEpochMilli(this[JwtRefreshTokens.createdAt]),
Instant.ofEpochMilli(this[JwtRefreshTokens.expiresAt])
)
}
object JwtRefreshTokens : Table("jwt_refresh_tokens") {
val id = long("id")
val userId = long("user_id")
val refreshToken = varchar("refresh_token", 1000)
val createdAt = long("created_at")
val expiresAt = long("expires_at")
override val primaryKey = PrimaryKey(id)
}

View File

@ -0,0 +1,10 @@
package dev.usbharu.hideout.util
import java.util.*
object Base64Util {
fun decode(str: String): ByteArray = Base64.getDecoder().decode(str)
fun encode(bytes: ByteArray): String = Base64.getEncoder().encodeToString(bytes)
}