From fc5e57c33e17c8730b40f1de2ef0105012b0de84 Mon Sep 17 00:00:00 2001
From: Matthias Urhahn <darken@darken.eu>
Date: Fri, 18 Sep 2020 21:03:46 +0200
Subject: [PATCH] Prefer server MD5 over local one for downloaded files.
 (#1181)

Refactor MD5 extraction to be part of `DiagnosisKeyServer`'s responsibility.

Co-authored-by: Matthias Urhahn <matthias.urhahn@sap.com>
---
 .../download/KeyFileDownloader.kt             |  27 ++--
 .../server/DiagnosisKeyServer.kt              |  28 +++--
 .../{KeyFileHeaderHook.kt => DownloadInfo.kt} |  16 +--
 .../download/KeyFileDownloaderTest.kt         | 115 ++++++++++++++++--
 .../diagnosiskeys/server/DownloadInfoTest.kt  |  19 +++
 5 files changed, 162 insertions(+), 43 deletions(-)
 rename Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/{KeyFileHeaderHook.kt => DownloadInfo.kt} (52%)
 create mode 100644 Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfoTest.kt

diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloader.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloader.kt
index 4cf842ea9..686a07972 100644
--- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloader.kt
+++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloader.kt
@@ -2,7 +2,7 @@ package de.rki.coronawarnapp.diagnosiskeys.download
 
 import dagger.Reusable
 import de.rki.coronawarnapp.diagnosiskeys.server.DiagnosisKeyServer
-import de.rki.coronawarnapp.diagnosiskeys.server.KeyFileHeaderHook
+import de.rki.coronawarnapp.diagnosiskeys.server.DownloadInfo
 import de.rki.coronawarnapp.diagnosiskeys.server.LocationCode
 import de.rki.coronawarnapp.diagnosiskeys.storage.CachedKeyInfo
 import de.rki.coronawarnapp.diagnosiskeys.storage.KeyCacheRepository
@@ -12,8 +12,6 @@ import de.rki.coronawarnapp.storage.AppSettings
 import de.rki.coronawarnapp.storage.DeviceStorage
 import de.rki.coronawarnapp.storage.InsufficientStorageException
 import de.rki.coronawarnapp.util.CWADebug
-import de.rki.coronawarnapp.util.HashExtensions.hashToMD5
-import de.rki.coronawarnapp.util.debug.measureTimeMillisWithResult
 import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.async
 import kotlinx.coroutines.awaitAll
@@ -300,28 +298,25 @@ class KeyFileDownloader @Inject constructor(
         keyInfo: CachedKeyInfo,
         saveTo: File
     ): Pair<CachedKeyInfo, File>? = try {
-        val validation = KeyFileHeaderHook { headers ->
-            // tryMigration returns true when a file was migrated, meaning, no download necessary
-            return@KeyFileHeaderHook !legacyKeyCache.tryMigration(
-                headers.getPayloadChecksumMD5(),
-                saveTo
-            )
-        }
+        val preconditionHook: suspend (DownloadInfo) -> Boolean =
+            { downloadInfo ->
+                val continueDownload = !legacyKeyCache.tryMigration(
+                    downloadInfo.serverMD5, saveTo
+                )
+                continueDownload // Continue download if no migration happened
+            }
 
-        keyServer.downloadKeyFile(
+        val dlInfo = keyServer.downloadKeyFile(
             locationCode = keyInfo.location,
             day = keyInfo.day,
             hour = keyInfo.hour,
             saveTo = saveTo,
-            headerHook = validation
+            precondition = preconditionHook
         )
 
         Timber.tag(TAG).v("Dowwnload finished: %s -> %s", keyInfo, saveTo)
 
-        val (downloadedMD5, duration) = measureTimeMillisWithResult { saveTo.hashToMD5() }
-        Timber.tag(TAG).v("Hashed to MD5 in %dms: %s", duration, saveTo)
-
-        keyCache.markKeyComplete(keyInfo, downloadedMD5)
+        keyCache.markKeyComplete(keyInfo, dlInfo.serverMD5 ?: dlInfo.localMD5!!)
         keyInfo to saveTo
     } catch (e: Exception) {
         Timber.tag(TAG).e(e, "Download failed: %s", keyInfo)
diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DiagnosisKeyServer.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DiagnosisKeyServer.kt
index 8995751a9..acf0c392b 100644
--- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DiagnosisKeyServer.kt
+++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DiagnosisKeyServer.kt
@@ -1,9 +1,10 @@
 package de.rki.coronawarnapp.diagnosiskeys.server
 
 import dagger.Lazy
+import de.rki.coronawarnapp.util.HashExtensions.hashToMD5
+import de.rki.coronawarnapp.util.debug.measureTimeMillisWithResult
 import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.withContext
-import okhttp3.Headers
 import org.joda.time.LocalDate
 import org.joda.time.LocalTime
 import org.joda.time.format.DateTimeFormat
@@ -44,10 +45,6 @@ class DiagnosisKeyServer @Inject constructor(
                 .map { hourString -> LocalTime.parse(hourString, HOUR_FORMATTER) }
         }
 
-    interface HeaderHook {
-        suspend fun validate(headers: Headers): Boolean = true
-    }
-
     /**
      * Retrieves Key Files from the Server
      * Leave **[hour]** null to download a day package
@@ -57,8 +54,8 @@ class DiagnosisKeyServer @Inject constructor(
         day: LocalDate,
         hour: LocalTime? = null,
         saveTo: File,
-        headerHook: HeaderHook = object : HeaderHook {}
-    ) = withContext(Dispatchers.IO) {
+        precondition: suspend (DownloadInfo) -> Boolean = { true }
+    ): DownloadInfo = withContext(Dispatchers.IO) {
         Timber.tag(TAG).v(
             "Starting download: country=%s, day=%s, hour=%s -> %s.",
             locationCode, day, hour, saveTo
@@ -84,9 +81,11 @@ class DiagnosisKeyServer @Inject constructor(
             )
         }
 
-        if (!headerHook.validate(response.headers())) {
-            Timber.tag(TAG).d("validateHeaders() told us to abort.")
-            return@withContext
+        var downloadInfo = DownloadInfo(response.headers())
+
+        if (!precondition(downloadInfo)) {
+            Timber.tag(TAG).d("Precondition is not met, aborting.")
+            return@withContext downloadInfo
         }
         if (response.isSuccessful) {
             saveTo.outputStream().use { target ->
@@ -94,7 +93,14 @@ class DiagnosisKeyServer @Inject constructor(
                     source.copyTo(target, DEFAULT_BUFFER_SIZE)
                 }
             }
-            Timber.tag(TAG).v("Key file download successful: %s", saveTo)
+
+            val (localMD5, duration) = measureTimeMillisWithResult { saveTo.hashToMD5() }
+            Timber.v("Hashed to MD5 in %dms: %s", duration, saveTo)
+
+            downloadInfo = downloadInfo.copy(localMD5 = localMD5)
+            Timber.tag(TAG).v("Key file download successful: %s", downloadInfo)
+
+            return@withContext downloadInfo
         } else {
             throw HttpException(response)
         }
diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/KeyFileHeaderHook.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfo.kt
similarity index 52%
rename from Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/KeyFileHeaderHook.kt
rename to Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfo.kt
index 51b89d5f8..96f0047da 100644
--- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/KeyFileHeaderHook.kt
+++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfo.kt
@@ -2,16 +2,16 @@ package de.rki.coronawarnapp.diagnosiskeys.server
 
 import okhttp3.Headers
 
-class KeyFileHeaderHook(
-    private val onEval: suspend KeyFileHeaderHook.(Headers) -> Boolean
-) : DiagnosisKeyServer.HeaderHook {
+data class DownloadInfo(
+    val headers: Headers,
+    val localMD5: String? = null
+) {
 
-    override suspend fun validate(headers: Headers): Boolean = onEval(headers)
+    val serverMD5 by lazy { headers.getPayloadChecksumMD5() }
 
-    fun Headers.getPayloadChecksumMD5(): String? {
-        // TODO Ping backend regarding alternative checksum sources
-        var fileMD5 = values("ETag").singleOrNull()
-        // The hash from these headers doesn't match, TODO EXPOSUREBACK-178
+    private fun Headers.getPayloadChecksumMD5(): String? {
+        // TODO EXPOSUREBACK-178
+        val fileMD5 = values("ETag").singleOrNull()
 //                var fileMD5 = headers.values("x-amz-meta-cwa-hash-md5").singleOrNull()
 //                if (fileMD5 == null) {
 //                    headers.values("x-amz-meta-cwa-hash").singleOrNull()
diff --git a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloaderTest.kt b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloaderTest.kt
index 455788367..f63d93257 100644
--- a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloaderTest.kt
+++ b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/download/KeyFileDownloaderTest.kt
@@ -2,6 +2,7 @@ package de.rki.coronawarnapp.diagnosiskeys.download
 
 import android.database.SQLException
 import de.rki.coronawarnapp.diagnosiskeys.server.DiagnosisKeyServer
+import de.rki.coronawarnapp.diagnosiskeys.server.DownloadInfo
 import de.rki.coronawarnapp.diagnosiskeys.server.LocationCode
 import de.rki.coronawarnapp.diagnosiskeys.storage.CachedKeyInfo
 import de.rki.coronawarnapp.diagnosiskeys.storage.KeyCacheRepository
@@ -103,7 +104,13 @@ class KeyFileDownloaderTest : BaseIOTest() {
             LocalTime.parse("22"), LocalTime.parse("23")
         )
         coEvery { diagnosisKeyServer.downloadKeyFile(any(), any(), any(), any(), any()) } answers {
-            mockDownloadServerDownload(arg(0), arg(1), arg(2), arg(3), arg(4))
+            mockDownloadServerDownload(
+                locationCode = arg(0),
+                day = arg(1),
+                hour = arg(2),
+                saveTo = arg(3),
+                precondition = arg(4)
+            )
         }
 
         coEvery { keyCache.createCacheEntry(any(), any(), any(), any()) } answers {
@@ -169,10 +176,15 @@ class KeyFileDownloaderTest : BaseIOTest() {
         day: LocalDate,
         hour: LocalTime? = null,
         saveTo: File,
-        validator: DiagnosisKeyServer.HeaderHook = object :
-            DiagnosisKeyServer.HeaderHook {}
-    ) {
+        precondition: suspend (DownloadInfo) -> Boolean = { true },
+        checksumServerMD5: String? = "serverMD5",
+        checksumLocalMD5: String? = "localMD5"
+    ): DownloadInfo {
         saveTo.writeText("$locationCode.$day.$hour")
+        return mockk<DownloadInfo>().apply {
+            every { serverMD5 } returns checksumServerMD5
+            every { localMD5 } returns checksumLocalMD5
+        }
     }
 
     private fun mockAddData(
@@ -184,8 +196,13 @@ class KeyFileDownloaderTest : BaseIOTest() {
     ): Pair<CachedKeyInfo, File> {
         val (keyInfo, file) = mockKeyCacheCreateEntry(type, location, day, hour)
         if (isCompleted) {
-            mockDownloadServerDownload(location, day, hour, file)
-            mockKeyCacheUpdateComplete(keyInfo, "checksum")
+            mockDownloadServerDownload(
+                locationCode = location,
+                day = day,
+                hour = hour,
+                saveTo = file
+            )
+            mockKeyCacheUpdateComplete(keyInfo, "serverMD5")
         }
         return keyRepoData[keyInfo.id]!! to file
     }
@@ -379,7 +396,13 @@ class KeyFileDownloaderTest : BaseIOTest() {
         coEvery { diagnosisKeyServer.downloadKeyFile(any(), any(), any(), any(), any()) } answers {
             dlCounter++
             if (dlCounter == 2) throw IOException("Timeout")
-            mockDownloadServerDownload(arg(0), arg(1), arg(2), arg(3), arg(4))
+            mockDownloadServerDownload(
+                locationCode = arg(0),
+                day = arg(1),
+                hour = arg(2),
+                saveTo = arg(3),
+                precondition = arg(4)
+            )
         }
 
         val downloader = createDownloader()
@@ -610,7 +633,13 @@ class KeyFileDownloaderTest : BaseIOTest() {
         coEvery { diagnosisKeyServer.downloadKeyFile(any(), any(), any(), any(), any()) } answers {
             dlCounter++
             if (dlCounter == 2) throw IOException("Timeout")
-            mockDownloadServerDownload(arg(0), arg(1), arg(2), arg(3), arg(4))
+            mockDownloadServerDownload(
+                locationCode = arg(0),
+                day = arg(1),
+                hour = arg(2),
+                saveTo = arg(3),
+                precondition = arg(4)
+            )
         }
 
         val downloader = createDownloader()
@@ -680,4 +709,74 @@ class KeyFileDownloaderTest : BaseIOTest() {
             )
         }
     }
+
+    @Test
+    fun `store server md5`() {
+        coEvery { diagnosisKeyServer.getCountryIndex() } returns listOf(LocationCode("DE"))
+        coEvery { diagnosisKeyServer.getDayIndex(LocationCode("DE")) } returns listOf(
+            LocalDate.parse("2020-09-01")
+        )
+
+        val downloader = createDownloader()
+
+        runBlocking {
+            downloader.asyncFetchKeyFiles(
+                listOf(LocationCode("DE"))
+            ).size shouldBe 1
+        }
+
+        coVerify {
+            keyCache.createCacheEntry(
+                type = CachedKeyInfo.Type.COUNTRY_DAY,
+                location = LocationCode("DE"),
+                dayIdentifier = LocalDate.parse("2020-09-01"),
+                hourIdentifier = null
+            )
+        }
+        keyRepoData.size shouldBe 1
+        keyRepoData.values.forEach {
+            it.isDownloadComplete shouldBe true
+            it.checksumMD5 shouldBe "serverMD5"
+        }
+    }
+
+    @Test
+    fun `use local MD5 as fallback if there is none available from the server`() {
+        coEvery { diagnosisKeyServer.getCountryIndex() } returns listOf(LocationCode("DE"))
+        coEvery { diagnosisKeyServer.getDayIndex(LocationCode("DE")) } returns listOf(
+            LocalDate.parse("2020-09-01")
+        )
+        coEvery { diagnosisKeyServer.downloadKeyFile(any(), any(), any(), any(), any()) } answers {
+            mockDownloadServerDownload(
+                locationCode = arg(0),
+                day = arg(1),
+                hour = arg(2),
+                saveTo = arg(3),
+                precondition = arg(4),
+                checksumServerMD5 = null
+            )
+        }
+
+        val downloader = createDownloader()
+
+        runBlocking {
+            downloader.asyncFetchKeyFiles(
+                listOf(LocationCode("DE"))
+            ).size shouldBe 1
+        }
+
+        coVerify {
+            keyCache.createCacheEntry(
+                type = CachedKeyInfo.Type.COUNTRY_DAY,
+                location = LocationCode("DE"),
+                dayIdentifier = LocalDate.parse("2020-09-01"),
+                hourIdentifier = null
+            )
+        }
+        keyRepoData.size shouldBe 1
+        keyRepoData.values.forEach {
+            it.isDownloadComplete shouldBe true
+            it.checksumMD5 shouldBe "localMD5"
+        }
+    }
 }
diff --git a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfoTest.kt b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfoTest.kt
new file mode 100644
index 000000000..cceb7f2d5
--- /dev/null
+++ b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/diagnosiskeys/server/DownloadInfoTest.kt
@@ -0,0 +1,19 @@
+package de.rki.coronawarnapp.diagnosiskeys.server
+
+import io.kotest.matchers.shouldBe
+import okhttp3.Headers
+import org.junit.jupiter.api.Test
+import testhelpers.BaseTest
+
+class DownloadInfoTest : BaseTest() {
+
+    @Test
+    fun `extract server MD5`() {
+        val info = DownloadInfo(
+            headers = Headers.headersOf("ETAG", "serverMD5"),
+            localMD5 = "localMD5"
+        )
+        info.serverMD5 shouldBe "serverMD5"
+        info.localMD5 shouldBe "localMD5"
+    }
+}
-- 
GitLab