1 /**
2  * Represents connection to the PostgreSQL server
3  *
4  * Most functions is correspond to those in the documentation of Postgres:
5  * $(HTTPS https://www.postgresql.org/docs/current/static/libpq.html)
6  */
7 module dpq2.connection;
8 
9 import dpq2.query;
10 import dpq2.args: QueryParams;
11 import dpq2.result;
12 import dpq2.exception;
13 
14 import derelict.pq.pq;
15 import std.conv: to;
16 import std.string: toStringz, fromStringz;
17 import std.exception: enforce;
18 import std.range;
19 import std.stdio: File;
20 import std.socket;
21 import core.exception;
22 import core.time: Duration;
23 
24 /*
25  * Bugs: On Unix connection is not thread safe.
26  *
27  * On Unix, forking a process with open libpq connections can lead
28  * to unpredictable results because the parent and child processes share
29  * the same sockets and operating system resources. For this reason,
30  * such usage is not recommended, though doing an exec from the child
31  * process to load a new executable is safe.
32 
33 
34 
35 int PQisthreadsafe();
36 Returns 1 if the libpq is thread-safe and 0 if it is not.
37 */
38 
39 private mixin template ConnectionCtors()
40 {
41 
42     /// Makes a new connection to the database server
43     this(string connString)
44     {
45         conn = PQconnectdb(toStringz(connString));
46         version(Dpq2_Dynamic) dynLoaderRefCnt = ReferenceCounter(true);
47         checkCreatedConnection();
48     }
49 
50     /// ditto
51     this(in string[string] keyValueParams)
52     {
53         auto a = keyValueParams.keyValToPQparamsArrays;
54 
55         conn = PQconnectdbParams(&a.keys[0], &a.vals[0], 0);
56         version(Dpq2_Dynamic) dynLoaderRefCnt = ReferenceCounter(true);
57         checkCreatedConnection();
58     }
59 
60 	/// Starts creation of a connection to the database server in a nonblocking manner
61     this(ConnectionStart, string connString)
62     {
63         conn = PQconnectStart(toStringz(connString));
64         version(Dpq2_Dynamic) dynLoaderRefCnt = ReferenceCounter(true);
65         checkCreatedConnection();
66     }
67 
68 	/// ditto
69     this(ConnectionStart, in string[string] keyValueParams)
70     {
71         auto a = keyValueParams.keyValToPQparamsArrays;
72 
73         conn = PQconnectStartParams(&a.keys[0], &a.vals[0], 0);
74         version(Dpq2_Dynamic) dynLoaderRefCnt = ReferenceCounter(true);
75         checkCreatedConnection();
76     }
77 }
78 
79 /// dumb flag for Connection ctor parametrization
80 struct ConnectionStart {};
81 
82 /// Connection
83 class Connection
84 {
85     package PGconn* conn;
86 
87     invariant
88     {
89         assert(conn !is null);
90     }
91 
92     version(Dpq2_Static)
93         mixin ConnectionCtors;
94     else
95     {
96         import dpq2.dynloader: ReferenceCounter;
97 
98         private immutable ReferenceCounter dynLoaderRefCnt;
99 
100         package mixin ConnectionCtors;
101     }
102 
103     private void checkCreatedConnection()
104     {
105         enforce!OutOfMemoryError(conn, "Unable to allocate libpq connection data");
106 
107         if( status == CONNECTION_BAD )
108             throw new ConnectionException(this, __FILE__, __LINE__);
109     }
110 
111     ~this()
112     {
113         PQfinish( conn );
114 
115         version(Dpq2_Dynamic) dynLoaderRefCnt.__custom_dtor();
116     }
117 
118     mixin Queries;
119 
120     /// Returns the blocking status of the database connection
121     bool isNonBlocking()
122     {
123         return PQisnonblocking(conn) == 1;
124     }
125 
126     /// Sets the nonblocking status of the connection
127     private void setNonBlocking(bool state)
128     {
129         if( PQsetnonblocking(conn, state ? 1 : 0 ) == -1 )
130             throw new ConnectionException(this, __FILE__, __LINE__);
131     }
132 
133     /// Begin reset the communication channel to the server, in a nonblocking manner
134     ///
135     /// Useful only for non-blocking operations.
136     void resetStart()
137     {
138         if(PQresetStart(conn) == 0)
139             throw new ConnectionException(this, __FILE__, __LINE__);
140     }
141 
142     /// Useful only for non-blocking operations.
143     PostgresPollingStatusType poll() nothrow
144     {
145         assert(conn);
146 
147         return PQconnectPoll(conn);
148     }
149 
150     /// Useful only for non-blocking operations.
151     PostgresPollingStatusType resetPoll() nothrow
152     {
153         assert(conn);
154 
155         return PQresetPoll(conn);
156     }
157 
158     /// Returns the status of the connection
159     ConnStatusType status() nothrow
160     {
161         return PQstatus(conn);
162     }
163 
164     /**
165         Returns the current in-transaction status of the server.
166         The status can be:
167             * PQTRANS_IDLE    - currently idle
168             * PQTRANS_ACTIVE  - a command is in progress (reported only when a query has been sent to the server and not yet completed)
169             * PQTRANS_INTRANS - idle, in a valid transaction block
170             * PQTRANS_INERROR - idle, in a failed transaction block
171             * PQTRANS_UNKNOWN - reported if the connection is bad
172      */
173     PGTransactionStatusType transactionStatus() nothrow
174     {
175         return PQtransactionStatus(conn);
176     }
177 
178     /// If input is available from the server, consume it
179     ///
180     /// Useful only for non-blocking operations.
181     void consumeInput()
182     {
183         assert(conn);
184 
185         const size_t r = PQconsumeInput( conn );
186         if( r != 1 ) throw new ConnectionException(this, __FILE__, __LINE__);
187     }
188 
189     package bool flush()
190     {
191         assert(conn);
192 
193         auto r = PQflush(conn);
194         if( r == -1 ) throw new ConnectionException(this, __FILE__, __LINE__);
195         return r == 0;
196     }
197 
198     /// Obtains the file descriptor number of the connection socket to the server
199     int posixSocket()
200     {
201         int r = PQsocket(conn);
202 
203         if(r == -1)
204             throw new ConnectionException(this, __FILE__, __LINE__);
205 
206         return r;
207     }
208 
209     /// Obtains duplicate file descriptor number of the connection socket to the server
210     version(Posix)
211     socket_t posixSocketDuplicate()
212     {
213         import core.sys.posix.unistd: dup;
214 
215         static assert(socket_t.sizeof == int.sizeof);
216 
217         return cast(socket_t) dup(cast(socket_t) posixSocket);
218     }
219 
220     /// Obtains duplicate file descriptor number of the connection socket to the server
221     version(Windows)
222     SOCKET posixSocketDuplicate()
223     {
224         import core.stdc.stdlib: malloc, free;
225         import core.sys.windows.winbase: GetCurrentProcessId;
226 
227         auto protocolInfo = cast(WSAPROTOCOL_INFOW*) malloc(WSAPROTOCOL_INFOW.sizeof);
228         scope(failure) free(protocolInfo);
229 
230         int dupStatus = WSADuplicateSocketW(posixSocket, GetCurrentProcessId, protocolInfo);
231 
232         if(dupStatus)
233             throw new ConnectionException("WSADuplicateSocketW error, code "~WSAGetLastError().to!string);
234 
235         SOCKET s = WSASocketW(
236                 FROM_PROTOCOL_INFO,
237                 FROM_PROTOCOL_INFO,
238                 FROM_PROTOCOL_INFO,
239                 protocolInfo,
240                 0,
241                 0
242             );
243 
244         if(s == INVALID_SOCKET)
245             throw new ConnectionException("WSASocket error, code "~WSAGetLastError().to!string);
246 
247         return s;
248     }
249 
250     /// Obtains std.socket.Socket of the connection to the server
251     ///
252     /// Due to a limitation of Dlang Socket actually for the Socket creation
253     /// duplicate of internal posix socket will be used.
254     Socket socket()
255     {
256         version(Windows) static assert(SOCKET.sizeof == socket_t.sizeof);
257 
258         return new Socket(cast(socket_t) posixSocketDuplicate, AddressFamily.UNSPEC);
259     }
260 
261     /// Returns the error message most recently generated by an operation on the connection
262     string errorMessage() const nothrow
263     {
264         return PQerrorMessage(conn).to!string;
265     }
266 
267     /**
268      * Sets or examines the current notice processor
269      *
270      * Returns the previous notice receiver or processor function pointer, and sets the new value.
271      * If you supply a null function pointer, no action is taken, but the current pointer is returned.
272      */
273     PQnoticeProcessor setNoticeProcessor(PQnoticeProcessor proc, void* arg) nothrow
274     {
275         assert(conn);
276 
277         return PQsetNoticeProcessor(conn, proc, arg);
278     }
279 
280     /// Get next result after sending a non-blocking commands. Can return null.
281     ///
282     /// Useful only for non-blocking operations.
283     immutable(Result) getResult()
284     {
285         // is guaranteed by libpq that the result will not be changed until it will not be destroyed
286         auto r = cast(immutable) PQgetResult(conn);
287 
288         if(r)
289         {
290             auto container = new immutable ResultContainer(r);
291             return new immutable Result(container);
292         }
293 
294         return null;
295     }
296 
297     /// Get result after PQexec* functions or throw exception if pull is empty
298     package immutable(ResultContainer) createResultContainer(immutable PGresult* r) const
299     {
300         if(r is null) throw new ConnectionException(this, __FILE__, __LINE__);
301 
302         return new immutable ResultContainer(r);
303     }
304 
305     /// Select single-row mode for the currently-executing query
306     bool setSingleRowMode()
307     {
308         return PQsetSingleRowMode(conn) == 1;
309     }
310 
311     /**
312      Try to cancel query
313 
314      If the cancellation is effective, the current command will
315      terminate early and return an error result or exception. If the
316      cancellation will fails (say, because the server was already done
317      processing the command) there will be no visible result at all.
318     */
319     void cancel()
320     {
321         auto c = new Cancellation(this);
322         c.doCancel;
323     }
324 
325     ///
326     bool isBusy() nothrow
327     {
328         assert(conn);
329 
330         return PQisBusy(conn) == 1;
331     }
332 
333     ///
334     string parameterStatus(string paramName)
335     {
336         assert(conn);
337 
338         auto res = PQparameterStatus(conn, toStringz(paramName));
339 
340         if(res is null)
341             throw new ConnectionException(this, __FILE__, __LINE__);
342 
343         return to!string(fromStringz(res));
344     }
345 
346     ///
347     string escapeLiteral(string msg)
348     {
349         assert(conn);
350 
351         auto buf = PQescapeLiteral(conn, msg.toStringz, msg.length);
352 
353         if(buf is null)
354             throw new ConnectionException(this, __FILE__, __LINE__);
355 
356         string res = buf.fromStringz.to!string;
357 
358         PQfreemem(buf);
359 
360         return res;
361     }
362 
363     ///
364     string escapeIdentifier(string msg)
365     {
366         assert(conn);
367 
368         auto buf = PQescapeIdentifier(conn, msg.toStringz, msg.length);
369 
370         if(buf is null)
371             throw new ConnectionException(this, __FILE__, __LINE__);
372 
373         string res = buf.fromStringz.to!string;
374 
375         PQfreemem(buf);
376 
377         return res;
378     }
379 
380     ///
381     string dbName() const nothrow
382     {
383         assert(conn);
384 
385         return PQdb(conn).fromStringz.to!string;
386     }
387 
388     ///
389     string host() const nothrow
390     {
391         assert(conn);
392 
393         return PQhost(conn).fromStringz.to!string;
394     }
395 
396     ///
397     int protocolVersion() const nothrow
398     {
399         assert(conn);
400 
401         return PQprotocolVersion(conn);
402     }
403 
404     ///
405     int serverVersion() const nothrow
406     {
407         assert(conn);
408 
409         return PQserverVersion(conn);
410     }
411 
412     ///
413     void trace(ref File stream)
414     {
415         PQtrace(conn, stream.getFP);
416     }
417 
418     ///
419     void untrace()
420     {
421         PQuntrace(conn);
422     }
423 
424     ///
425     void setClientEncoding(string encoding)
426     {
427         if(PQsetClientEncoding(conn, encoding.toStringz) != 0)
428             throw new ConnectionException(this, __FILE__, __LINE__);
429     }
430 }
431 
432 // Socket duplication stuff for Win32
433 version(Windows)
434 private
435 {
436     import core.sys.windows.windef;
437     import core.sys.windows.basetyps: GUID;
438 
439     alias GROUP = uint;
440 
441     enum INVALID_SOCKET = 0;
442     enum FROM_PROTOCOL_INFO =-1;
443     enum MAX_PROTOCOL_CHAIN = 7;
444     enum WSAPROTOCOL_LEN = 255;
445 
446     struct WSAPROTOCOLCHAIN
447     {
448         int ChainLen;
449         DWORD[MAX_PROTOCOL_CHAIN] ChainEntries;
450     }
451 
452     struct WSAPROTOCOL_INFOW
453     {
454         DWORD dwServiceFlags1;
455         DWORD dwServiceFlags2;
456         DWORD dwServiceFlags3;
457         DWORD dwServiceFlags4;
458         DWORD dwProviderFlags;
459         GUID ProviderId;
460         DWORD dwCatalogEntryId;
461         WSAPROTOCOLCHAIN ProtocolChain;
462         int iVersion;
463         int iAddressFamily;
464         int iMaxSockAddr;
465         int iMinSockAddr;
466         int iSocketType;
467         int iProtocol;
468         int iProtocolMaxOffset;
469         int iNetworkByteOrder;
470         int iSecurityScheme;
471         DWORD dwMessageSize;
472         DWORD dwProviderReserved;
473         WCHAR[WSAPROTOCOL_LEN+1] szProtocol;
474     }
475 
476     extern(Windows) nothrow @nogc
477     {
478         import core.sys.windows.winsock2: WSAGetLastError;
479         int WSADuplicateSocketW(SOCKET s, DWORD dwProcessId, WSAPROTOCOL_INFOW* lpProtocolInfo);
480         SOCKET WSASocketW(int af, int type, int protocol, WSAPROTOCOL_INFOW*, GROUP, DWORD dwFlags);
481     }
482 }
483 
484 private auto keyValToPQparamsArrays(in string[string] keyValueParams)
485 {
486     static struct PQparamsArrays
487     {
488         immutable(char)*[] keys;
489         immutable(char)*[] vals;
490     }
491 
492     PQparamsArrays a;
493     a.keys.length = keyValueParams.length + 1;
494     a.vals.length = keyValueParams.length + 1;
495 
496     size_t i;
497     foreach(e; keyValueParams.byKeyValue)
498     {
499         a.keys[i] = e.key.toStringz;
500         a.vals[i] = e.value.toStringz;
501 
502         i++;
503     }
504 
505     assert(i == keyValueParams.length);
506 
507     return a;
508 }
509 
510 /// Check connection options in the provided connection string
511 ///
512 /// Throws exception if connection string isn't passes check.
513 version(Dpq2_Static)
514 void connStringCheck(string connString)
515 {
516     _connStringCheck(connString);
517 }
518 
519 /// ditto
520 package void _connStringCheck(string connString)
521 {
522     char* errmsg = null;
523     PQconninfoOption* r = PQconninfoParse(connString.toStringz, &errmsg);
524 
525     if(r is null)
526     {
527         enforce!OutOfMemoryError(errmsg, "Unable to allocate libpq conninfo data");
528     }
529     else
530     {
531         PQconninfoFree(r);
532     }
533 
534     if(errmsg !is null)
535     {
536         string s = errmsg.fromStringz.to!string;
537         PQfreemem(cast(void*) errmsg);
538 
539         throw new ConnectionException(s, __FILE__, __LINE__);
540     }
541 }
542 
543 /// Represents query cancellation process
544 class Cancellation
545 {
546     version(Dpq2_Dynamic)
547     {
548         import dpq2.dynloader: ReferenceCounter;
549         private immutable ReferenceCounter dynLoaderRefCnt;
550     }
551 
552     private PGcancel* cancel;
553 
554     ///
555     this(Connection c)
556     {
557         version(Dpq2_Dynamic) dynLoaderRefCnt = ReferenceCounter(true);
558 
559         cancel = PQgetCancel(c.conn);
560 
561         if(cancel is null)
562             throw new ConnectionException(c, __FILE__, __LINE__);
563     }
564 
565     ///
566     ~this()
567     {
568         PQfreeCancel(cancel);
569 
570         version(Dpq2_Dynamic) dynLoaderRefCnt.__custom_dtor();
571     }
572 
573     /**
574      Requests that the server abandon processing of the current command
575 
576      Throws exception if cancel request was not successfully dispatched.
577 
578      Successful dispatch is no guarantee that the request will have any
579      effect, however. If the cancellation is effective, the current
580      command will terminate early and return an error result
581      (exception). If the cancellation fails (say, because the server
582      was already done processing the command), then there will be no
583      visible result at all.
584     */
585     void doCancel()
586     {
587         char[256] errbuf;
588         auto res = PQcancel(cancel, errbuf.ptr, errbuf.length);
589 
590         if(res != 1)
591             throw new CancellationException(to!string(errbuf.ptr.fromStringz), __FILE__, __LINE__);
592     }
593 }
594 
595 ///
596 class CancellationException : Dpq2Exception
597 {
598     this(string msg, string file = __FILE__, size_t line = __LINE__)
599     {
600         super(msg, file, line);
601     }
602 }
603 
604 /// Connection exception
605 class ConnectionException : Dpq2Exception
606 {
607     this(in Connection c, string file = __FILE__, size_t line = __LINE__)
608     {
609         super(c.errorMessage(), file, line);
610     }
611 
612     this(string msg, string file = __FILE__, size_t line = __LINE__)
613     {
614         super(msg, file, line);
615     }
616 }
617 
618 version (integration_tests)
619 Connection createTestConn(T...)(T params)
620 {
621     version(Dpq2_Static)
622         auto c = new Connection(params);
623     else
624     {
625         import dpq2.dynloader: connFactory;
626 
627         Connection c = connFactory.createConnection(params);
628     }
629 
630     return c;
631 }
632 
633 version (integration_tests)
634 void _integration_test( string connParam )
635 {
636     {
637         debug import std.experimental.logger;
638 
639         auto c = createTestConn(connParam);
640 
641         assert( PQlibVersion() >= 9_0100 );
642 
643         auto dbname = c.dbName();
644         auto pver = c.protocolVersion();
645         auto sver = c.serverVersion();
646 
647         debug
648         {
649             trace("DB name: ", dbname);
650             trace("Protocol version: ", pver);
651             trace("Server version: ", sver);
652         }
653 
654         destroy(c);
655     }
656 
657     {
658         version(Dpq2_Dynamic)
659         {
660             void csc(string s)
661             {
662                 import dpq2.dynloader: connFactory;
663 
664                 connFactory.connStringCheck(s);
665             }
666         }
667         else
668             void csc(string s){ connStringCheck(s); }
669 
670         csc("dbname=postgres user=postgres");
671 
672         {
673             bool raised = false;
674 
675             try
676                 csc("wrong conninfo string");
677             catch(ConnectionException e)
678                 raised = true;
679 
680             assert(raised);
681         }
682     }
683 
684     {
685         bool exceptionFlag = false;
686 
687         try
688             auto c = createTestConn(ConnectionStart(), "!!!some incorrect connection string!!!");
689         catch(ConnectionException e)
690         {
691             exceptionFlag = true;
692             assert(e.msg.length > 40); // error message check
693         }
694         finally
695             assert(exceptionFlag);
696     }
697 
698     {
699         auto c = createTestConn(connParam);
700 
701         assert(c.escapeLiteral("abc'def") == "'abc''def'");
702         assert(c.escapeIdentifier("abc'def") == "\"abc'def\"");
703 
704         c.setClientEncoding("WIN866");
705         assert(c.exec("show client_encoding")[0][0].as!string == "WIN866");
706     }
707 
708     {
709         auto c = createTestConn(connParam);
710 
711         assert(c.transactionStatus == PQTRANS_IDLE);
712 
713         c.exec("BEGIN");
714         assert(c.transactionStatus == PQTRANS_INTRANS);
715 
716         try c.exec("DISCARD ALL");
717         catch (Exception) {}
718         assert(c.transactionStatus == PQTRANS_INERROR);
719 
720         c.exec("ROLLBACK");
721         assert(c.transactionStatus == PQTRANS_IDLE);
722     }
723 
724     {
725         import std.exception: assertThrown;
726 
727         string[string] kv;
728         kv["host"] = "wrong-host";
729         kv["dbname"] = "wrong-db-name";
730 
731         assertThrown!ConnectionException(createTestConn(kv));
732         assertThrown!ConnectionException(createTestConn(ConnectionStart(), kv));
733     }
734 }