/* bagmatch_mt.c * * Multithreaded version of bagmatch. * * - Threads count from BAGMATCH_THREADS env (default 8, min 1). * - File mmap'd once; partitioned into N contiguous chunks by chunk_starts[0..N] * where chunk_starts[N] == map + filesize. * - Each chunk begins at a line boundary (except chunk 0). Chunks are non-overlapping. * - Each thread scans its chunk, compares sorted-letter canonical forms, and writes * matches using a single fwrite of a small stack buffer (word + '\n'). * * Constraints preserved: * - MAXWORD = 64; key checked <= MAXWORD. * - insertion_sort_chars used. * - No dynamic allocation for per-line buffers. */ #define _POSIX_C_SOURCE 200809L #include #include #include #include #include #include #include #include #include #include #include #include #include #define MAXWORD 64 static void insertion_sort_chars(unsigned char *a, size_t n) { for (size_t i = 1; i < n; ++i) { unsigned char key = a[i]; size_t j = i; while (j > 0 && a[j - 1] > key) { a[j] = a[j - 1]; --j; } a[j] = key; } } struct thread_arg { const unsigned char *start; const unsigned char *end; /* exclusive */ const unsigned char *keybuf; size_t keylen; atomic_int *err_flag; }; static void *worker(void *varg) { struct thread_arg *arg = (struct thread_arg *)varg; const unsigned char *ptr = arg->start; const unsigned char *end = arg->end; const unsigned char *key = arg->keybuf; size_t keylen = arg->keylen; unsigned char work[MAXWORD]; unsigned char outbuf[MAXWORD + 1]; /* preallocated match buffer */ /* quick no-op */ if (ptr >= end) return NULL; while (ptr < end) { /* find next newline within this chunk */ const void *found = memchr(ptr, '\n', (size_t)(end - ptr)); const unsigned char *nl = found ? (const unsigned char *)found : end; ptrdiff_t len = nl - ptr; /* bytes in line */ if (len == (ptrdiff_t)keylen) { /* copy original bytes for sorting */ memcpy(work, ptr, keylen); insertion_sort_chars(work, keylen); if (memcmp(work, key, keylen) == 0) { /* prepare outbuf = original bytes + '\n' (if nl==end and file doesn't have newline, we still append a newline to output) */ memcpy(outbuf, ptr, keylen); outbuf[keylen] = '\n'; size_t wrote = fwrite(outbuf, 1, keylen + 1, stdout); if (wrote != keylen + 1) { atomic_store(arg->err_flag, 1); return NULL; } } } if (atomic_load(arg->err_flag)) return NULL; /* advance past newline (or to end) */ ptr = (nl < end) ? (nl + 1) : nl; } return NULL; } int bagmatch(const char *keyword, const char *path) { unsigned char keybuf[MAXWORD]; size_t keylen = strlen(keyword); if (keylen > MAXWORD) { fputs("Error: keyword too long (max 64 bytes)\n", stderr); return 2; } memcpy(keybuf, keyword, keylen); insertion_sort_chars(keybuf, keylen); int fd = open(path, O_RDONLY); if (fd == -1) { perror(path); return 3; } struct stat st; if (fstat(fd, &st) == -1) { perror("fstat"); close(fd); return 4; } off_t filesize = st.st_size; if (filesize == 0) { close(fd); return 0; } const unsigned char *map = mmap(NULL, (size_t)filesize, PROT_READ, MAP_PRIVATE, fd, 0); if (map == MAP_FAILED) { perror("mmap"); close(fd); return 5; } /* Determine thread count from env */ int nthreads = 8; const char *env = getenv("BAGMATCH_THREADS"); if (env) { char *endptr = NULL; long v = strtol(env, &endptr, 10); if (endptr != env && v > 0) nthreads = (int)v; } if (nthreads < 1) nthreads = 1; /* Limit nthreads to a sensible upper bound (avoid excessively large allocations) */ if ((size_t)nthreads > (size_t)filesize) nthreads = (int)filesize; /* at most one byte per thread */ /* allocate chunk starts array of N+1 pointers on heap (size small) */ const unsigned char **chunk_starts = malloc((nthreads + 1) * sizeof(const unsigned char *)); if (!chunk_starts) { perror("malloc"); munmap((void *)map, (size_t)filesize); close(fd); return 6; } /* initial tentative starts */ for (int i = 0; i <= nthreads; ++i) { if (i == nthreads) { chunk_starts[i] = map + filesize; } else { off_t off = (off_t)(((__int128)filesize * i) / nthreads); /* avoid overflow */ chunk_starts[i] = map + off; } } /* adjust starts forward to next line boundary (except chunk 0) */ for (int i = 1; i < nthreads; ++i) { const unsigned char *s = chunk_starts[i]; /* search for '\n' between s and end; if found, start after it; else start = end */ const void *found = memchr(s, '\n', (size_t)((map + filesize) - s)); if (found) chunk_starts[i] = (const unsigned char *)found + 1; else chunk_starts[i] = map + filesize; } /* create threads */ pthread_t *tids = malloc(nthreads * sizeof(pthread_t)); if (!tids) { perror("malloc"); free(chunk_starts); munmap((void *)map, (size_t)filesize); close(fd); return 6; } struct thread_arg *targs = malloc(nthreads * sizeof(struct thread_arg)); if (!targs) { perror("malloc"); free(tids); free(chunk_starts); munmap((void *)map, (size_t)filesize); close(fd); return 6; } atomic_int err_flag; atomic_init(&err_flag, 0); for (int i = 0; i < nthreads; ++i) { targs[i].start = chunk_starts[i]; targs[i].end = chunk_starts[i + 1]; targs[i].keybuf = keybuf; targs[i].keylen = keylen; targs[i].err_flag = &err_flag; int rc = pthread_create(&tids[i], NULL, worker, &targs[i]); if (rc != 0) { fprintf(stderr, "pthread_create: %s\n", strerror(rc)); atomic_store(&err_flag, 1); /* reduce nthreads to i for join loop */ for (int j = 0; j < i; ++j) pthread_join(tids[j], NULL); free(targs); free(tids); free(chunk_starts); munmap((void *)map, (size_t)filesize); close(fd); return 7; } } /* join threads */ for (int i = 0; i < nthreads; ++i) { pthread_join(tids[i], NULL); } int error_return = atomic_load(&err_flag) ? 2 : 0; free(targs); free(tids); free(chunk_starts); if (munmap((void *)map, (size_t)filesize) == -1) { perror("munmap"); error_return = 8; } close(fd); return error_return; } int main(int argc, char *argv[]) { if (argc != 3) { fputs("Usage: bagmatch KEYWORD DICTIONARY_FILE\n", stderr); return 2; } /* Optionally repeat the whole program more than once */ size_t n_repetitions = 1; const char *env = getenv("BAGMATCH_REPS"); if (env) n_repetitions = atoi(env); fprintf(stderr, "Searching %zd time%s for %s\n", n_repetitions, n_repetitions == 1 ? "" : "s", argv[1]); int rv = 0; for (size_t i = 0; i < n_repetitions; i++) { int status = bagmatch(argv[1], argv[2]); if (status) rv = status; } return rv; }