add prepared statement implementation
This commit is contained in:
parent
c62bc61bd0
commit
a70b08e3d7
|
@ -0,0 +1,54 @@
|
||||||
|
package gorqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EscapeString sql-escapes a string.
|
||||||
|
func EscapeString(value string) string {
|
||||||
|
replace := [][2]string{
|
||||||
|
{`\`, `\\`},
|
||||||
|
{`\0`, `\\0`},
|
||||||
|
{`\n`, `\\n`},
|
||||||
|
{`\r`, `\\r`},
|
||||||
|
{`"`, `\"`},
|
||||||
|
{`'`, `\'`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, val := range replace {
|
||||||
|
value = strings.Replace(value, val[0], val[1], -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreparedStatement is a simple wrapper around fmt.Sprintf for prepared SQL
|
||||||
|
// statements.
|
||||||
|
type PreparedStatement struct {
|
||||||
|
body string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPreparedStatement takes a sprintf syntax SQL query for later binding of
|
||||||
|
// parameters.
|
||||||
|
func NewPreparedStatement(body string) PreparedStatement {
|
||||||
|
return PreparedStatement{body: body}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind takes arguments and SQL-escapes them, then calling fmt.Sprintf.
|
||||||
|
func (p PreparedStatement) Bind(args ...interface{}) string {
|
||||||
|
var spargs []interface{}
|
||||||
|
|
||||||
|
for _, arg := range args {
|
||||||
|
switch arg.(type) {
|
||||||
|
case string:
|
||||||
|
spargs = append(spargs, `'`+EscapeString(arg.(string))+`'`)
|
||||||
|
case fmt.Stringer:
|
||||||
|
spargs = append(spargs, `'`+EscapeString(arg.(fmt.Stringer).String())+`'`)
|
||||||
|
default:
|
||||||
|
spargs = append(spargs, arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(p.body, spargs...)
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package gorqlite
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestPreparedStatement(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
input string
|
||||||
|
args []interface{}
|
||||||
|
output string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
input: "SELECT * FROM posts WHERE creator=%d",
|
||||||
|
args: []interface{}{42},
|
||||||
|
output: "SELECT * FROM posts WHERE creator=42",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "INSERT INTO posts(body) VALUES(%s)",
|
||||||
|
args: []interface{}{`foo "bar" baz`},
|
||||||
|
output: `INSERT INTO posts(body) VALUES('foo \"bar\" baz')`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cs := range cases {
|
||||||
|
t.Run(cs.input, func(t *testing.T) {
|
||||||
|
p := NewPreparedStatement(cs.input)
|
||||||
|
outp := p.Bind(cs.args...)
|
||||||
|
|
||||||
|
if outp != cs.output {
|
||||||
|
t.Fatalf("expected output to be %s but got: %s", cs.output, outp)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue