[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

tftp diff



Hello,

This diff applies the following changes to tftp:

- replaced setjmp(3) / alarm(3) with poll(2) in the network routines
  sendfile() and recvfile()
- open files descriptors were not closed when a file transfer was
  aborted by ctrl-c.  fixed
- added alias 'help' for the '?' command
- style(9)ed the code a bit for better readability

The setjmp(3) in main.c which is used when ctrl-c is catched (SIGINT)
is kept which is legitim IMO for that use and done this way by a lot
of command line programs (ftp, ed, restore etc.) to return to the
command prompt no matter which routine was active (e.g. blocking
fgets(3)).

I've tested the diff by transfering large files with put and get over
the network.  Further testing welcome.

Regards,
Marcus

-- 
Marcus Glocker, marcus_(_at_)_nazgul_(_dot_)_ch, http://www.nazgul.ch -----------------
diff -urN src/usr.bin/tftp.orig/main.c src/usr.bin/tftp/main.c
--- src/usr.bin/tftp.orig/main.c	Wed Apr 26 12:36:12 2006
+++ src/usr.bin/tftp/main.c	Thu Apr 27 10:29:37 2006
@@ -48,12 +48,12 @@
 /*
  * TFTP User Program -- Command Interface.
  */
+
 #include <sys/param.h>
 #include <sys/socket.h>
 #include <sys/file.h>
 
 #include <netinet/in.h>
-
 #include <arpa/inet.h>
 
 #include <ctype.h>
@@ -72,10 +72,11 @@
 #define	TIMEOUT		5		/* secs between rexmt's */
 #define	LBUFLEN		200		/* size of input buffer */
 #define	MAXARGV		20
+#define HELPINDENT	(sizeof("connect"))
 
 struct	sockaddr_in peeraddr;
 int	f;
-short   port;
+short	port;
 int	trace;
 int	verbose;
 int	connected;
@@ -87,6 +88,10 @@
 jmp_buf	toplevel;
 void	intr(int);
 struct	servent *sp;
+int	rexmtval = TIMEOUT;
+int	maxtimeout = 5 * TIMEOUT;
+char	hostname[MAXHOSTNAMELEN];
+FILE	*file = NULL;
 
 void	get(int, char **);
 void	help(int, char **);
@@ -109,8 +114,6 @@
 static void putusage(char *);
 static void settftpmode(char *);
 
-#define HELPINDENT (sizeof("connect"))
-
 struct cmd {
 	char	*name;
 	char	*help;
@@ -128,8 +131,8 @@
 char	sthelp[] = "show current status";
 char	xhelp[] = "set per-packet retransmission timeout";
 char	ihelp[] = "set total retransmission timeout";
-char    ashelp[] = "set mode to netascii";
-char    bnhelp[] = "set mode to octet";
+char	ashelp[] = "set mode to netascii";
+char	bnhelp[] = "set mode to octet";
 
 struct cmd cmdtab[] = {
 	{ "connect",	chelp,		setpeer },
@@ -144,10 +147,24 @@
 	{ "ascii",      ashelp,         setascii },
 	{ "rexmt",	xhelp,		setrexmt },
 	{ "timeout",	ihelp,		settimeout },
+	{ "help",	hhelp,		help },
 	{ "?",		hhelp,		help },
 	{ NULL,		NULL,		NULL }
 };
 
+struct	modes {
+	char *m_name;
+	char *m_mode;
+} modes[] = {
+	{ "ascii",	"netascii" },
+	{ "netascii",	"netascii" },
+	{ "binary",	"octet" },
+	{ "image",	"octet" },
+	{ "octet",	"octet" },
+/*	{ "mail",	"mail" }, */
+	{ NULL,		NULL }
+};
+
 struct	cmd *getcmd(char *);
 char	*tail(char *);
 
@@ -156,31 +173,42 @@
 {
 	struct sockaddr_in s_in;
 
+	/* socket, bind */
 	sp = getservbyname("tftp", "udp");
 	if (sp == 0)
 		errx(1, "udp/tftp: unknown service");
 	f = socket(AF_INET, SOCK_DGRAM, 0);
 	if (f < 0)
 		err(3, "socket");
-	bzero((char *)&s_in, sizeof (s_in));
+	bzero((char *)&s_in, sizeof(s_in));
 	s_in.sin_family = AF_INET;
-	if (bind(f, (struct sockaddr *)&s_in, sizeof (s_in)) < 0)
+	if (bind(f, (struct sockaddr *)&s_in, sizeof(s_in)) < 0)
 		err(1, "bind");
-	strlcpy(mode, "netascii", sizeof mode);
-	signal(SIGINT, intr);
-	if (argc > 1) {
-		if (setjmp(toplevel) != 0)
-			exit(0);
+
+	/* set default transfer mode */
+	strlcpy(mode, "netascii", sizeof(mode));
+
+	/* set peer if given */
+	if (argc > 1)
 		setpeer(argc, argv);
-	}
+
+	/* catch SIGINT */
+	signal(SIGINT, intr);
+
+	/* jump here on SIGINT, mostly ctrl-c */
 	if (setjmp(toplevel) != 0)
-		(void)putchar('\n');
+		(void) putchar('\n');
+
+	/* close open file */
+	if (file != NULL)
+		fclose(file);
+
+	/* command prompt */
 	command();
+
 	return (0);
 }
 
-char    hostname[MAXHOSTNAMELEN];
-
 void
 setpeer(int argc, char *argv[])
 {
@@ -227,19 +255,6 @@
 	connected = 1;
 }
 
-struct	modes {
-	char *m_name;
-	char *m_mode;
-} modes[] = {
-	{ "ascii",	"netascii" },
-	{ "netascii",   "netascii" },
-	{ "binary",     "octet" },
-	{ "image",      "octet" },
-	{ "octet",     "octet" },
-/*      { "mail",       "mail" },       */
-	{ NULL,		NULL }
-};
-
 void
 modecmd(int argc, char *argv[])
 {
@@ -276,14 +291,12 @@
 void
 setbinary(int argc, char *argv[])
 {
-
 	settftpmode("octet");
 }
 
 void
 setascii(int argc, char *argv[])
 {
-
 	settftpmode("netascii");
 }
 
@@ -295,7 +308,6 @@
 		printf("mode set to %s\n", mode);
 }
 
-
 /*
  * Send file(s).
  */
@@ -309,7 +321,7 @@
 	if (argc < 2) {
 		strlcpy(line, "send ", sizeof line);
 		printf("(file) ");
-		fgets(&line[strlen(line)], LBUFLEN-strlen(line), stdin);
+		fgets(&line[strlen(line)], LBUFLEN - strlen(line), stdin);
 		if (makeargv())
 			return;
 		argc = margc;
@@ -362,8 +374,10 @@
 		return;
 	}
 
-	/* this assumes the target is a directory */
-	/* on a remote unix system.  hmmmm.  */
+	/*
+	 * this assumes the target is a directory on
+	 * on a remote unix system.  hmmmm.
+	 */
 	for (n = 1; n < argc - 1; n++) {
 		if (asprintf(&cp, "%s/%s", targ, tail(argv[n])) == -1)
 			err(1, "asprintf");
@@ -386,7 +400,8 @@
 putusage(char *s)
 {
 	printf("usage: %s file [[host:]remotename]\n", s);
-	printf("       %s file1 file2 ... fileN [[host:]remote-directory]\n", s);
+	printf("       %s file1 file2 ... fileN [[host:]remote-directory]\n",
+	    s);
 }
 
 /*
@@ -474,8 +489,6 @@
 	printf("       %s [host1:]file1 [host2:]file2 ... [hostN:]fileN\n", s);
 }
 
-int	rexmtval = TIMEOUT;
-
 void
 setrexmt(int argc, char *argv[])
 {
@@ -501,8 +514,6 @@
 		rexmtval = t;
 }
 
-int	maxtimeout = 5 * TIMEOUT;
-
 void
 settimeout(int argc, char *argv[])
 {
@@ -544,9 +555,6 @@
 void
 intr(int signo)
 {
-
-	signal(SIGALRM, SIG_IGN);
-	alarm(0);
 	longjmp(toplevel, -1);
 }
 
@@ -666,7 +674,6 @@
 void
 quit(int argc, char *argv[])
 {
-
 	exit(0);
 }
 
diff -urN src/usr.bin/tftp.orig/tftp.c src/usr.bin/tftp/tftp.c
--- src/usr.bin/tftp.orig/tftp.c	Tue Apr 25 22:08:37 2006
+++ src/usr.bin/tftp/tftp.c	Thu Apr 27 12:02:19 2006
@@ -42,17 +42,16 @@
 /*
  * TFTP User Program -- Protocol Machines
  */
+
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <sys/time.h>
 
 #include <netinet/in.h>
-
 #include <arpa/tftp.h>
 
 #include <errno.h>
-#include <setjmp.h>
-#include <signal.h>
+#include <poll.h>
 #include <stdio.h>
 #include <stddef.h>
 #include <string.h>
@@ -62,40 +61,57 @@
 #include "extern.h"
 #include "tftpsubs.h"
 
+#define	PKTSIZE	SEGSIZE + 4
 
-extern  struct sockaddr_in peeraddr;	/* filled in by main */
-extern  int     f;			/* the opened socket */
-extern  int     trace;
-extern  int     verbose;
-extern  int     rexmtval;
-extern  int     maxtimeout;
+extern struct sockaddr_in	peeraddr;	/* filled in by main */
+extern int			f;		/* the opened socket */
+extern int			trace;
+extern int			verbose;
+extern int			rexmtval;
+extern int			maxtimeout;
+extern FILE			*file;
 
-#define PKTSIZE    SEGSIZE+4
-char    ackbuf[PKTSIZE];
-int	timeout;
-jmp_buf	toplevel;
-jmp_buf	timeoutbuf;
+char	ackbuf[PKTSIZE];
 
-static void nak(int);
-static int makerequest(int, const char *, struct tftphdr *, const char *);
-static void printstats(const char *, unsigned long);
-static void startclock(void);
-static void stopclock(void);
-static void timer(int);
-static void tpacket(const char *, struct tftphdr *, int);
+struct timeval	tstart;
+struct timeval	tstop;
 
+struct errmsg {
+	int	e_code;
+	char	*e_msg;
+} errmsgs[] = {
+	{ EUNDEF,	"Undefined error code" },
+	{ ENOTFOUND,	"File not found" },
+	{ EACCESS,	"Access violation" },
+	{ ENOSPACE,	"Disk full or allocation exceeded" },
+	{ EBADOP,	"Illegal TFTP operation" },
+	{ EBADID,	"Unknown transfer ID" },
+	{ EEXISTS,	"File already exists" },
+	{ ENOUSER,	"No such user" },
+	{ -1,		NULL }
+};
+
+static int	makerequest(int, const char *, struct tftphdr *, const char *);
+static void	nak(int);
+static void 	tpacket(const char *, struct tftphdr *, int);
+static void	startclock(void);
+static void	stopclock(void);
+static void	printstats(const char *, unsigned long);
+static void	printtimeout(void);
+
 /*
  * Send the requested file.
  */
 void
 sendfile(int fd, char *name, char *mode)
 {
-	struct tftphdr *dp, *ap;	   /* data and ack packets */
-	volatile int block, size, convert;
-	volatile unsigned long amount;
+	struct tftphdr *dp, *ap;	/* data and ack packets */
 	struct sockaddr_in from;
-	int n, fromlen;
-	FILE *file;
+	struct pollfd pfd[1];
+	volatile int convert;		/* true if converting crlf -> lf */
+	volatile int block, size;
+	volatile unsigned long amount;
+	int n, nfds, error, fromlen, timeouts;
 
 	startclock();		/* start stat's clock */
 	dp = r_init();		/* reset fillbuf/read-ahead code */
@@ -105,12 +121,11 @@
 	block = 0;
 	amount = 0;
 
-	signal(SIGALRM, timer);
 	do {
-		if (block == 0)
+		/* read data from file */
+		if (!block)
 			size = makerequest(WRQ, name, dp, mode) - 4;
 		else {
-		/*	size = read(fd, dp->th_data, SEGSIZE);	 */
 			size = readit(file, &dp, convert);
 			if (size < 0) {
 				nak(errno + 100);
@@ -119,61 +134,85 @@
 			dp->th_opcode = htons((u_short)DATA);
 			dp->th_block = htons((u_short)block);
 		}
-		timeout = 0;
-		(void) setjmp(timeoutbuf);
-send_data:
-		if (trace)
-			tpacket("sent", dp, size + 4);
-		n = sendto(f, dp, size + 4, 0,
-		    (struct sockaddr *)&peeraddr, sizeof(peeraddr));
-		if (n != size + 4) {
-			warn("sendto");
-			goto abort;
-		}
-		read_ahead(file, convert);
-		for ( ; ; ) {
-			alarm(rexmtval);
-			do {
-				fromlen = sizeof(from);
-				n = recvfrom(f, ackbuf, sizeof(ackbuf), 0,
-				    (struct sockaddr *)&from, &fromlen);
-			} while (n <= 0);
-			alarm(0);
-			if (n < 0) {
+
+		/* send data to server and wait for server ACK */
+		for (timeouts = 0, error = 0;;) {
+			if (timeouts == maxtimeout) {
+				printtimeout();
+				goto abort;
+			}
+
+			if (!error) {
+				if (trace)
+					tpacket("sent", dp, size + 4);
+				if (sendto(f, dp, size + 4, 0,
+		    		    (struct sockaddr *)&peeraddr,
+				    sizeof(peeraddr)) != size + 4) {
+					warn("sendto");
+					goto abort;
+				}
+				read_ahead(file, convert);
+			}
+			error = 0;
+
+			pfd[0].fd = f;
+			pfd[0].events = POLLIN;
+			nfds = poll(pfd, 1, rexmtval * 1000);
+			if (nfds == 0) {
+				timeouts++;
+				continue;
+			}
+			if (nfds == -1) {
+				error = 1;
+				if (errno == EINTR)
+					continue;
+				warn("poll");
+				goto abort;
+			}
+			fromlen = sizeof(from);
+			n = recvfrom(f, ackbuf, sizeof(ackbuf), 0,
+			    (struct sockaddr *)&from, &fromlen);
+			if (n == 0) {
 				warn("recvfrom");
 				goto abort;
 			}
+			if (n == -1) {
+				error = 1;
+				if (errno == EINTR)
+					continue;
+				warn("recvfrom");
+				goto abort;
+			}
 			peeraddr.sin_port = from.sin_port;	/* added */
 			if (trace)
 				tpacket("received", ap, n);
-			/* should verify packet came from server */
 			ap->th_opcode = ntohs(ap->th_opcode);
 			ap->th_block = ntohs(ap->th_block);
+
 			if (ap->th_opcode == ERROR) {
-				printf("Error code %d: %s\n", ap->th_code,
-				    ap->th_msg);
+				printf("Error code %d: %s\n",
+				    ap->th_code, ap->th_msg);
 				goto abort;
 			}
 			if (ap->th_opcode == ACK) {
 				int j;
-
-				if (ap->th_block == block) {
+				if (ap->th_block == block)
 					break;
-				}
-				/* On an error, try to synchronize
-				 * both sides.
-				 */
+				/* re-synchronize with other side */
 				j = synchnet(f);
 				if (j && trace)
 					printf("discarded %d packets\n", j);
-				if (ap->th_block == (block-1))
-					goto send_data;
+				if (ap->th_block == (block - 1))
+					continue;
 			}
+			error = 1;	/* FALLTHROUGH */
 		}
+
 		if (block > 0)
 			amount += size;
 		block++;
 	} while (size == SEGSIZE || block == 1);
+
 abort:
 	fclose(file);
 	stopclock();
@@ -187,25 +226,25 @@
 void
 recvfile(int fd, char *name, char *mode)
 {
-	struct tftphdr *dp, *ap;
-	volatile int block, size, firsttrip;
-	volatile unsigned long amount;
+	struct tftphdr *dp, *ap;	/* data and ack packets */
 	struct sockaddr_in from;
-	int n, fromlen;
-	FILE *file;
+	struct pollfd pfd[1];
 	volatile int convert;		/* true if converting crlf -> lf */
+	volatile int block, size, firsttrip;
+	volatile unsigned long amount;
+	int n, nfds, error, fromlen, timeouts;
 
-	startclock();
-	dp = w_init();
+	startclock();		/* start stat's clock */
+	dp = w_init();		/* reset fillbuf/read-ahead code */
 	ap = (struct tftphdr *)ackbuf;
 	file = fdopen(fd, "w");
 	convert = !strcmp(mode, "netascii");
 	block = 1;
-	firsttrip = 1;
 	amount = 0;
+	firsttrip = 1;
 
-	signal(SIGALRM, timer);
 	do {
+		/* create new ACK packet */
 		if (firsttrip) {
 			size = makerequest(RRQ, name, ap, mode);
 			firsttrip = 0;
@@ -215,58 +254,81 @@
 			size = 4;
 			block++;
 		}
-		timeout = 0;
-		(void) setjmp(timeoutbuf);
-send_ack:
-		if (trace)
-			tpacket("sent", ap, size);
-		if (sendto(f, ackbuf, size, 0, (struct sockaddr *)&peeraddr,
-		    sizeof(peeraddr)) != size) {
-			alarm(0);
-			warn("sendto");
-			goto abort;
-		}
-		write_behind(file, convert);
-		for ( ; ; ) {
-			alarm(rexmtval);
-			do  {
-				fromlen = sizeof(from);
-				n = recvfrom(f, dp, PKTSIZE, 0,
-				    (struct sockaddr *)&from, &fromlen);
-			} while (n <= 0);
-			alarm(0);
-			if (n < 0) {
+
+		/* send ACK to server and wait for server data */
+		for (timeouts = 0, error = 0;;) {
+			if (timeouts == maxtimeout) {
+				printtimeout();
+				goto abort;
+			}
+
+			if (!error) {
+				if (trace)
+					tpacket("sent", ap, size);
+				if (sendto(f, ackbuf, size, 0,
+			    	    (struct sockaddr *)&peeraddr,
+				    sizeof(peeraddr)) != size) {
+					warn("sendto");
+					goto abort;
+				}
+				write_behind(file, convert);
+			}
+			error = 0;
+
+			pfd[0].fd = f;
+			pfd[0].events = POLLIN;
+			nfds = poll(pfd, 1, rexmtval * 1000);
+			if (nfds == 0) {
+				timeouts++;
+				continue;
+			}
+			if (nfds == -1) {
+				error = 1;
+				if (errno == EINTR)
+					continue;
+				warn("poll");
+				goto abort;
+			}
+			fromlen = sizeof(from);
+			n = recvfrom(f, dp, PKTSIZE, 0,
+			    (struct sockaddr *)&from, &fromlen);
+			if (n == 0) {
 				warn("recvfrom");
 				goto abort;
 			}
+			if (n == -1) {
+				error = 1;
+				if (errno == EINTR)
+					continue;
+				warn("recvfrom");
+				goto abort;
+			}
 			peeraddr.sin_port = from.sin_port;	/* added */
 			if (trace)
 				tpacket("received", dp, n);
-			/* should verify client address */
 			dp->th_opcode = ntohs(dp->th_opcode);
 			dp->th_block = ntohs(dp->th_block);
+
 			if (dp->th_opcode == ERROR) {
-				printf("Error code %d: %s\n", dp->th_code,
-				    dp->th_msg);
+				printf("Error code %d: %s\n",
+				    dp->th_code, dp->th_msg);
 				goto abort;
 			}
 			if (dp->th_opcode == DATA) {
 				int j;
-
-				if (dp->th_block == block) {
-					break;		/* have next packet */
-				}
-				/* On an error, try to synchronize
-				 * both sides.
-				 */
+				if (dp->th_block == block)
+					break;
+				/* re-synchronize with other side */
 				j = synchnet(f);
 				if (j && trace)
 					printf("discarded %d packets\n", j);
-				if (dp->th_block == (block-1))
-					goto send_ack;	/* resend ack */
+				if (dp->th_block == (block - 1))
+					continue;
 			}
+			error = 1;	/* FALLTHROUGH */
 		}
-	/*	size = write(fd, dp->th_data, n - 4); */
+
+		/* write data to file */
 		size = writeit(file, &dp, n - 4, convert);
 		if (size < 0) {
 			nak(errno + 100);
@@ -274,12 +336,15 @@
 		}
 		amount += size;
 	} while (size == SEGSIZE);
-abort:						/* ok to ack, since user */
-	ap->th_opcode = htons((u_short)ACK);	/* has seen err msg */
+
+abort:
+	/* ok to ack, since user has seen err msg */
+	ap->th_opcode = htons((u_short)ACK);
 	ap->th_block = htons((u_short)block);
 	(void) sendto(f, ackbuf, 4, 0, (struct sockaddr *)&peeraddr,
 	    sizeof(peeraddr));
-	write_behind(file, convert);		/* flush last buffer */
+	write_behind(file, convert);	/* flush last buffer */
+
 	fclose(file);
 	stopclock();
 	if (amount > 0)
@@ -303,21 +368,6 @@
 	return (cp + len - (char *)tp);
 }
 
-struct errmsg {
-	int	e_code;
-	char	*e_msg;
-} errmsgs[] = {
-	{ EUNDEF,	"Undefined error code" },
-	{ ENOTFOUND,	"File not found" },
-	{ EACCESS,	"Access violation" },
-	{ ENOSPACE,	"Disk full or allocation exceeded" },
-	{ EBADOP,	"Illegal TFTP operation" },
-	{ EBADID,	"Unknown transfer ID" },
-	{ EEXISTS,	"File already exists" },
-	{ ENOUSER,	"No such user" },
-	{ -1,		NULL }
-};
-
 /*
  * Send a nak packet (error message).
  * Error code passed in is one of the
@@ -355,7 +405,7 @@
 tpacket(const char *s, struct tftphdr *tp, int n)
 {
 	static char *opcodes[] =
-	   { "#0", "RRQ", "WRQ", "DATA", "ACK", "ERROR" };
+	    { "#0", "RRQ", "WRQ", "DATA", "ACK", "ERROR" };
 	char *cp, *file;
 	u_short op = ntohs(tp->th_opcode);
 
@@ -363,8 +413,8 @@
 		printf("%s opcode=%x ", s, op);
 	else
 		printf("%s %s ", s, opcodes[op]);
-	switch (op) {
 
+	switch (op) {
 	case RRQ:
 	case WRQ:
 		n -= 2;
@@ -372,36 +422,28 @@
 		cp = strchr(cp, '\0');
 		printf("<file=%s, mode=%s>\n", file, cp + 1);
 		break;
-
 	case DATA:
 		printf("<block=%d, %d bytes>\n", ntohs(tp->th_block), n - 4);
 		break;
-
 	case ACK:
 		printf("<block=%d>\n", ntohs(tp->th_block));
 		break;
-
 	case ERROR:
 		printf("<code=%d, msg=%s>\n", ntohs(tp->th_code), tp->th_msg);
 		break;
 	}
 }
 
-struct timeval tstart;
-struct timeval tstop;
-
 static void
 startclock(void)
 {
-
-	(void)gettimeofday(&tstart, NULL);
+	(void) gettimeofday(&tstart, NULL);
 }
 
 static void
 stopclock(void)
 {
-
-	(void)gettimeofday(&tstop, NULL);
+	(void) gettimeofday(&tstop, NULL);
 }
 
 static void
@@ -410,26 +452,17 @@
 	double delta;
 
 	/* compute delta in 1/10's second units */
-	delta = ((tstop.tv_sec*10.)+(tstop.tv_usec/100000)) -
-		((tstart.tv_sec*10.)+(tstart.tv_usec/100000));
-	delta = delta/10.;      /* back to seconds */
+	delta = ((tstop.tv_sec * 10.) + (tstop.tv_usec / 100000)) -
+	    ((tstart.tv_sec * 10.) + (tstart.tv_usec / 100000));
+	delta = delta / 10.;	/* back to seconds */
 	printf("%s %lu bytes in %.1f seconds", direction, amount, delta);
 	if (verbose)
-		printf(" [%.0f bits/sec]", (amount*8.)/delta);
+		printf(" [%.0f bits/sec]", (amount * 8.) / delta);
 	putchar('\n');
 }
 
 static void
-timer(int sig)
+printtimeout(void)
 {
-	int save_errno = errno;
-
-	timeout += rexmtval;
-	if (timeout >= maxtimeout) {
-		printf("Transfer timed out.\n");
-		errno = save_errno;
-		longjmp(toplevel, -1);
-	}
-	errno = save_errno;
-	longjmp(timeoutbuf, 1);
+	printf("Transfer timed out.\n");
 }