/* modprobe.c: insert a module into the kernel, intelligently.
    Copyright (C) 2001  Rusty Russell.
    Copyright (C) 2002  Rusty Russell, IBM Corporation.

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/
#include <sys/utsname.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <string.h>
#include <errno.h>
#include <unistd.h>
#include <dirent.h>
#include <limits.h>
#include <elf.h>
#include <getopt.h>
#include <asm/unistd.h>

#include "backwards_compat.c"

#define MODULE_DIR "/lib/modules/%s/kernel/"
#define MODULE_EXTENSION ".o"

#define MODULE_NAME_LEN (64 - ELF_TYPE / 8)
struct kernel_symbol
{
	char value[ELF_TYPE / 8]; /* We don't care about this bit */
	char name[MODULE_NAME_LEN];
};

#if ELF_TYPE == 32
#define Elf_Shdr Elf32_Shdr
#define Elf_Sym Elf32_Sym
#define Elf_Ehdr Elf32_Ehdr
#elif ELF_TYPE == 64
#define Elf_Shdr Elf64_Shdr
#define Elf_Sym Elf64_Sym
#define Elf_Ehdr Elf64_Ehdr
#else
#error Unknown ELF_TYPE setting.
#endif

static void fatal(const char *fmt, ...)
__attribute__ ((noreturn, format (printf, 1, 2)));

static void fatal(const char *fmt, ...)
{
	va_list arglist;

	fprintf(stderr, "FATAL: ");

	va_start(arglist, fmt);
	vfprintf(stderr, fmt, arglist);
	va_end(arglist);

	exit(1);
}

static void warn(const char *fmt, ...)
__attribute__ ((format (printf, 1, 2)));

static void warn(const char *fmt, ...)
{
	va_list arglist;

	fprintf(stderr, "WARNING: ");

	va_start(arglist, fmt);
	vfprintf(stderr, fmt, arglist);
	va_end(arglist);
}

static void print_usage(const char *progname)
{
	fprintf(stderr,
		"Usage: %s [--verbose] filename\n",
		progname);
	exit(1);
}

/* All the modules kept in this list */
struct module
{
	struct module *next;

	/* mmaped export symbols area */
	unsigned int num_exports;
	struct kernel_symbol *exports;

	/* What order it has to be loaded (0 = never). */
	unsigned int order;

	/* full path name */
	char name[0];
};

static int ends_in(const char *name, const char *ext)
{
	unsigned int namelen, extlen, i;

	/* Grab lengths */
	namelen = strlen(name);
	extlen = strlen(ext);

	if (namelen < extlen) return 0;

	/* Look backwards */
	for (i = 0; i < extlen; i++)
		if (name[namelen - i] != ext[extlen - i]) return 0;

	return 1;
}

static void *load_section(int fd, unsigned long shdroff,
			  unsigned int num_secs,
			  unsigned int secnamesec,
			  const char *secname,
			  unsigned long *size)
{
	Elf_Shdr sechdrs[num_secs];
	unsigned int i;
	char *secnames;

	/* Grab headers. */
	lseek(fd, shdroff, SEEK_SET);
	if (read(fd, sechdrs, sizeof(sechdrs)) != sizeof(sechdrs))
		return (void*)-1;

	/* Grab strings so we can tell who is who */
	secnames = malloc(sechdrs[secnamesec].sh_size);
	lseek(fd, sechdrs[secnamesec].sh_offset, SEEK_SET);
	if (read(fd, secnames, sechdrs[secnamesec].sh_size)
	    != sechdrs[secnamesec].sh_size) {
		free(secnames);
		return (void*)-1;
	}


	/* Find the section they want */
	for (i = 1; i < num_secs; i++) {
		if (strcmp(secnames+sechdrs[i].sh_name, secname) == 0) {
			void *buf;

			free(secnames);
			*size = sechdrs[i].sh_size;
			buf = malloc(*size);
			if (lseek(fd, sechdrs[i].sh_offset, SEEK_SET) == -1
			    || read(fd, buf, *size) != *size) {
				free(buf);
				return (void *)-1;
			}
			return buf;
		}
	}
	free(secnames);
	return NULL;
}

static struct kernel_symbol *map_exports(int fd,
					 unsigned long shdroff,
					 unsigned int num_secs,
					 unsigned int secnamesec,
					 const char *name,
					 unsigned int *num_exports)
{
	struct kernel_symbol *syms;
	unsigned long size;

	syms = load_section(fd, shdroff, num_secs, secnamesec, "__ksymtab",
			    &size);
	*num_exports = 0;
	if (syms == (void*)-1)
		warn("Error finding exports for module %s\n", name);
	else if (syms)
		*num_exports = size / sizeof(struct kernel_symbol);
	return syms;
}

static struct module *add_module(const char *dirname, const char *entry,
				 struct module *last)
{
	struct module *new;
	Elf_Ehdr hdr;
	int fd;

	new = malloc(sizeof(*new) + strlen(dirname) + strlen(entry) + 1);
	sprintf(new->name, "%s%s", dirname, entry);
	new->next = last;
	new->order = 0;
	fd = open(new->name, O_RDONLY);
	if (fd < 0) {
		warn("Can't read module %s: %s\n", new->name, strerror(errno));
		free(new);
		return last;
	}

	if (read(fd, &hdr, sizeof(hdr)) != sizeof(hdr)) {
		warn("Error reading module %s\n", new->name);
		free(new);
		close(fd);
		return last;
	}

	/* Map the section table */
	new->exports = map_exports(fd, hdr.e_shoff, hdr.e_shnum,
				   hdr.e_shstrndx, new->name,
				   &new->num_exports);
	close(fd);
	return new;
}

static struct module *load_all_exports(const char *revision)
{
	struct module *mods = NULL;
	struct dirent *dirent;
	DIR *dir;
	char dirname[strlen(revision) + sizeof(MODULE_DIR)];

	sprintf(dirname, MODULE_DIR, revision);
	dir = opendir(dirname);
	if (dir) {
		while ((dirent = readdir(dir)) != NULL) {
			/* Is it a .o file? */
			if (ends_in(dirent->d_name, MODULE_EXTENSION))
				mods = add_module(dirname, dirent->d_name,
						  mods);
		}
		closedir(dir);
	}
	return mods;
}

static int need_symbol(unsigned int order,
		       const char *name,
		       struct module *modules,
		       const char *modname)
{
	struct module *m;
	struct module *found = NULL;

	for (m = modules; m; m = m->next) {
		unsigned int i;
		for (i = 0; i < m->num_exports; i++) {
			if (strcmp(m->exports[i].name, name) == 0) {
				if (found) {
					warn("%s supplied by %s and %s:"
					     " picking neither\n",
					     name, m->name, found->name);
					/* Noone chosen */
					return 0;
				}
				if (modname)
					printf("%s needs %s: found in %s\n",
					       modname, name, m->name);
				found = m;
				/* If we didn't need to load it
                                   already, we do now. */
				found->order = order;
			}
		}
	}
	if (found) return 1;
	else return 0;
}

/* Analyse this module to see if it needs others. */
static int get_deps(unsigned int order,
		    const char *modpath,
		    struct module *modules,
		    int verbose)
{
	unsigned int i;
	unsigned long size;
	Elf_Ehdr hdr;
	int fd;
	char *strings;
	Elf_Sym *syms;
	int needed = 0;

	fd = open(modpath, O_RDONLY);
	if (fd < 0)
		fatal("Can't open module %s: %s\n", modpath, strerror(errno));

	if (read(fd, &hdr, sizeof(hdr)) != sizeof(hdr))
		fatal("Error reading module %s\n", modpath);

	strings = load_section(fd, hdr.e_shoff, hdr.e_shnum,
			       hdr.e_shstrndx, ".strtab", &size);
	syms = load_section(fd, hdr.e_shoff, hdr.e_shnum,
			    hdr.e_shstrndx, ".symtab", &size);
	if (!strings || strings == (void *)-1
	    || !syms || syms == (void *)-1) {
		fatal("Could not load strings and symbol table from %s\n",
		      modpath);
	}

	/* Now establish which modules we need */
	for (i = 0; i < size / sizeof(syms[0]); i++) {
		if (syms[i].st_shndx == SHN_UNDEF) {
			/* Look for symbol */
			const char *name = strings + syms[i].st_name;

			if (strcmp(name, "") == 0)
				continue;

			/* Did this pull in a new module? */
			if (need_symbol(order, name, modules,
					verbose ? modpath : NULL))
				needed = 1;
		}
	}
	close(fd);
	return needed;
}

/* We use error numbers in a loose translation... */
static const char *moderror(int err)
{
	switch (err) {
	case ENOEXEC:
		return "Invalid module format";
	case ENOENT:
		return "Unknown symbol in module";
	default:
		return strerror(err);
	}
}

/* Actually do the insert. */
static void insmod(const char *filename, int dont_fail)
{
	char *ext, *name, *i;
	int fd, ret;
	struct stat st;
	unsigned long len;
	void *map;

	/* FIXME: Look in module for name. --RR */
	/* Strip path and .o off filename to give name */
	name = strrchr(filename, '/');
	if (!name) name = strdup(filename);
	else name = strdup(name + 1);
	ext = strrchr(name, '.');
	if (ext) *ext = '\0';
	/* Convert to underscores */
	for (i = name; *i; i++) if (*i == '-') *i = '_';

	/* Now, it may already be loaded: check /proc/modules */
	fd = open("/proc/modules", O_RDONLY);
	if (fd < 0)
		warn("Cannot open /proc/modules:"
		     " assuming no modules loaded.\n");
	else {
		char *buf;
		unsigned int fill, size = 1024;

		buf = malloc(size+1);
		buf[0] = '\n';
		fill = 1;

		while ((ret = read(fd, buf+fill, size - fill)) > 0) {
			size *= 2;
			buf = realloc(buf, size+1);
			fill += ret;
		}
		if (ret < 0)
			fatal("Error reading /proc/modules: %s\n",
			      strerror(errno));
		else {
			char *ptr;
			char name_with_ret[strlen(name) + 2];

			buf[fill+1] = '\0';
			/* Must appear at start of line. */
			name_with_ret[0] = '\n';
			strcpy(name_with_ret + 1, name);

			for (ptr = buf;
			     (ptr = strstr(ptr, name_with_ret)) != NULL;
			     ptr++) {
				if (!isspace(ptr[strlen(name_with_ret)]))
					continue;
				/* Found: don't try to load again */
				if (dont_fail)
					fatal("Module %s already loaded\n",
					      name);
				close(fd);
				free(name);
				free(buf);
				return;
			}
		}
		close(fd);
		free(buf);
	}
	close(fd);

	fd = open(filename, O_RDONLY);
	if (fd < 0)
		fatal("Could not open `%s': %s\n", filename, strerror(errno));

	fstat(fd, &st);
	len = st.st_size;
	map = mmap(NULL, len, PROT_READ, MAP_SHARED, fd, 0);
	if (map == MAP_FAILED)
		fatal("Can't map `%s': %s\n", filename, strerror(errno));

	ret = syscall(__NR_init_module, map, len, "");
	if (ret != 0) {
		if (dont_fail)
			fatal("Error inserting %s: %s\n",name,moderror(errno));
		else
			warn("Error inserting %s: %s\n", name,moderror(errno));
	}
	free(name);
	close(fd);
}

static void load(const char *revision, const char *modname, int verbose)
{
	unsigned int order;
	struct module *modules;
	struct module *i;
	char pathname[strlen(revision) + sizeof(MODULE_DIR) + strlen(modname)
		     + sizeof(MODULE_EXTENSION)];

	/* Create path name */
	sprintf(pathname, MODULE_DIR "%s" MODULE_EXTENSION, revision, modname);

	modules = load_all_exports(revision);
	order = 1;
	if (get_deps(order, pathname, modules, verbose)) {
		/* We need some other modules. */
		int more_needed;

		do {
			more_needed = 0;
			for (i = modules; i; i = i->next) {
				if (i->order == order) {
					if (get_deps(order + 1, i->name,
						     modules, verbose))
						more_needed = 1;
				}
			}
			order++;
		} while (more_needed);
	}

	/* Now, walk back through orders, loading */
	for (; order > 0; order--) {
		for (i = modules; i; i = i->next) {
			if (i->order == order) {
				if (verbose) printf("Loading %s\n", i->name);
				insmod(i->name, 0);
			}
		}
	}
	if (verbose) printf("Loading %s\n", pathname);
	insmod(pathname, 1);
}

static struct option options[] = { { "verbose", 0, NULL, 'v' },
				   { "version", 0, NULL, 'V' },
				   { NULL, 0, NULL, 0 } };

int main(int argc, char *argv[])
{
	struct utsname buf;
	int opt;
	int verbose = 0;

	try_old_version("modprobe", argv);

	while ((opt = getopt_long(argc, argv, "v", options, NULL)) != -1) {
		switch (opt) {
		case 'v':
			verbose = 1;
			break;
		case 'V':
			printf("0.4\n");
			exit(0);
		default:
			fprintf(stderr, "Unknown option `%s'\n",
				argv[optind]);
			print_usage(argv[0]);
		}
	}

	if (argc != optind + 1)
		print_usage(argv[0]);

	uname(&buf);
	load(buf.release, argv[optind], verbose);
	exit(0);
}
