Exposed icon indicating copy to clipboard operation
Exposed copied to clipboard

PostgreSQL `RETURNING` clause for `INSERT` / `UPDATE` / `DELETE`

Open Jakobeha opened this issue 4 years ago • 5 comments
trafficstars

Hello,

I'm not sure if there are other statements too. But PostgreSQL allows you to specify RETURNING on these statements to return data from the inserted / updated / deleted rows. For instance, you can delete rows and return them at the same time.

This feature would really be useful to me, and I didn't see any existing functionality or issues discussing it.

I'm planning to hack together my own implementation first by subclassing Statement, and I'll post my code and maybe create a PR.

Jakobeha avatar Jun 17 '21 01:06 Jakobeha

Just in case, returning can be both "columns" or "aggregates" (e.g. returning count(*)) Oracle supports that as well.

vlsi avatar Jun 19 '21 12:06 vlsi

Here is the code I have so far. It can surely be cleaned up and might have bugs, but it does the job for me. Just include these 3 files and then use Table#updateReturning and Table#deleteReturning. Remember that PostgreSQL does not support LIMIT.

A lot of it is copy / pasted from jetbrains exposed classes. It would definitely be much cleaner to refactor the base classes to have a returning clause instead. But until then, for anyone who wants this functionality, this is a good workaround.

ReturningStatement.kt

import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.statements.api.PreparedStatementApi
import org.jetbrains.exposed.sql.statements.Statement
import org.jetbrains.exposed.sql.statements.StatementType
import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.sql.ResultSet

abstract class ReturningStatement(type: StatementType, targets: List<Table>) :
  Iterable<ResultRow>, Statement<ResultSet>(type, targets) {
  protected val transaction get() = TransactionManager.current()

  abstract val set: FieldSet

  override fun PreparedStatementApi.executeInternal(transaction: Transaction): ResultSet =
    executeQuery()

  private var iterator: Iterator<ResultRow>? = null

  fun exec() {
    require(iterator == null) { "already executed" }

    val resultIterator = ResultIterator(transaction.exec(this)!!)
    iterator = if (transaction.db.supportsMultipleResultSets) resultIterator
    else Iterable { resultIterator }.toList().iterator()
  }

  override fun iterator(): Iterator<ResultRow> =
    iterator ?: throw IllegalStateException("must call exec() first")

  protected inner class ResultIterator(val rs: ResultSet) : Iterator<ResultRow> {
    private var hasNext: Boolean? = null

    private val fieldsIndex = set.realFields.toSet().mapIndexed { index, expression -> expression to index }.toMap()

    override operator fun next(): ResultRow {
      if (hasNext == null) hasNext()
      if (hasNext == false) throw NoSuchElementException()
      hasNext = null
      return ResultRow.create(rs, fieldsIndex)
    }

    override fun hasNext(): Boolean {
      if (hasNext == null) hasNext = rs.next()
      if (hasNext == false) rs.close()
      return hasNext!!
    }
  }
}

DeleteReturningStatement.kt

import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.statements.StatementType

class DeleteReturningStatement(
  private val table: Table,
  private val where: Op<Boolean>? = null,
  private val limit: Int? = 0,
  private val returning: ColumnSet? = null
) : ReturningStatement(StatementType.DELETE, listOf(table)) {
  override val set: FieldSet = returning ?: table

  override fun prepareSQL(transaction: Transaction): String = buildString {
    append("DELETE FROM ")
    append(transaction.identity(table))
    if (where != null) {
      append(" WHERE ")
      append(QueryBuilder(true).append(where).toString())
    }
    if (limit != null) {
      append(" LIMIT ")
      append(limit)
    }
    append(" RETURNING ")
    if (returning != null) {
      append(QueryBuilder(true).append(returning).toString())
    } else {
      append("*")
    }
  }

  override fun arguments(): Iterable<Iterable<Pair<IColumnType, Any?>>> =
    QueryBuilder(true).run {
      where?.toQueryBuilder(this)
      listOf(args)
    }

  companion object {
    fun where(
      table: Table,
      op: Op<Boolean>,
      limit: Int? = 0,
      returning: ColumnSet? = null
    ): DeleteReturningStatement = DeleteReturningStatement(
      table,
      op,
      limit,
      returning
    ).apply {
      exec()
    }
  }
}

fun Table.deleteReturningWhere(
  limit: Int? = 0,
  returning: ColumnSet? = null,
  where: SqlExpressionBuilder.() -> Op<Boolean>
): DeleteReturningStatement =
  DeleteReturningStatement.where(
    this,
    SqlExpressionBuilder.run(where),
    limit,
    returning
  )

UpdateReturningStatement.kt

import org.jetbrains.exposed.dao.id.EntityID
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.statements.StatementType

class UpdateReturningStatement(
  private val table: Table,
  private val where: Op<Boolean>? = null,
  private val limit: Int? = null,
  private val returning: ColumnSet? = null
) : ReturningStatement(StatementType.DELETE, listOf(table)) {
  override val set: FieldSet = returning ?: table

  private val firstDataSet: List<Pair<Column<*>, Any?>>
    get() = values.toList()

  override fun prepareSQL(transaction: Transaction): String =
    with(QueryBuilder(true)) {
      +"UPDATE "
      table.describe(transaction, this)

      firstDataSet.appendTo(this, prefix = " SET ") { (col, value) ->
        append("${transaction.identity(col)}=")
        registerArgument(col, value)
      }

      where?.let {
        +" WHERE "
        +it
      }
      limit?.let {
        +" LIMIT "
        +it
      }

      +" RETURNING "
      if (returning != null) {
        append(returning)
      } else {
        +"*"
      }

      toString()
    }

  override fun arguments(): Iterable<Iterable<Pair<IColumnType, Any?>>> =
    QueryBuilder(true).run {
      for ((key, value) in values) {
        registerArgument(key, value)
      }
      where?.toQueryBuilder(this)
      listOf(args)
    }

  // region UpdateBuilder
  private val values: MutableMap<Column<*>, Any?> = LinkedHashMap()

  operator fun <S> set(column: Column<S>, value: S) {
    when {
      values.containsKey(column) -> error("$column is already initialized")
      !column.columnType.nullable && value == null -> error("Trying to set null to not nullable column $column")
      else -> values[column] = value
    }
  }

  @JvmName("setWithEntityIdExpression")
  operator fun <S, ID : EntityID<S>, E : Expression<S>> set(
    column: Column<ID>,
    value: E
  ) {
    require(!values.containsKey(column)) { "$column is already initialized" }
    values[column] = value
  }

  @JvmName("setWithEntityIdValue")
  operator fun <S : Comparable<S>, ID : EntityID<S>, E : S?> set(
    column: Column<ID>,
    value: E
  ) {
    require(!values.containsKey(column)) { "$column is already initialized" }
    values[column] = value
  }

  operator fun <T, S : T, E : Expression<S>> set(column: Column<T>, value: E) =
    update(column, value)

  operator fun <S> set(column: CompositeColumn<S>, value: S) {
    @Suppress("UNCHECKED_CAST")
    column.getRealColumnsWithValues(value).forEach { (realColumn, itsValue) ->
      set(
        realColumn as Column<Any?>,
        itsValue
      )
    }
  }

  fun <T, S : T?> update(column: Column<T>, value: Expression<S>) {
    require(!values.containsKey(column)) { "$column is already initialized" }
    values[column] = value
  }

  fun <T, S : T?> update(
    column: Column<T>,
    value: SqlExpressionBuilder.() -> Expression<S>
  ) {
    require(!values.containsKey(column)) { "$column is already initialized" }
    values[column] = SqlExpressionBuilder.value()
  }
  // endregion
}

fun <T : Table> T.updateReturning(
  where: SqlExpressionBuilder.() -> Op<Boolean>,
  limit: Int? = null,
  returning: ColumnSet? = null,
  body: T.(UpdateReturningStatement) -> Unit
): UpdateReturningStatement = UpdateReturningStatement(
  this,
  SqlExpressionBuilder.run(where),
  limit,
  returning
).apply {
  [email protected](this)
  exec()
}

Jakobeha avatar Jun 19 '21 15:06 Jakobeha

In case of need only returning * on update there is shorter impl.

class UpdateReturningStatement(
    table: Table,
    where: Op<Boolean>? = null,
) : UpdateStatement(table, null, where) {

    var resultRows: List<ResultRow> = listOf()
        private set

    override fun PreparedStatementApi.executeInternal(transaction: Transaction): Int {
        if (values.isEmpty()) return 0
        //executeUpdate is return only number of affected so it can't be used
        val updatedReturning = executeQuery()
        resultRows = ResultIterator(updatedReturning, targetsSet).iterator().asSequence().toList()

        return resultRows.size
    }

    override fun prepareSQL(transaction: Transaction): String {
        val sql = super.prepareSQL(transaction)
        return QueryBuilder(prepared = true).apply {
            append(sql)
            targetsSet.realFields.appendTo(prefix = " RETURNING ") {
                it.toQueryBuilder(this)
            }
        }.toString()
    }

    //copied from AbstractQuery
    private class ResultIterator(
        private val rs: ResultSet,
        fieldSet: FieldSet
    ) : Iterator<ResultRow> {
        private var hasNext: Boolean? = null
        private val fieldsIndex = fieldSet.realFields.toSet().mapIndexed { index, expression -> expression to index }.toMap()

        override operator fun next(): ResultRow {
            if (hasNext == null) hasNext()
            if (hasNext == false) throw NoSuchElementException()
            hasNext = null
            return ResultRow.create(rs, fieldsIndex)
        }

        override fun hasNext(): Boolean {
            if (hasNext == null) hasNext = rs.next()
            if (hasNext == false) rs.close()
            return hasNext!!
        }
    }
}

fun <T : Table> T.updateReturning(
    where: SqlExpressionBuilder.() -> Op<Boolean>,
    body: T.(UpdateReturningStatement) -> Unit
): List<ResultRow> {
    val statement = UpdateReturningStatement(
        this,
        SqlExpressionBuilder.run(where)
    )
    body(statement)
    statement.execute(TransactionManager.current())!!

    return statement.resultRows
}

stengvac avatar Oct 01 '21 16:10 stengvac

Hi @Jakobeha, thanks for submitting this issue and the code snippets. Please go ahead and open a PR for this, including the necessary tests for the functionality, and we will review and get back to you.

joc-a avatar May 09 '23 09:05 joc-a

Hey @Jakobeha @joc-a , what's the status on this? has a PR been made?

Flaxoos avatar Oct 30 '23 14:10 Flaxoos