Merge branch 'develop' into feature/error-response

This commit is contained in:
usbharu 2024-02-21 16:01:16 +09:00
commit 32a4315952
15 changed files with 943 additions and 132 deletions

View File

@ -197,9 +197,10 @@ dependencies {
implementation("org.springframework.boot:spring-boot-starter-oauth2-authorization-server") implementation("org.springframework.boot:spring-boot-starter-oauth2-authorization-server")
implementation("org.springframework.boot:spring-boot-starter-oauth2-resource-server") implementation("org.springframework.boot:spring-boot-starter-oauth2-resource-server")
implementation("org.springframework.boot:spring-boot-starter-log4j2") implementation("org.springframework.boot:spring-boot-starter-log4j2")
compileOnly("jakarta.validation:jakarta.validation-api") implementation("org.springframework.boot:spring-boot-starter-validation")
compileOnly("jakarta.annotation:jakarta.annotation-api:2.1.0") implementation("jakarta.validation:jakarta.validation-api")
compileOnly("io.swagger.core.v3:swagger-annotations:2.2.6") implementation("jakarta.annotation:jakarta.annotation-api:2.1.0")
implementation("io.swagger.core.v3:swagger-annotations:2.2.6")
implementation("io.swagger.core.v3:swagger-models:2.2.6") implementation("io.swagger.core.v3:swagger-models:2.2.6")
implementation("org.jetbrains.exposed:exposed-java-time:$exposed_version") implementation("org.jetbrains.exposed:exposed-java-time:$exposed_version")
testImplementation("org.springframework.boot:spring-boot-starter-test") testImplementation("org.springframework.boot:spring-boot-starter-test")

View File

@ -27,6 +27,7 @@ import org.flywaydb.core.Flyway
import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.TestFactory import org.junit.jupiter.api.TestFactory
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.context.SpringBootTest import org.springframework.boot.test.context.SpringBootTest
@ -38,6 +39,7 @@ import org.springframework.transaction.annotation.Transactional
webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT
) )
@Transactional @Transactional
@Disabled
class InboxCommonTest { class InboxCommonTest {
@LocalServerPort @LocalServerPort
private var port = "" private var port = ""

View File

@ -17,11 +17,13 @@
package activitypub.inbox package activitypub.inbox
import dev.usbharu.hideout.SpringApplication import dev.usbharu.hideout.SpringApplication
import dev.usbharu.hideout.util.Base64Util
import org.flywaydb.core.Flyway import org.flywaydb.core.Flyway
import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.beans.factory.annotation.Qualifier
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc
import org.springframework.boot.test.context.SpringBootTest import org.springframework.boot.test.context.SpringBootTest
import org.springframework.boot.test.context.TestConfiguration import org.springframework.boot.test.context.TestConfiguration
@ -37,17 +39,25 @@ import org.springframework.transaction.annotation.Transactional
import org.springframework.web.context.WebApplicationContext import org.springframework.web.context.WebApplicationContext
import util.TestTransaction import util.TestTransaction
import util.WithMockHttpSignature import util.WithMockHttpSignature
import java.security.MessageDigest
import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
@SpringBootTest(classes = [SpringApplication::class]) @SpringBootTest(classes = [SpringApplication::class])
@AutoConfigureMockMvc @AutoConfigureMockMvc
@Transactional @Transactional
class InboxTest { class InboxTest {
@Autowired
@Qualifier("http")
private lateinit var dateTimeFormatter: DateTimeFormatter
@Autowired @Autowired
private lateinit var context: WebApplicationContext private lateinit var context: WebApplicationContext
private lateinit var mockMvc: MockMvc private lateinit var mockMvc: MockMvc
@BeforeEach @BeforeEach
fun setUp() { fun setUp() {
mockMvc = MockMvcBuilders.webAppContextSetup(context) mockMvc = MockMvcBuilders.webAppContextSetup(context)
@ -62,6 +72,12 @@ class InboxTest {
.post("/inbox") { .post("/inbox") {
content = "{}" content = "{}"
contentType = MediaType.APPLICATION_JSON contentType = MediaType.APPLICATION_JSON
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header(
"Digest",
"SHA-256=" + Base64Util.encode(MessageDigest.getInstance("SHA-256").digest("{}".toByteArray()))
)
} }
.asyncDispatch() .asyncDispatch()
.andExpect { status { isUnauthorized() } } .andExpect { status { isUnauthorized() } }
@ -74,7 +90,13 @@ class InboxTest {
.post("/inbox") { .post("/inbox") {
content = "{}" content = "{}"
contentType = MediaType.APPLICATION_JSON contentType = MediaType.APPLICATION_JSON
header("Signature", "") header("Signature", "a")
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header(
"Digest",
"SHA-256=" + Base64Util.encode(MessageDigest.getInstance("SHA-256").digest("{}".toByteArray()))
)
} }
.asyncDispatch() .asyncDispatch()
.andExpect { status { isAccepted() } } .andExpect { status { isAccepted() } }
@ -87,8 +109,15 @@ class InboxTest {
.post("/users/hoge/inbox") { .post("/users/hoge/inbox") {
content = "{}" content = "{}"
contentType = MediaType.APPLICATION_JSON contentType = MediaType.APPLICATION_JSON
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header(
"Digest",
"SHA-256=" + Base64Util.encode(MessageDigest.getInstance("SHA-256").digest("{}".toByteArray()))
)
} }
.asyncDispatch() .asyncDispatch()
.andDo { print() }
.andExpect { status { isUnauthorized() } } .andExpect { status { isUnauthorized() } }
} }
@ -99,9 +128,16 @@ class InboxTest {
.post("/users/hoge/inbox") { .post("/users/hoge/inbox") {
content = "{}" content = "{}"
contentType = MediaType.APPLICATION_JSON contentType = MediaType.APPLICATION_JSON
header("Signature", "") header("Signature", "a")
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header(
"Digest",
"SHA-256=" + Base64Util.encode(MessageDigest.getInstance("SHA-256").digest("{}".toByteArray()))
)
} }
.asyncDispatch() .asyncDispatch()
.andDo { print() }
.andExpect { status { isAccepted() } } .andExpect { status { isAccepted() } }
} }

View File

@ -22,6 +22,7 @@ import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.beans.factory.annotation.Qualifier
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc
import org.springframework.boot.test.context.SpringBootTest import org.springframework.boot.test.context.SpringBootTest
import org.springframework.http.MediaType import org.springframework.http.MediaType
@ -36,6 +37,8 @@ import org.springframework.transaction.annotation.Transactional
import org.springframework.web.context.WebApplicationContext import org.springframework.web.context.WebApplicationContext
import util.WithHttpSignature import util.WithHttpSignature
import util.WithMockHttpSignature import util.WithMockHttpSignature
import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
@SpringBootTest(classes = [SpringApplication::class]) @SpringBootTest(classes = [SpringApplication::class])
@AutoConfigureMockMvc @AutoConfigureMockMvc
@ -46,6 +49,10 @@ class NoteTest {
@Autowired @Autowired
private lateinit var context: WebApplicationContext private lateinit var context: WebApplicationContext
@Autowired
@Qualifier("http")
private lateinit var dateTimeFormatter: DateTimeFormatter
@BeforeEach @BeforeEach
fun setUp() { fun setUp() {
mockMvc = MockMvcBuilders.webAppContextSetup(context).apply<DefaultMockMvcBuilder>(springSecurity()).build() mockMvc = MockMvcBuilders.webAppContextSetup(context).apply<DefaultMockMvcBuilder>(springSecurity()).build()
@ -197,6 +204,29 @@ class NoteTest {
.andExpect { jsonPath("\$.attachment[1].url") { value("https://example.com/media/test-media2.png") } } .andExpect { jsonPath("\$.attachment[1].url") { value("https://example.com/media/test-media2.png") } }
} }
@Test
fun signatureヘッダーがあるのにhostヘッダーがないと401() {
mockMvc
.get("/users/test-user10/posts/9999") {
accept(MediaType("application", "activity+json"))
header("Signature", "a")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
}
.andExpect { status { isUnauthorized() } }
}
@Test
fun signatureヘッダーがあるのにdateヘッダーがないと401() {
mockMvc
.get("/users/test-user10/posts/9999") {
accept(MediaType("application", "activity+json"))
header("Signature", "a")
header("Host", "example.com")
}
.andExpect { status { isUnauthorized() } }
}
companion object { companion object {
@JvmStatic @JvmStatic
@AfterAll @AfterAll

View File

@ -24,6 +24,7 @@ import org.assertj.core.api.Assertions.assertThat
import org.flywaydb.core.Flyway import org.flywaydb.core.Flyway
import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.springframework.beans.factory.annotation.Autowired import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc
@ -31,7 +32,6 @@ import org.springframework.boot.test.context.SpringBootTest
import org.springframework.http.MediaType import org.springframework.http.MediaType
import org.springframework.security.core.authority.SimpleGrantedAuthority import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.test.context.support.WithAnonymousUser import org.springframework.security.test.context.support.WithAnonymousUser
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt
import org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity import org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity
@ -160,6 +160,7 @@ class AccountApiTest {
} }
@Test @Test
@Disabled("JSONでも作れるようにするため")
@WithAnonymousUser @WithAnonymousUser
fun apiV1AccountsPostでJSONで作ろうとしても400() { fun apiV1AccountsPostでJSONで作ろうとしても400() {
mockMvc mockMvc

View File

@ -16,8 +16,8 @@
package dev.usbharu.hideout.activitypub.interfaces.api.inbox package dev.usbharu.hideout.activitypub.interfaces.api.inbox
import jakarta.servlet.http.HttpServletRequest
import org.springframework.http.ResponseEntity import org.springframework.http.ResponseEntity
import org.springframework.web.bind.annotation.RequestBody
import org.springframework.web.bind.annotation.RequestMapping import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RequestMethod import org.springframework.web.bind.annotation.RequestMethod
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
@ -34,5 +34,5 @@ interface InboxController {
consumes = ["application/json", "application/*+json"], consumes = ["application/json", "application/*+json"],
method = [RequestMethod.POST] method = [RequestMethod.POST]
) )
suspend fun inbox(@RequestBody string: String): ResponseEntity<Unit> suspend fun inbox(httpServletRequest: HttpServletRequest): ResponseEntity<String>
} }

View File

@ -17,67 +17,65 @@
package dev.usbharu.hideout.activitypub.interfaces.api.inbox package dev.usbharu.hideout.activitypub.interfaces.api.inbox
import dev.usbharu.hideout.activitypub.service.common.APService import dev.usbharu.hideout.activitypub.service.common.APService
import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureHeaderChecker
import dev.usbharu.httpsignature.common.HttpHeaders import dev.usbharu.httpsignature.common.HttpHeaders
import dev.usbharu.httpsignature.common.HttpMethod import dev.usbharu.httpsignature.common.HttpMethod
import dev.usbharu.httpsignature.common.HttpRequest import dev.usbharu.httpsignature.common.HttpRequest
import jakarta.servlet.http.HttpServletRequest
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.slf4j.MDCContext
import kotlinx.coroutines.withContext
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import org.springframework.http.HttpHeaders.WWW_AUTHENTICATE import org.springframework.http.HttpHeaders.WWW_AUTHENTICATE
import org.springframework.http.HttpStatus import org.springframework.http.HttpStatus
import org.springframework.http.ResponseEntity import org.springframework.http.ResponseEntity
import org.springframework.web.bind.annotation.RequestBody
import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestController
import org.springframework.web.context.request.RequestContextHolder
import org.springframework.web.context.request.ServletRequestAttributes
import java.net.URL import java.net.URL
@RestController @RestController
class InboxControllerImpl(private val apService: APService) : InboxController { class InboxControllerImpl(
private val apService: APService,
private val httpSignatureHeaderChecker: HttpSignatureHeaderChecker,
) : InboxController {
@Suppress("TooGenericExceptionCaught") @Suppress("TooGenericExceptionCaught")
override suspend fun inbox( override suspend fun inbox(
@RequestBody string: String httpServletRequest: HttpServletRequest,
): ResponseEntity<Unit> { ): ResponseEntity<String> {
val request = (requireNotNull(RequestContextHolder.getRequestAttributes()) as ServletRequestAttributes).request val headersList = httpServletRequest.headerNames?.toList().orEmpty()
val headersList = request.headerNames?.toList().orEmpty()
LOGGER.trace("Inbox Headers {}", headersList) LOGGER.trace("Inbox Headers {}", headersList)
if (headersList.map { it.lowercase() }.contains("signature").not()) { val body = withContext(Dispatchers.IO + MDCContext()) {
return ResponseEntity.status(HttpStatus.UNAUTHORIZED) httpServletRequest.inputStream.readAllBytes()!!
.header( }
WWW_AUTHENTICATE,
"Signature realm=\"Example\",headers=\"(request-target) date host digest\"" val responseEntity = checkHeader(httpServletRequest, body)
)
.build() if (responseEntity != null) {
return responseEntity
} }
val parseActivity = try { val parseActivity = try {
apService.parseActivity(string) apService.parseActivity(body.decodeToString())
} catch (e: Exception) { } catch (e: Exception) {
LOGGER.warn("FAILED Parse Activity", e) LOGGER.warn("FAILED Parse Activity", e)
return ResponseEntity.accepted().build() return ResponseEntity.accepted().build()
} }
LOGGER.info("INBOX Processing Activity Type: {}", parseActivity) LOGGER.info("INBOX Processing Activity Type: {}", parseActivity)
try { try {
val url = request.requestURL.toString() val url = httpServletRequest.requestURL.toString()
val headers = val headers =
headersList.associateWith { header -> request.getHeaders(header)?.toList().orEmpty() } headersList.associateWith { header ->
httpServletRequest.getHeaders(header)?.toList().orEmpty()
val method = when (val method = request.method.lowercase()) {
"get" -> HttpMethod.GET
"post" -> HttpMethod.POST
else -> {
throw IllegalArgumentException("Unsupported method: $method")
} }
}
apService.processActivity( apService.processActivity(
string, body.decodeToString(),
parseActivity, parseActivity,
HttpRequest( HttpRequest(
URL(url + request.queryString.orEmpty()), URL(url + httpServletRequest.queryString.orEmpty()),
HttpHeaders(headers), HttpHeaders(headers),
method HttpMethod.POST
), ),
headers headers
) )
@ -89,6 +87,46 @@ class InboxControllerImpl(private val apService: APService) : InboxController {
return ResponseEntity(HttpStatus.ACCEPTED) return ResponseEntity(HttpStatus.ACCEPTED)
} }
private fun checkHeader(
httpServletRequest: HttpServletRequest,
body: ByteArray,
): ResponseEntity<String>? {
try {
httpSignatureHeaderChecker.checkDate(httpServletRequest.getHeader("date")!!)
} catch (_: NullPointerException) {
return ResponseEntity.status(HttpStatus.BAD_REQUEST).body("Required date header")
} catch (_: IllegalArgumentException) {
return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body("Request is too old.")
}
try {
httpSignatureHeaderChecker.checkHost(httpServletRequest.getHeader("host")!!)
} catch (_: NullPointerException) {
return ResponseEntity.status(HttpStatus.BAD_REQUEST).body("Required host header")
} catch (_: IllegalArgumentException) {
return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body("Wrong host for request")
}
try {
httpSignatureHeaderChecker.checkDigest(body, httpServletRequest.getHeader("digest")!!)
} catch (_: NullPointerException) {
return ResponseEntity.status(HttpStatus.BAD_REQUEST)
.body("Required request body digest in digest header (sha256)")
} catch (_: IllegalArgumentException) {
return ResponseEntity
.status(HttpStatus.UNAUTHORIZED)
.body("Wrong digest for request")
}
if (httpServletRequest.getHeader("signature").orEmpty().isBlank()) {
return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
.header(
WWW_AUTHENTICATE,
"Signature realm=\"Example\",headers=\"(request-target) date host digest\""
)
.build()
}
return null
}
companion object { companion object {
private val LOGGER = LoggerFactory.getLogger(InboxControllerImpl::class.java) private val LOGGER = LoggerFactory.getLogger(InboxControllerImpl::class.java)
} }

View File

@ -26,6 +26,7 @@ import dev.usbharu.hideout.application.external.Transaction
import dev.usbharu.hideout.application.infrastructure.springframework.RoleHierarchyAuthorizationManagerFactory import dev.usbharu.hideout.application.infrastructure.springframework.RoleHierarchyAuthorizationManagerFactory
import dev.usbharu.hideout.core.domain.model.actor.ActorRepository import dev.usbharu.hideout.core.domain.model.actor.ActorRepository
import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureFilter import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureFilter
import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureHeaderChecker
import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureUserDetailsService import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureUserDetailsService
import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureVerifierComposite import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureVerifierComposite
import dev.usbharu.hideout.core.infrastructure.springframework.oauth2.UserDetailsImpl import dev.usbharu.hideout.core.infrastructure.springframework.oauth2.UserDetailsImpl
@ -35,6 +36,9 @@ import dev.usbharu.httpsignature.sign.RsaSha256HttpSignatureSigner
import dev.usbharu.httpsignature.verify.DefaultSignatureHeaderParser import dev.usbharu.httpsignature.verify.DefaultSignatureHeaderParser
import dev.usbharu.httpsignature.verify.RsaSha256HttpSignatureVerifier import dev.usbharu.httpsignature.verify.RsaSha256HttpSignatureVerifier
import jakarta.annotation.PostConstruct import jakarta.annotation.PostConstruct
import jakarta.servlet.*
import org.springframework.beans.factory.support.BeanDefinitionRegistry
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty
import org.springframework.boot.autoconfigure.jackson.Jackson2ObjectMapperBuilderCustomizer import org.springframework.boot.autoconfigure.jackson.Jackson2ObjectMapperBuilderCustomizer
import org.springframework.boot.context.properties.ConfigurationProperties import org.springframework.boot.context.properties.ConfigurationProperties
@ -58,6 +62,7 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe
import org.springframework.security.config.annotation.web.invoke import org.springframework.security.config.annotation.web.invoke
import org.springframework.security.config.http.SessionCreationPolicy import org.springframework.security.config.http.SessionCreationPolicy
import org.springframework.security.core.Authentication import org.springframework.security.core.Authentication
import org.springframework.security.core.context.SecurityContextHolderStrategy
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder
import org.springframework.security.crypto.password.PasswordEncoder import org.springframework.security.crypto.password.PasswordEncoder
import org.springframework.security.oauth2.core.AuthorizationGrantType import org.springframework.security.oauth2.core.AuthorizationGrantType
@ -67,14 +72,21 @@ import org.springframework.security.oauth2.server.authorization.config.annotatio
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer
import org.springframework.security.web.FilterChainProxy
import org.springframework.security.web.SecurityFilterChain import org.springframework.security.web.SecurityFilterChain
import org.springframework.security.web.access.ExceptionTranslationFilter import org.springframework.security.web.access.ExceptionTranslationFilter
import org.springframework.security.web.authentication.AuthenticationEntryPointFailureHandler import org.springframework.security.web.authentication.AuthenticationEntryPointFailureHandler
import org.springframework.security.web.authentication.HttpStatusEntryPoint import org.springframework.security.web.authentication.HttpStatusEntryPoint
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationProvider import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationProvider
import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer
import org.springframework.security.web.debug.DebugFilter
import org.springframework.security.web.firewall.HttpFirewall
import org.springframework.security.web.firewall.RequestRejectedHandler
import org.springframework.security.web.savedrequest.RequestCacheAwareFilter import org.springframework.security.web.savedrequest.RequestCacheAwareFilter
import org.springframework.security.web.util.matcher.AnyRequestMatcher import org.springframework.security.web.util.matcher.AnyRequestMatcher
import org.springframework.web.filter.CompositeFilter
import java.io.IOException
import java.security.KeyPairGenerator import java.security.KeyPairGenerator
import java.security.interfaces.RSAPrivateKey import java.security.interfaces.RSAPrivateKey
import java.security.interfaces.RSAPublicKey import java.security.interfaces.RSAPublicKey
@ -86,15 +98,14 @@ import java.util.*
class SecurityConfig { class SecurityConfig {
@Bean @Bean
fun authenticationManager(authenticationConfiguration: AuthenticationConfiguration): AuthenticationManager? { fun authenticationManager(authenticationConfiguration: AuthenticationConfiguration): AuthenticationManager? =
return authenticationConfiguration.authenticationManager authenticationConfiguration.authenticationManager
}
@Bean @Bean
@Order(1) @Order(1)
fun httpSignatureFilterChain( fun httpSignatureFilterChain(
http: HttpSecurity, http: HttpSecurity,
httpSignatureFilter: HttpSignatureFilter httpSignatureFilter: HttpSignatureFilter,
): SecurityFilterChain { ): SecurityFilterChain {
http { http {
securityMatcher("/users/*/posts/*") securityMatcher("/users/*/posts/*")
@ -122,9 +133,10 @@ class SecurityConfig {
@Bean @Bean
fun getHttpSignatureFilter( fun getHttpSignatureFilter(
authenticationManager: AuthenticationManager, authenticationManager: AuthenticationManager,
httpSignatureHeaderChecker: HttpSignatureHeaderChecker,
): HttpSignatureFilter { ): HttpSignatureFilter {
val httpSignatureFilter = val httpSignatureFilter =
HttpSignatureFilter(DefaultSignatureHeaderParser()) HttpSignatureFilter(DefaultSignatureHeaderParser(), httpSignatureHeaderChecker)
httpSignatureFilter.setAuthenticationManager(authenticationManager) httpSignatureFilter.setAuthenticationManager(authenticationManager)
httpSignatureFilter.setContinueFilterChainOnUnsuccessfulAuthentication(false) httpSignatureFilter.setContinueFilterChainOnUnsuccessfulAuthentication(false)
val authenticationEntryPointFailureHandler = val authenticationEntryPointFailureHandler =
@ -147,7 +159,7 @@ class SecurityConfig {
@Order(1) @Order(1)
fun httpSignatureAuthenticationProvider( fun httpSignatureAuthenticationProvider(
transaction: Transaction, transaction: Transaction,
actorRepository: ActorRepository actorRepository: ActorRepository,
): PreAuthenticatedAuthenticationProvider { ): PreAuthenticatedAuthenticationProvider {
val provider = PreAuthenticatedAuthenticationProvider() val provider = PreAuthenticatedAuthenticationProvider()
val signatureHeaderParser = DefaultSignatureHeaderParser() val signatureHeaderParser = DefaultSignatureHeaderParser()
@ -190,7 +202,7 @@ class SecurityConfig {
@Order(4) @Order(4)
fun defaultSecurityFilterChain( fun defaultSecurityFilterChain(
http: HttpSecurity, http: HttpSecurity,
rf: RoleHierarchyAuthorizationManagerFactory rf: RoleHierarchyAuthorizationManagerFactory,
): SecurityFilterChain { ): SecurityFilterChain {
http { http {
authorizeHttpRequests { authorizeHttpRequests {
@ -401,6 +413,86 @@ class SecurityConfig {
return roleHierarchyImpl return roleHierarchyImpl
} }
// Spring Security 3.2.1 に存在する EnableWebSecurity(debug = true)にすると発生するエラーに対処するためのコード
// trueにしないときはコメントアウト
// @Bean
fun beanDefinitionRegistryPostProcessor(): BeanDefinitionRegistryPostProcessor {
return BeanDefinitionRegistryPostProcessor { registry: BeanDefinitionRegistry ->
registry.getBeanDefinition(AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME).beanClassName =
CompositeFilterChainProxy::class.java.name
}
}
@Suppress("ExpressionBodySyntax")
internal class CompositeFilterChainProxy(filters: List<Filter?>) : FilterChainProxy() {
private val doFilterDelegate: Filter
private val springSecurityFilterChain: FilterChainProxy
init {
this.doFilterDelegate = createDoFilterDelegate(filters)
this.springSecurityFilterChain = findFilterChainProxy(filters)
}
override fun afterPropertiesSet() {
springSecurityFilterChain.afterPropertiesSet()
}
@Throws(IOException::class, ServletException::class)
override fun doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain) {
doFilterDelegate.doFilter(request, response, chain)
}
override fun getFilters(url: String): List<Filter> {
return springSecurityFilterChain.getFilters(url)
}
override fun getFilterChains(): List<SecurityFilterChain> {
return springSecurityFilterChain.filterChains
}
override fun setSecurityContextHolderStrategy(securityContextHolderStrategy: SecurityContextHolderStrategy) {
springSecurityFilterChain.setSecurityContextHolderStrategy(securityContextHolderStrategy)
}
override fun setFilterChainValidator(filterChainValidator: FilterChainValidator) {
springSecurityFilterChain.setFilterChainValidator(filterChainValidator)
}
override fun setFilterChainDecorator(filterChainDecorator: FilterChainDecorator) {
springSecurityFilterChain.setFilterChainDecorator(filterChainDecorator)
}
override fun setFirewall(firewall: HttpFirewall) {
springSecurityFilterChain.setFirewall(firewall)
}
override fun setRequestRejectedHandler(requestRejectedHandler: RequestRejectedHandler) {
springSecurityFilterChain.setRequestRejectedHandler(requestRejectedHandler)
}
companion object {
private fun createDoFilterDelegate(filters: List<Filter?>): Filter {
val delegate: CompositeFilter = CompositeFilter()
delegate.setFilters(filters)
return delegate
}
private fun findFilterChainProxy(filters: List<Filter?>): FilterChainProxy {
for (filter in filters) {
if (filter is FilterChainProxy) {
return filter
}
if (filter is DebugFilter) {
return filter.filterChainProxy
}
}
throw IllegalStateException("Couldn't find FilterChainProxy in $filters")
}
}
}
} }
@ConfigurationProperties("hideout.security.jwt") @ConfigurationProperties("hideout.security.jwt")
@ -408,14 +500,14 @@ class SecurityConfig {
data class JwkConfig( data class JwkConfig(
val keyId: String, val keyId: String,
val publicKey: String, val publicKey: String,
val privateKey: String val privateKey: String,
) )
@Configuration @Configuration
class PostSecurityConfig( class PostSecurityConfig(
val auth: AuthenticationManagerBuilder, val auth: AuthenticationManagerBuilder,
val daoAuthenticationProvider: DaoAuthenticationProvider, val daoAuthenticationProvider: DaoAuthenticationProvider,
val httpSignatureAuthenticationProvider: PreAuthenticatedAuthenticationProvider val httpSignatureAuthenticationProvider: PreAuthenticatedAuthenticationProvider,
) { ) {
@PostConstruct @PostConstruct

View File

@ -24,7 +24,10 @@ import jakarta.servlet.http.HttpServletRequest
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter
import java.net.URL import java.net.URL
class HttpSignatureFilter(private val httpSignatureHeaderParser: SignatureHeaderParser) : class HttpSignatureFilter(
private val httpSignatureHeaderParser: SignatureHeaderParser,
private val httpSignatureHeaderChecker: HttpSignatureHeaderChecker,
) :
AbstractPreAuthenticatedProcessingFilter() { AbstractPreAuthenticatedProcessingFilter() {
override fun getPreAuthenticatedPrincipal(request: HttpServletRequest?): Any? { override fun getPreAuthenticatedPrincipal(request: HttpServletRequest?): Any? {
val headersList = request?.headerNames?.toList().orEmpty() val headersList = request?.headerNames?.toList().orEmpty()
@ -42,7 +45,7 @@ class HttpSignatureFilter(private val httpSignatureHeaderParser: SignatureHeader
return signature.keyId return signature.keyId
} }
override fun getPreAuthenticatedCredentials(request: HttpServletRequest?): Any { override fun getPreAuthenticatedCredentials(request: HttpServletRequest?): Any? {
requireNotNull(request) requireNotNull(request)
val url = request.requestURL.toString() val url = request.requestURL.toString()
@ -55,10 +58,26 @@ class HttpSignatureFilter(private val httpSignatureHeaderParser: SignatureHeader
"get" -> HttpMethod.GET "get" -> HttpMethod.GET
"post" -> HttpMethod.POST "post" -> HttpMethod.POST
else -> { else -> {
throw IllegalArgumentException("Unsupported method: $method") // throw IllegalArgumentException("Unsupported method: $method")
return null
} }
} }
try {
httpSignatureHeaderChecker.checkDate(request.getHeader("date")!!)
httpSignatureHeaderChecker.checkHost(request.getHeader("host")!!)
if (request.method.equals("post", true)) {
httpSignatureHeaderChecker.checkDigest(
request.inputStream.readAllBytes()!!,
request.getHeader("digest")!!
)
}
} catch (_: NullPointerException) {
return null
} catch (_: IllegalArgumentException) {
return null
}
return HttpRequest( return HttpRequest(
URL(url + request.queryString.orEmpty()), URL(url + request.queryString.orEmpty()),
HttpHeaders(headers), HttpHeaders(headers),

View File

@ -0,0 +1,58 @@
/*
* Copyright (C) 2024 usbharu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.usbharu.hideout.core.infrastructure.springframework.httpsignature
import dev.usbharu.hideout.application.config.ApplicationConfig
import dev.usbharu.hideout.util.Base64Util
import org.springframework.stereotype.Component
import java.security.MessageDigest
import java.time.Instant
import java.time.format.DateTimeFormatter
import java.util.*
@Component
class HttpSignatureHeaderChecker(private val applicationConfig: ApplicationConfig) {
fun checkDate(date: String) {
val from = Instant.from(dateFormat.parse(date))
if (from.isAfter(Instant.now()) || from.isBefore(Instant.now().minusSeconds(86400))) {
throw IllegalArgumentException("未来")
}
}
fun checkHost(host: String) {
if (applicationConfig.url.host.equals(host, true).not()) {
throw IllegalArgumentException("ホスト名が違う")
}
}
fun checkDigest(byteArray: ByteArray, digest: String) {
val find = regex.find(digest)
val sha256 = MessageDigest.getInstance("SHA-256")
val other = find?.groups?.get(2)?.value.orEmpty()
if (Base64Util.encode(sha256.digest(byteArray)).equals(other, true).not()) {
throw IllegalArgumentException("リクエストボディが違う")
}
}
companion object {
private val dateFormat = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US)
private val regex = Regex("^([a-zA-Z0-9\\-]+)=(.+)$")
}
}

View File

@ -19,6 +19,7 @@ package dev.usbharu.hideout.generate
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import org.springframework.core.MethodParameter import org.springframework.core.MethodParameter
import org.springframework.validation.BindException
import org.springframework.web.bind.support.WebDataBinderFactory import org.springframework.web.bind.support.WebDataBinderFactory
import org.springframework.web.context.request.NativeWebRequest import org.springframework.web.context.request.NativeWebRequest
import org.springframework.web.method.annotation.ModelAttributeMethodProcessor import org.springframework.web.method.annotation.ModelAttributeMethodProcessor
@ -56,12 +57,17 @@ class JsonOrFormModelMethodProcessor(
return try { return try {
modelAttributeMethodProcessor.resolveArgument(parameter, mavContainer, webRequest, binderFactory) modelAttributeMethodProcessor.resolveArgument(parameter, mavContainer, webRequest, binderFactory)
} catch (e: BindException) {
throw e
} catch (exception: Exception) { } catch (exception: Exception) {
try { try {
requestResponseBodyMethodProcessor.resolveArgument(parameter, mavContainer, webRequest, binderFactory) requestResponseBodyMethodProcessor.resolveArgument(parameter, mavContainer, webRequest, binderFactory)
} catch (e: BindException) {
throw e
} catch (e: Exception) { } catch (e: Exception) {
logger.warn("Failed to bind request (1)", exception) logger.warn("Failed to bind request (1)", exception)
logger.warn("Failed to bind request (2)", e) logger.warn("Failed to bind request (2)", e)
throw IllegalArgumentException("Failed to bind request.")
} }
} }
} }

View File

@ -59,19 +59,19 @@ class MastodonAccountApiController(
HttpStatus.OK HttpStatus.OK
) )
override suspend fun apiV1AccountsPost( override suspend fun apiV1AccountsPost(accountsCreateRequest: AccountsCreateRequest): ResponseEntity<Unit> {
username: String,
password: String,
email: String?,
agreement: Boolean?,
locale: Boolean?,
reason: String?
): ResponseEntity<Unit> {
transaction.transaction { transaction.transaction {
accountApiService.registerAccount(UserCreateDto(username, username, "", password)) accountApiService.registerAccount(
UserCreateDto(
accountsCreateRequest.username,
accountsCreateRequest.username,
"",
accountsCreateRequest.password
)
)
} }
val httpHeaders = HttpHeaders() val httpHeaders = HttpHeaders()
httpHeaders.location = URI("/users/$username") httpHeaders.location = URI("/users/${accountsCreateRequest.username}")
return ResponseEntity(Unit, httpHeaders, HttpStatus.FOUND) return ResponseEntity(Unit, httpHeaders, HttpStatus.FOUND)
} }

View File

@ -284,6 +284,9 @@ paths:
requestBody: requestBody:
required: true required: true
content: content:
application/json:
schema:
$ref: "#/components/schemas/AccountsCreateRequest"
application/x-www-form-urlencoded: application/x-www-form-urlencoded:
schema: schema:
$ref: "#/components/schemas/AccountsCreateRequest" $ref: "#/components/schemas/AccountsCreateRequest"
@ -1429,6 +1432,7 @@ components:
format: binary format: binary
description: description:
type: string type: string
maxLength: 4000
focus: focus:
type: string type: string
required: required:
@ -1439,12 +1443,18 @@ components:
properties: properties:
username: username:
type: string type: string
minLength: 1
maxLength: 300
pattern: '^[a-zA-Z0-9_-]{1,300}$'
email: email:
type: string type: string
format: email
password: password:
type: string type: string
format: password
agreement: agreement:
type: boolean type: boolean
default: false
locale: locale:
type: boolean type: boolean
reason: reason:
@ -1473,6 +1483,8 @@ components:
type: string type: string
username: username:
type: string type: string
minLength: 1
pattern: '^[a-zA-Z0-9_-]{1,300}$'
acct: acct:
type: string type: string
url: url:
@ -2042,8 +2054,10 @@ components:
properties: properties:
phrase: phrase:
type: string type: string
maxLength: 1000
context: context:
type: array type: array
maxItems: 10
items: items:
type: string type: string
enum: enum:
@ -2054,8 +2068,10 @@ components:
- account - account
irreversible: irreversible:
type: boolean type: boolean
default: false
whole_word: whole_word:
type: boolean type: boolean
default: false
expires_in: expires_in:
type: integer type: integer
required: required:
@ -2067,8 +2083,10 @@ components:
properties: properties:
phrase: phrase:
type: string type: string
maxLength: 1000
context: context:
type: array type: array
maxItems: 10
items: items:
type: string type: string
enum: enum:
@ -2089,8 +2107,10 @@ components:
properties: properties:
title: title:
type: string type: string
maxLength: 255
context: context:
type: array type: array
maxItems: 10
items: items:
type: string type: string
enum: enum:
@ -2107,6 +2127,7 @@ components:
expires_in: expires_in:
type: integer type: integer
keywords_attributes: keywords_attributes:
maxItems: 1000
type: array type: array
items: items:
$ref: "#/components/schemas/FilterPostRequestKeyword" $ref: "#/components/schemas/FilterPostRequestKeyword"
@ -2119,6 +2140,7 @@ components:
properties: properties:
keyword: keyword:
type: string type: string
maxLength: 1000
whole_word: whole_word:
type: boolean type: boolean
default: false default: false
@ -2133,6 +2155,7 @@ components:
properties: properties:
keyword: keyword:
type: string type: string
maxLength: 1000
whole_word: whole_word:
type: boolean type: boolean
default: false default: false
@ -2147,6 +2170,7 @@ components:
properties: properties:
keyword: keyword:
type: string type: string
maxLength: 1000
whole_word: whole_word:
type: boolean type: boolean
regex: regex:
@ -2157,8 +2181,10 @@ components:
properties: properties:
title: title:
type: string type: string
maxLength: 255
context: context:
type: array type: array
maxItems: 10
items: items:
type: string type: string
enum: enum:
@ -2175,6 +2201,7 @@ components:
expires_in: expires_in:
type: integer type: integer
keywords_attributes: keywords_attributes:
maxItems: 1000
type: array type: array
items: items:
$ref: "#/components/schemas/FilterPubRequestKeyword" $ref: "#/components/schemas/FilterPubRequestKeyword"
@ -2184,6 +2211,7 @@ components:
properties: properties:
keyword: keyword:
type: string type: string
maxLength: 1000
whole_word: whole_word:
type: boolean type: boolean
regex: regex:
@ -2192,6 +2220,9 @@ components:
type: string type: string
_destroy: _destroy:
type: boolean type: boolean
default: false
required:
- id
FilterStatusRequest: FilterStatusRequest:
type: object type: object
@ -2547,18 +2578,22 @@ components:
status: status:
type: string type: string
nullable: true nullable: true
maxLength: 3000
media_ids: media_ids:
type: array type: array
items: items:
type: string type: string
maxItems: 4
poll: poll:
$ref: "#/components/schemas/StatusesRequestPoll" $ref: "#/components/schemas/StatusesRequestPoll"
in_reply_to_id: in_reply_to_id:
type: string type: string
sensitive: sensitive:
type: boolean type: boolean
default: false
spoiler_text: spoiler_text:
type: string type: string
maxLength: 100
visibility: visibility:
type: string type: string
enum: enum:
@ -2568,22 +2603,29 @@ components:
- direct - direct
language: language:
type: string type: string
maxLength: 100
scheduled_at: scheduled_at:
type: string type: string
format: date-time
example: "2019-12-05T12:33:01.000Z"
StatusesRequestPoll: StatusesRequestPoll:
type: object type: object
properties: properties:
options: options:
type: array type: array
maxItems: 10
items: items:
type: string type: string
maxLength: 100
expires_in: expires_in:
type: integer type: integer
multiple: multiple:
type: boolean type: boolean
default: false
hide_totals: hide_totals:
type: boolean type: boolean
default: false
Application: Application:
type: object type: object
@ -2610,12 +2652,16 @@ components:
properties: properties:
client_name: client_name:
type: string type: string
maxLength: 200
redirect_uris: redirect_uris:
type: string type: string
maxLength: 1000
scopes: scopes:
type: string type: string
maxLength: 1000
website: website:
type: string type: string
maxLength: 1000
required: required:
- client_name - client_name
- redirect_uris - redirect_uris
@ -2676,16 +2722,20 @@ components:
default: false default: false
languages: languages:
type: array type: array
maxItems: 10
items: items:
type: string type: string
maxLength: 10
UpdateCredentials: UpdateCredentials:
type: object type: object
properties: properties:
display_name: display_name:
type: string type: string
maxLength: 300
note: note:
type: string type: string
maxLength: 2000
avatar: avatar:
type: string type: string
format: binary format: binary

View File

@ -19,13 +19,17 @@ package dev.usbharu.hideout.activitypub.interfaces.api.inbox
import dev.usbharu.hideout.activitypub.domain.exception.JsonParseException import dev.usbharu.hideout.activitypub.domain.exception.JsonParseException
import dev.usbharu.hideout.activitypub.service.common.APService import dev.usbharu.hideout.activitypub.service.common.APService
import dev.usbharu.hideout.activitypub.service.common.ActivityType import dev.usbharu.hideout.activitypub.service.common.ActivityType
import dev.usbharu.hideout.application.config.ApplicationConfig
import dev.usbharu.hideout.core.domain.exception.FailedToGetResourcesException import dev.usbharu.hideout.core.domain.exception.FailedToGetResourcesException
import dev.usbharu.hideout.core.infrastructure.springframework.httpsignature.HttpSignatureHeaderChecker
import dev.usbharu.hideout.util.Base64Util
import kotlinx.coroutines.test.runTest import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith import org.junit.jupiter.api.extension.ExtendWith
import org.mockito.InjectMocks import org.mockito.InjectMocks
import org.mockito.Mock import org.mockito.Mock
import org.mockito.Spy
import org.mockito.junit.jupiter.MockitoExtension import org.mockito.junit.jupiter.MockitoExtension
import org.mockito.kotlin.* import org.mockito.kotlin.*
import org.springframework.http.MediaType import org.springframework.http.MediaType
@ -33,12 +37,21 @@ import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.get import org.springframework.test.web.servlet.get
import org.springframework.test.web.servlet.post import org.springframework.test.web.servlet.post
import org.springframework.test.web.servlet.setup.MockMvcBuilders import org.springframework.test.web.servlet.setup.MockMvcBuilders
import java.net.URI
import java.security.MessageDigest
import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
import java.util.*
@ExtendWith(MockitoExtension::class) @ExtendWith(MockitoExtension::class)
class InboxControllerImplTest { class InboxControllerImplTest {
private lateinit var mockMvc: MockMvc private lateinit var mockMvc: MockMvc
@Spy
private val httpSignatureHeaderChecker =
HttpSignatureHeaderChecker(ApplicationConfig(URI.create("https://example.com").toURL()))
@Mock @Mock
private lateinit var apService: APService private lateinit var apService: APService
@ -50,6 +63,10 @@ class InboxControllerImplTest {
mockMvc = MockMvcBuilders.standaloneSetup(inboxController).build() mockMvc = MockMvcBuilders.standaloneSetup(inboxController).build()
} }
private val dateTimeFormatter: DateTimeFormatter =
DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US)
@Test @Test
fun `inbox 正常なPOSTリクエストをしたときAcceptが返ってくる`() = runTest { fun `inbox 正常なPOSTリクエストをしたときAcceptが返ってくる`() = runTest {
@ -58,24 +75,25 @@ class InboxControllerImplTest {
whenever(apService.parseActivity(eq(json))).doReturn(ActivityType.Follow) whenever(apService.parseActivity(eq(json))).doReturn(ActivityType.Follow)
whenever( whenever(
apService.processActivity( apService.processActivity(
eq(json), eq(json), eq(ActivityType.Follow), any(), any()
eq(ActivityType.Follow),
any(),
any()
) )
).doReturn(Unit) ).doReturn(Unit)
mockMvc val sha256 = MessageDigest.getInstance("SHA-256")
.post("/inbox") {
content = json val digest = Base64Util.encode(sha256.digest(json.toByteArray()))
contentType = MediaType.APPLICATION_JSON
header("Signature", "") mockMvc.post("/inbox") {
} content = json
.asyncDispatch() contentType = MediaType.APPLICATION_JSON
.andExpect { header("Signature", "a")
status { isAccepted() } header("Host", "example.com")
} header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header("Digest", "SHA-256=" + digest)
}.asyncDispatch().andExpect {
status { isAccepted() }
}
} }
@ -83,17 +101,19 @@ class InboxControllerImplTest {
fun `inbox parseActivityに失敗したときAcceptが返ってくる`() = runTest { fun `inbox parseActivityに失敗したときAcceptが返ってくる`() = runTest {
val json = """{"type":"Hoge"}""" val json = """{"type":"Hoge"}"""
whenever(apService.parseActivity(eq(json))).doThrow(JsonParseException::class) whenever(apService.parseActivity(eq(json))).doThrow(JsonParseException::class)
val sha256 = MessageDigest.getInstance("SHA-256")
mockMvc val digest = Base64Util.encode(sha256.digest(json.toByteArray()))
.post("/inbox") { mockMvc.post("/inbox") {
content = json content = json
contentType = MediaType.APPLICATION_JSON contentType = MediaType.APPLICATION_JSON
header("Signature", "") header("Signature", "a")
} header("Host", "example.com")
.asyncDispatch() header("Date", ZonedDateTime.now().format(dateTimeFormatter))
.andExpect { header("Digest", "SHA-256=$digest")
status { isAccepted() } }.asyncDispatch().andExpect {
} status { isAccepted() }
}
} }
@ -103,23 +123,22 @@ class InboxControllerImplTest {
whenever(apService.parseActivity(eq(json))).doReturn(ActivityType.Follow) whenever(apService.parseActivity(eq(json))).doReturn(ActivityType.Follow)
whenever( whenever(
apService.processActivity( apService.processActivity(
eq(json), eq(json), eq(ActivityType.Follow), any(), any()
eq(ActivityType.Follow),
any(),
any()
) )
).doThrow(FailedToGetResourcesException::class) ).doThrow(FailedToGetResourcesException::class)
val sha256 = MessageDigest.getInstance("SHA-256")
mockMvc val digest = Base64Util.encode(sha256.digest(json.toByteArray()))
.post("/inbox") { mockMvc.post("/inbox") {
content = json content = json
contentType = MediaType.APPLICATION_JSON contentType = MediaType.APPLICATION_JSON
header("Signature", "") header("Signature", "a")
} header("Host", "example.com")
.asyncDispatch() header("Date", ZonedDateTime.now().format(dateTimeFormatter))
.andExpect { header("Digest", "SHA-256=$digest")
status { isAccepted() } }.asyncDispatch().andExpect {
} status { isAccepted() }
}
} }
@ -137,17 +156,19 @@ class InboxControllerImplTest {
whenever(apService.processActivity(eq(json), eq(ActivityType.Follow), any(), any())).doReturn( whenever(apService.processActivity(eq(json), eq(ActivityType.Follow), any(), any())).doReturn(
Unit Unit
) )
val sha256 = MessageDigest.getInstance("SHA-256")
mockMvc val digest = Base64Util.encode(sha256.digest(json.toByteArray()))
.post("/users/hoge/inbox") { mockMvc.post("/users/hoge/inbox") {
content = json content = json
contentType = MediaType.APPLICATION_JSON contentType = MediaType.APPLICATION_JSON
header("Signature", "") header("Signature", "a")
} header("Host", "example.com")
.asyncDispatch() header("Date", ZonedDateTime.now().format(dateTimeFormatter))
.andExpect { header("Digest", "SHA-256=$digest")
status { isAccepted() } }.asyncDispatch().andExpect {
} status { isAccepted() }
}
} }
@ -155,17 +176,19 @@ class InboxControllerImplTest {
fun `user-inbox parseActivityに失敗したときAcceptが返ってくる`() = runTest { fun `user-inbox parseActivityに失敗したときAcceptが返ってくる`() = runTest {
val json = """{"type":"Hoge"}""" val json = """{"type":"Hoge"}"""
whenever(apService.parseActivity(eq(json))).doThrow(JsonParseException::class) whenever(apService.parseActivity(eq(json))).doThrow(JsonParseException::class)
val sha256 = MessageDigest.getInstance("SHA-256")
mockMvc val digest = Base64Util.encode(sha256.digest(json.toByteArray()))
.post("/users/hoge/inbox") { mockMvc.post("/users/hoge/inbox") {
content = json content = json
contentType = MediaType.APPLICATION_JSON contentType = MediaType.APPLICATION_JSON
header("Signature", "") header("Signature", "a")
} header("Host", "example.com")
.asyncDispatch() header("Date", ZonedDateTime.now().format(dateTimeFormatter))
.andExpect { header("Digest", "SHA-256=$digest")
status { isAccepted() } }.asyncDispatch().andExpect {
} status { isAccepted() }
}
} }
@ -175,23 +198,22 @@ class InboxControllerImplTest {
whenever(apService.parseActivity(eq(json))).doReturn(ActivityType.Follow) whenever(apService.parseActivity(eq(json))).doReturn(ActivityType.Follow)
whenever( whenever(
apService.processActivity( apService.processActivity(
eq(json), eq(json), eq(ActivityType.Follow), any(), any()
eq(ActivityType.Follow),
any(),
any()
) )
).doThrow(FailedToGetResourcesException::class) ).doThrow(FailedToGetResourcesException::class)
val sha256 = MessageDigest.getInstance("SHA-256")
mockMvc val digest = Base64Util.encode(sha256.digest(json.toByteArray()))
.post("/users/hoge/inbox") { mockMvc.post("/users/hoge/inbox") {
content = json content = json
contentType = MediaType.APPLICATION_JSON contentType = MediaType.APPLICATION_JSON
header("Signature", "") header("Signature", "a")
} header("Host", "example.com")
.asyncDispatch() header("Date", ZonedDateTime.now().format(dateTimeFormatter))
.andExpect { header("Digest", "SHA-256=$digest")
status { isAccepted() } }.asyncDispatch().andExpect {
} status { isAccepted() }
}
} }
@ -199,4 +221,350 @@ class InboxControllerImplTest {
fun `user-inbox GETリクエストには405を返す`() { fun `user-inbox GETリクエストには405を返す`() {
mockMvc.get("/users/hoge/inbox").andExpect { status { isMethodNotAllowed() } } mockMvc.get("/users/hoge/inbox").andExpect { status { isMethodNotAllowed() } }
} }
@Test
fun `inbox Dateヘッダーが無いと400`() {
val json = """{"type":"Follow"}"""
mockMvc
.post("/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
}
.asyncDispatch()
.andExpect {
status {
isBadRequest()
}
}
}
@Test
fun `user-inbox Dateヘッダーが無いと400`() {
val json = """{"type":"Follow"}"""
mockMvc
.post("/users/hoge/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
}
.asyncDispatch()
.andExpect {
status {
isBadRequest()
}
}
}
@Test
fun `inbox Dateヘッダーが未来だと401`() {
val json = """{"type":"Follow"}"""
mockMvc
.post("/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Date", ZonedDateTime.now().plusDays(1).format(dateTimeFormatter))
}
.asyncDispatch()
.andExpect {
status {
isUnauthorized()
}
}
}
@Test
fun `user-inbox Dateヘッダーが未来だと401`() {
val json = """{"type":"Follow"}"""
mockMvc
.post("/users/hoge/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Date", ZonedDateTime.now().plusDays(1).format(dateTimeFormatter))
}
.asyncDispatch()
.andExpect {
status {
isUnauthorized()
}
}
}
@Test
fun `inbox Dateヘッダーが過去過ぎると401`() {
val json = """{"type":"Follow"}"""
mockMvc
.post("/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Date", ZonedDateTime.now().minusDays(1).format(dateTimeFormatter))
}
.asyncDispatch()
.andExpect {
status {
isUnauthorized()
}
}
}
@Test
fun `user-inbox Dateヘッダーが過去過ぎると401`() {
val json = """{"type":"Follow"}"""
mockMvc
.post("/users/hoge/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Date", ZonedDateTime.now().minusDays(1).format(dateTimeFormatter))
}
.asyncDispatch()
.andExpect {
status {
isUnauthorized()
}
}
}
@Test
fun `inbox Hostヘッダーが無いと400`() {
val json = """{"type":"Follow"}"""
mockMvc
.post("/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
}
.asyncDispatch()
.andExpect {
status {
isBadRequest()
}
}
}
@Test
fun `user-inbox Hostヘッダーが無いと400`() {
val json = """{"type":"Follow"}"""
mockMvc
.post("/users/hoge/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
}
.asyncDispatch()
.andExpect {
status {
isBadRequest()
}
}
}
@Test
fun `inbox Hostヘッダーが間違ってると401`() {
val json = """{"type":"Follow"}"""
mockMvc
.post("/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header("Host", "example.jp")
}
.asyncDispatch()
.andExpect {
status {
isUnauthorized()
}
}
}
@Test
fun `user-inbox Hostヘッダーが間違ってると401`() {
val json = """{"type":"Follow"}"""
mockMvc
.post("/users/hoge/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header("Host", "example.jp")
}
.asyncDispatch()
.andExpect {
status {
isUnauthorized()
}
}
}
@Test
fun `inbox Digestヘッダーがないと400`() = runTest {
val json = """{"type":"Follow"}"""
mockMvc
.post("/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Signature", "")
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
}
.asyncDispatch()
.andExpect {
status { isBadRequest() }
}
}
@Test
fun `inbox Digestヘッダーが間違ってると401`() = runTest {
val json = """{"type":"Follow"}"""
val sha256 = MessageDigest.getInstance("SHA-256")
val digest = Base64Util.encode(sha256.digest(("$json aaaaaaaa").toByteArray()))
mockMvc
.post("/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Signature", "")
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header("Digest", "SHA-256=$digest")
}
.asyncDispatch()
.andExpect {
status { isUnauthorized() }
}
}
@Test
fun `user-inbox Digestヘッダーがないと400`() = runTest {
val json = """{"type":"Follow"}"""
mockMvc
.post("/users/hoge/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Signature", "")
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
}
.asyncDispatch()
.andExpect {
status { isBadRequest() }
}
}
@Test
fun `user-inbox Digestヘッダーが間違ってると401`() = runTest {
val json = """{"type":"Follow"}"""
val sha256 = MessageDigest.getInstance("SHA-256")
val digest = Base64Util.encode(sha256.digest(("$json aaaaaaaa").toByteArray()))
mockMvc
.post("/users/hoge/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Signature", "")
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header("Digest", "SHA-256=$digest")
}
.asyncDispatch()
.andExpect {
status { isUnauthorized() }
}
}
@Test
fun `inbox Signatureヘッダーがないと401`() = runTest {
val json = """{"type":"Follow"}"""
val sha256 = MessageDigest.getInstance("SHA-256")
val digest = Base64Util.encode(sha256.digest(json.toByteArray()))
mockMvc
.post("/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header("Digest", "SHA-256=$digest")
}
.asyncDispatch()
.andExpect {
status { isUnauthorized() }
}
}
@Test
fun `inbox Signatureヘッダーが空だと401`() = runTest {
val json = """{"type":"Follow"}"""
val sha256 = MessageDigest.getInstance("SHA-256")
val digest = Base64Util.encode(sha256.digest(json.toByteArray()))
mockMvc
.post("/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Signature", "")
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header("Digest", "SHA-256=$digest")
}
.asyncDispatch()
.andExpect {
status { isUnauthorized() }
}
}
@Test
fun `user-inbox Digestヘッダーがないと401`() = runTest {
val json = """{"type":"Follow"}"""
val sha256 = MessageDigest.getInstance("SHA-256")
val digest = Base64Util.encode(sha256.digest(json.toByteArray()))
mockMvc
.post("/users/hoge/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header("Digest", "SHA-256=$digest")
}
.asyncDispatch()
.andExpect {
status { isUnauthorized() }
}
}
@Test
fun `user-inbox Digestヘッダーが空だと401`() = runTest {
val json = """{"type":"Follow"}"""
val sha256 = MessageDigest.getInstance("SHA-256")
val digest = Base64Util.encode(sha256.digest(json.toByteArray()))
mockMvc
.post("/users/hoge/inbox") {
content = json
contentType = MediaType.APPLICATION_JSON
header("Signature", "")
header("Host", "example.com")
header("Date", ZonedDateTime.now().format(dateTimeFormatter))
header("Digest", "SHA-256=$digest")
}
.asyncDispatch()
.andExpect {
status { isUnauthorized() }
}
}
} }

View File

@ -0,0 +1,110 @@
package dev.usbharu.hideout.core.infrastructure.springframework.httpsignature
import dev.usbharu.hideout.application.config.ApplicationConfig
import dev.usbharu.hideout.util.Base64Util
import org.intellij.lang.annotations.Language
import org.junit.jupiter.api.Assertions.assertDoesNotThrow
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import java.net.URI
import java.security.MessageDigest
import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
import java.util.*
class HttpSignatureHeaderCheckerTest {
val format = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US)
@Test
fun `checkDate 未来はダメ`() {
val httpSignatureHeaderChecker =
HttpSignatureHeaderChecker(ApplicationConfig(URI.create("http://example.com").toURL()))
val s = ZonedDateTime.now().plusDays(1).format(format)
assertThrows<IllegalArgumentException> {
httpSignatureHeaderChecker.checkDate(s)
}
}
@Test
fun `checkDate 過去はOK`() {
val httpSignatureHeaderChecker =
HttpSignatureHeaderChecker(ApplicationConfig(URI.create("http://example.com").toURL()))
val s = ZonedDateTime.now().minusHours(1).format(format)
assertDoesNotThrow {
httpSignatureHeaderChecker.checkDate(s)
}
}
@Test
fun `checkDate 86400秒以上昔はダメ`() {
val httpSignatureHeaderChecker =
HttpSignatureHeaderChecker(ApplicationConfig(URI.create("http://example.com").toURL()))
val s = ZonedDateTime.now().minusSeconds(86401).format(format)
assertThrows<IllegalArgumentException> {
httpSignatureHeaderChecker.checkDate(s)
}
}
@Test
fun `checkHost 大文字小文字の違いはセーフ`() {
val httpSignatureHeaderChecker =
HttpSignatureHeaderChecker(ApplicationConfig(URI.create("https://example.com").toURL()))
assertDoesNotThrow {
httpSignatureHeaderChecker.checkHost("example.com")
httpSignatureHeaderChecker.checkHost("EXAMPLE.COM")
}
}
@Test
fun `checkHost サブドメインはダメ`() {
val httpSignatureHeaderChecker =
HttpSignatureHeaderChecker(ApplicationConfig(URI.create("https://example.com").toURL()))
assertThrows<IllegalArgumentException> {
httpSignatureHeaderChecker.checkHost("follower.example.com")
}
}
@Test
fun `checkDigest リクエストボディが同じなら何もしない`() {
val httpSignatureHeaderChecker =
HttpSignatureHeaderChecker(ApplicationConfig(URI.create("https://example.com").toURL()))
val sha256 = MessageDigest.getInstance("SHA-256")
@Language("JSON") val requestBody = """{"@context":"","type":"hoge"}"""
val digest = Base64Util.encode(sha256.digest(requestBody.toByteArray()))
assertDoesNotThrow {
httpSignatureHeaderChecker.checkDigest(requestBody.toByteArray(), "SHA-256=" + digest)
}
}
@Test
fun `checkDigest リクエストボディがちょっとでも違うとダメ`() {
val httpSignatureHeaderChecker =
HttpSignatureHeaderChecker(ApplicationConfig(URI.create("https://example.com").toURL()))
val sha256 = MessageDigest.getInstance("SHA-256")
@Language("JSON") val requestBody = """{"type":"hoge","@context":""}"""
@Language("JSON") val requestBody2 = """{"@context":"","type":"hoge"}"""
val digest = Base64Util.encode(sha256.digest(requestBody.toByteArray()))
assertThrows<IllegalArgumentException> {
httpSignatureHeaderChecker.checkDigest(requestBody2.toByteArray(), digest)
}
}
}