Skip to content
Snippets Groups Projects
Unverified Commit 64383e85 authored by Matthias Urhahn's avatar Matthias Urhahn Committed by GitHub
Browse files

Add coroutine dispatcher injection to fix tests (DEV) (#1299)

* Add coroutine dispatcher injection to deal with flaky tests.
We need to be able to replace Dispatchers.IO in tests, and that requires providing the dispatcher via injection.

* Linting

* Replace remaining static access of Dispatchers.IO.
parent 48d256cb
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,7 @@ import de.rki.coronawarnapp.diagnosiskeys.storage.legacy.LegacyKeyCacheMigration
import de.rki.coronawarnapp.risk.TimeVariables
import de.rki.coronawarnapp.storage.AppSettings
import de.rki.coronawarnapp.storage.DeviceStorage
import kotlinx.coroutines.Dispatchers
import de.rki.coronawarnapp.util.coroutine.DispatcherProvider
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.withContext
......@@ -28,7 +28,8 @@ class KeyFileDownloader @Inject constructor(
private val keyServer: DiagnosisKeyServer,
private val keyCache: KeyCacheRepository,
private val legacyKeyCache: LegacyKeyCacheMigration,
private val settings: AppSettings
private val settings: AppSettings,
private val dispatcherProvider: DispatcherProvider
) {
private suspend fun requireStorageSpace(data: List<CountryData>): DeviceStorage.CheckResult {
......@@ -68,7 +69,7 @@ class KeyFileDownloader @Inject constructor(
* @return list of all files from both the cache and the diff query
*/
suspend fun asyncFetchKeyFiles(wantedCountries: List<LocationCode>): List<File> =
withContext(Dispatchers.IO) {
withContext(dispatcherProvider.IO) {
val availableCountries = keyServer.getCountryIndex()
val filteredCountries = availableCountries.filter { wantedCountries.contains(it) }
Timber.tag(TAG).v(
......@@ -125,7 +126,7 @@ class KeyFileDownloader @Inject constructor(
*/
private suspend fun syncMissingDays(
availableCountries: List<LocationCode>
) = withContext(Dispatchers.IO) {
) = withContext(dispatcherProvider.IO) {
val countriesWithMissingDays = determineMissingDays(availableCountries)
requireStorageSpace(countriesWithMissingDays)
......@@ -256,7 +257,7 @@ class KeyFileDownloader @Inject constructor(
private suspend fun syncMissing3Hours(
availableCountries: List<LocationCode>,
hourItemLimit: Int
) = withContext(Dispatchers.IO) {
) = withContext(dispatcherProvider.IO) {
Timber.tag(TAG).v(
"asyncHandleLast3HoursFilesFetch(availableCountries=%s, hourLimit=%d)",
availableCountries, hourItemLimit
......
package de.rki.coronawarnapp.util.coroutine
import dagger.Binds
import dagger.Module
@Module
abstract class CoroutineModule {
@Binds
abstract fun dispatcherProvider(defaultProvider: DefaultDispatcherProvider): DispatcherProvider
}
package de.rki.coronawarnapp.util.coroutine
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class DefaultDispatcherProvider @Inject constructor() : DispatcherProvider
package de.rki.coronawarnapp.util.coroutine
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
// Need this to improve testing
// Can currently only replace the main-thread dispatcher.
// https://github.com/Kotlin/kotlinx.coroutines/issues/1365
@Suppress("PropertyName", "VariableNaming")
interface DispatcherProvider {
val Default: CoroutineDispatcher
get() = Dispatchers.Default
val Main: CoroutineDispatcher
get() = Dispatchers.Main
val MainImmediate: CoroutineDispatcher
get() = Dispatchers.Main.immediate
val Unconfined: CoroutineDispatcher
get() = Dispatchers.Unconfined
val IO: CoroutineDispatcher
get() = Dispatchers.IO
}
......@@ -28,6 +28,7 @@ import de.rki.coronawarnapp.transaction.SubmitDiagnosisInjectionHelper
import de.rki.coronawarnapp.ui.ActivityBinder
import de.rki.coronawarnapp.util.ConnectivityHelperInjection
import de.rki.coronawarnapp.util.UtilModule
import de.rki.coronawarnapp.util.coroutine.CoroutineModule
import de.rki.coronawarnapp.util.device.DeviceModule
import de.rki.coronawarnapp.util.security.EncryptedPreferencesFactory
import de.rki.coronawarnapp.util.security.EncryptionErrorResetTool
......@@ -39,6 +40,7 @@ import javax.inject.Singleton
modules = [
AndroidSupportInjectionModule::class,
AssistedInjectModule::class,
CoroutineModule::class,
AndroidModule::class,
ReceiverBinder::class,
ServiceBinder::class,
......
......@@ -27,11 +27,8 @@ import org.joda.time.LocalTime
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import testhelpers.BaseIOTest
import testhelpers.extensions.CoroutinesTestExtension
import testhelpers.extensions.InstantExecutorExtension
import testhelpers.flakyTest
import testhelpers.TestDispatcherProvider
import timber.log.Timber
import java.io.File
import java.io.IOException
......@@ -42,7 +39,6 @@ import kotlin.time.ExperimentalTime
*/
@ExperimentalTime
@ExperimentalCoroutinesApi
@ExtendWith(InstantExecutorExtension::class, CoroutinesTestExtension::class)
class KeyFileDownloaderTest : BaseIOTest() {
@MockK
......@@ -113,8 +109,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
locationCode = arg(0),
day = arg(1),
hour = arg(2),
saveTo = arg(3),
precondition = arg(4)
saveTo = arg(3)
)
}
......@@ -172,7 +167,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
checksum: String
) {
keyRepoData[keyInfo.id] = keyInfo.copy(
isDownloadComplete = checksum != null, checksumMD5 = checksum
isDownloadComplete = true, checksumMD5 = checksum
)
}
......@@ -181,7 +176,6 @@ class KeyFileDownloaderTest : BaseIOTest() {
day: LocalDate,
hour: LocalTime? = null,
saveTo: File,
precondition: suspend (DownloadInfo) -> Boolean = { true },
checksumServerMD5: String? = "serverMD5",
checksumLocalMD5: String? = "localMD5"
): DownloadInfo {
......@@ -218,15 +212,15 @@ class KeyFileDownloaderTest : BaseIOTest() {
keyServer = diagnosisKeyServer,
keyCache = keyCache,
legacyKeyCache = legacyMigration,
settings = settings
settings = settings,
dispatcherProvider = TestDispatcherProvider
)
Timber.i("createDownloader(): %s", downloader)
return downloader
}
@Test
fun `wanted country list is empty, day mode`() = flakyTest {
fun `wanted country list is empty, day mode`() {
val downloader = createDownloader()
runBlocking {
downloader.asyncFetchKeyFiles(emptyList()) shouldBe emptyList()
......@@ -234,7 +228,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `wanted country list is empty, hour mode`() = flakyTest {
fun `wanted country list is empty, hour mode`() {
every { settings.isLast3HourModeEnabled } returns true
val downloader = createDownloader()
......@@ -244,7 +238,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `fetching is aborted in day if not enough free storage`() = flakyTest {
fun `fetching is aborted in day if not enough free storage`() {
coEvery { deviceStorage.requireSpacePrivateStorage(1048576L) } throws InsufficientStorageException(
mockk(relaxed = true)
)
......@@ -259,7 +253,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `fetching is aborted in hour if not enough free storage`() = flakyTest {
fun `fetching is aborted in hour if not enough free storage`() {
every { settings.isLast3HourModeEnabled } returns true
coEvery { deviceStorage.requireSpacePrivateStorage(67584L) } throws InsufficientStorageException(
......@@ -276,7 +270,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `error during country index fetch`() = flakyTest {
fun `error during country index fetch`() {
coEvery { diagnosisKeyServer.getCountryIndex() } throws IOException()
val downloader = createDownloader()
......@@ -289,7 +283,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `day fetch without prior data`() = flakyTest {
fun `day fetch without prior data`() {
val downloader = createDownloader()
runBlocking {
......@@ -330,7 +324,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `day fetch with existing data`() = flakyTest {
fun `day fetch with existing data`() {
mockAddData(
type = CachedKeyInfo.Type.COUNTRY_DAY,
location = LocationCode("DE"),
......@@ -377,7 +371,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `day fetch deletes stale data`() = flakyTest {
fun `day fetch deletes stale data`() {
coEvery { diagnosisKeyServer.getDayIndex(LocationCode("DE")) } returns listOf(
LocalDate.parse("2020-09-02")
)
......@@ -425,7 +419,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `day fetch skips single download failures`() = flakyTest {
fun `day fetch skips single download failures`() {
var dlCounter = 0
coEvery { diagnosisKeyServer.downloadKeyFile(any(), any(), any(), any(), any()) } answers {
dlCounter++
......@@ -434,8 +428,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
locationCode = arg(0),
day = arg(1),
hour = arg(2),
saveTo = arg(3),
precondition = arg(4)
saveTo = arg(3)
)
}
......@@ -452,7 +445,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `last3Hours fetch without prior data`() = flakyTest {
fun `last3Hours fetch without prior data`() {
every { settings.isLast3HourModeEnabled } returns true
val downloader = createDownloader()
......@@ -511,7 +504,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `last3Hours fetch with prior data`() = flakyTest {
fun `last3Hours fetch with prior data`() {
every { settings.isLast3HourModeEnabled } returns true
mockAddData(
......@@ -577,7 +570,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `last3Hours fetch deletes stale data`() = flakyTest {
fun `last3Hours fetch deletes stale data`() {
every { settings.isLast3HourModeEnabled } returns true
val (staleKey1, _) = mockAddData(
......@@ -659,7 +652,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `last3Hours fetch skips single download failures`() = flakyTest {
fun `last3Hours fetch skips single download failures`() {
every { settings.isLast3HourModeEnabled } returns true
var dlCounter = 0
......@@ -670,8 +663,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
locationCode = arg(0),
day = arg(1),
hour = arg(2),
saveTo = arg(3),
precondition = arg(4)
saveTo = arg(3)
)
}
......@@ -688,7 +680,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `not completed cache entries are overwritten`() = flakyTest {
fun `not completed cache entries are overwritten`() {
mockAddData(
type = CachedKeyInfo.Type.COUNTRY_DAY,
location = LocationCode("DE"),
......@@ -716,7 +708,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `database errors do not abort the whole process`() = flakyTest {
fun `database errors do not abort the whole process`() {
var completionCounter = 0
coEvery { keyCache.markKeyComplete(any(), any()) } answers {
completionCounter++
......@@ -744,7 +736,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `store server md5`() = flakyTest {
fun `store server md5`() {
coEvery { diagnosisKeyServer.getCountryIndex() } returns listOf(LocationCode("DE"))
coEvery { diagnosisKeyServer.getDayIndex(LocationCode("DE")) } returns listOf(
LocalDate.parse("2020-09-01")
......@@ -774,7 +766,7 @@ class KeyFileDownloaderTest : BaseIOTest() {
}
@Test
fun `use local MD5 as fallback if there is none available from the server`() = flakyTest {
fun `use local MD5 as fallback if there is none available from the server`() {
coEvery { diagnosisKeyServer.getCountryIndex() } returns listOf(LocationCode("DE"))
coEvery { diagnosisKeyServer.getDayIndex(LocationCode("DE")) } returns listOf(
LocalDate.parse("2020-09-01")
......@@ -785,7 +777,6 @@ class KeyFileDownloaderTest : BaseIOTest() {
day = arg(1),
hour = arg(2),
saveTo = arg(3),
precondition = arg(4),
checksumServerMD5 = null
)
}
......
package testhelpers
import de.rki.coronawarnapp.util.coroutine.DispatcherProvider
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
object TestDispatcherProvider : DispatcherProvider {
override val Default: CoroutineDispatcher
get() = Dispatchers.Unconfined
override val Main: CoroutineDispatcher
get() = Dispatchers.Unconfined
override val MainImmediate: CoroutineDispatcher
get() = Dispatchers.Unconfined
override val Unconfined: CoroutineDispatcher
get() = Dispatchers.Unconfined
override val IO: CoroutineDispatcher
get() = Dispatchers.Unconfined
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment