Skip to content

Commit 531cebe

Browse files
Merge pull request #18 from DoctorLai/multithread-sum
Add multithreaded example
2 parents ba9733e + 2dc72c9 commit 531cebe

File tree

4 files changed

+126
-1
lines changed

4 files changed

+126
-1
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ This repository is intended as:
1212

1313
Examples include (and will expand to):
1414

15-
* [thread-safe-queue](./thread-safe-queue/)
15+
* Multithreading
16+
* [thread-safe-queue](./thread-safe-queue/)
17+
* [multithread-sum](./multithread-sum/)
1618
* Smart pointers
1719
* [unique-ptr-basics](./unique-ptr-basics/)
1820
* [smart-ptr](./smart-ptr/)

multithread-sum/Makefile

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# pull in shared compiler settings
2+
include ../common.mk
3+
4+
# per-example flags
5+
# CXXFLAGS += -pthread
6+
7+
## get it from the folder name
8+
TARGET := $(notdir $(CURDIR))
9+
SRCS := $(wildcard *.cpp)
10+
OBJS := $(SRCS:.cpp=.o)
11+
12+
all: $(TARGET)
13+
14+
$(TARGET): $(OBJS)
15+
$(CXX) $(CXXFLAGS) -o $@ $^
16+
17+
%.o: %.cpp
18+
$(CXX) $(CXXFLAGS) -c $< -o $@
19+
20+
run: $(TARGET)
21+
./$(TARGET) $(ARGS)
22+
23+
clean:
24+
rm -f $(OBJS) $(TARGET)
25+
26+
# Delegates to top-level Makefile
27+
check-format:
28+
$(MAKE) -f ../Makefile check-format DIR=$(CURDIR)
29+
30+
.PHONY: all clean run check-format

multithread-sum/main.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/**
2+
This demonstrates a simple multithreaded sum calculation using C++11 threads.
3+
*/
4+
5+
#include <cassert>
6+
#include <iostream>
7+
#include <vector>
8+
#include <thread>
9+
#include <numeric>
10+
#include <stdexcept> // for exceptions
11+
#include <limits>
12+
13+
using SumT = unsigned long long;
14+
15+
int
16+
main(int argc, char* argv[])
17+
{
18+
std::size_t data_size = 1'000'000;
19+
std::size_t num_threads = 4;
20+
21+
if (argc == 3) {
22+
try {
23+
long long ds = std::stoll(argv[1]);
24+
long long nt = std::stoll(argv[2]);
25+
if (ds < 0 || nt <= 0) {
26+
throw std::invalid_argument("data_size must be >= 0 and num_threads must be > 0");
27+
}
28+
data_size = static_cast<std::size_t>(ds);
29+
num_threads = static_cast<std::size_t>(nt);
30+
} catch (const std::exception& e) {
31+
std::cerr << "Invalid arguments: " << e.what() << "\n"
32+
<< "Usage: " << argv[0] << " <data_size> <num_threads>\n";
33+
return 1;
34+
}
35+
} else if (argc != 1) {
36+
std::cerr << "Usage: " << argv[0] << " <data_size> <num_threads>\n";
37+
return 1;
38+
}
39+
40+
if (num_threads == 0)
41+
return 1;
42+
if (data_size == 0) {
43+
std::cout << "Total Sum: 0\nExpected Sum: 0\n";
44+
return 0;
45+
}
46+
47+
// Avoid spawning more threads than elements (optional but sensible).
48+
if (num_threads > data_size)
49+
num_threads = data_size;
50+
51+
// Guard against int overflow in iota values.
52+
if (data_size > static_cast<std::size_t>(std::numeric_limits<int>::max())) {
53+
std::cerr << "data_size too large for vector<int> initialization via iota.\n";
54+
return 1;
55+
}
56+
57+
std::vector<int> data(data_size);
58+
std::iota(data.begin(), data.end(), 1);
59+
60+
std::vector<SumT> partial_sums(num_threads, 0);
61+
std::vector<std::thread> threads;
62+
threads.reserve(num_threads);
63+
64+
const std::size_t block_size = data_size / num_threads;
65+
66+
for (std::size_t i = 0; i < num_threads; ++i) {
67+
const std::size_t start = i * block_size;
68+
const std::size_t end = (i == num_threads - 1) ? data_size : start + block_size;
69+
70+
threads.emplace_back([&, i, start, end] {
71+
partial_sums[i] = std::accumulate(data.begin() + start, data.begin() + end, SumT{0});
72+
std::cout << "Thread processing range [" << start << ", " << end
73+
<< ") computed local sum: " << partial_sums[i] << "\n";
74+
});
75+
}
76+
77+
for (auto& t : threads)
78+
t.join();
79+
80+
const SumT global_sum = std::accumulate(partial_sums.begin(), partial_sums.end(), SumT{0});
81+
std::cout << "Total Sum: " << global_sum << "\n";
82+
83+
const SumT expected_sum = static_cast<SumT>(data_size) * (static_cast<SumT>(data_size) + 1) / 2;
84+
std::cout << "Expected Sum: " << expected_sum << "\n";
85+
86+
assert(expected_sum == global_sum);
87+
return 0;
88+
}

multithread-sum/tests.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
3+
set -ex
4+
5+
./multithread-sum 1000000 4

0 commit comments

Comments
 (0)