Checking Postgres availability with pure Ruby

A quick introduction to Postgres wire protocol

#ruby #postgresql

Recently I had to write a simple script, which single job was to wait until Postgres becomes available. This is really important when you need to run queries on your database as part of some automatic workflow, but you don’t have any guarantees that database service will be ready in time. My concrete use cases were:

The obvious way to implement such script is to either use psql or Postgres library for your language of choice. But because I played a little bit with Postgres protocol in the past, I’ve decided that it will be good idea to write it in pure Ruby and avoid installing psql inside Docker container.

This article describes small portion of Postgres protocol and how to write Ruby program that communicates with Postgres using this protocol.

TL;DR If you are interested just in the script, you can jump straight to the repository. If you want to learn about the protocol or some Ruby tricks, read on.

Protocol basics

Postgres protocol is quite extensive with over 50 message types and various modes. But for the purpose of this article, you need to know only few of them. Take a look at diagram below that describes message flow when establishing new connection to the server:

Postgres connection sequence

Technically the format is binary, but most of the data is sent as ASCII. And almost all message types have the same format, comprised of three different parts:

Startup message is the exception here. It does not contain the character tag:

Going back to the diagram, here is the flow again, explained step by step:

  1. Clients sends Startup Message with required parameters
  2. If authentication is required, server responds with Authentication Request message. This message varies based on server configuration and selected authentication method. If authentication is not required, server immediately responds with Authentication OK message and the flow continues from step 5.
  3. Client sends Password Message with password challange
  4. Server checks password challenge and responds with Authentication OK or Error Message
  5. Server sends important parameters as Parameter Status messages
  6. Server sends cancellation key as Backend Key Data message
  7. Server sends Ready for query message. At this point server is ready for accepting queries.

The code

Here is the code in its glory:

require "digest"
require "socket"

module PgTools
  class Status
    def self.ready?(host:, port:, user:, password: nil, database: nil)
      socket = TCPSocket.new(host, port)
      socket.write(build_startup_message(user: user, database: database))

      while true
        char_tag = socket.read(1)
        data = socket.read(4)
        length = data.unpack("L>").first - 4
        payload = socket.read(length)

        case char_tag
        when "E"
          # Received **ErrorResponse**
          break
        when "R"
          # Received **AuthenticationRequest** message
          decoded = payload.unpack("L>").first

          case decoded
          when 3
            # Cleartext password
            packet = [112, password.size + 5, password].pack("CL>Z*")
            socket.write(packet)
          when 5
            # MD5 password
            salt = payload[4..-1]
            hashed_pass = Digest::MD5.hexdigest([password, user].join)
            encoded_password = "md5" + Digest::MD5.hexdigest([hashed_pass, salt].join)
            socket.write([112, encoded_password.size + 5, encoded_password].pack("CL>Z*"))
          end
        when "Z"
          # Received **Ready for query** message
          return true
        end
      end

      false
    rescue Errno::ECONNREFUSED
      false
    end

    def self.build_startup_message(user:, database: nil)
      message = [0, 196608]
      message_size = 4 + 4
      pack_string = "L>L>"

      params = ["user", user]
      params.concat(["database", database]) if database

      params.each do |value|
        message << value
        message_size += value.size + 1
        pack_string << "Z*"
      end

      message << 0
      message_size += 1
      pack_string << "C"

      message[0] = message_size

      message.pack(pack_string)
    end
  end
end

Let’s walk through it step by step, starting with build_startup_message method.

It’s main job is to construct array of values and translate it into byte stream that can be send to the Postgres server. The easiest way to do this in Ruby is to use Array#pack method. It takes an array and a format string as parameters and returns an output string containing all values encoded according to the specified format string. This method takes care of correct padding, big / little endian encoding, terminating strings with zero bytes etc. You can read more about it here and here

The method starts with some variable definitions:

message = [0, 196608]
message_size = 4 + 4
pack_string = "L>L>"

The first two parts of startup message are it’s length and protocol version. At this point we don’t know yet what message size will be, so we need to put 0 as the first element. Protocol version is 3.0. Two most significant bytes represent major version, and two least significant bytes represent 0. This gives value 196608.

message_size stores current message size. We initialize it with value of 8, because message size and protocol version are already added to the message body. Writing 4 + 4 instead of just 8 better indicates that this is 4 bytes for message length and another 4 bytes for protocol version.

pack_string stores current format string that will be passed to Array#pack method. Final format string will depend on passed arguments to the build_startup_message method, so we should store it in a variable. And it will be easy to extend this method with support for other parameters.

params = ["user", user]
params.concat(["database", database]) if database

Next we create params array for storing name / value pairs. User parameter is required so it is already added to the array. Database name is added if specified.

params.each do |value|
  message << value
  message_size += value.size + 1
  pack_string << "Z*"
end

We iterate over the params array and add each value to the message body. In the same time we need to increase the message size and update format string. Notice that message size is increased by one extra byte. This is for zero byte that terminates each value.

message << 0
message_size += 1
pack_string << "C"

We need to add zero byte at the end of the message and of course update size and format string.

message[0] = message_size

At this point the whole message body is ready, so we can set the correct message length inside the message body.

message.pack(pack_string)

And finally we encode message body as a string. For example, if user is postgres and database name is my_database, the final message will look like this:

[44, 196608, "user", "postgres", "database", "my_database", 0]

and encoded string:

"\x00\x00\x00,\x00\x03\x00\x00user\x00postgres\x00database\x00my_database\x00\x00"

Now let’s walk through ready? method:

socket = TCPSocket.new(host, port)
socket.write(build_startup_message(user: user, database: database))

We start with opening TCP connection to the server and sending startup message.

while true
  char_tag = socket.read(1)
  data = socket.read(4)
  length = data.unpack("L>").first - 4
  payload = socket.read(length)

  case char_tag
    # hidden for brevity
  end
end

This is main event loop. At the beginning we read 1-byte character tag from the server, followed by message length and actual message payload. Depending on received character tag (message type) we can executed appropriate code.

case char_tag
when "E"
  # Received **ErrorResponse**
  break

# hidden for brevity
end

This matches error response code. In such case we break the loop and the method will return false.

case char_tag
# hidden for brevity
when "R"
  # Received **AuthenticationRequest** message
  decoded = payload.unpack("L>").first

  case decoded
  when 3
    # Cleartext password
    packet = [112, password.size + 5, password].pack("CL>Z*")
    socket.write(packet)
  when 5
    # MD5 password
    salt = payload[4..-1]
    hashed_pass = Digest::MD5.hexdigest([password, user].join)
    encoded_password = "md5" + Digest::MD5.hexdigest([hashed_pass, salt].join)
    socket.write([112, encoded_password.size + 5, encoded_password].pack("CL>Z*"))
  end
end
# hidden for brevity

This piece matches various authentiation requests. First 4-bytes of payload specify what type of request it is. Value 0 means that authentication was successful and we don’t have to do anything more.

Value 3 means that cleartext password is required, so we are sending Password Message. This packet is encoded in the same way as startup message, using Array.pack method. But the format is simpler. First element is the character tag 112 (lowercase p). Then message length and password itself. Message length is always password length + 5 bytes. 4 bytes for message length field and 1 byte for zero byte at the end of the message.

Value 5 means that MD5-encoded password is required. In this case server also sends salt value in the message payload. We are using user name, password and received salt to calculate the correct password challange and then we are sending it to the server. The message format is the same as with cleartext password. The only difference is how the password is encoded.

case char_tag
# hidden code
when "Z"
  # Received **Ready for query** message
  return true
end

The final statement matches Ready for query message. At this point Postgres is ready to accept queries and we can return true from the method.

And that’s it. Now to use this code, all you need to do is to invoke ready? method:

PgTools::Status.ready?(
  host: "localhost",
  port: 5432,
  user: "postgres",
  password: "password",
  database: "my-database"
)

or better:

60.times do
  exit(0) if PgTools::Status.ready?(host: "localhost", port: 5432, user: "postgres")

  sleep 1
end

puts "Postgres is not responding"

Conclusion

I hope you liked the article and learnt something new. I realize that knowing Postgres protocol is not a skill that you will use everyday, but it has some practical usages. In the future I plan to write simple Postgres proxy that will be able to simulate various database problems in tests such as timeouts, terminating connections, cancelled requests etc. Stay tuned!