Implement simple query support

This commit is contained in:
luaneko 2025-01-11 02:00:05 +11:00
parent 3793e14f50
commit 137422601b
Signed by: luaneko
GPG Key ID: 406809B8763FF07A
3 changed files with 109 additions and 26 deletions

View File

@ -345,6 +345,7 @@ export interface Row extends Iterable<unknown, void, void> {
} }
export interface QueryOptions { export interface QueryOptions {
readonly simple: boolean;
readonly chunk_size: number; readonly chunk_size: number;
readonly stdin: ReadableStream<Uint8Array> | null; readonly stdin: ReadableStream<Uint8Array> | null;
readonly stdout: WritableStream<Uint8Array> | null; readonly stdout: WritableStream<Uint8Array> | null;
@ -359,6 +360,11 @@ export class Query<T = Row>
this.#f = f; this.#f = f;
} }
simple(simple = true) {
const f = this.#f;
return new Query((o) => f({ simple, ...o }));
}
chunked(chunk_size = 1) { chunked(chunk_size = 1) {
const f = this.#f; const f = this.#f;
return new Query((o) => f({ chunk_size, ...o })); return new Query((o) => f({ chunk_size, ...o }));

36
test.ts
View File

@ -2,9 +2,9 @@ import pglue, { PostgresError, SqlTypeError } from "./mod.ts";
import { expect } from "jsr:@std/expect"; import { expect } from "jsr:@std/expect";
import { toText } from "jsr:@std/streams"; import { toText } from "jsr:@std/streams";
async function connect() { async function connect(params?: Record<string, string>) {
const pg = await pglue.connect(`postgres://test:test@localhost:5432/test`, { const pg = await pglue.connect(`postgres://test:test@localhost:5432/test`, {
runtime_params: { client_min_messages: "INFO" }, runtime_params: { client_min_messages: "INFO", ...params },
}); });
return pg.on("log", (_level, ctx, msg) => { return pg.on("log", (_level, ctx, msg) => {
@ -139,7 +139,7 @@ Deno.test(`sql injection`, async () => {
expect(name).toBe(input); expect(name).toBe(input);
}); });
Deno.test(`pubsub`, async () => { Deno.test(`listen/notify`, async () => {
await using pg = await connect(); await using pg = await connect();
const sent: string[] = []; const sent: string[] = [];
@ -152,6 +152,8 @@ Deno.test(`pubsub`, async () => {
sent.push(payload); sent.push(payload);
await ch.notify(payload); await ch.notify(payload);
} }
expect(sent.length).toBe(0);
}); });
Deno.test(`transactions`, async () => { Deno.test(`transactions`, async () => {
@ -195,15 +197,35 @@ Deno.test(`streaming`, async () => {
await pg.query`create table my_table (field text not null)`; await pg.query`create table my_table (field text not null)`;
for (let i = 0; i < 100; i++) { for (let i = 0; i < 20; i++) {
await pg.query`insert into my_table (field) values (${i})`; await pg.query`insert into my_table (field) values (${i})`;
} }
let i = 0; let i = 0;
for await (const chunk of pg.query`select * from my_table`.chunked(10)) { for await (const chunk of pg.query`select * from my_table`.chunked(5)) {
expect(chunk.length).toBe(10); expect(chunk.length).toBe(5);
for (const row of chunk) expect(row.field).toBe(`${i++}`); for (const row of chunk) expect(row.field).toBe(`${i++}`);
} }
expect(i).toBe(100); expect(i).toBe(20);
});
Deno.test(`simple`, async () => {
await using pg = await connect();
await using _tx = await pg.begin();
const rows = await pg.query`
create table my_table (field text not null);
insert into my_table (field) values ('one'), ('two'), ('three');
select * from my_table;
select * from my_table where field = 'two';
`.simple();
expect(rows.length).toBe(4);
const [{ field: a }, { field: b }, { field: c }, { field: d }] = rows;
expect(a).toBe("one");
expect(b).toBe("two");
expect(c).toBe("three");
expect(d).toBe("two");
}); });

91
wire.ts
View File

@ -34,6 +34,7 @@ import {
} from "./ser.ts"; } from "./ser.ts";
import { import {
type CommandResult, type CommandResult,
format,
is_sql, is_sql,
Query, Query,
type ResultStream, type ResultStream,
@ -916,18 +917,15 @@ function wire_impl(
}), }),
async () => { async () => {
await read(ParseComplete); await read(ParseComplete);
const param_desc = await read(ParameterDescription); const ser_params = make_param_ser(await read(ParameterDescription));
const msg = msg_check_err(await read_raw()); const msg = msg_check_err(await read_raw());
const row_desc = const Row =
msg_type(msg) === NoData.type msg_type(msg) === NoData.type
? { columns: [] } ? EmptyRow
: ser_decode(RowDescription, msg); : make_row_ctor(ser_decode(RowDescription, msg));
return { return { ser_params, Row };
ser_params: make_param_ser(param_desc),
Row: make_row_ctor(row_desc),
};
} }
); );
} catch (e) { } catch (e) {
@ -962,6 +960,7 @@ function wire_impl(
new (columns: (BinaryLike | null)[]): Row; new (columns: (BinaryLike | null)[]): Row;
} }
const EmptyRow = make_row_ctor({ columns: [] });
function make_row_ctor({ columns }: RowDescription) { function make_row_ctor({ columns }: RowDescription) {
const Row = jit.compiled<RowConstructor>`function Row(xs) { const Row = jit.compiled<RowConstructor>`function Row(xs) {
${jit.map(" ", columns, ({ name, type_oid }, i) => { ${jit.map(" ", columns, ({ name, type_oid }, i) => {
@ -1017,11 +1016,15 @@ function wire_impl(
case EmptyQueryResponse.type: case EmptyQueryResponse.type:
return { done: true as const, rows, tag: "" }; return { done: true as const, rows, tag: "" };
case RowDescription.type:
Row = make_row_ctor(ser_decode(RowDescription, msg));
continue;
case CopyInResponse.type: case CopyInResponse.type:
continue; continue;
case CopyOutResponse.type: case CopyOutResponse.type:
await read_copy_out(stdout); await read_copy_out(stdout), (stdout = null);
continue; continue;
} }
} }
@ -1065,6 +1068,53 @@ function wire_impl(
} }
} }
async function* execute_simple(
query: string,
stdin: ReadableStream<Uint8Array> | null,
stdout: WritableStream<Uint8Array> | null
): ResultStream<Row> {
log("debug", { query: query }, `executing simple query`);
const { chunks, err } = await pipeline(
async () => {
await write(QueryMessage, { query });
return write_copy_in(stdin);
},
async () => {
for (let chunks = [], err; ; ) {
const msg = await read_raw();
switch (msg_type(msg)) {
default:
case ReadyForQuery.type:
return { chunks, err };
case RowDescription.type: {
const Row = make_row_ctor(ser_decode(RowDescription, msg));
const { rows } = await read_rows(Row, stdout);
chunks.push(rows);
stdout = null;
continue;
}
case EmptyQueryResponse.type:
case CommandComplete.type:
continue;
case ErrorResponse.type: {
const { fields } = ser_decode(ErrorResponse, msg);
err = new PostgresError(fields);
continue;
}
}
}
}
);
yield* chunks;
if (err) throw err;
return { tag: "" };
}
async function* execute_fast( async function* execute_fast(
st: Statement, st: Statement,
params: unknown[], params: unknown[],
@ -1097,7 +1147,7 @@ function wire_impl(
if (stdin !== null) { if (stdin !== null) {
await write(msg_BE, { B, E }); await write(msg_BE, { B, E });
await write_copy_in(stdin); await write_copy_in(stdin);
await write(Close, C); return write(Close, C);
} else { } else {
return write(msg_BEcC, { B, E, c: {}, C }); return write(msg_BEcC, { B, E, c: {}, C });
} }
@ -1155,7 +1205,7 @@ function wire_impl(
if (stdin !== null) { if (stdin !== null) {
await write(msg_BE, { B, E }); await write(msg_BE, { B, E });
await write_copy_in(stdin); return write_copy_in(stdin);
} else { } else {
return write(msg_BEc, { B, E, c: {} }); return write(msg_BEc, { B, E, c: {} });
} }
@ -1186,15 +1236,20 @@ function wire_impl(
} }
} }
function query(s: SqlFragment) { function query(sql: SqlFragment) {
const { query, params } = sql.format(s); return new Query(
({ simple = false, chunk_size = 0, stdin = null, stdout = null }) => {
const { query, params } = format(sql);
if (simple) {
if (!params.length) return execute_simple(query, stdin, stdout);
else throw new WireError(`simple query cannot be parameterised`);
}
let st = st_cache.get(query); let st = st_cache.get(query);
if (!st) st_cache.set(query, (st = new Statement(query))); if (!st) st_cache.set(query, (st = new Statement(query)));
if (!chunk_size) return execute_fast(st, params, stdin, stdout);
return new Query(({ chunk_size = 0, stdin = null, stdout = null }) => else return execute_chunked(st, params, chunk_size, stdin, stdout);
chunk_size !== 0 }
? execute_chunked(st, params, chunk_size, stdin, stdout)
: execute_fast(st, params, stdin, stdout)
); );
} }