1#![deny(trivial_casts, trivial_numeric_casts, unused_import_braces)]
2
3use std::cmp::Ordering;
4use std::collections::BTreeMap;
5use std::fmt::Display;
6use std::fs;
7use std::io;
8use std::marker::PhantomData;
9use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
10use std::path::Path;
11
12use ipnetwork::{IpNetwork, IpNetworkError};
13use serde::{de, Deserialize, Serialize};
14use thiserror::Error;
15
16#[cfg(feature = "mmap")]
17pub use memmap2::Mmap;
18#[cfg(feature = "mmap")]
19use memmap2::MmapOptions;
20#[cfg(feature = "mmap")]
21use std::fs::File;
22
23#[cfg(all(feature = "simdutf8", feature = "unsafe-str-decode"))]
24compile_error!("features `simdutf8` and `unsafe-str-decode` are mutually exclusive");
25
26#[derive(Error, Debug)]
27pub enum MaxMindDbError {
28 #[error("Invalid database: {0}")]
29 InvalidDatabase(String),
30
31 #[error("I/O error: {0}")]
32 Io(
33 #[from]
34 #[source]
35 io::Error,
36 ),
37
38 #[cfg(feature = "mmap")]
39 #[error("Memory map error: {0}")]
40 Mmap(#[source] io::Error),
41
42 #[error("Decoding error: {0}")]
43 Decoding(String),
44
45 #[error("Invalid network: {0}")]
46 InvalidNetwork(
47 #[from]
48 #[source]
49 IpNetworkError,
50 ),
51}
52
53impl de::Error for MaxMindDbError {
54 fn custom<T: Display>(msg: T) -> Self {
55 MaxMindDbError::Decoding(format!("{msg}"))
56 }
57}
58
59#[derive(Deserialize, Serialize, Clone, Debug)]
60pub struct Metadata {
61 pub binary_format_major_version: u16,
62 pub binary_format_minor_version: u16,
63 pub build_epoch: u64,
64 pub database_type: String,
65 pub description: BTreeMap<String, String>,
66 pub ip_version: u16,
67 pub languages: Vec<String>,
68 pub node_count: u32,
69 pub record_size: u16,
70}
71
72#[derive(Debug)]
73struct WithinNode {
74 node: usize,
75 ip_int: IpInt,
76 prefix_len: usize,
77}
78
79#[derive(Debug)]
80pub struct Within<'de, T: Deserialize<'de>, S: AsRef<[u8]>> {
81 reader: &'de Reader<S>,
82 node_count: usize,
83 stack: Vec<WithinNode>,
84 phantom: PhantomData<&'de T>,
85}
86
87#[derive(Debug)]
88pub struct WithinItem<T> {
89 pub ip_net: IpNetwork,
90 pub info: T,
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94enum IpInt {
95 V4(u32),
96 V6(u128),
97}
98
99impl IpInt {
100 fn new(ip_addr: IpAddr) -> Self {
101 match ip_addr {
102 IpAddr::V4(v4) => IpInt::V4(v4.into()),
103 IpAddr::V6(v6) => IpInt::V6(v6.into()),
104 }
105 }
106
107 fn get_bit(&self, index: usize) -> bool {
108 match self {
109 IpInt::V4(ip) => (ip >> (31 - index)) & 1 == 1,
110 IpInt::V6(ip) => (ip >> (127 - index)) & 1 == 1,
111 }
112 }
113
114 fn bit_count(&self) -> usize {
115 match self {
116 IpInt::V4(_) => 32,
117 IpInt::V6(_) => 128,
118 }
119 }
120
121 fn is_ipv4_in_ipv6(&self) -> bool {
122 match self {
123 IpInt::V4(_) => false,
124 IpInt::V6(ip) => *ip <= 0xFFFFFFFF,
125 }
126 }
127}
128
129impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> {
130 type Item = Result<WithinItem<T>, MaxMindDbError>;
131
132 fn next(&mut self) -> Option<Self::Item> {
133 while let Some(current) = self.stack.pop() {
134 let bit_count = current.ip_int.bit_count();
135
136 if self.reader.ipv4_start != 0
138 && current.node == self.reader.ipv4_start
139 && bit_count == 128
140 && !current.ip_int.is_ipv4_in_ipv6()
141 {
142 continue;
143 }
144
145 match current.node.cmp(&self.node_count) {
146 Ordering::Greater => {
147 let ip_net =
149 match bytes_and_prefix_to_net(¤t.ip_int, current.prefix_len as u8) {
150 Ok(ip_net) => ip_net,
151 Err(e) => return Some(Err(e)),
152 };
153
154 return match self.reader.decode_data_at_pointer(current.node) {
156 Ok(info) => Some(Ok(WithinItem { ip_net, info })),
157 Err(e) => Some(Err(e)),
158 };
159 }
160 Ordering::Equal => {
161 }
163 Ordering::Less => {
164 let mut right_ip_int = current.ip_int;
167
168 if current.prefix_len < bit_count {
169 let bit = current.prefix_len;
170 match &mut right_ip_int {
171 IpInt::V4(ip) => *ip |= 1 << (31 - bit),
172 IpInt::V6(ip) => *ip |= 1 << (127 - bit),
173 };
174 }
175
176 let node = match self.reader.read_node(current.node, 1) {
177 Ok(node) => node,
178 Err(e) => return Some(Err(e)),
179 };
180 self.stack.push(WithinNode {
181 node,
182 ip_int: right_ip_int,
183 prefix_len: current.prefix_len + 1,
184 });
185 let node = match self.reader.read_node(current.node, 0) {
187 Ok(node) => node,
188 Err(e) => return Some(Err(e)),
189 };
190 self.stack.push(WithinNode {
191 node,
192 ip_int: current.ip_int,
193 prefix_len: current.prefix_len + 1,
194 });
195 }
196 }
197 }
198 None
199 }
200}
201
202#[derive(Debug)]
204pub struct Reader<S: AsRef<[u8]>> {
205 buf: S,
206 pub metadata: Metadata,
207 ipv4_start: usize,
208 pointer_base: usize,
209}
210
211#[cfg(feature = "mmap")]
212impl<'de> Reader<Mmap> {
213 pub fn open_mmap<P: AsRef<Path>>(database: P) -> Result<Reader<Mmap>, MaxMindDbError> {
224 let file_read = File::open(database)?;
225 let mmap = unsafe { MmapOptions::new().map(&file_read) }.map_err(MaxMindDbError::Mmap)?;
226 Reader::from_source(mmap)
227 }
228}
229
230impl Reader<Vec<u8>> {
231 pub fn open_readfile<P: AsRef<Path>>(database: P) -> Result<Reader<Vec<u8>>, MaxMindDbError> {
239 let buf: Vec<u8> = fs::read(&database)?; Reader::from_source(buf)
241 }
242}
243
244impl<'de, S: AsRef<[u8]>> Reader<S> {
245 pub fn from_source(buf: S) -> Result<Reader<S>, MaxMindDbError> {
255 let data_section_separator_size = 16;
256
257 let metadata_start = find_metadata_start(buf.as_ref())?;
258 let mut type_decoder = decoder::Decoder::new(&buf.as_ref()[metadata_start..], 0);
259 let metadata = Metadata::deserialize(&mut type_decoder)?;
260
261 let search_tree_size = (metadata.node_count as usize) * (metadata.record_size as usize) / 4;
262
263 let mut reader = Reader {
264 buf,
265 pointer_base: search_tree_size + data_section_separator_size,
266 metadata,
267 ipv4_start: 0,
268 };
269 reader.ipv4_start = reader.find_ipv4_start()?;
270
271 Ok(reader)
272 }
273
274 pub fn lookup<T>(&'de self, address: IpAddr) -> Result<Option<T>, MaxMindDbError>
296 where
297 T: Deserialize<'de>,
298 {
299 self.lookup_prefix(address)
300 .map(|(option_value, _prefix_len)| option_value)
301 }
302
303 pub fn lookup_prefix<T>(
336 &'de self,
337 address: IpAddr,
338 ) -> Result<(Option<T>, usize), MaxMindDbError>
339 where
340 T: Deserialize<'de>,
341 {
342 let ip_int = IpInt::new(address);
343 let (pointer, prefix_len) = self.find_address_in_tree(&ip_int)?;
345
346 if pointer == 0 {
347 return Ok((None, prefix_len));
350 }
351
352 match self.decode_data_at_pointer(pointer) {
354 Ok(value) => Ok((Some(value), prefix_len)),
355 Err(e) => Err(e),
356 }
357 }
358
359 pub fn within<T>(&'de self, cidr: IpNetwork) -> Result<Within<'de, T, S>, MaxMindDbError>
377 where
378 T: Deserialize<'de>,
379 {
380 let ip_address = cidr.network();
381 let prefix_len = cidr.prefix() as usize;
382 let ip_int = IpInt::new(ip_address);
383 let bit_count = ip_int.bit_count();
384
385 let mut node = self.start_node(bit_count);
386 let node_count = self.metadata.node_count as usize;
387
388 let mut stack: Vec<WithinNode> = Vec::with_capacity(bit_count - prefix_len);
389
390 let mut i = 0_usize;
392 while i < prefix_len {
393 let bit = ip_int.get_bit(i);
394 node = self.read_node(node, bit as usize)?;
395 if node >= node_count {
396 break;
398 }
399
400 i += 1;
401 }
402
403 if node < node_count {
404 stack.push(WithinNode {
407 node,
408 ip_int,
409 prefix_len,
410 });
411 }
412 let within: Within<T, S> = Within {
416 reader: self,
417 node_count,
418 stack,
419 phantom: PhantomData,
420 };
421
422 Ok(within)
423 }
424
425 fn find_address_in_tree(&self, ip_int: &IpInt) -> Result<(usize, usize), MaxMindDbError> {
426 let bit_count = ip_int.bit_count();
427 let mut node = self.start_node(bit_count);
428
429 let node_count = self.metadata.node_count as usize;
430 let mut prefix_len = bit_count;
431
432 for i in 0..bit_count {
433 if node >= node_count {
434 prefix_len = i;
435 break;
436 }
437 let bit = ip_int.get_bit(i);
438 node = self.read_node(node, bit as usize)?;
439 }
440 match node_count {
441 n if n == node => Ok((0, prefix_len)),
444 n if node > n => Ok((node, prefix_len)),
445 _ => Err(MaxMindDbError::InvalidDatabase(
446 "invalid node in search tree".to_owned(),
447 )),
448 }
449 }
450
451 fn start_node(&self, length: usize) -> usize {
452 if length == 128 {
453 0
454 } else {
455 self.ipv4_start
456 }
457 }
458
459 fn find_ipv4_start(&self) -> Result<usize, MaxMindDbError> {
460 if self.metadata.ip_version != 6 {
461 return Ok(0);
462 }
463
464 let mut node: usize = 0_usize;
467 for _ in 0_u8..96 {
468 if node >= self.metadata.node_count as usize {
469 break;
470 }
471 node = self.read_node(node, 0)?;
472 }
473 Ok(node)
474 }
475
476 fn read_node(&self, node_number: usize, index: usize) -> Result<usize, MaxMindDbError> {
477 let buf = self.buf.as_ref();
478 let base_offset = node_number * (self.metadata.record_size as usize) / 4;
479
480 let val = match self.metadata.record_size {
481 24 => {
482 let offset = base_offset + index * 3;
483 to_usize(0, &buf[offset..offset + 3])
484 }
485 28 => {
486 let mut middle = buf[base_offset + 3];
487 if index != 0 {
488 middle &= 0x0F
489 } else {
490 middle = (0xF0 & middle) >> 4
491 }
492 let offset = base_offset + index * 4;
493 to_usize(middle, &buf[offset..offset + 3])
494 }
495 32 => {
496 let offset = base_offset + index * 4;
497 to_usize(0, &buf[offset..offset + 4])
498 }
499 s => {
500 return Err(MaxMindDbError::InvalidDatabase(format!(
501 "unknown record size: \
502 {s:?}"
503 )))
504 }
505 };
506 Ok(val)
507 }
508
509 fn resolve_data_pointer(&self, pointer: usize) -> Result<usize, MaxMindDbError> {
511 let resolved = pointer - (self.metadata.node_count as usize) - 16;
512
513 if resolved >= (self.buf.as_ref().len() - self.pointer_base) {
515 return Err(MaxMindDbError::InvalidDatabase(
516 "the MaxMind DB file's data pointer resolves to an invalid location".to_owned(),
517 ));
518 }
519
520 Ok(resolved)
521 }
522
523 fn decode_data_at_pointer<T>(&'de self, pointer: usize) -> Result<T, MaxMindDbError>
526 where
527 T: Deserialize<'de>,
528 {
529 let resolved_offset = self.resolve_data_pointer(pointer)?;
530 let mut decoder =
531 decoder::Decoder::new(&self.buf.as_ref()[self.pointer_base..], resolved_offset);
532 T::deserialize(&mut decoder)
533 }
534}
535
536fn to_usize(base: u8, bytes: &[u8]) -> usize {
539 bytes
540 .iter()
541 .fold(base as usize, |acc, &b| (acc << 8) | b as usize)
542}
543
544#[inline]
545fn bytes_and_prefix_to_net(bytes: &IpInt, prefix: u8) -> Result<IpNetwork, MaxMindDbError> {
546 let (ip, prefix) = match bytes {
547 IpInt::V4(ip) => (IpAddr::V4(Ipv4Addr::from(*ip)), prefix),
548 IpInt::V6(ip) if bytes.is_ipv4_in_ipv6() => {
549 (IpAddr::V4(Ipv4Addr::from(*ip as u32)), prefix - 96)
550 }
551 IpInt::V6(ip) => (IpAddr::V6(Ipv6Addr::from(*ip)), prefix),
552 };
553 IpNetwork::new(ip, prefix).map_err(MaxMindDbError::InvalidNetwork)
554}
555
556fn find_metadata_start(buf: &[u8]) -> Result<usize, MaxMindDbError> {
557 const METADATA_START_MARKER: &[u8] = b"\xab\xcd\xefMaxMind.com";
558
559 memchr::memmem::rfind(buf, METADATA_START_MARKER)
560 .map(|x| x + METADATA_START_MARKER.len())
561 .ok_or_else(|| {
562 MaxMindDbError::InvalidDatabase(
563 "Could not find MaxMind DB metadata in file.".to_owned(),
564 )
565 })
566}
567
568mod decoder;
569pub mod geoip2;
570
571#[cfg(test)]
572mod reader_test;
573
574#[cfg(test)]
575mod tests {
576 use super::MaxMindDbError;
577 use ipnetwork::IpNetworkError;
578 use std::io::{Error, ErrorKind};
579
580 #[test]
581 fn test_error_display() {
582 assert_eq!(
583 format!(
584 "{}",
585 MaxMindDbError::InvalidDatabase("something went wrong".to_owned())
586 ),
587 "Invalid database: something went wrong".to_owned(),
588 );
589 let io_err = Error::new(ErrorKind::NotFound, "file not found");
590 assert_eq!(
591 format!("{}", MaxMindDbError::from(io_err)),
592 "I/O error: file not found".to_owned(),
593 );
594
595 #[cfg(feature = "mmap")]
596 {
597 let mmap_io_err = Error::new(ErrorKind::PermissionDenied, "mmap failed");
598 assert_eq!(
599 format!("{}", MaxMindDbError::Mmap(mmap_io_err)),
600 "Memory map error: mmap failed".to_owned(),
601 );
602 }
603
604 assert_eq!(
605 format!("{}", MaxMindDbError::Decoding("unexpected type".to_owned())),
606 "Decoding error: unexpected type".to_owned(),
607 );
608
609 let net_err = IpNetworkError::InvalidPrefix;
610 assert_eq!(
611 format!("{}", MaxMindDbError::from(net_err)),
612 "Invalid network: invalid prefix".to_owned(),
613 );
614 }
615
616 #[test]
617 fn test_lookup_returns_none_for_unknown_address() {
618 use super::Reader;
619 use crate::geoip2;
620 use std::net::IpAddr;
621 use std::str::FromStr;
622
623 let reader = Reader::open_readfile("test-data/test-data/GeoIP2-City-Test.mmdb").unwrap();
624 let ip: IpAddr = FromStr::from_str("10.0.0.1").unwrap();
625
626 let result_lookup = reader.lookup::<geoip2::City>(ip);
627 assert!(
628 matches!(result_lookup, Ok(None)),
629 "lookup should return Ok(None) for unknown IP"
630 );
631
632 let result_lookup_prefix = reader.lookup_prefix::<geoip2::City>(ip);
633 assert!(
634 matches!(result_lookup_prefix, Ok((None, 8))),
635 "lookup_prefix should return Ok((None, 8)) for unknown IP, got {:?}",
636 result_lookup_prefix
637 );
638 }
639
640 #[test]
641 fn test_lookup_returns_some_for_known_address() {
642 use super::Reader;
643 use crate::geoip2;
644 use std::net::IpAddr;
645 use std::str::FromStr;
646
647 let reader = Reader::open_readfile("test-data/test-data/GeoIP2-City-Test.mmdb").unwrap();
648 let ip: IpAddr = FromStr::from_str("89.160.20.128").unwrap();
649
650 let result_lookup = reader.lookup::<geoip2::City>(ip);
651 assert!(
652 matches!(result_lookup, Ok(Some(_))),
653 "lookup should return Ok(Some(_)) for known IP"
654 );
655 assert!(
656 result_lookup.unwrap().unwrap().city.is_some(),
657 "Expected city data"
658 );
659
660 let result_lookup_prefix = reader.lookup_prefix::<geoip2::City>(ip);
661 assert!(
662 matches!(result_lookup_prefix, Ok((Some(_), _))),
663 "lookup_prefix should return Ok(Some(_)) for known IP"
664 );
665 let (city_data, prefix_len) = result_lookup_prefix.unwrap();
666 assert!(
667 city_data.unwrap().city.is_some(),
668 "Expected city data from prefix lookup"
669 );
670 assert_eq!(prefix_len, 25, "Expected valid prefix length");
671 }
672}