diff --git a/src/main/kotlin/dev/usbharu/hideout/application/config/SecurityConfig.kt b/src/main/kotlin/dev/usbharu/hideout/application/config/SecurityConfig.kt index 05d97d94..57edc427 100644 --- a/src/main/kotlin/dev/usbharu/hideout/application/config/SecurityConfig.kt +++ b/src/main/kotlin/dev/usbharu/hideout/application/config/SecurityConfig.kt @@ -26,6 +26,7 @@ import dev.usbharu.hideout.application.external.Transaction import dev.usbharu.hideout.application.infrastructure.springframework.RoleHierarchyAuthorizationManagerFactory import dev.usbharu.hideout.core.domain.model.actor.ActorRepository import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureFilter +import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureHeaderChecker import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureUserDetailsService import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureVerifierComposite import dev.usbharu.hideout.core.infrastructure.springframework.oauth2.UserDetailsImpl @@ -35,6 +36,9 @@ import dev.usbharu.httpsignature.sign.RsaSha256HttpSignatureSigner import dev.usbharu.httpsignature.verify.DefaultSignatureHeaderParser import dev.usbharu.httpsignature.verify.RsaSha256HttpSignatureVerifier import jakarta.annotation.PostConstruct +import jakarta.servlet.* +import org.springframework.beans.factory.support.BeanDefinitionRegistry +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty import org.springframework.boot.autoconfigure.jackson.Jackson2ObjectMapperBuilderCustomizer import org.springframework.boot.context.properties.ConfigurationProperties @@ -58,6 +62,7 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe import org.springframework.security.config.annotation.web.invoke import org.springframework.security.config.http.SessionCreationPolicy import org.springframework.security.core.Authentication +import org.springframework.security.core.context.SecurityContextHolderStrategy import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder import org.springframework.security.crypto.password.PasswordEncoder import org.springframework.security.oauth2.core.AuthorizationGrantType @@ -67,20 +72,28 @@ import org.springframework.security.oauth2.server.authorization.config.annotatio import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer +import org.springframework.security.web.FilterChainProxy import org.springframework.security.web.SecurityFilterChain import org.springframework.security.web.access.ExceptionTranslationFilter import org.springframework.security.web.authentication.AuthenticationEntryPointFailureHandler import org.springframework.security.web.authentication.HttpStatusEntryPoint import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationProvider +import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer +import org.springframework.security.web.debug.DebugFilter +import org.springframework.security.web.firewall.HttpFirewall +import org.springframework.security.web.firewall.RequestRejectedHandler import org.springframework.security.web.savedrequest.RequestCacheAwareFilter import org.springframework.security.web.util.matcher.AnyRequestMatcher +import org.springframework.web.filter.CompositeFilter +import java.io.IOException import java.security.KeyPairGenerator import java.security.interfaces.RSAPrivateKey import java.security.interfaces.RSAPublicKey import java.util.* -@EnableWebSecurity(debug = false) + +@EnableWebSecurity(debug = true) @Configuration @Suppress("FunctionMaxLength", "TooManyFunctions", "LongMethod") class SecurityConfig { @@ -94,7 +107,7 @@ class SecurityConfig { @Order(1) fun httpSignatureFilterChain( http: HttpSecurity, - httpSignatureFilter: HttpSignatureFilter + httpSignatureFilter: HttpSignatureFilter, ): SecurityFilterChain { http { securityMatcher("/users/*/posts/*") @@ -122,9 +135,10 @@ class SecurityConfig { @Bean fun getHttpSignatureFilter( authenticationManager: AuthenticationManager, + httpSignatureHeaderChecker: HttpSignatureHeaderChecker, ): HttpSignatureFilter { val httpSignatureFilter = - HttpSignatureFilter(DefaultSignatureHeaderParser()) + HttpSignatureFilter(DefaultSignatureHeaderParser(), httpSignatureHeaderChecker) httpSignatureFilter.setAuthenticationManager(authenticationManager) httpSignatureFilter.setContinueFilterChainOnUnsuccessfulAuthentication(false) val authenticationEntryPointFailureHandler = @@ -147,7 +161,7 @@ class SecurityConfig { @Order(1) fun httpSignatureAuthenticationProvider( transaction: Transaction, - actorRepository: ActorRepository + actorRepository: ActorRepository, ): PreAuthenticatedAuthenticationProvider { val provider = PreAuthenticatedAuthenticationProvider() val signatureHeaderParser = DefaultSignatureHeaderParser() @@ -190,7 +204,7 @@ class SecurityConfig { @Order(4) fun defaultSecurityFilterChain( http: HttpSecurity, - rf: RoleHierarchyAuthorizationManagerFactory + rf: RoleHierarchyAuthorizationManagerFactory, ): SecurityFilterChain { http { authorizeHttpRequests { @@ -401,6 +415,82 @@ class SecurityConfig { return roleHierarchyImpl } + + @Bean + fun beanDefinitionRegistryPostProcessor(): BeanDefinitionRegistryPostProcessor { + return BeanDefinitionRegistryPostProcessor { registry: BeanDefinitionRegistry -> + registry.getBeanDefinition(AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME).beanClassName = + CompositeFilterChainProxy::class.java.name + } + } + + internal class CompositeFilterChainProxy(filters: List) : FilterChainProxy() { + private val doFilterDelegate: Filter + + private val springSecurityFilterChain: FilterChainProxy + + init { + this.doFilterDelegate = createDoFilterDelegate(filters) + this.springSecurityFilterChain = findFilterChainProxy(filters) + } + + override fun afterPropertiesSet() { + springSecurityFilterChain.afterPropertiesSet() + } + + @Throws(IOException::class, ServletException::class) + override fun doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain) { + doFilterDelegate.doFilter(request, response, chain) + } + + override fun getFilters(url: String): List { + return springSecurityFilterChain.getFilters(url) + } + + override fun getFilterChains(): List { + return springSecurityFilterChain.filterChains + } + + override fun setSecurityContextHolderStrategy(securityContextHolderStrategy: SecurityContextHolderStrategy) { + springSecurityFilterChain.setSecurityContextHolderStrategy(securityContextHolderStrategy) + } + + override fun setFilterChainValidator(filterChainValidator: FilterChainValidator) { + springSecurityFilterChain.setFilterChainValidator(filterChainValidator) + } + + override fun setFilterChainDecorator(filterChainDecorator: FilterChainDecorator) { + springSecurityFilterChain.setFilterChainDecorator(filterChainDecorator) + } + + override fun setFirewall(firewall: HttpFirewall) { + springSecurityFilterChain.setFirewall(firewall) + } + + override fun setRequestRejectedHandler(requestRejectedHandler: RequestRejectedHandler) { + springSecurityFilterChain.setRequestRejectedHandler(requestRejectedHandler) + } + + companion object { + private fun createDoFilterDelegate(filters: List): Filter { + val delegate: CompositeFilter = CompositeFilter() + delegate.setFilters(filters) + return delegate + } + + private fun findFilterChainProxy(filters: List): FilterChainProxy { + for (filter in filters) { + if (filter is FilterChainProxy) { + return filter + } + if (filter is DebugFilter) { + return filter.filterChainProxy + } + } + throw IllegalStateException("Couldn't find FilterChainProxy in $filters") + } + } + } } @ConfigurationProperties("hideout.security.jwt") @@ -408,14 +498,14 @@ class SecurityConfig { data class JwkConfig( val keyId: String, val publicKey: String, - val privateKey: String + val privateKey: String, ) @Configuration class PostSecurityConfig( val auth: AuthenticationManagerBuilder, val daoAuthenticationProvider: DaoAuthenticationProvider, - val httpSignatureAuthenticationProvider: PreAuthenticatedAuthenticationProvider + val httpSignatureAuthenticationProvider: PreAuthenticatedAuthenticationProvider, ) { @PostConstruct diff --git a/src/main/kotlin/dev/usbharu/hideout/core/infrastructure/springframework/httpsignature/HttpSignatureFilter.kt b/src/main/kotlin/dev/usbharu/hideout/core/infrastructure/springframework/httpsignature/HttpSignatureFilter.kt index e115df38..b55388a8 100644 --- a/src/main/kotlin/dev/usbharu/hideout/core/infrastructure/springframework/httpsignature/HttpSignatureFilter.kt +++ b/src/main/kotlin/dev/usbharu/hideout/core/infrastructure/springframework/httpsignature/HttpSignatureFilter.kt @@ -24,7 +24,10 @@ import jakarta.servlet.http.HttpServletRequest import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter import java.net.URL -class HttpSignatureFilter(private val httpSignatureHeaderParser: SignatureHeaderParser) : +class HttpSignatureFilter( + private val httpSignatureHeaderParser: SignatureHeaderParser, + private val httpSignatureHeaderChecker: HttpSignatureHeaderChecker, +) : AbstractPreAuthenticatedProcessingFilter() { override fun getPreAuthenticatedPrincipal(request: HttpServletRequest?): Any? { val headersList = request?.headerNames?.toList().orEmpty() @@ -59,6 +62,15 @@ class HttpSignatureFilter(private val httpSignatureHeaderParser: SignatureHeader } } + httpSignatureHeaderChecker.checkDate(request.getHeader("date")) + httpSignatureHeaderChecker.checkHost(request.getHeader("host")) + + + + if (request.method.equals("post", true)) { + httpSignatureHeaderChecker.checkDigest(request.inputStream.readAllBytes(), request.getHeader("digest")) + } + return HttpRequest( URL(url + request.queryString.orEmpty()), HttpHeaders(headers), diff --git a/src/main/kotlin/dev/usbharu/hideout/core/infrastructure/springframework/httpsignature/HttpSignatureHeaderChecker.kt b/src/main/kotlin/dev/usbharu/hideout/core/infrastructure/springframework/httpsignature/HttpSignatureHeaderChecker.kt new file mode 100644 index 00000000..28fb90aa --- /dev/null +++ b/src/main/kotlin/dev/usbharu/hideout/core/infrastructure/springframework/httpsignature/HttpSignatureHeaderChecker.kt @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2024 usbharu + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dev.usbharu.hideout.core.infrastructure.springframework.httpsignature + +import dev.usbharu.hideout.application.config.ApplicationConfig +import dev.usbharu.hideout.util.Base64Util +import org.springframework.stereotype.Component +import java.security.MessageDigest +import java.time.Instant +import java.time.format.DateTimeFormatter +import java.util.* + +@Component +class HttpSignatureHeaderChecker(private val applicationConfig: ApplicationConfig) { + fun checkDate(date: String) { + val from = Instant.from(dateFormat.parse(date)) + + if (from.isAfter(Instant.now()) || from.isBefore(Instant.now().minusSeconds(86400))) { + throw IllegalArgumentException("未来") + } + } + + fun checkHost(host: String) { + if (applicationConfig.url.host.equals(host, true).not()) { + throw IllegalArgumentException("ホスト名が違う") + } + } + + fun checkDigest(byteArray: ByteArray, digest: String) { + val sha256 = MessageDigest.getInstance("SHA-256") + + if (Base64Util.encode(sha256.digest(byteArray)).equals(digest, true).not()) { + throw IllegalArgumentException("リクエストボディが違う") + } + } + + companion object { + private val dateFormat = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US) + } +} \ No newline at end of file