Browse Source

Implement support for passing Go functions as custom functions to SQLite.

Fixes #226.
David Anderson 3 years ago
parent
commit
cf8fa0af80
5 changed files with 342 additions and 6 deletions
  1. 20 0
      callback.go
  2. 20 3
      doc.go
  3. 191 0
      sqlite3.go
  4. 108 0
      sqlite3_test.go
  5. 3 3
      sqlite3_test/sqltest.go

+ 20 - 0
callback.go

@@ -0,0 +1,20 @@
+// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
+//
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package sqlite3
+
+/*
+#include <sqlite3-binding.h>
+*/
+import "C"
+
+import "unsafe"
+
+//export callbackTrampoline
+func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
+	args := (*[1 << 30]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
+	fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
+	fi.Call(ctx, args)
+}

+ 20 - 3
doc.go

@@ -33,7 +33,7 @@ extension for Regexp matcher operation.
     #include <string.h>
     #include <stdio.h>
     #include <sqlite3ext.h>
-    
+
     SQLITE_EXTENSION_INIT1
     static void regexp_func(sqlite3_context *context, int argc, sqlite3_value **argv) {
       if (argc >= 2) {
@@ -44,7 +44,7 @@ extension for Regexp matcher operation.
         int vec[500];
         int n, rc;
         pcre* re = pcre_compile(pattern, 0, &errstr, &erroff, NULL);
-        rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500); 
+        rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500);
         if (rc <= 0) {
           sqlite3_result_error(context, errstr, 0);
           return;
@@ -52,7 +52,7 @@ extension for Regexp matcher operation.
         sqlite3_result_int(context, 1);
       }
     }
-    
+
     #ifdef _WIN32
     __declspec(dllexport)
     #endif
@@ -91,5 +91,22 @@ you need to hook ConnectHook and get the SQLiteConn.
 					},
 			})
 
+Go SQlite3 Extensions
+
+If you want to register Go functions as SQLite extension functions,
+call RegisterFunction from ConnectHook.
+
+	regex = func(re, s string) (bool, error) {
+		return regexp.MatchString(re, s)
+	}
+	sql.Register("sqlite3_with_go_func",
+			&sqlite3.SQLiteDriver{
+					ConnectHook: func(conn *sqlite3.SQLiteConn) error {
+						return conn.RegisterFunc("regex", regex, true)
+					},
+			})
+
+See the documentation of RegisterFunc for more details.
+
 */
 package sqlite3

+ 191 - 0
sqlite3.go

@@ -66,6 +66,15 @@ _sqlite3_step(sqlite3_stmt* stmt, long long* rowid, long long* changes)
   return rv;
 }
 
+void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
+  sqlite3_result_text(ctx, s, -1, &free);
+}
+
+void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) {
+  sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT);
+}
+
+void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
 */
 import "C"
 import (
@@ -75,6 +84,7 @@ import (
 	"fmt"
 	"io"
 	"net/url"
+	"reflect"
 	"runtime"
 	"strconv"
 	"strings"
@@ -120,6 +130,7 @@ type SQLiteConn struct {
 	db     *C.sqlite3
 	loc    *time.Location
 	txlock string
+	funcs  []*functionInfo
 }
 
 // Tx struct.
@@ -153,6 +164,89 @@ type SQLiteRows struct {
 	cls      bool
 }
 
+type functionInfo struct {
+	f             reflect.Value
+	argConverters []func(*C.sqlite3_value) (reflect.Value, error)
+}
+
+func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) {
+	cstr := C.CString(err.Error())
+	defer C.free(unsafe.Pointer(cstr))
+	C.sqlite3_result_error(ctx, cstr, -1)
+}
+
+func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
+	var args []reflect.Value
+	for i, arg := range argv {
+		v, err := fi.argConverters[i](arg)
+		if err != nil {
+			fi.error(ctx, err)
+			return
+		}
+		args = append(args, v)
+	}
+
+	ret := fi.f.Call(args)
+
+	if len(ret) == 2 && ret[1].Interface() != nil {
+		fi.error(ctx, ret[1].Interface().(error))
+		return
+	}
+
+	res := ret[0].Interface()
+	// Normalize ret to one of the types sqlite knows.
+	switch r := res.(type) {
+	case int64, float64, []byte, string:
+		// Already the right type
+	case bool:
+		if r {
+			res = int64(1)
+		} else {
+			res = int64(0)
+		}
+	case int:
+		res = int64(r)
+	case uint:
+		res = int64(r)
+	case uint8:
+		res = int64(r)
+	case uint16:
+		res = int64(r)
+	case uint32:
+		res = int64(r)
+	case uint64:
+		res = int64(r)
+	case int8:
+		res = int64(r)
+	case int16:
+		res = int64(r)
+	case int32:
+		res = int64(r)
+	case float32:
+		res = float64(r)
+	default:
+		fi.error(ctx, errors.New("cannot convert returned type to sqlite type"))
+		return
+	}
+
+	switch r := res.(type) {
+	case int64:
+		C.sqlite3_result_int64(ctx, C.sqlite3_int64(r))
+	case float64:
+		C.sqlite3_result_double(ctx, C.double(r))
+	case []byte:
+		if len(r) == 0 {
+			C.sqlite3_result_null(ctx)
+		} else {
+			C._sqlite3_result_blob(ctx, unsafe.Pointer(&r[0]), C.int(len(r)))
+		}
+	case string:
+		C._sqlite3_result_text(ctx, C.CString(r))
+	default:
+		panic("unreachable")
+	}
+}
+
 // Commit transaction.
 func (tx *SQLiteTx) Commit() error {
 	_, err := tx.c.exec("COMMIT")
@@ -165,6 +259,103 @@ func (tx *SQLiteTx) Rollback() error {
 	return err
 }
 
+// RegisterFunc makes a Go function available as a SQLite function.
+//
+// The function must accept only arguments of type int64, float64,
+// []byte or string, and return one value of any numeric type except
+// complex, bool, []byte or string. Optionally, an error can be
+// provided as a second return value.
+//
+// If pure is true. SQLite will assume that the function's return
+// value depends only on its inputs, and make more aggressive
+// optimizations in its queries.
+func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error {
+	var fi functionInfo
+	fi.f = reflect.ValueOf(impl)
+	t := fi.f.Type()
+	if t.Kind() != reflect.Func {
+		return errors.New("Non-function passed to RegisterFunc")
+	}
+	if t.IsVariadic() {
+		return errors.New("Variadic SQLite functions are not supported")
+	}
+	if t.NumOut() != 1 && t.NumOut() != 2 {
+		return errors.New("SQLite functions must return 1 or 2 values")
+	}
+	if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
+		return errors.New("Second return value of SQLite function must be error")
+	}
+
+	for i := 0; i < t.NumIn(); i++ {
+		arg := t.In(i)
+		var conv func(*C.sqlite3_value) (reflect.Value, error)
+		switch arg.Kind() {
+		case reflect.Int64:
+			conv = func(v *C.sqlite3_value) (reflect.Value, error) {
+				if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
+					return reflect.Value{}, fmt.Errorf("Argument %d to %s must be an INTEGER", i+1, name)
+				}
+				return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
+			}
+		case reflect.Float64:
+			conv = func(v *C.sqlite3_value) (reflect.Value, error) {
+				if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
+					return reflect.Value{}, fmt.Errorf("Argument %d to %s must be a FLOAT", i+1, name)
+				}
+				return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
+			}
+		case reflect.Slice:
+			if arg.Elem().Kind() != reflect.Uint8 {
+				return errors.New("The only supported slice type is []byte")
+			}
+			conv = func(v *C.sqlite3_value) (reflect.Value, error) {
+				switch C.sqlite3_value_type(v) {
+				case C.SQLITE_BLOB:
+					l := C.sqlite3_value_bytes(v)
+					p := C.sqlite3_value_blob(v)
+					return reflect.ValueOf(C.GoBytes(p, l)), nil
+				case C.SQLITE_TEXT:
+					l := C.sqlite3_value_bytes(v)
+					c := unsafe.Pointer(C.sqlite3_value_text(v))
+					return reflect.ValueOf(C.GoBytes(c, l)), nil
+				default:
+					return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name)
+				}
+			}
+		case reflect.String:
+			conv = func(v *C.sqlite3_value) (reflect.Value, error) {
+				switch C.sqlite3_value_type(v) {
+				case C.SQLITE_BLOB:
+					l := C.sqlite3_value_bytes(v)
+					p := (*C.char)(C.sqlite3_value_blob(v))
+					return reflect.ValueOf(C.GoStringN(p, l)), nil
+				case C.SQLITE_TEXT:
+					c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
+					return reflect.ValueOf(C.GoString(c)), nil
+				default:
+					return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name)
+				}
+			}
+		}
+		fi.argConverters = append(fi.argConverters, conv)
+	}
+
+	// fi must outlast the database connection, or we'll have dangling pointers.
+	c.funcs = append(c.funcs, &fi)
+
+	cname := C.CString(name)
+	defer C.free(unsafe.Pointer(cname))
+	opts := C.SQLITE_UTF8
+	if pure {
+		opts |= C.SQLITE_DETERMINISTIC
+	}
+	rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil)
+	if rv != C.SQLITE_OK {
+		return c.lastError()
+	}
+	return nil
+}
+
 // AutoCommit return which currently auto commit or not.
 func (c *SQLiteConn) AutoCommit() bool {
 	return int(C.sqlite3_get_autocommit(c.db)) != 0

+ 108 - 0
sqlite3_test.go

@@ -15,7 +15,9 @@ import (
 	"net/url"
 	"os"
 	"path/filepath"
+	"regexp"
 	"strings"
+	"sync"
 	"testing"
 	"time"
 
@@ -1056,3 +1058,109 @@ func TestDateTimeNow(t *testing.T) {
 		t.Fatal("Failed to scan datetime:", err)
 	}
 }
+
+func TestFunctionRegistration(t *testing.T) {
+	custom_add := func(a, b int64) (int64, error) {
+		return a + b, nil
+	}
+	custom_regex := func(s, re string) bool {
+		matched, err := regexp.MatchString(re, s)
+		if err != nil {
+			// We should really return the error here, but this
+			// function is also testing single return value functions.
+			panic("Bad regexp")
+		}
+		return matched
+	}
+
+	sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
+		ConnectHook: func(conn *SQLiteConn) error {
+			if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil {
+				return err
+			}
+			if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil {
+				return err
+			}
+			return nil
+		},
+	})
+	db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:")
+	if err != nil {
+		t.Fatal("Failed to open database:", err)
+	}
+	defer db.Close()
+
+	additions := []struct {
+		a, b, c int64
+	}{
+		{1, 1, 2},
+		{1, 3, 4},
+		{1, -1, 0},
+	}
+
+	for _, add := range additions {
+		var i int64
+		err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i)
+		if err != nil {
+			t.Fatal("Failed to call custom_add:", err)
+		}
+		if i != add.c {
+			t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c)
+		}
+	}
+
+	regexes := []struct {
+		re, in string
+		out    bool
+	}{
+		{".*", "foo", true},
+		{"^foo.*", "foobar", true},
+		{"^foo.*", "barfoo", false},
+	}
+
+	for _, re := range regexes {
+		var b bool
+		err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b)
+		if err != nil {
+			t.Fatal("Failed to call regexp:", err)
+		}
+		if b != re.out {
+			t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out)
+		}
+	}
+}
+
+var customFunctionOnce sync.Once
+
+func BenchmarkCustomFunctions(b *testing.B) {
+	customFunctionOnce.Do(func() {
+		custom_add := func(a, b int64) (int64, error) {
+			return a + b, nil
+		}
+
+		sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
+			ConnectHook: func(conn *SQLiteConn) error {
+				// Impure function to force sqlite to reexecute it each time.
+				if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil {
+					return err
+				}
+				return nil
+			},
+		})
+	})
+
+	db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:")
+	if err != nil {
+		b.Fatal("Failed to open database:", err)
+	}
+	defer db.Close()
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		var i int64
+		err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i)
+		if err != nil {
+			b.Fatal("Failed to run custom add:", err)
+		}
+	}
+}

+ 3 - 3
sqlite3_test/sqltest.go

@@ -318,7 +318,7 @@ func BenchmarkQuery(b *testing.B) {
 		var i int
 		var f float64
 		var s string
-//		var t time.Time
+		//		var t time.Time
 		if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
 			panic(err)
 		}
@@ -331,7 +331,7 @@ func BenchmarkParams(b *testing.B) {
 		var i int
 		var f float64
 		var s string
-//		var t time.Time
+		//		var t time.Time
 		if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
 			panic(err)
 		}
@@ -350,7 +350,7 @@ func BenchmarkStmt(b *testing.B) {
 		var i int
 		var f float64
 		var s string
-//		var t time.Time
+		//		var t time.Time
 		if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
 			panic(err)
 		}