Add DbContextExtension for junit5
https://github.com/jhannes/fluent-jdbc/blob/main/src/main/java/org/fluentjdbc/opt/junit/DbContextRule.java needs to be converted to something like this:
public class DbContextExtension extends DbContext
implements BeforeEachCallback, AfterEachCallback {
val DbContextConnection ignoredConnection
// Constructors is bad most likley, need to check how to provide constructor argument to an Extension class.
// Maybe allow to set datasource and databaseStatementFactory via setters in a @BeforeClass/@Before when using the extension?
public DbContextExtension(DataSource dataSource) {
this(dataSource, new DatabaseStatementFactory(DatabaseReporter.LOGGING_REPORTER));
}
public DbContextExtension(DataSource dataSource, DatabaseStatementFactory factory) {
super(factory);
this.dataSource = dataSource;
}
@Override
public void beforeEach(ExtensionContext context) throws Exception {
ignoredConnection = startConnection(dataSource)
}
@Override
public void afterEach(ExtensionContext context) throws Exception {
if (ignoredConnection != null) ignoredConnection.close()
}
}
Because extensions are used like this:
class JUnit5ServerTest {
@RegisterExtension
static ServerExtension extension = new ServerExtension();
@Test
void serverIsRunning() {
Assertions.assertTrue(extension.getServer().isRunning());
}
}
See https://www.arhohuttunen.com/junit-5-migration/ for helpful introduction on changes between junit4 and junit5.
https://www.baeldung.com/junit-5-extensions has quick introduction as well, and mentions we can use a static field for providing arguments for the extension. Ctrl+f we should annotate a static field with the @RegisterExtension annotation
I've created a similar extension in my projects using fluent-jdbc, but it seems I'm stuck at how to provide the data source. Any ideas @norrs ?
@jhannes I do, I made an internal version in Kotlin:
companion objects are basically static methods/fields.
class TestDataSourceFactory {
companion object {
fun getDataSource(): DataSource =
when (System.getenv("ALP_USE_TESTCONTAINERS")) {
"true" -> getDataSource(DataSourceImplementation.CONTAINER_DATASOURCE)
else -> getDataSource(DataSourceImplementation.DIRECT_DATASOURCE)
}
fun getDataSource(type: DataSourceImplementation): DataSource =
when (type) {
DataSourceImplementation.DIRECT_DATASOURCE -> TestDirectDataSource().getDataSource()
DataSourceImplementation.CONTAINER_DATASOURCE -> TestContainerDataSource().getDataSource()
}
}
}
This impl. probably seen better days but it works(trademark) partly, as it is rather optismitic locking which we do experience race conditions with:
class TestDirectDataSource {
fun getDataSource(): DataSource {
return dataSourceForThread()
}
private fun dataSourceForThread() = dataSource.getOrSet { getAvailableDatasource() }
companion object {
private val logger = LoggerFactory.getLogger(TestDirectDataSource::class.java)
private var dataSource: ThreadLocal<DataSource> = ThreadLocal()
private val jdbcUrl = System.getenv("ALP_TEST_JDBC_URL") ?: "jdbc:postgresql://localhost:5432/unitTest"
private val username = System.getenv("ALP_TEST_JDBC_USER") ?: "xxx"
private val password = System.getenv("ALP_TEST_JDBC_PASSWORD") ?: "xxx"
private val numberOfDbs = 30
private var loopCounter = 0
@Synchronized
private fun getAvailableDatasource(): DataSource {
val mainHikariDataSource = hikariDataSource(jdbcUrl, "initialBootstrap")
requireParallellDbs(mainHikariDataSource)
mainHikariDataSource.close()
val selectAvailableDb = selectAvailableDb()
migrateDb(selectAvailableDb)
return selectAvailableDb
}
private fun migrateDb(dataSource: DataSource) {
try {
val flyway = Flyway.configure().cleanDisabled(false).dataSource(dataSource).load()
flyway.clean()
flyway.migrate()
markDbAsLocked(dataSource)
} catch (fe: FlywayException) {
logger.error("Failed running migrations against dataSource: {}", dataSource.toString())
throw fe
}
}
private fun selectAvailableDb(blacklist: List<Int> = emptyList()): DataSource {
val dbs = (0..(numberOfDbs - 1)).toList().filter { it !in blacklist }
if (dbs.isEmpty()) {
loopCounter++
val errorMessage = "Vi er tomme for ledige db-er!! (altfor mange PR som må behandles =p)"
logger.error(errorMessage)
if (loopCounter > 6) {
throw IllegalStateException("Avsluttning i mistanke om at db- testoppsett er galt og gir deg evig løkke")
}
Thread.sleep(5000)
return selectAvailableDb(emptyList())
}
val preferredDb = dbs[Random().nextInt(dbs.size)]
val datasource = ciDataSource(preferredDb)
if (!lockTableExists(datasource) || available(datasource)) {
return datasource
}
datasource.close()
return selectAvailableDb(blacklist + preferredDb)
}
private fun ciDataSource(preferredDb: Int): HikariDataSource {
val jdbcUrlToPreferredDb = jdbcUrl.substring(0, jdbcUrl.indexOfLast { it == '/' } + 1) + "testdb_$preferredDb"
return hikariDataSource(jdbcUrlToPreferredDb, "testdb_$preferredDb")
}
private fun requireParallellDbs(mainHikariDataSource: DataSource) {
val isPreviouslySetup = isAvailableDb(mainHikariDataSource, numberOfDbs - 1)
if (!isPreviouslySetup) {
logger.warn("Must set up test db")
mainHikariDataSource.connection?.use {
it.autoCommit = true
for (i in 0..(numberOfDbs - 1)) {
val exists = it.prepareStatement("SELECT 1 FROM pg_database WHERE datname = 'testdb_$i'")
.use { it.executeQuery().next() }
if (!exists) {
it.createStatement().use {
it.executeUpdate(
"""
CREATE DATABASE testdb_$i
WITH TEMPLATE = template0
ENCODING = 'UTF8'
LC_COLLATE = 'nb_NO.UTF8'
LC_CTYPE = 'nb_NO.UTF8'
CONNECTION LIMIT = 100;
""".trimIndent()
)
}
}
}
}
}
}
fun isAvailableDb(datasource: DataSource, dbNumber: Int) =
datasource.connection!!.use {
it.prepareStatement("SELECT 1 FROM pg_database WHERE datname = 'testdb_$dbNumber'")
.use { it.executeQuery().next() }
}
fun lockTableExists(datasource: DataSource) =
datasource.connection!!.use {
it.prepareStatement("SELECT * from pg_tables where tablename = 'ci_locktable'")
.use { it.executeQuery().next() }
}
private fun available(datasource: DataSource): Boolean =
datasource.connection!!.use {
it.autoCommit = true
val existingTimeStamp = it.prepareStatement("SELECT lock_time FROM ci_locktable").use {
val rs = it.executeQuery()
if (!rs.next()) {
return true // Nobody has locked it
}
rs.getTimestamp(1)
}
if (existingTimeStamp.after(Timestamp.from(Instant.now().minusSeconds(60)))) {
return false // Someone else is using it
}
val markAsLockedStatement =
it.prepareStatement("UPDATE ci_locktable SET lock_time = now() WHERE lock_time = ?").use {
it.setTimestamp(1, existingTimeStamp)
it.executeUpdate()
}
markAsLockedStatement > 0
}
private fun hikariDataSource(jdbcURL: String, poolName: String): HikariDataSource {
val mainHikariDataSource = HikariDataSource()
mainHikariDataSource.poolName = poolName
mainHikariDataSource.jdbcUrl = jdbcURL
mainHikariDataSource.username = username
mainHikariDataSource.password = password
mainHikariDataSource.isAutoCommit = false
mainHikariDataSource.leakDetectionThreshold = 4000
return mainHikariDataSource
}
fun markDbAsLocked(dataSourceForThread: DataSource) {
dataSourceForThread.connection!!.use {
it.autoCommit = true
it.createStatement().use {
it.executeUpdate(
"""
CREATE TABLE IF NOT EXISTS ci_locktable(
lock_time timestamptz NOT NULL DEFAULT now()
)
""".trimIndent()
)
}
it.createStatement().use { it.execute("INSERT INTO ci_locktable VALUES (now())") }
}
}
}
}
Hence we have this with using test containers:
private const val dockerImageTag = "11-nbno"
class TestContainerDataSource {
fun getDataSource(): DataSource {
return dataSourceForThread()
}
private fun dataSourceForThread() = dataSource.getOrSet { getAvailableDatasource() }
companion object {
private val logger = LoggerFactory.getLogger(TestContainerDataSource::class.java)
private var dataSource: ThreadLocal<DataSource> = ThreadLocal()
private val dbId: ThreadLocal<String> = ThreadLocal.withInitial {
UUID.randomUUID().toString().replace("-", "")
}
@JvmStatic
fun verifyDatabaseLocale(connection: Connection) {
connection.prepareStatement("SHOW LC_COLLATE").use {
verifyNorwegianLocale(it)
}
connection.prepareStatement("SHOW LC_CTYPE").use {
verifyNorwegianLocale(it)
}
}
private fun verifyNorwegianLocale(it: PreparedStatement) {
val executeQuery = it.executeQuery()
executeQuery.next()
val result = executeQuery.getString(1)
if (result != "nb_NO.UTF8") {
throw IllegalStateException(
"LC_COLLATE and LC_CTYPE must be nb_NO.UTF8. Found $result. Check image postgres:$dockerImageTag"
)
}
}
}
@Synchronized
private fun getAvailableDatasource(): DataSource {
val mainHikariDataSource = hikariDataSource(
"jdbc:tc:postgresql:$dockerImageTag:///testdb${dbId.get()}" +
"?TC_INITFUNCTION=no.unit.alp.TestContainerDataSource::verifyDatabaseLocale",
"testdb_${dbId.get()}"
)
migrateDb(mainHikariDataSource)
return mainHikariDataSource
}
private fun migrateDb(dataSource: DataSource) {
try {
val flyway = Flyway.configure().cleanDisabled(false).dataSource(dataSource).load()
flyway.clean()
flyway.migrate()
} catch (fe: FlywayException) {
logger.error("Failed running migrations against dataSource: {}", dataSource.toString())
throw fe
}
}
private fun hikariDataSource(jdbcURL: String, poolName: String): HikariDataSource {
val mainHikariDataSource = HikariDataSource()
mainHikariDataSource.poolName = poolName
mainHikariDataSource.jdbcUrl = jdbcURL
mainHikariDataSource.username = "not_relevant_notused_in_this_container_setup"
mainHikariDataSource.password = "not_relevant_notused_in_this_container_setup"
mainHikariDataSource.isAutoCommit = false
mainHikariDataSource.leakDetectionThreshold = 4000
return mainHikariDataSource
}
}
/**
* DbContextExtension is used for tests to prepare a database connection for usage.
*
* It uses savepoint for fast rollbacks between each test,
* and beforeAll test cases it updates [DbContextProvider] and opening the database connection.
* After all tests in the test case has run, it closes the connection.
*
* This should be the default DbContext extension to use!
*
* If you for some reason have tests which obtains the threadConnection and issues a commit,
* you can get em passing by using [DbContextSlowVersionExtension].
* But please consider rewriting the test!
*/
class DbContextExtension(val dataSource: DataSource) :
DbContext(DatabaseStatementFactory(DatabaseReporter.LOGGING_REPORTER)),
BeforeAllCallback,
AfterAllCallback,
AfterEachCallback,
BeforeEachCallback {
var connectionDuringTestcase: DbContextConnection? = null
var savepoint: Savepoint? = null
override fun beforeEach(context: ExtensionContext) {
savepoint = threadConnection.setSavepoint("beforeEach")
}
override fun afterEach(context: ExtensionContext) {
threadConnection.rollback(savepoint)
savepoint = null
}
override fun beforeAll(context: ExtensionContext?) {
DbContextProvider.setInstance(this)
connectionDuringTestcase = startConnection(dataSource)
}
override fun afterAll(context: ExtensionContext?) {
connectionDuringTestcase?.close()
}
}
I hope this gives you some ideas.
in junit5:
class AktivitetskravRepositoryTest {
companion object {
@RegisterExtension
@JvmStatic
var dbContext = DbContextExtension(TestDataSourceFactory.getDataSource())
}
@Test
fun `should be able to store and fetch `() { }
}
That wasn't so hard. I'll look into it this week
I haven't been following up on this too much I'm afraid. I've experimented with JUnit 5 extensions and made some progress, but have been too busy to complete it. It's coming Any Day Now