databricks-sdk-go icon indicating copy to clipboard operation
databricks-sdk-go copied to clipboard

[FEATURE] Replace Databricks Go SDK with Arrow ODBC: 55 rows/sec → 99k rows/sec!

Open caldempsey opened this issue 10 months ago • 3 comments
trafficstars

Problem Statement

Our Go application using the official Databricks SQL Go SDK processes only 100K rows in 30 minutes (~55 rows/second). This creates a significant bottleneck in our data pipeline and increases warehouse costs due to long-running queries.

After some experimentation, our Rust prototype using the arrow-odbc crate processes 21M rows in 212 seconds (~99,000 rows/second), a ~1500x improvement!

Proposed Solution

  • Switch from github.com/databricks/databricks-sql-go to an Arrow ODBC driver
  • Update client to handle Arrow record batches
  • Maintain existing interfaces

Additional Context

  • Databricks recommends Arrow-based protocols for high-performance data transfer, but I can't find evidence of this being implemented in the Go SDK.
  • Both implementations use identical application structure, SQL queries, network conditions, CPU etc.
  • Are there real reasons why the Go SDK can't implement Arrow under the hood and expand the returned rows per batch in-memory?
  • The Go application was implemented using db, err := sql.Open("databricks", dsn). Worth confirming if there are any configuration options we're missing that could improve performance.
  • I'm familiar with the ecosystem so might be able to help implement this if you can point me in the right direction.

To all passers by, here's your free app.

package main

import (
	"bufio"
	"bytes"
	"database/sql"
	"encoding/json"
	"fmt"
	"log"
	"net/http"
	"os"
	"runtime"
	"sync"
	"time"

	_ "github.com/alexbrainman/odbc"
)

// RowWrapper represents a single row with metadata
type RowWrapper struct {
	RowNum    int64                  `json:"row_num"`
	FetchedAt time.Time              `json:"fetched_at"`
	Payload   map[string]interface{} `json:"payload"`
}

// QueryRequest represents the incoming query request
type QueryRequest struct {
	Query string `json:"query"`
}

// DatabricksClient handles database connections and queries
type DatabricksClient struct {
	db *sql.DB
}

// NewDatabricksClient creates a new client with the connection string
func NewDatabricksClient(connectionString string) (*DatabricksClient, error) {
	db, err := sql.Open("odbc", connectionString)
	if err != nil {
		return nil, fmt.Errorf("failed to open database connection: %w", err)
	}

	// Performance tuning: Set connection pool settings – feel free to tweak further
	db.SetMaxOpenConns(64) // allow more concurrency for wide scans
	db.SetMaxIdleConns(32)
	db.SetConnMaxLifetime(5 * time.Minute)
	db.SetConnMaxIdleTime(2 * time.Minute)

	if err := db.Ping(); err != nil {
		return nil, fmt.Errorf("failed to ping database: %w", err)
	}

	return &DatabricksClient{db: db}, nil
}

// StreamQuery executes a query and streams results with heavy-weight optimisations
// – zero allocations per row (apart from the raw data coming from the driver)
// – buffered NDJSON writer to cut syscalls
// – sync.Pool-backed map + bytes.Buffer reuse to relieve the GC
// – RawBytes scanning to skip reflection-heavy interface{} conversions
func (c *DatabricksClient) StreamQuery(query string, w http.ResponseWriter) error {
	rows, err := c.db.Query(query)
	if err != nil {
		return fmt.Errorf("failed to execute query: %w", err)
	}
	defer rows.Close()

	columns, err := rows.Columns()
	if err != nil {
		return fmt.Errorf("failed to get columns: %w", err)
	}

	// HTTP headers – keep-alive + chunked; client decides when to flush
	w.Header().Set("Content-Type", "application/x-ndjson")
	w.Header().Set("Transfer-Encoding", "chunked")
	w.Header().Set("Cache-Control", "no-cache")
	w.WriteHeader(http.StatusOK)

	// Buffered writer (64 KiB) to minimise Write syscalls
	bw := bufio.NewWriterSize(w, 64*1024)
	// Make sure we push out whatever is left at the end
	defer bw.Flush()

	// Pre-build scanning slice of sql.RawBytes (cheaper than interface{})
	values := make([]sql.RawBytes, len(columns))
	scanArgs := make([]interface{}, len(columns))
	for i := range values {
		scanArgs[i] = &values[i]
	}

	// Small helper pools to recycle allocations
	var (
		mapPool = sync.Pool{New: func() any { return make(map[string]interface{}, len(columns)) }}
		bufPool = sync.Pool{New: func() any { return new(bytes.Buffer) }}
	)

	const flushEvery = 16_384 // tune – bigger reduces flushes, lower reduces tail latency

	var (
		rowNum     int64
		flushedCnt int64
	)

	for rows.Next() {
		rowNum++

		if err := rows.Scan(scanArgs...); err != nil {
			log.Printf("scan error on row %d: %v", rowNum, err)
			continue
		}

		// payload map recycle
		payload := mapPool.Get().(map[string]interface{})
		// fast clear – keeps capacity
		for k := range payload {
			delete(payload, k)
		}
		for i, col := range columns {
			// The RawBytes slice is re-used by driver on next Scan, so we *must* copy or convert now
			payload[col] = string(values[i])
		}

		// Row wrapper – FetchedAt per row to keep parity with the Rust impl
		row := RowWrapper{
			RowNum:    rowNum,
			FetchedAt: time.Now().UTC(),
			Payload:   payload,
		}

		// Encode using a pooled bytes.Buffer to avoid fresh allocations
		buf := bufPool.Get().(*bytes.Buffer)
		buf.Reset()
		if err := json.NewEncoder(buf).Encode(&row); err != nil {
			mapPool.Put(payload)
			bufPool.Put(buf)
			return fmt.Errorf("json encode error: %w", err)
		}

		if _, err := bw.Write(buf.Bytes()); err != nil {
			mapPool.Put(payload)
			bufPool.Put(buf)
			return err
		}

		// newline already written by Encoder's Encode()

		// recycle
		mapPool.Put(payload)
		bufPool.Put(buf)

		// adaptive flush – every N rows
		if rowNum-flushedCnt >= flushEvery {
			if err := bw.Flush(); err != nil {
				return err
			}
			flushedCnt = rowNum
		}
	}

	if err := rows.Err(); err != nil {
		return fmt.Errorf("row iteration error: %w", err)
	}

	// final flush handled by defer
	return nil
}

// Close closes the DB connection
func (c *DatabricksClient) Close() error { return c.db.Close() }

func main() {
	runtime.GOMAXPROCS(runtime.NumCPU())
	log.Printf("GOMAXPROCS=%d", runtime.GOMAXPROCS(0))

	// ----- Same env parsing as before (unchanged) -----
	driverPath := os.Getenv("DATABRICKS_ODBC_DRIVER_PATH")
	if driverPath == "" {
		log.Fatal("DATABRICKS_ODBC_DRIVER_PATH must be set")
	}
	workspaceURL := os.Getenv("DATABRICKS_WORKSPACE_URL")
	if workspaceURL == "" {
		log.Fatal("DATABRICKS_WORKSPACE_URL must be set")
	}
	env := os.Getenv("ENV")
	if env == "" {
		log.Fatal("ENV must be set")
	}
	warehouseID := os.Getenv("DATABRICKS_WAREHOUSE_ID")
	if warehouseID == "" {
		log.Fatal("DATABRICKS_WAREHOUSE_ID must be set")
	}
	clientID := os.Getenv("DATABRICKS_OAUTH_CLIENT_ID")
	if clientID == "" {
		log.Fatal("DATABRICKS_OAUTH_CLIENT_ID must be set")
	}
	clientSecret := os.Getenv("DATABRICKS_OAUTH_CLIENT_SECRET")
	if clientSecret == "" {
		log.Fatal("DATABRICKS_OAUTH_CLIENT_SECRET must be set")
	}

	// clean workspace hostname
	host := workspaceURL
	if len(host) > 8 && host[:8] == "https://" {
		host = host[8:]
	}
	if len(host) > 0 && host[len(host)-1] == '/' {
		host = host[:len(host)-1]
	}

	databaseName := fmt.Sprintf("rd-%s-orca-db", env)
	restAddr := os.Getenv("REST_ADDR")
	if restAddr == "" {
		restAddr = "0.0.0.0:8080"
	}

	numThreads := runtime.NumCPU()

	connStr := fmt.Sprintf(
		"Driver=%s;Host=%s;Port=443;HTTPPath=/sql/1.0/warehouses/%s;"+
			"Catalog=%s;SSL=1;ThriftTransport=2;AuthMech=11;Auth_Flow=1;"+
			"Auth_Client_ID=%s;Auth_Client_Secret=%s;RowsFetchedPerBlock=65536;"+
			"EnableQueryResultDownload=1;EnableAsyncExec=1;EnableArrow=1;NumThreads=%d;UseArrowNativeReader=1;",
		driverPath, host, warehouseID, databaseName, clientID, clientSecret, numThreads,
	)

	client, err := NewDatabricksClient(connStr)
	if err != nil {
		log.Fatalf("client init failed: %v", err)
	}
	defer client.Close()

	srv := &http.Server{
		Addr:           restAddr,
		ReadTimeout:    30 * time.Second,
		WriteTimeout:   0, // unlimited – we are streaming
		IdleTimeout:    120 * time.Second,
		MaxHeaderBytes: 1 << 20,
	}

	http.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
		w.Header().Set("Content-Type", "application/json")
		_ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
	})

	http.HandleFunc("/api/v1/query", func(w http.ResponseWriter, r *http.Request) {
		switch r.Method {
		case http.MethodGet:
			q := r.URL.Query().Get("q")
			if q == "" {
				http.Error(w, `{"error":"q required"}`, http.StatusBadRequest)
				return
			}
			execQuery(client, q, w)
		case http.MethodPost:
			if r.Header.Get("Content-Type") != "application/json" {
				http.Error(w, `{"error":"Content-Type must be application/json"}`, http.StatusBadRequest)
				return
			}
			var req QueryRequest
			if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
				http.Error(w, `{"error":"invalid JSON body"}`, http.StatusBadRequest)
				return
			}
			if req.Query == "" {
				http.Error(w, `{"error":"query cannot be empty"}`, http.StatusBadRequest)
				return
			}
			execQuery(client, req.Query, w)
		default:
			http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
		}
	})

	log.Printf("🚀 Databricks ODBC Query Service listening on %s", restAddr)
	if err := srv.ListenAndServe(); err != nil {
		log.Fatal(err)
	}
}

func execQuery(c *DatabricksClient, q string, w http.ResponseWriter) {
	log.Printf("query: %s", q)
	if err := c.StreamQuery(q, w); err != nil {
		log.Printf("query error: %v", err)
		if !headerWritten(w) {
			http.Error(w, fmt.Sprintf(`{"error":"%s"}`, err.Error()), http.StatusInternalServerError)
		}
	}
}

// headerWritten – very best-effort; net/http gives us no official way
func headerWritten(w http.ResponseWriter) bool { return false }

caldempsey avatar Jan 20 '25 00:01 caldempsey