fluent-jdbc icon indicating copy to clipboard operation
fluent-jdbc copied to clipboard

Add DbContextExtension for junit5

Open norrs opened this issue 3 years ago • 7 comments

https://github.com/jhannes/fluent-jdbc/blob/main/src/main/java/org/fluentjdbc/opt/junit/DbContextRule.java needs to be converted to something like this:

public class DbContextExtension extends DbContext 
        implements BeforeEachCallback, AfterEachCallback {
    val DbContextConnection ignoredConnection     
    
    // Constructors is bad most likley, need to check how to provide constructor argument to an Extension class. 
    // Maybe allow to set datasource and databaseStatementFactory via setters in a @BeforeClass/@Before  when using the extension? 
    
    
    public DbContextExtension(DataSource dataSource) {
        this(dataSource, new DatabaseStatementFactory(DatabaseReporter.LOGGING_REPORTER));
    }

    public DbContextExtension(DataSource dataSource, DatabaseStatementFactory factory) {
        super(factory);
        this.dataSource = dataSource;
    }

    @Override
    public void beforeEach(ExtensionContext context) throws Exception {
      ignoredConnection = startConnection(dataSource)
    }

    @Override
    public void afterEach(ExtensionContext context) throws Exception {
      if (ignoredConnection != null) ignoredConnection.close()
    }
}

norrs avatar Aug 01 '22 11:08 norrs

Because extensions are used like this:

class JUnit5ServerTest {
    @RegisterExtension
    static ServerExtension extension = new ServerExtension();

    @Test
    void serverIsRunning() {
        Assertions.assertTrue(extension.getServer().isRunning());
    }
}

See https://www.arhohuttunen.com/junit-5-migration/ for helpful introduction on changes between junit4 and junit5.

norrs avatar Aug 01 '22 11:08 norrs

https://www.baeldung.com/junit-5-extensions has quick introduction as well, and mentions we can use a static field for providing arguments for the extension. Ctrl+f we should annotate a static field with the @RegisterExtension annotation

norrs avatar Aug 01 '22 11:08 norrs

I've created a similar extension in my projects using fluent-jdbc, but it seems I'm stuck at how to provide the data source. Any ideas @norrs ?

jhannes avatar Aug 17 '22 06:08 jhannes

@jhannes I do, I made an internal version in Kotlin:

companion objects are basically static methods/fields.

class TestDataSourceFactory {
    companion object {

        fun getDataSource(): DataSource =
            when (System.getenv("ALP_USE_TESTCONTAINERS")) {
                "true" -> getDataSource(DataSourceImplementation.CONTAINER_DATASOURCE)
                else -> getDataSource(DataSourceImplementation.DIRECT_DATASOURCE)
            }

        fun getDataSource(type: DataSourceImplementation): DataSource =
            when (type) {
                DataSourceImplementation.DIRECT_DATASOURCE -> TestDirectDataSource().getDataSource()
                DataSourceImplementation.CONTAINER_DATASOURCE -> TestContainerDataSource().getDataSource()
            }
    }
}


This impl. probably seen better days but it works(trademark) partly, as it is rather optismitic locking which we do experience race conditions with:

class TestDirectDataSource {
    fun getDataSource(): DataSource {
        return dataSourceForThread()
    }

    private fun dataSourceForThread() = dataSource.getOrSet { getAvailableDatasource() }

    companion object {
        private val logger = LoggerFactory.getLogger(TestDirectDataSource::class.java)
        private var dataSource: ThreadLocal<DataSource> = ThreadLocal()
        private val jdbcUrl = System.getenv("ALP_TEST_JDBC_URL") ?: "jdbc:postgresql://localhost:5432/unitTest"
        private val username = System.getenv("ALP_TEST_JDBC_USER") ?: "xxx"
        private val password = System.getenv("ALP_TEST_JDBC_PASSWORD") ?: "xxx"
        private val numberOfDbs = 30
        private var loopCounter = 0

        @Synchronized
        private fun getAvailableDatasource(): DataSource {
            val mainHikariDataSource = hikariDataSource(jdbcUrl, "initialBootstrap")
            requireParallellDbs(mainHikariDataSource)
            mainHikariDataSource.close()
            val selectAvailableDb = selectAvailableDb()
            migrateDb(selectAvailableDb)
            return selectAvailableDb
        }

        private fun migrateDb(dataSource: DataSource) {
            try {
                val flyway = Flyway.configure().cleanDisabled(false).dataSource(dataSource).load()
                flyway.clean()
                flyway.migrate()
                markDbAsLocked(dataSource)
            } catch (fe: FlywayException) {
                logger.error("Failed running migrations against dataSource: {}", dataSource.toString())
                throw fe
            }
        }

        private fun selectAvailableDb(blacklist: List<Int> = emptyList()): DataSource {
            val dbs = (0..(numberOfDbs - 1)).toList().filter { it !in blacklist }
            if (dbs.isEmpty()) {
                loopCounter++
                val errorMessage = "Vi er tomme for ledige db-er!!  (altfor mange PR som må behandles =p)"
                logger.error(errorMessage)
                if (loopCounter > 6) {
                    throw IllegalStateException("Avsluttning i mistanke om at db- testoppsett er galt og gir deg evig løkke")
                }

                Thread.sleep(5000)
                return selectAvailableDb(emptyList())
            }

            val preferredDb = dbs[Random().nextInt(dbs.size)]
            val datasource = ciDataSource(preferredDb)

            if (!lockTableExists(datasource) || available(datasource)) {
                return datasource
            }
            datasource.close()

            return selectAvailableDb(blacklist + preferredDb)
        }

        private fun ciDataSource(preferredDb: Int): HikariDataSource {
            val jdbcUrlToPreferredDb = jdbcUrl.substring(0, jdbcUrl.indexOfLast { it == '/' } + 1) + "testdb_$preferredDb"
            return hikariDataSource(jdbcUrlToPreferredDb, "testdb_$preferredDb")
        }

        private fun requireParallellDbs(mainHikariDataSource: DataSource) {
            val isPreviouslySetup = isAvailableDb(mainHikariDataSource, numberOfDbs - 1)
            if (!isPreviouslySetup) {
                logger.warn("Must set up test db")
                mainHikariDataSource.connection?.use {
                    it.autoCommit = true
                    for (i in 0..(numberOfDbs - 1)) {
                        val exists = it.prepareStatement("SELECT 1 FROM pg_database WHERE datname = 'testdb_$i'")
                            .use { it.executeQuery().next() }
                        if (!exists) {
                            it.createStatement().use {
                                it.executeUpdate(
                                    """
                                        CREATE DATABASE testdb_$i
                                         WITH TEMPLATE = template0
                                         ENCODING = 'UTF8'
                                         LC_COLLATE = 'nb_NO.UTF8'
                                         LC_CTYPE = 'nb_NO.UTF8'
                                         CONNECTION LIMIT = 100;
                                    """.trimIndent()
                                )
                            }
                        }
                    }
                }
            }
        }

        fun isAvailableDb(datasource: DataSource, dbNumber: Int) =
            datasource.connection!!.use {
                it.prepareStatement("SELECT 1 FROM pg_database WHERE datname = 'testdb_$dbNumber'")
                    .use { it.executeQuery().next() }
            }

        fun lockTableExists(datasource: DataSource) =
            datasource.connection!!.use {
                it.prepareStatement("SELECT * from pg_tables where tablename = 'ci_locktable'")
                    .use { it.executeQuery().next() }
            }

        private fun available(datasource: DataSource): Boolean =
            datasource.connection!!.use {
                it.autoCommit = true
                val existingTimeStamp = it.prepareStatement("SELECT lock_time FROM ci_locktable").use {
                    val rs = it.executeQuery()
                    if (!rs.next()) {
                        return true // Nobody has locked it
                    }
                    rs.getTimestamp(1)
                }

                if (existingTimeStamp.after(Timestamp.from(Instant.now().minusSeconds(60)))) {
                    return false // Someone else is using it
                }

                val markAsLockedStatement =
                    it.prepareStatement("UPDATE ci_locktable SET lock_time = now() WHERE lock_time = ?").use {
                        it.setTimestamp(1, existingTimeStamp)
                        it.executeUpdate()
                    }
                markAsLockedStatement > 0
            }

        private fun hikariDataSource(jdbcURL: String, poolName: String): HikariDataSource {
            val mainHikariDataSource = HikariDataSource()
            mainHikariDataSource.poolName = poolName
            mainHikariDataSource.jdbcUrl = jdbcURL
            mainHikariDataSource.username = username
            mainHikariDataSource.password = password
            mainHikariDataSource.isAutoCommit = false
            mainHikariDataSource.leakDetectionThreshold = 4000
            return mainHikariDataSource
        }

        fun markDbAsLocked(dataSourceForThread: DataSource) {
            dataSourceForThread.connection!!.use {
                it.autoCommit = true
                it.createStatement().use {
                    it.executeUpdate(
                        """
                    CREATE TABLE IF NOT EXISTS ci_locktable(
                        lock_time timestamptz NOT NULL DEFAULT now()
                    )
                        """.trimIndent()
                    )
                }
                it.createStatement().use { it.execute("INSERT INTO ci_locktable VALUES (now())") }
            }
        }
    }
}

Hence we have this with using test containers:

private const val dockerImageTag = "11-nbno"

class TestContainerDataSource {
    fun getDataSource(): DataSource {
        return dataSourceForThread()
    }

    private fun dataSourceForThread() = dataSource.getOrSet { getAvailableDatasource() }

    companion object {
        private val logger = LoggerFactory.getLogger(TestContainerDataSource::class.java)
        private var dataSource: ThreadLocal<DataSource> = ThreadLocal()
        private val dbId: ThreadLocal<String> = ThreadLocal.withInitial {
            UUID.randomUUID().toString().replace("-", "")
        }

        @JvmStatic
        fun verifyDatabaseLocale(connection: Connection) {
            connection.prepareStatement("SHOW LC_COLLATE").use {
                verifyNorwegianLocale(it)
            }
            connection.prepareStatement("SHOW LC_CTYPE").use {
                verifyNorwegianLocale(it)
            }
        }

        private fun verifyNorwegianLocale(it: PreparedStatement) {
            val executeQuery = it.executeQuery()
            executeQuery.next()
            val result = executeQuery.getString(1)
            if (result != "nb_NO.UTF8") {
                throw IllegalStateException(
                    "LC_COLLATE and LC_CTYPE must be nb_NO.UTF8. Found $result. Check image postgres:$dockerImageTag"
                )
            }
        }
    }

    @Synchronized
    private fun getAvailableDatasource(): DataSource {
        val mainHikariDataSource = hikariDataSource(
            "jdbc:tc:postgresql:$dockerImageTag:///testdb${dbId.get()}" +
                "?TC_INITFUNCTION=no.unit.alp.TestContainerDataSource::verifyDatabaseLocale",
            "testdb_${dbId.get()}"
        )
        migrateDb(mainHikariDataSource)
        return mainHikariDataSource
    }

    private fun migrateDb(dataSource: DataSource) {
        try {
            val flyway = Flyway.configure().cleanDisabled(false).dataSource(dataSource).load()
            flyway.clean()
            flyway.migrate()
        } catch (fe: FlywayException) {
            logger.error("Failed running migrations against dataSource: {}", dataSource.toString())
            throw fe
        }
    }

    private fun hikariDataSource(jdbcURL: String, poolName: String): HikariDataSource {
        val mainHikariDataSource = HikariDataSource()
        mainHikariDataSource.poolName = poolName
        mainHikariDataSource.jdbcUrl = jdbcURL
        mainHikariDataSource.username = "not_relevant_notused_in_this_container_setup"
        mainHikariDataSource.password = "not_relevant_notused_in_this_container_setup"
        mainHikariDataSource.isAutoCommit = false
        mainHikariDataSource.leakDetectionThreshold = 4000
        return mainHikariDataSource
    }
}
/**
 * DbContextExtension is used for tests to prepare a database connection for usage.
 *
 * It uses savepoint for fast rollbacks between each test,
 * and beforeAll test cases it updates [DbContextProvider] and opening the database connection.
 * After all tests in the test case has run, it closes the connection.
 *
 * This should be the default DbContext extension to use!
 *
 * If you for some reason have tests which obtains the threadConnection and issues a commit,
 * you can get em passing by using [DbContextSlowVersionExtension].
 * But please consider rewriting the test!
 */
class DbContextExtension(val dataSource: DataSource) :
    DbContext(DatabaseStatementFactory(DatabaseReporter.LOGGING_REPORTER)),
    BeforeAllCallback,
    AfterAllCallback,
    AfterEachCallback,
    BeforeEachCallback {
    var connectionDuringTestcase: DbContextConnection? = null
    var savepoint: Savepoint? = null

    override fun beforeEach(context: ExtensionContext) {
        savepoint = threadConnection.setSavepoint("beforeEach")
    }

    override fun afterEach(context: ExtensionContext) {
        threadConnection.rollback(savepoint)
        savepoint = null
    }

    override fun beforeAll(context: ExtensionContext?) {
        DbContextProvider.setInstance(this)
        connectionDuringTestcase = startConnection(dataSource)
    }

    override fun afterAll(context: ExtensionContext?) {
        connectionDuringTestcase?.close()
    }
}

I hope this gives you some ideas.

norrs avatar Aug 17 '22 08:08 norrs

in junit5:

class AktivitetskravRepositoryTest {
    companion object {
        @RegisterExtension
        @JvmStatic
        var dbContext = DbContextExtension(TestDataSourceFactory.getDataSource())
    }

    @Test
    fun `should be able to store and fetch `() { }
    
}

norrs avatar Aug 17 '22 08:08 norrs

That wasn't so hard. I'll look into it this week

jhannes avatar Aug 17 '22 14:08 jhannes

I haven't been following up on this too much I'm afraid. I've experimented with JUnit 5 extensions and made some progress, but have been too busy to complete it. It's coming Any Day Now

jhannes avatar Jan 22 '23 11:01 jhannes