From a70b08e3d72f40023cd0c42c80857de6186b14c7 Mon Sep 17 00:00:00 2001 From: Christine Dodrill Date: Thu, 13 Jul 2017 23:29:18 -0700 Subject: [PATCH] add prepared statement implementation --- prepared_statement.go | 54 ++++++++++++++++++++++++++++++++++++++ prepared_statement_test.go | 33 +++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 prepared_statement.go create mode 100644 prepared_statement_test.go diff --git a/prepared_statement.go b/prepared_statement.go new file mode 100644 index 0000000..2ef43ca --- /dev/null +++ b/prepared_statement.go @@ -0,0 +1,54 @@ +package gorqlite + +import ( + "fmt" + "strings" +) + +// EscapeString sql-escapes a string. +func EscapeString(value string) string { + replace := [][2]string{ + {`\`, `\\`}, + {`\0`, `\\0`}, + {`\n`, `\\n`}, + {`\r`, `\\r`}, + {`"`, `\"`}, + {`'`, `\'`}, + } + + for _, val := range replace { + value = strings.Replace(value, val[0], val[1], -1) + } + + return value +} + +// PreparedStatement is a simple wrapper around fmt.Sprintf for prepared SQL +// statements. +type PreparedStatement struct { + body string +} + +// NewPreparedStatement takes a sprintf syntax SQL query for later binding of +// parameters. +func NewPreparedStatement(body string) PreparedStatement { + return PreparedStatement{body: body} +} + +// Bind takes arguments and SQL-escapes them, then calling fmt.Sprintf. +func (p PreparedStatement) Bind(args ...interface{}) string { + var spargs []interface{} + + for _, arg := range args { + switch arg.(type) { + case string: + spargs = append(spargs, `'`+EscapeString(arg.(string))+`'`) + case fmt.Stringer: + spargs = append(spargs, `'`+EscapeString(arg.(fmt.Stringer).String())+`'`) + default: + spargs = append(spargs, arg) + } + } + + return fmt.Sprintf(p.body, spargs...) +} diff --git a/prepared_statement_test.go b/prepared_statement_test.go new file mode 100644 index 0000000..1b2650d --- /dev/null +++ b/prepared_statement_test.go @@ -0,0 +1,33 @@ +package gorqlite + +import "testing" + +func TestPreparedStatement(t *testing.T) { + cases := []struct { + input string + args []interface{} + output string + }{ + { + input: "SELECT * FROM posts WHERE creator=%d", + args: []interface{}{42}, + output: "SELECT * FROM posts WHERE creator=42", + }, + { + input: "INSERT INTO posts(body) VALUES(%s)", + args: []interface{}{`foo "bar" baz`}, + output: `INSERT INTO posts(body) VALUES('foo \"bar\" baz')`, + }, + } + + for _, cs := range cases { + t.Run(cs.input, func(t *testing.T) { + p := NewPreparedStatement(cs.input) + outp := p.Bind(cs.args...) + + if outp != cs.output { + t.Fatalf("expected output to be %s but got: %s", cs.output, outp) + } + }) + } +}