From bae7bf74418c42fa45e79dc487262b16eb9662f9 Mon Sep 17 00:00:00 2001
From: Chilja Gossow <49635654+chiljamgossow@users.noreply.github.com>
Date: Thu, 15 Apr 2021 16:02:09 +0200
Subject: [PATCH] Fix risk calc (EXPOSUREAPP-6363) (#2836)

* prevent deletion of matches on config changes
reset config on config changes

* revert arguments
use mutex to guard config refresh

* more logging

* null save

Co-authored-by: harambasicluka <64483219+harambasicluka@users.noreply.github.com>
---
 .../appconfig/ConfigChangeDetector.kt         | 13 +++++-
 .../calculation/PresenceTracingRiskMapper.kt  | 43 +++++++++++++------
 .../execution/PresenceTracingWarningTask.kt   |  5 +++
 .../storage/PresenceTracingRiskRepository.kt  |  8 +++-
 .../risk/storage/BaseRiskLevelStorage.kt      |  8 ++++
 .../risk/storage/RiskLevelStorage.kt          |  2 +
 .../appconfig/ConfigChangeDetectorTest.kt     | 24 ++++++++---
 .../PresenceTracingRiskMapperTest.kt          | 24 ++++++++++-
 .../PresenceTracingWarningTaskTest.kt         | 26 +++++++++++
 9 files changed, 131 insertions(+), 22 deletions(-)

diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/appconfig/ConfigChangeDetector.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/appconfig/ConfigChangeDetector.kt
index 9ba3aef29..4e51b2cbb 100644
--- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/appconfig/ConfigChangeDetector.kt
+++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/appconfig/ConfigChangeDetector.kt
@@ -1,6 +1,7 @@
 package de.rki.coronawarnapp.appconfig
 
 import androidx.annotation.VisibleForTesting
+import de.rki.coronawarnapp.presencetracing.risk.execution.PresenceTracingWarningTask
 import de.rki.coronawarnapp.risk.RiskLevelSettings
 import de.rki.coronawarnapp.risk.RiskLevelTask
 import de.rki.coronawarnapp.risk.storage.RiskLevelStorage
@@ -46,8 +47,16 @@ class ConfigChangeDetector @Inject constructor(
         val oldConfigId = riskLevelSettings.lastUsedConfigIdentifier
         if (newIdentifier != oldConfigId) {
             Timber.tag(TAG).i("New config id ($newIdentifier) differs from last one ($oldConfigId), resetting.")
-            riskLevelStorage.clear()
-            taskController.submit(DefaultTaskRequest(RiskLevelTask::class, originTag = "ConfigChangeDetector"))
+            riskLevelStorage.clearResults()
+            taskController.submit(
+                DefaultTaskRequest(RiskLevelTask::class, originTag = "ConfigChangeDetector")
+            )
+            taskController.submit(
+                DefaultTaskRequest(
+                    PresenceTracingWarningTask::class,
+                    originTag = "ConfigChangeDetector"
+                )
+            )
         } else {
             Timber.tag(TAG).v("Config identifier ($oldConfigId) didn't change, NOOP.")
         }
diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/calculation/PresenceTracingRiskMapper.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/calculation/PresenceTracingRiskMapper.kt
index 06b72ed46..f906c47bb 100644
--- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/calculation/PresenceTracingRiskMapper.kt
+++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/calculation/PresenceTracingRiskMapper.kt
@@ -6,22 +6,35 @@ import de.rki.coronawarnapp.risk.DefaultRiskLevels.Companion.inRange
 import de.rki.coronawarnapp.risk.RiskState
 import de.rki.coronawarnapp.risk.mapToRiskState
 import kotlinx.coroutines.flow.first
+import kotlinx.coroutines.sync.Mutex
+import kotlinx.coroutines.sync.withLock
 import timber.log.Timber
 import javax.inject.Inject
+import javax.inject.Singleton
 
+@Singleton
 class PresenceTracingRiskMapper @Inject constructor(
     private val configProvider: AppConfigProvider
 ) {
     private var presenceTracingRiskCalculationParamContainer: PresenceTracingRiskCalculationParamContainer? = null
 
+    private val mutex = Mutex()
+
+    suspend fun clearConfig() {
+        mutex.withLock {
+            Timber.tag(TAG).i("Clearing config params.")
+            presenceTracingRiskCalculationParamContainer = null
+        }
+    }
+
     suspend fun lookupTransmissionRiskValue(transmissionRiskLevel: Int): Double {
-        return getTransmissionRiskValueMapping()?.find {
+        return getTransmissionRiskValueMapping().find {
             (it.transmissionRiskLevel == transmissionRiskLevel)
         }?.transmissionRiskValue ?: 0.0
     }
 
     suspend fun lookupRiskStatePerDay(normalizedTime: Double): RiskState {
-        return getNormalizedTimePerDayToRiskLevelMapping()?.find {
+        return getNormalizedTimePerDayToRiskLevelMapping().find {
             it.normalizedTimeRange.inRange(normalizedTime)
         }
             ?.riskLevel
@@ -29,7 +42,7 @@ class PresenceTracingRiskMapper @Inject constructor(
     }
 
     suspend fun lookupRiskStatePerCheckIn(normalizedTime: Double): RiskState {
-        return getNormalizedTimePerCheckInToRiskLevelMapping()?.find {
+        return getNormalizedTimePerCheckInToRiskLevelMapping().find {
             it.normalizedTimeRange.inRange(normalizedTime)
         }
             ?.riskLevel
@@ -37,20 +50,26 @@ class PresenceTracingRiskMapper @Inject constructor(
     }
 
     private suspend fun getTransmissionRiskValueMapping() =
-        getRiskCalculationParameters()?.transmissionRiskValueMapping
+        getRiskCalculationParameters().transmissionRiskValueMapping
 
     private suspend fun getNormalizedTimePerDayToRiskLevelMapping() =
-        getRiskCalculationParameters()?.normalizedTimePerDayToRiskLevelMapping
+        getRiskCalculationParameters().normalizedTimePerDayToRiskLevelMapping
 
     private suspend fun getNormalizedTimePerCheckInToRiskLevelMapping() =
-        getRiskCalculationParameters()?.normalizedTimePerCheckInToRiskLevelMapping
+        getRiskCalculationParameters().normalizedTimePerCheckInToRiskLevelMapping
 
-    private suspend fun getRiskCalculationParameters(): PresenceTracingRiskCalculationParamContainer? {
-        if (presenceTracingRiskCalculationParamContainer == null) {
-            presenceTracingRiskCalculationParamContainer =
-                configProvider.currentConfig.first().presenceTracing.riskCalculationParameters
-            Timber.d(presenceTracingRiskCalculationParamContainer.toString())
+    private suspend fun getRiskCalculationParameters(): PresenceTracingRiskCalculationParamContainer = mutex.withLock {
+        presenceTracingRiskCalculationParamContainer.let {
+            if (it == null) {
+                val newParams = configProvider.currentConfig.first().presenceTracing.riskCalculationParameters
+                Timber.d("New params %s", newParams)
+                presenceTracingRiskCalculationParamContainer = newParams
+                newParams
+            } else {
+                it
+            }
         }
-        return presenceTracingRiskCalculationParamContainer
     }
 }
+
+private const val TAG = "PtRiskMapper"
diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/execution/PresenceTracingWarningTask.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/execution/PresenceTracingWarningTask.kt
index 8000745d6..2c8e77307 100644
--- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/execution/PresenceTracingWarningTask.kt
+++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/execution/PresenceTracingWarningTask.kt
@@ -6,6 +6,7 @@ import de.rki.coronawarnapp.eventregistration.checkins.CheckInRepository
 import de.rki.coronawarnapp.exception.ExceptionCategory
 import de.rki.coronawarnapp.exception.reporting.report
 import de.rki.coronawarnapp.presencetracing.risk.calculation.CheckInWarningMatcher
+import de.rki.coronawarnapp.presencetracing.risk.calculation.PresenceTracingRiskMapper
 import de.rki.coronawarnapp.presencetracing.risk.storage.PresenceTracingRiskRepository
 import de.rki.coronawarnapp.presencetracing.warning.download.TraceWarningPackageSyncTool
 import de.rki.coronawarnapp.presencetracing.warning.storage.TraceWarningRepository
@@ -30,6 +31,7 @@ class PresenceTracingWarningTask @Inject constructor(
     private val presenceTracingRiskRepository: PresenceTracingRiskRepository,
     private val traceWarningRepository: TraceWarningRepository,
     private val checkInsRepository: CheckInRepository,
+    private val presenceTracingRiskMapper: PresenceTracingRiskMapper
 ) : Task<PresenceTracingWarningTaskProgress, PresenceTracingWarningTask.Result> {
 
     private val internalProgress = ConflatedBroadcastChannel<PresenceTracingWarningTaskProgress>()
@@ -60,6 +62,9 @@ class PresenceTracingWarningTask @Inject constructor(
         val nowUTC = timeStamper.nowUTC
         checkCancel()
 
+        Timber.tag(TAG).d("Resetting config to make sure latest changes are considered.")
+        presenceTracingRiskMapper.clearConfig()
+
         Timber.tag(TAG).d("Syncing packages.")
         internalProgress.send(PresenceTracingWarningTaskProgress.Downloading())
 
diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/storage/PresenceTracingRiskRepository.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/storage/PresenceTracingRiskRepository.kt
index 1b356f35b..f4b13532a 100644
--- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/storage/PresenceTracingRiskRepository.kt
+++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/presencetracing/risk/storage/PresenceTracingRiskRepository.kt
@@ -31,7 +31,7 @@ import javax.inject.Singleton
 class PresenceTracingRiskRepository @Inject constructor(
     private val presenceTracingRiskCalculator: PresenceTracingRiskCalculator,
     private val databaseFactory: PresenceTracingRiskDatabase.Factory,
-    private val timeStamper: TimeStamper,
+    private val timeStamper: TimeStamper
 ) {
 
     private val database by lazy {
@@ -144,9 +144,15 @@ class PresenceTracingRiskRepository @Inject constructor(
         get() = timeStamper.nowUTC.minus(Days.days(15).toStandardDuration())
 
     suspend fun clearAllTables() {
+        Timber.i("Deleting all matches and results.")
         traceTimeIntervalMatchDao.deleteAll()
         riskLevelResultDao.deleteAll()
     }
+
+    suspend fun clearResults() {
+        Timber.i("Deleting all results.")
+        riskLevelResultDao.deleteAll()
+    }
 }
 
 /*
diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/risk/storage/BaseRiskLevelStorage.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/risk/storage/BaseRiskLevelStorage.kt
index 0582cc3c9..800757c75 100644
--- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/risk/storage/BaseRiskLevelStorage.kt
+++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/risk/storage/BaseRiskLevelStorage.kt
@@ -224,9 +224,17 @@ abstract class BaseRiskLevelStorage constructor(
     override suspend fun clear() {
         Timber.w("clear() - Clearing stored risklevel/exposure-detection results.")
         database.clearAllTables()
+        Timber.w("clear() - Clearing stored presence tracing matches and results.")
         presenceTracingRiskRepository.clearAllTables()
     }
 
+    override suspend fun clearResults() {
+        Timber.w("clearResults() - Clearing stored risklevel/exposure-detection results.")
+        database.clearAllTables()
+        Timber.w("clearResults() - Clearing stored presence tracing results.")
+        presenceTracingRiskRepository.clearResults()
+    }
+
     companion object {
         private const val TAG = "RiskLevelStorage"
     }
diff --git a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/risk/storage/RiskLevelStorage.kt b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/risk/storage/RiskLevelStorage.kt
index a72b458ba..e0d42b72d 100644
--- a/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/risk/storage/RiskLevelStorage.kt
+++ b/Corona-Warn-App/src/main/java/de/rki/coronawarnapp/risk/storage/RiskLevelStorage.kt
@@ -79,4 +79,6 @@ interface RiskLevelStorage {
     suspend fun storeResult(resultEw: EwRiskLevelResult)
 
     suspend fun clear()
+
+    suspend fun clearResults()
 }
diff --git a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/appconfig/ConfigChangeDetectorTest.kt b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/appconfig/ConfigChangeDetectorTest.kt
index d8199e642..39f8b0fb3 100644
--- a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/appconfig/ConfigChangeDetectorTest.kt
+++ b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/appconfig/ConfigChangeDetectorTest.kt
@@ -34,6 +34,7 @@ class ConfigChangeDetectorTest : BaseTest() {
         every { taskController.submit(any()) } just Runs
         every { appConfigProvider.currentConfig } returns currentConfigFake
         coEvery { riskLevelStorage.clear() } just Runs
+        coEvery { riskLevelStorage.clearResults() } just Runs
     }
 
     private fun mockConfigId(id: String): ConfigData {
@@ -59,7 +60,7 @@ class ConfigChangeDetectorTest : BaseTest() {
 
         coVerify(exactly = 0) {
             taskController.submit(any())
-            riskLevelStorage.clear()
+            riskLevelStorage.clearResults()
         }
     }
 
@@ -70,20 +71,25 @@ class ConfigChangeDetectorTest : BaseTest() {
         createInstance().launch()
 
         coVerifySequence {
-            riskLevelStorage.clear()
+            riskLevelStorage.clearResults()
             taskController.submit(any())
+            taskController.submit(any())
+        }
+
+        coVerify(exactly = 0) {
+            riskLevelStorage.clear()
         }
     }
 
     @Test
-    fun `same idetifier results in no op`() {
+    fun `same identifier results in no op`() {
         every { riskLevelSettings.lastUsedConfigIdentifier } returns "initial"
 
         createInstance().launch()
 
         coVerify(exactly = 0) {
             taskController.submit(any())
-            riskLevelStorage.clear()
+            riskLevelStorage.clearResults()
         }
     }
 
@@ -96,10 +102,16 @@ class ConfigChangeDetectorTest : BaseTest() {
         currentConfigFake.value = mockConfigId("berry")
 
         coVerifySequence {
-            riskLevelStorage.clear()
+            riskLevelStorage.clearResults()
+            taskController.submit(any())
+            taskController.submit(any())
+            riskLevelStorage.clearResults()
             taskController.submit(any())
-            riskLevelStorage.clear()
             taskController.submit(any())
         }
+
+        coVerify(exactly = 0) {
+            riskLevelStorage.clear()
+        }
     }
 }
diff --git a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/presencetracing/risk/calculation/PresenceTracingRiskMapperTest.kt b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/presencetracing/risk/calculation/PresenceTracingRiskMapperTest.kt
index dd07536e3..341d5b9bb 100644
--- a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/presencetracing/risk/calculation/PresenceTracingRiskMapperTest.kt
+++ b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/presencetracing/risk/calculation/PresenceTracingRiskMapperTest.kt
@@ -8,6 +8,7 @@ import de.rki.coronawarnapp.server.protocols.internal.v2.RiskCalculationParamete
 import io.kotest.matchers.shouldBe
 import io.mockk.MockKAnnotations
 import io.mockk.coEvery
+import io.mockk.coVerify
 import io.mockk.every
 import io.mockk.impl.annotations.MockK
 import kotlinx.coroutines.flow.flowOf
@@ -50,7 +51,7 @@ class PresenceTracingRiskMapperTest : BaseTest() {
             .setRiskLevel(RiskCalculationParametersOuterClass.NormalizedTimeToRiskLevelMapping.RiskLevel.HIGH)
             .build()
 
-    val container = PresenceTracingRiskCalculationParamContainer(
+    private val container = PresenceTracingRiskCalculationParamContainer(
         transmissionRiskValueMapping = listOf(transmissionRiskValueMapping),
         normalizedTimePerCheckInToRiskLevelMapping = listOf(normalizedTimeMappingLow, normalizedTimeMappingHigh),
         normalizedTimePerDayToRiskLevelMapping = listOf(normalizedTimeMappingLow, normalizedTimeMappingHigh)
@@ -100,5 +101,26 @@ class PresenceTracingRiskMapperTest : BaseTest() {
         }
     }
 
+    @Test
+    fun `config is requested only once`() {
+        runBlockingTest {
+            val mapper = createInstance()
+            mapper.lookupRiskStatePerDay(30.0)
+            mapper.lookupRiskStatePerDay(60.0)
+            coVerify(exactly = 1) { configProvider.currentConfig }
+        }
+    }
+
+    @Test
+    fun `config is requested again after reset`() {
+        runBlockingTest {
+            val mapper = createInstance()
+            mapper.lookupRiskStatePerDay(30.0)
+            mapper.clearConfig()
+            mapper.lookupRiskStatePerDay(60.0)
+            coVerify(exactly = 2) { configProvider.currentConfig }
+        }
+    }
+
     private fun createInstance() = PresenceTracingRiskMapper(configProvider)
 }
diff --git a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/presencetracing/risk/execution/PresenceTracingWarningTaskTest.kt b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/presencetracing/risk/execution/PresenceTracingWarningTaskTest.kt
index 918257658..2c0b73a9c 100644
--- a/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/presencetracing/risk/execution/PresenceTracingWarningTaskTest.kt
+++ b/Corona-Warn-App/src/test/java/de/rki/coronawarnapp/presencetracing/risk/execution/PresenceTracingWarningTaskTest.kt
@@ -2,6 +2,7 @@ package de.rki.coronawarnapp.presencetracing.risk.execution
 
 import de.rki.coronawarnapp.eventregistration.checkins.CheckInRepository
 import de.rki.coronawarnapp.presencetracing.risk.calculation.CheckInWarningMatcher
+import de.rki.coronawarnapp.presencetracing.risk.calculation.PresenceTracingRiskMapper
 import de.rki.coronawarnapp.presencetracing.risk.calculation.createCheckIn
 import de.rki.coronawarnapp.presencetracing.risk.calculation.createWarning
 import de.rki.coronawarnapp.presencetracing.risk.storage.PresenceTracingRiskRepository
@@ -40,6 +41,7 @@ class PresenceTracingWarningTaskTest : BaseTest() {
     @MockK lateinit var presenceTracingRiskRepository: PresenceTracingRiskRepository
     @MockK lateinit var traceWarningRepository: TraceWarningRepository
     @MockK lateinit var checkInsRepository: CheckInRepository
+    @MockK lateinit var presenceTracingRiskMapper: PresenceTracingRiskMapper
 
     @BeforeEach
     fun setup() {
@@ -71,6 +73,8 @@ class PresenceTracingWarningTaskTest : BaseTest() {
             coEvery { deleteStaleData() } just Runs
             coEvery { reportCalculation(any(), any()) } just Runs
         }
+
+        coEvery { presenceTracingRiskMapper.clearConfig() } just Runs
     }
 
     private fun createInstance() = PresenceTracingWarningTask(
@@ -80,6 +84,7 @@ class PresenceTracingWarningTaskTest : BaseTest() {
         presenceTracingRiskRepository = presenceTracingRiskRepository,
         traceWarningRepository = traceWarningRepository,
         checkInsRepository = checkInsRepository,
+        presenceTracingRiskMapper = presenceTracingRiskMapper
     )
 
     @Test
@@ -102,6 +107,27 @@ class PresenceTracingWarningTaskTest : BaseTest() {
         }
     }
 
+    @Test
+    fun `happy path with config change`() = runBlockingTest {
+        createInstance().run(mockk()) shouldNotBe null
+
+        coVerifySequence {
+            presenceTracingRiskMapper.clearConfig()
+            syncTool.syncPackages()
+            presenceTracingRiskRepository.deleteStaleData()
+            checkInsRepository.checkInsWithinRetention
+            traceWarningRepository.unprocessedWarningPackages
+
+            checkInWarningMatcher.process(any(), any())
+
+            presenceTracingRiskRepository.reportCalculation(
+                successful = true,
+                overlaps = any()
+            )
+            traceWarningRepository.markPackagesProcessed(listOf(WARNING_PKG.packageId))
+        }
+    }
+
     @Test
     fun `overall task errors lead to a reported failed calculation`() = runBlockingTest {
         coEvery { syncTool.syncPackages() } throws IOException("Unexpected")
-- 
GitLab