mirror of https://github.com/usbharu/Hideout.git
feat: リフレッシュトークンを追加
This commit is contained in:
parent
96c54d26fd
commit
6b30fc1f4d
|
@ -86,6 +86,7 @@ fun Application.parent() {
|
|||
single<IdGenerateService> { TwitterSnowflakeIdGenerateService }
|
||||
single<IMetaRepository> { MetaRepositoryImpl(get()) }
|
||||
single<IServerInitialiseService> { ServerInitialiseServiceImpl(get()) }
|
||||
single<IJwtRefreshTokenRepository> { JwtRefreshTokenRepositoryImpl(get()) }
|
||||
}
|
||||
configureKoin(module)
|
||||
runBlocking {
|
||||
|
@ -96,7 +97,12 @@ fun Application.parent() {
|
|||
configureMonitoring()
|
||||
configureSerialization()
|
||||
register(inject<IUserService>().value)
|
||||
configureSecurity(inject<IUserAuthService>().value,inject<IMetaRepository>().value)
|
||||
configureSecurity(
|
||||
inject<IUserAuthService>().value,
|
||||
inject<IMetaRepository>().value,
|
||||
inject<IJwtRefreshTokenRepository>().value,
|
||||
inject<IUserRepository>().value
|
||||
)
|
||||
configureRouting(
|
||||
inject<HttpSignatureVerifyService>().value,
|
||||
inject<ActivityPubService>().value,
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
package dev.usbharu.hideout.domain.model.hideout.dto
|
||||
|
||||
data class JwtToken(val token:String,val refreshToken:String)
|
|
@ -0,0 +1,11 @@
|
|||
package dev.usbharu.hideout.domain.model.hideout.entity
|
||||
|
||||
import java.time.Instant
|
||||
|
||||
data class JwtRefreshToken(
|
||||
val id: Long,
|
||||
val userId: Long,
|
||||
val refreshToken: String,
|
||||
val createdAt: Instant,
|
||||
val expiresAt: Instant
|
||||
)
|
|
@ -0,0 +1,3 @@
|
|||
package dev.usbharu.hideout.domain.model.hideout.form
|
||||
|
||||
data class RefreshToken(val refreshToken:String)
|
|
@ -6,10 +6,18 @@ import com.auth0.jwk.JwkProviderBuilder
|
|||
import com.auth0.jwt.JWT
|
||||
import com.auth0.jwt.algorithms.Algorithm
|
||||
import dev.usbharu.hideout.config.Config
|
||||
import dev.usbharu.hideout.domain.model.hideout.dto.JwtToken
|
||||
import dev.usbharu.hideout.domain.model.hideout.entity.JwtRefreshToken
|
||||
import dev.usbharu.hideout.domain.model.hideout.form.RefreshToken
|
||||
import dev.usbharu.hideout.domain.model.hideout.form.UserLogin
|
||||
import dev.usbharu.hideout.exception.UserNotFoundException
|
||||
import dev.usbharu.hideout.property
|
||||
import dev.usbharu.hideout.repository.IJwtRefreshTokenRepository
|
||||
import dev.usbharu.hideout.repository.IMetaRepository
|
||||
import dev.usbharu.hideout.repository.IUserRepository
|
||||
import dev.usbharu.hideout.service.IUserAuthService
|
||||
import dev.usbharu.hideout.service.IdGenerateService
|
||||
import dev.usbharu.hideout.util.Base64Util
|
||||
import dev.usbharu.hideout.util.JsonWebKeyUtil
|
||||
import dev.usbharu.hideout.util.RsaUtil
|
||||
import io.ktor.http.*
|
||||
|
@ -20,25 +28,27 @@ import io.ktor.server.request.*
|
|||
import io.ktor.server.response.*
|
||||
import io.ktor.server.routing.*
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import java.security.KeyFactory
|
||||
import java.security.interfaces.RSAPrivateKey
|
||||
import java.security.spec.PKCS8EncodedKeySpec
|
||||
import java.time.Instant
|
||||
import java.util.*
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
const val TOKEN_AUTH = "jwt-auth"
|
||||
|
||||
fun Application.configureSecurity(userAuthService: IUserAuthService, metaRepository: IMetaRepository) {
|
||||
fun Application.configureSecurity(
|
||||
userAuthService: IUserAuthService,
|
||||
metaRepository: IMetaRepository,
|
||||
refreshTokenRepository: IJwtRefreshTokenRepository,
|
||||
userRepository: IUserRepository,
|
||||
idGenerateService: IdGenerateService
|
||||
) {
|
||||
|
||||
val privateKeyString = runBlocking {
|
||||
requireNotNull(metaRepository.get()).jwt.privateKey
|
||||
val privateKey = runBlocking {
|
||||
RsaUtil.decodeRsaPrivateKey(Base64Util.decode(requireNotNull(metaRepository.get()).jwt.privateKey))
|
||||
}
|
||||
val publicKey = runBlocking {
|
||||
val publicKey = requireNotNull(metaRepository.get()).jwt.publicKey
|
||||
println(publicKey)
|
||||
RsaUtil.decodeRsaPublicKey(Base64.getDecoder().decode(publicKey))
|
||||
RsaUtil.decodeRsaPublicKey(Base64Util.decode(publicKey))
|
||||
}
|
||||
println(privateKeyString)
|
||||
val issuer = property("hideout.url")
|
||||
// val audience = property("jwt.audience")
|
||||
val myRealm = property("jwt.realm")
|
||||
|
@ -73,16 +83,57 @@ fun Application.configureSecurity(userAuthService: IUserAuthService, metaReposit
|
|||
if (check.not()) {
|
||||
return@post call.respond(HttpStatusCode.Unauthorized)
|
||||
}
|
||||
val keySpecPKCS8 = PKCS8EncodedKeySpec(Base64.getDecoder().decode(privateKeyString))
|
||||
val privateKey = KeyFactory.getInstance("RSA").generatePrivate(keySpecPKCS8)
|
||||
|
||||
val findByNameAndDomain = userRepository.findByNameAndDomain(user.username, Config.configData.domain)
|
||||
?: throw UserNotFoundException("${user.username} was not found.")
|
||||
|
||||
val token = JWT.create()
|
||||
.withAudience("${Config.configData.url}/users/${user.username}")
|
||||
.withIssuer(issuer)
|
||||
.withKeyId(metaRepository.get()?.jwt?.kid.toString())
|
||||
.withClaim("username", user.username)
|
||||
.withExpiresAt(Date(System.currentTimeMillis() + 60000))
|
||||
.sign(Algorithm.RSA256(publicKey, privateKey as RSAPrivateKey))
|
||||
return@post call.respond(token)
|
||||
.sign(Algorithm.RSA256(publicKey, privateKey))
|
||||
val refreshToken = UUID.randomUUID().toString()
|
||||
refreshTokenRepository.save(
|
||||
JwtRefreshToken(
|
||||
idGenerateService.generateId(), findByNameAndDomain.id, refreshToken, Instant.now(),
|
||||
Instant.ofEpochMilli(Instant.now().toEpochMilli() + 1209600033)
|
||||
)
|
||||
)
|
||||
return@post call.respond(JwtToken(token, refreshToken))
|
||||
}
|
||||
|
||||
post("/refresh-token") {
|
||||
val refreshToken = call.receive<RefreshToken>()
|
||||
val findByToken = refreshTokenRepository.findByToken(refreshToken.refreshToken)
|
||||
?: return@post call.respond(HttpStatusCode.Forbidden)
|
||||
|
||||
if (findByToken.createdAt.isAfter(Instant.now())) {
|
||||
return@post call.respond(HttpStatusCode.Forbidden)
|
||||
}
|
||||
|
||||
if (findByToken.expiresAt.isAfter(Instant.now())) {
|
||||
return@post call.respond(HttpStatusCode.Forbidden)
|
||||
}
|
||||
|
||||
val user = userRepository.findById(findByToken.userId)
|
||||
?: throw UserNotFoundException("${findByToken.userId} was not found.")
|
||||
val token = JWT.create()
|
||||
.withAudience("${Config.configData.url}/users/${user.name}")
|
||||
.withIssuer(issuer)
|
||||
.withKeyId(metaRepository.get()?.jwt?.kid.toString())
|
||||
.withClaim("username", user.name)
|
||||
.withExpiresAt(Date(System.currentTimeMillis() + 60000))
|
||||
.sign(Algorithm.RSA256(publicKey, privateKey))
|
||||
val newRefreshToken = UUID.randomUUID().toString()
|
||||
refreshTokenRepository.save(
|
||||
JwtRefreshToken(
|
||||
idGenerateService.generateId(), user.id, newRefreshToken, Instant.now(),
|
||||
Instant.ofEpochMilli(Instant.now().toEpochMilli() + 1209600033)
|
||||
)
|
||||
)
|
||||
return@post call.respond(JwtToken(token, newRefreshToken))
|
||||
}
|
||||
|
||||
get("/.well-known/jwks.json") {
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
package dev.usbharu.hideout.repository
|
||||
|
||||
import dev.usbharu.hideout.domain.model.hideout.entity.JwtRefreshToken
|
||||
|
||||
interface IJwtRefreshTokenRepository {
|
||||
suspend fun save(token: JwtRefreshToken)
|
||||
|
||||
suspend fun findById(id:Long):JwtRefreshToken?
|
||||
suspend fun findByToken(token:String):JwtRefreshToken?
|
||||
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package dev.usbharu.hideout.repository
|
||||
|
||||
import dev.usbharu.hideout.domain.model.hideout.entity.JwtRefreshToken
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import org.jetbrains.exposed.sql.*
|
||||
import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransaction
|
||||
import org.jetbrains.exposed.sql.transactions.transaction
|
||||
import java.time.Instant
|
||||
|
||||
class JwtRefreshTokenRepositoryImpl(private val database: Database) : IJwtRefreshTokenRepository {
|
||||
|
||||
init {
|
||||
transaction(database){
|
||||
SchemaUtils.create(JwtRefreshTokens)
|
||||
SchemaUtils.createMissingTablesAndColumns(JwtRefreshTokens)
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun <T> query(block: suspend () -> T): T =
|
||||
newSuspendedTransaction(Dispatchers.IO) { block() }
|
||||
|
||||
override suspend fun save(token: JwtRefreshToken) {
|
||||
query {
|
||||
if (JwtRefreshTokens.select { JwtRefreshTokens.id.eq(token.id) }.empty()) {
|
||||
JwtRefreshTokens.insert {
|
||||
it[id] = token.id
|
||||
it[userId] = token.userId
|
||||
it[refreshToken] = token.refreshToken
|
||||
it[createdAt] = token.createdAt.toEpochMilli()
|
||||
it[expiresAt] = token.expiresAt.toEpochMilli()
|
||||
}
|
||||
} else {
|
||||
JwtRefreshTokens.update({ JwtRefreshTokens.id eq token.id }) {
|
||||
it[userId] = token.userId
|
||||
it[refreshToken] = token.refreshToken
|
||||
it[createdAt] = token.createdAt.toEpochMilli()
|
||||
it[expiresAt] = token.expiresAt.toEpochMilli()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun findById(id: Long): JwtRefreshToken? {
|
||||
return query {
|
||||
JwtRefreshTokens.select { JwtRefreshTokens.id.eq(id) }.singleOrNull()?.toJwtRefreshToken()
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun findByToken(token: String): JwtRefreshToken? {
|
||||
return query {
|
||||
JwtRefreshTokens.select { JwtRefreshTokens.refreshToken.eq(token) }.singleOrNull()?.toJwtRefreshToken()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun ResultRow.toJwtRefreshToken(): JwtRefreshToken {
|
||||
return JwtRefreshToken(
|
||||
this[JwtRefreshTokens.id],
|
||||
this[JwtRefreshTokens.userId],
|
||||
this[JwtRefreshTokens.refreshToken],
|
||||
Instant.ofEpochMilli(this[JwtRefreshTokens.createdAt]),
|
||||
Instant.ofEpochMilli(this[JwtRefreshTokens.expiresAt])
|
||||
)
|
||||
}
|
||||
|
||||
object JwtRefreshTokens : Table("jwt_refresh_tokens") {
|
||||
val id = long("id")
|
||||
val userId = long("user_id")
|
||||
val refreshToken = varchar("refresh_token", 1000)
|
||||
val createdAt = long("created_at")
|
||||
val expiresAt = long("expires_at")
|
||||
override val primaryKey = PrimaryKey(id)
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
package dev.usbharu.hideout.util
|
||||
|
||||
import java.util.*
|
||||
|
||||
object Base64Util {
|
||||
fun decode(str: String): ByteArray = Base64.getDecoder().decode(str)
|
||||
|
||||
fun encode(bytes: ByteArray): String = Base64.getEncoder().encodeToString(bytes)
|
||||
|
||||
}
|
Loading…
Reference in New Issue