safe

Password protected secret keeper
git clone git://git.z3bra.org/safe.git
Log | Files | Refs | README | LICENSE

safe-agent.c (5473B)


      1 #include <sys/resource.h>
      2 #include <sys/socket.h>
      3 #include <sys/stat.h>
      4 #include <sys/types.h>
      5 #include <sys/un.h>
      6 
      7 #include <err.h>
      8 #include <fcntl.h>
      9 #include <limits.h>
     10 #include <poll.h>
     11 #include <signal.h>
     12 #include <stdint.h>
     13 #include <stdio.h>
     14 #include <stdlib.h>
     15 #include <string.h>
     16 #include <unistd.h>
     17 
     18 #include <sodium.h>
     19 
     20 #include "arg.h"
     21 #include "strlcpy.h"
     22 
     23 #define SOCKDIR "/tmp/safe-XXXXXX"
     24 #define SOCKET  "agent"
     25 
     26 struct safe {
     27 	int loaded;
     28 	uint8_t saltkey[crypto_secretstream_xchacha20poly1305_KEYBYTES + crypto_pwhash_SALTBYTES];
     29 };
     30 
     31 char *argv0;
     32 struct safe s;
     33 char *sockp = NULL;
     34 int verbose = 0;
     35 
     36 void
     37 usage(void)
     38 {
     39 	fprintf(stderr, "usage: %s [-hdv] [-t timeout] [-f socket]\n", argv0);
     40 	exit(1);
     41 }
     42 
     43 char *
     44 dirname(char *path)
     45 {
     46 	static char tmp[PATH_MAX];
     47 	char *p = NULL;
     48 	size_t len;
     49 	snprintf(tmp, sizeof(tmp), "%s", path);
     50 	len = strlen(tmp);
     51 	for(p = tmp + len; p > tmp; p--)
     52 		if(*p == '/')
     53 			break;
     54 
     55 	*p = 0;
     56 	return tmp;
     57 }
     58 
     59 ssize_t
     60 xread(int fd, void *buf, size_t nbytes)
     61 {
     62 	uint8_t *bp = buf;
     63 	ssize_t total = 0;
     64 
     65 	while (nbytes > 0) {
     66 		ssize_t n;
     67 
     68 		n = read(fd, &bp[total], nbytes);
     69 		if (n < 0)
     70 			err(1, "read");
     71 		else if (n == 0)
     72 			return total;
     73 		total += n;
     74 		nbytes -= n;
     75 	}
     76 	return total;
     77 }
     78 
     79 ssize_t
     80 xwrite(int fd, const void *buf, size_t nbytes)
     81 {
     82 	const uint8_t *bp = buf;
     83 	ssize_t total = 0;
     84 
     85 	while (nbytes > 0) {
     86 		ssize_t n;
     87 
     88 		n = write(fd, &bp[total], nbytes);
     89 		if (n < 0)
     90 			err(1, "write");
     91 		else if (n == 0)
     92 			return total;
     93 		total += n;
     94 		nbytes -= n;
     95 	}
     96 	return total;
     97 }
     98 
     99 int
    100 creatsock(char *sockpath)
    101 {
    102 	int sfd;
    103 	struct sockaddr_un addr;
    104 
    105 	sfd = socket(AF_UNIX, SOCK_STREAM, 0);
    106 	if (sfd < 0)
    107 		return -1;
    108 
    109 	umask(0177);
    110 	memset(&addr, 0, sizeof(addr));
    111 	addr.sun_family = AF_UNIX;
    112 	strlcpy(addr.sun_path, sockpath, sizeof(addr.sun_path));
    113 
    114 	if (bind(sfd, (struct sockaddr *) &addr, sizeof(addr)) < 0)
    115 		return -1;
    116 
    117 	if (listen(sfd, 10) < 0)
    118 		return -1;
    119 
    120 	return sfd;
    121 }
    122 
    123 void
    124 forgetkey()
    125 {
    126 	sodium_memzero(s.saltkey, sizeof(s.saltkey));
    127 	s.loaded = 0;
    128 	alarm(0);
    129 }
    130 
    131 void
    132 sighandler(int signal)
    133 {
    134 	switch (signal) {
    135 	case SIGINT:
    136 	case SIGTERM:
    137 		if (verbose)
    138 			fprintf(stderr, "unlocking key from memory\n");
    139 		sodium_munlock(s.saltkey, sizeof(s.saltkey));
    140 
    141 		if (verbose)
    142 			fprintf(stderr, "removing socket %s\n", sockp);
    143 		unlink(sockp);
    144 		rmdir(dirname(sockp));
    145 		exit(0);
    146 		/* NOTREACHED */
    147 	case SIGALRM:
    148 	case SIGUSR1:
    149 		if (verbose)
    150 			fprintf(stderr, "clearing key from memory\n");
    151 		forgetkey();
    152 		break;
    153 	}
    154 }
    155 
    156 int
    157 servekey(int timeout)
    158 {
    159 	int r, sfd;
    160 	ssize_t n;
    161 	struct pollfd pfd;
    162 
    163 	if (verbose)
    164 		fprintf(stderr, "listening on %s\n", sockp);
    165 	sfd = creatsock(sockp);
    166 	if (sfd < 0)
    167 		err(1, "%s", sockp);
    168 
    169 	s.loaded = 0;
    170 
    171 	for (;;) {
    172 		pfd.fd = accept(sfd, NULL, NULL);
    173 		pfd.revents = 0;
    174 		pfd.events = POLLIN;
    175 
    176 		if (s.loaded)
    177 			pfd.events |= POLLOUT;
    178 
    179 		if (pfd.fd < 0)
    180 			err(1, "%s", sockp);
    181 
    182 		if ((r = poll(&pfd, 1, 100)) < 0)
    183 			return r;
    184 
    185 		if (pfd.revents & POLLIN) {
    186 			if (verbose)
    187 				fprintf(stderr, "reading key from client fd %d\n", pfd.fd);
    188 
    189 			n = xread(pfd.fd, s.saltkey, sizeof(s.saltkey));
    190 			if (n == sizeof(s.saltkey)) {
    191 				s.loaded = 1;
    192 				if (verbose) {
    193 					fprintf(stderr, "key loaded in memory\n");
    194 					if (timeout > 0)
    195 						fprintf(stderr, "setting timeout to %d seconds\n", timeout);
    196 				}
    197 				alarm(timeout);
    198 			} else {
    199 				forgetkey();
    200 				if (verbose)
    201 					fprintf(stderr, "failed to load key in memory\n");
    202 			}
    203 		} else if (pfd.revents & POLLOUT) {
    204 			if (verbose)
    205 				fprintf(stderr, "sending key to client fd %d\n", pfd.fd);
    206 
    207 			xwrite(pfd.fd, s.saltkey, sizeof(s.saltkey));
    208 		}
    209 
    210 		close(pfd.fd);
    211 	}
    212 
    213 	/* NOTREACHED */
    214 	close(sfd);
    215 	return -1;
    216 }
    217 
    218 int
    219 main(int argc, char *argv[])
    220 {
    221 	pid_t pid;
    222 	int fd, timeout = 0, dflag = 0;
    223 	size_t dirlen;
    224 	char path[PATH_MAX] = SOCKDIR;
    225 	struct rlimit rlim;
    226 
    227 	pid = getpid();
    228 
    229 	ARGBEGIN {
    230 	case 'd':
    231 		dflag = 1;
    232 		break;
    233 	case 'f':
    234 		sockp = EARGF(usage());
    235 		break;
    236 	case 't':
    237 		timeout = atoi(EARGF(usage()));
    238 		break;
    239 	case 'v':
    240 		verbose = 1;
    241 		break;
    242 	default:
    243 		usage();
    244 	} ARGEND
    245 
    246 	if (sockp) {
    247 		strlcpy(path, sockp, sizeof(path));
    248 	} else {
    249 		if (!mkdtemp(path))
    250 			err(1, "mkdtemp: %s", path);
    251 
    252 		dirlen = strnlen(path, sizeof(path));
    253 		snprintf(path + dirlen, PATH_MAX - dirlen, "/%s.%d", SOCKET, pid);
    254 		sockp = path;
    255 	}
    256 
    257 	/* deny core dump as memory contains derivated key */
    258 	rlim.rlim_cur = rlim.rlim_max = 0;
    259 	if (setrlimit(RLIMIT_CORE, &rlim) < 0)
    260 		err(1, "setrlimit RLIMIT_CORE");
    261 
    262 	if (dflag) {
    263 		printf("SAFE_PID=%d; export SAFE_PID\n", pid);
    264 		printf("SAFE_SOCK=%s; export SAFE_SOCK\n", sockp);
    265 		fflush(stdout);
    266 		goto skip;
    267 	}
    268 
    269 	if (verbose)
    270 		fprintf(stderr, "forking agent to the background\n");
    271 
    272 	pid = fork();
    273 	if (pid < 0)
    274 		err(1, "fork");
    275 
    276 	if (pid) {
    277 		if (verbose)
    278 			fprintf(stderr, "agent pid is %d\n", pid);
    279 
    280 		printf("SAFE_PID=%d; export SAFE_PID\n", pid);
    281 		printf("SAFE_SOCK=%s; export SAFE_SOCK\n", sockp);
    282 		return 0;
    283 	}
    284 
    285 	if (setsid() < 0)
    286 		err(1, "setsid");
    287 
    288 	chdir("/");
    289 	if ((fd = open("/dev/null", O_RDWR, 0)) != -1) {
    290 		(void)dup2(fd, STDIN_FILENO);
    291 		(void)dup2(fd, STDOUT_FILENO);
    292 		(void)dup2(fd, STDERR_FILENO);
    293 		if (fd > 2)
    294 			close(fd);
    295 	}
    296 
    297 skip:
    298 	pid = getpid();
    299 	signal(SIGINT, sighandler);
    300 	signal(SIGTERM, sighandler);
    301 	signal(SIGUSR1, sighandler);
    302 	signal(SIGALRM, sighandler);
    303 
    304 	if (sodium_init() < 0)
    305 		return -1;
    306 
    307 	if (verbose)
    308 		fprintf(stderr, "locking key in memory\n");
    309 	sodium_mlock(s.saltkey, sizeof(s.saltkey));
    310 
    311 	return servekey(timeout);
    312 }