add prepared statement implementation
This commit is contained in:
parent
c62bc61bd0
commit
a70b08e3d7
|
@ -0,0 +1,54 @@
|
|||
package gorqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// EscapeString sql-escapes a string.
|
||||
func EscapeString(value string) string {
|
||||
replace := [][2]string{
|
||||
{`\`, `\\`},
|
||||
{`\0`, `\\0`},
|
||||
{`\n`, `\\n`},
|
||||
{`\r`, `\\r`},
|
||||
{`"`, `\"`},
|
||||
{`'`, `\'`},
|
||||
}
|
||||
|
||||
for _, val := range replace {
|
||||
value = strings.Replace(value, val[0], val[1], -1)
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
// PreparedStatement is a simple wrapper around fmt.Sprintf for prepared SQL
|
||||
// statements.
|
||||
type PreparedStatement struct {
|
||||
body string
|
||||
}
|
||||
|
||||
// NewPreparedStatement takes a sprintf syntax SQL query for later binding of
|
||||
// parameters.
|
||||
func NewPreparedStatement(body string) PreparedStatement {
|
||||
return PreparedStatement{body: body}
|
||||
}
|
||||
|
||||
// Bind takes arguments and SQL-escapes them, then calling fmt.Sprintf.
|
||||
func (p PreparedStatement) Bind(args ...interface{}) string {
|
||||
var spargs []interface{}
|
||||
|
||||
for _, arg := range args {
|
||||
switch arg.(type) {
|
||||
case string:
|
||||
spargs = append(spargs, `'`+EscapeString(arg.(string))+`'`)
|
||||
case fmt.Stringer:
|
||||
spargs = append(spargs, `'`+EscapeString(arg.(fmt.Stringer).String())+`'`)
|
||||
default:
|
||||
spargs = append(spargs, arg)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf(p.body, spargs...)
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package gorqlite
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPreparedStatement(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
args []interface{}
|
||||
output string
|
||||
}{
|
||||
{
|
||||
input: "SELECT * FROM posts WHERE creator=%d",
|
||||
args: []interface{}{42},
|
||||
output: "SELECT * FROM posts WHERE creator=42",
|
||||
},
|
||||
{
|
||||
input: "INSERT INTO posts(body) VALUES(%s)",
|
||||
args: []interface{}{`foo "bar" baz`},
|
||||
output: `INSERT INTO posts(body) VALUES('foo \"bar\" baz')`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.input, func(t *testing.T) {
|
||||
p := NewPreparedStatement(cs.input)
|
||||
outp := p.Bind(cs.args...)
|
||||
|
||||
if outp != cs.output {
|
||||
t.Fatalf("expected output to be %s but got: %s", cs.output, outp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue