diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloader.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloader.kt index 4cf842ea9291f4a09f2c37309dc2f2d0f1cb379a..686a079728ce13c8df4e061a417949c251841286 100644 --- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloader.kt +++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloader.kt @@ -2,7 +2,7 @@ package de.rki.coronawarnapp.diagnosiskeys.download import dagger.Reusable import de.rki.coronawarnapp.diagnosiskeys.server.DiagnosisKeyServer -import de.rki.coronawarnapp.diagnosiskeys.server.KeyFileHeaderHook +import de.rki.coronawarnapp.diagnosiskeys.server.DownloadInfo import de.rki.coronawarnapp.diagnosiskeys.server.LocationCode import de.rki.coronawarnapp.diagnosiskeys.storage.CachedKeyInfo import de.rki.coronawarnapp.diagnosiskeys.storage.KeyCacheRepository @@ -12,8 +12,6 @@ import de.rki.coronawarnapp.storage.AppSettings import de.rki.coronawarnapp.storage.DeviceStorage import de.rki.coronawarnapp.storage.InsufficientStorageException import de.rki.coronawarnapp.util.CWADebug -import de.rki.coronawarnapp.util.HashExtensions.hashToMD5 -import de.rki.coronawarnapp.util.debug.measureTimeMillisWithResult import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll @@ -300,28 +298,25 @@ class KeyFileDownloader @Inject constructor( keyInfo: CachedKeyInfo, saveTo: File ): Pair<CachedKeyInfo, File>? = try { - val validation = KeyFileHeaderHook { headers -> - // tryMigration returns true when a file was migrated, meaning, no download necessary - return@KeyFileHeaderHook !legacyKeyCache.tryMigration( - headers.getPayloadChecksumMD5(), - saveTo - ) - } + val preconditionHook: suspend (DownloadInfo) -> Boolean = + { downloadInfo -> + val continueDownload = !legacyKeyCache.tryMigration( + downloadInfo.serverMD5, saveTo + ) + continueDownload // Continue download if no migration happened + } - keyServer.downloadKeyFile( + val dlInfo = keyServer.downloadKeyFile( locationCode = keyInfo.location, day = keyInfo.day, hour = keyInfo.hour, saveTo = saveTo, - headerHook = validation + precondition = preconditionHook ) Timber.tag(TAG).v("Dowwnload finished: %s -> %s", keyInfo, saveTo) - val (downloadedMD5, duration) = measureTimeMillisWithResult { saveTo.hashToMD5() } - Timber.tag(TAG).v("Hashed to MD5 in %dms: %s", duration, saveTo) - - keyCache.markKeyComplete(keyInfo, downloadedMD5) + keyCache.markKeyComplete(keyInfo, dlInfo.serverMD5 ?: dlInfo.localMD5!!) keyInfo to saveTo } catch (e: Exception) { Timber.tag(TAG).e(e, "Download failed: %s", keyInfo) diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DiagnosisKeyServer.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DiagnosisKeyServer.kt index 8995751a9eeb31770d4bdc2444b947361d9b967d..acf0c392b270e0f710bbaf9335e8c3dba16a0b2d 100644 --- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DiagnosisKeyServer.kt +++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DiagnosisKeyServer.kt @@ -1,9 +1,10 @@ package de.rki.coronawarnapp.diagnosiskeys.server import dagger.Lazy +import de.rki.coronawarnapp.util.HashExtensions.hashToMD5 +import de.rki.coronawarnapp.util.debug.measureTimeMillisWithResult import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext -import okhttp3.Headers import org.joda.time.LocalDate import org.joda.time.LocalTime import org.joda.time.format.DateTimeFormat @@ -44,10 +45,6 @@ class DiagnosisKeyServer @Inject constructor( .map { hourString -> LocalTime.parse(hourString, HOUR_FORMATTER) } } - interface HeaderHook { - suspend fun validate(headers: Headers): Boolean = true - } - /** * Retrieves Key Files from the Server * Leave **[hour]** null to download a day package @@ -57,8 +54,8 @@ class DiagnosisKeyServer @Inject constructor( day: LocalDate, hour: LocalTime? = null, saveTo: File, - headerHook: HeaderHook = object : HeaderHook {} - ) = withContext(Dispatchers.IO) { + precondition: suspend (DownloadInfo) -> Boolean = { true } + ): DownloadInfo = withContext(Dispatchers.IO) { Timber.tag(TAG).v( "Starting download: country=%s, day=%s, hour=%s -> %s.", locationCode, day, hour, saveTo @@ -84,9 +81,11 @@ class DiagnosisKeyServer @Inject constructor( ) } - if (!headerHook.validate(response.headers())) { - Timber.tag(TAG).d("validateHeaders() told us to abort.") - return@withContext + var downloadInfo = DownloadInfo(response.headers()) + + if (!precondition(downloadInfo)) { + Timber.tag(TAG).d("Precondition is not met, aborting.") + return@withContext downloadInfo } if (response.isSuccessful) { saveTo.outputStream().use { target -> @@ -94,7 +93,14 @@ class DiagnosisKeyServer @Inject constructor( source.copyTo(target, DEFAULT_BUFFER_SIZE) } } - Timber.tag(TAG).v("Key file download successful: %s", saveTo) + + val (localMD5, duration) = measureTimeMillisWithResult { saveTo.hashToMD5() } + Timber.v("Hashed to MD5 in %dms: %s", duration, saveTo) + + downloadInfo = downloadInfo.copy(localMD5 = localMD5) + Timber.tag(TAG).v("Key file download successful: %s", downloadInfo) + + return@withContext downloadInfo } else { throw HttpException(response) } diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/KeyFileHeaderHook.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfo.kt similarity index 52% rename from Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/KeyFileHeaderHook.kt rename to Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfo.kt index 51b89d5f891f21cb2f6210f978eed21f522a2a68..96f0047daf54c66dc5e6637ac07d2247eb54cf49 100644 --- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/KeyFileHeaderHook.kt +++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfo.kt @@ -2,16 +2,16 @@ package de.rki.coronawarnapp.diagnosiskeys.server import okhttp3.Headers -class KeyFileHeaderHook( - private val onEval: suspend KeyFileHeaderHook.(Headers) -> Boolean -) : DiagnosisKeyServer.HeaderHook { +data class DownloadInfo( + val headers: Headers, + val localMD5: String? = null +) { - override suspend fun validate(headers: Headers): Boolean = onEval(headers) + val serverMD5 by lazy { headers.getPayloadChecksumMD5() } - fun Headers.getPayloadChecksumMD5(): String? { - // TODO Ping backend regarding alternative checksum sources - var fileMD5 = values("ETag").singleOrNull() - // The hash from these headers doesn't match, TODO EXPOSUREBACK-178 + private fun Headers.getPayloadChecksumMD5(): String? { + // TODO EXPOSUREBACK-178 + val fileMD5 = values("ETag").singleOrNull() // var fileMD5 = headers.values("x-amz-meta-cwa-hash-md5").singleOrNull() // if (fileMD5 == null) { // headers.values("x-amz-meta-cwa-hash").singleOrNull() diff --git a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloaderTest.kt b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloaderTest.kt index 4557883674743c15d6c6545486aefce57a5d90b4..f63d93257d4f8b3775a972d7b14a29e905b1df1f 100644 --- a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloaderTest.kt +++ b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloaderTest.kt @@ -2,6 +2,7 @@ package de.rki.coronawarnapp.diagnosiskeys.download import android.database.SQLException import de.rki.coronawarnapp.diagnosiskeys.server.DiagnosisKeyServer +import de.rki.coronawarnapp.diagnosiskeys.server.DownloadInfo import de.rki.coronawarnapp.diagnosiskeys.server.LocationCode import de.rki.coronawarnapp.diagnosiskeys.storage.CachedKeyInfo import de.rki.coronawarnapp.diagnosiskeys.storage.KeyCacheRepository @@ -103,7 +104,13 @@ class KeyFileDownloaderTest : BaseIOTest() { LocalTime.parse("22"), LocalTime.parse("23") ) coEvery { diagnosisKeyServer.downloadKeyFile(any(), any(), any(), any(), any()) } answers { - mockDownloadServerDownload(arg(0), arg(1), arg(2), arg(3), arg(4)) + mockDownloadServerDownload( + locationCode = arg(0), + day = arg(1), + hour = arg(2), + saveTo = arg(3), + precondition = arg(4) + ) } coEvery { keyCache.createCacheEntry(any(), any(), any(), any()) } answers { @@ -169,10 +176,15 @@ class KeyFileDownloaderTest : BaseIOTest() { day: LocalDate, hour: LocalTime? = null, saveTo: File, - validator: DiagnosisKeyServer.HeaderHook = object : - DiagnosisKeyServer.HeaderHook {} - ) { + precondition: suspend (DownloadInfo) -> Boolean = { true }, + checksumServerMD5: String? = "serverMD5", + checksumLocalMD5: String? = "localMD5" + ): DownloadInfo { saveTo.writeText("$locationCode.$day.$hour") + return mockk<DownloadInfo>().apply { + every { serverMD5 } returns checksumServerMD5 + every { localMD5 } returns checksumLocalMD5 + } } private fun mockAddData( @@ -184,8 +196,13 @@ class KeyFileDownloaderTest : BaseIOTest() { ): Pair<CachedKeyInfo, File> { val (keyInfo, file) = mockKeyCacheCreateEntry(type, location, day, hour) if (isCompleted) { - mockDownloadServerDownload(location, day, hour, file) - mockKeyCacheUpdateComplete(keyInfo, "checksum") + mockDownloadServerDownload( + locationCode = location, + day = day, + hour = hour, + saveTo = file + ) + mockKeyCacheUpdateComplete(keyInfo, "serverMD5") } return keyRepoData[keyInfo.id]!! to file } @@ -379,7 +396,13 @@ class KeyFileDownloaderTest : BaseIOTest() { coEvery { diagnosisKeyServer.downloadKeyFile(any(), any(), any(), any(), any()) } answers { dlCounter++ if (dlCounter == 2) throw IOException("Timeout") - mockDownloadServerDownload(arg(0), arg(1), arg(2), arg(3), arg(4)) + mockDownloadServerDownload( + locationCode = arg(0), + day = arg(1), + hour = arg(2), + saveTo = arg(3), + precondition = arg(4) + ) } val downloader = createDownloader() @@ -610,7 +633,13 @@ class KeyFileDownloaderTest : BaseIOTest() { coEvery { diagnosisKeyServer.downloadKeyFile(any(), any(), any(), any(), any()) } answers { dlCounter++ if (dlCounter == 2) throw IOException("Timeout") - mockDownloadServerDownload(arg(0), arg(1), arg(2), arg(3), arg(4)) + mockDownloadServerDownload( + locationCode = arg(0), + day = arg(1), + hour = arg(2), + saveTo = arg(3), + precondition = arg(4) + ) } val downloader = createDownloader() @@ -680,4 +709,74 @@ class KeyFileDownloaderTest : BaseIOTest() { ) } } + + @Test + fun `store server md5`() { + coEvery { diagnosisKeyServer.getCountryIndex() } returns listOf(LocationCode("DE")) + coEvery { diagnosisKeyServer.getDayIndex(LocationCode("DE")) } returns listOf( + LocalDate.parse("2020-09-01") + ) + + val downloader = createDownloader() + + runBlocking { + downloader.asyncFetchKeyFiles( + listOf(LocationCode("DE")) + ).size shouldBe 1 + } + + coVerify { + keyCache.createCacheEntry( + type = CachedKeyInfo.Type.COUNTRY_DAY, + location = LocationCode("DE"), + dayIdentifier = LocalDate.parse("2020-09-01"), + hourIdentifier = null + ) + } + keyRepoData.size shouldBe 1 + keyRepoData.values.forEach { + it.isDownloadComplete shouldBe true + it.checksumMD5 shouldBe "serverMD5" + } + } + + @Test + fun `use local MD5 as fallback if there is none available from the server`() { + coEvery { diagnosisKeyServer.getCountryIndex() } returns listOf(LocationCode("DE")) + coEvery { diagnosisKeyServer.getDayIndex(LocationCode("DE")) } returns listOf( + LocalDate.parse("2020-09-01") + ) + coEvery { diagnosisKeyServer.downloadKeyFile(any(), any(), any(), any(), any()) } answers { + mockDownloadServerDownload( + locationCode = arg(0), + day = arg(1), + hour = arg(2), + saveTo = arg(3), + precondition = arg(4), + checksumServerMD5 = null + ) + } + + val downloader = createDownloader() + + runBlocking { + downloader.asyncFetchKeyFiles( + listOf(LocationCode("DE")) + ).size shouldBe 1 + } + + coVerify { + keyCache.createCacheEntry( + type = CachedKeyInfo.Type.COUNTRY_DAY, + location = LocationCode("DE"), + dayIdentifier = LocalDate.parse("2020-09-01"), + hourIdentifier = null + ) + } + keyRepoData.size shouldBe 1 + keyRepoData.values.forEach { + it.isDownloadComplete shouldBe true + it.checksumMD5 shouldBe "localMD5" + } + } } diff --git a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfoTest.kt b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfoTest.kt new file mode 100644 index 0000000000000000000000000000000000000000..cceb7f2d5481cccdb7f1e6c3edd1afbaf38bf0e4 --- /dev/null +++ b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfoTest.kt @@ -0,0 +1,19 @@ +package de.rki.coronawarnapp.diagnosiskeys.server + +import io.kotest.matchers.shouldBe +import okhttp3.Headers +import org.junit.jupiter.api.Test +import testhelpers.BaseTest + +class DownloadInfoTest : BaseTest() { + + @Test + fun `extract server MD5`() { + val info = DownloadInfo( + headers = Headers.headersOf("ETAG", "serverMD5"), + localMD5 = "localMD5" + ) + info.serverMD5 shouldBe "serverMD5" + info.localMD5 shouldBe "localMD5" + } +}